Sudoku Solver

This post presents a method to solve 9x9 sudoku, using recursion and backtracking.

A board consist 9x9 grid of cells which stores the value at the cell, and a boolean fixed, which stores if the cell’s value is fixed, or modifiable. A cell’s value being 0 indicates that its empty. A legal value is a value that has not occured in the row, column, and 3x3 internal square of a cell.

We start from the first cell putting a legal value in the current cell if its not fixed and recursively trying putting legal values in the next cells. If the assumed value is not correct, then at some point in the recursion, we’re bound to come across a state, where we’re unable to produce any legal value to put in the cell. We return failure, at this point and backtrack to the last non-failure cell, try putting another legal value, recurse again trying to put values in the next cell. If we’re able to put a legal value in the last cell, it means we’ve solved the sudoku puzzle. At this point we’ll return success.

The below code uses this naive method to solve a 9x9 puzzle. To test the code’s correctness, the main function is written to produce the solution to Problem 96, in Project Euler(Yes, that’s how I normally quick test correctness). The code is written compromising size for simplicity.

#include <iostream>
using namespace std;


struct sudoku{
    int A[9][9];
    bool FIXED[9][9];

    bool check_row(int x, int i, int j){
        //For a row, i=const
        int jj;
        for(jj=0; jj<9; jj++){
            if( A[i][jj] == x )
                return false;
        }
        return true;
    }

    bool check_col(int x, int i, int j){
        //For a col, j=const
        int ii;
        for(ii=0; ii<9; ii++){
            if ( A[ii][j] == x )
                return false;
        }
        return true;
    }

    bool check_sqr(int x, int i, int j){
        //For subsquare, start and end of row and col
        int si, ei, sj, ej;
        si = (i/3) * 3; 
        ei = si + 2;

        sj = (j/3) * 3;
        ej = sj + 2;

        int ii, jj; //New looping variables.

        for(ii=si; ii<=ei; ii++){
            for(jj=sj; jj<=ej; jj++){
                if ( A[ii][jj] == x )
                    return false;
            }
        }
        return true;
    }

    bool check_all(int x, int i, int j){
        if( check_row(x, i, j)  && check_col(x, i, j) && check_sqr(x, i, j) )
            return true;
        return false;
    }


    bool try_cell(int i, int j){
        int in, jn;
        bool last_cell;

        //Setting up next cell to be visited.
        jn = (j+1)%9;
        if (jn == 0)
            in = (i+1)%9;
        else
            in = i;
        
        //Checking if last cell or not.
        if ( i==8 && j==8)
            last_cell = true;
        else
            last_cell = false;

        //In case the cell has a value that is fixed.
        if ( FIXED[i][j] ){
            if ( last_cell )
                return true;
            else
                return try_cell(in, jn);
            
        }


        //The other case, obviously.
        else{
            //Brute-force, naive method.
            for(int x = 1; x <= 9; x++){
                if(check_all(x, i, j)){
                    A[i][j] = x;
                    if ( last_cell ){
                        return true;
                    }
                    else if(try_cell(in, jn)){
                        return true;
                    }
                }
            }
            A[i][j] = 0;
            return false;
        }
    }

    void solve(){
        bool solved = try_cell(0, 0);
        if ( solved ){
            printf("Successfully Solved\n");
            print();
        }
        else{
            printf("Unable to solve\n");
        }
    }

    void print(){
        int i, j;
        for(i=0; i<9; i++){
            for(j=0; j<9; j++){
                printf("%d ", A[i][j]);
            }
            printf("\n");
        }
    }

    void input(){
        int i, j;
        char S[20];
        for(i=0; i<9; i++){
            scanf("%s", S);
            for(j=0; j<9; j++){
                A[i][j] = S[j]-'0';
                if ( A[i][j] != 0 )
                    FIXED[i][j] = true;

                else{
                    FIXED[i][j] = false;
                }
            }
        }
    }
};

int main(){
    sudoku S;
    char grid[50];
    int TC = 50, tc, sum=0;
    for(int t = 1; t <= TC; t++){
        scanf("%s %d", grid, &tc);
        printf("%s %d\n", grid, tc);
        S.input();
        S.solve();
        sum += S.A[0][0]*100 + S.A[0][1]*10 + S.A[0][2];
    }
    printf("Total sum = %d\n", sum);
    return 0;
}