数独暴力猜解

思路

  1. 验证已填格子是否合法,非法则回 false
  2. 选一个空格子,如果没有空格子则返回 true
  3. 对空格子进行填数,依次尝试 1-9
  4. 猜解剩余格子,重复步骤 1
  5. 步骤 4 猜解成功则返回 true,否则回到步骤 3 尝试下一个填数
  6. 1-9 均返回 false,则返回 false

实现

  1. 约定
    • 使用 9 x 9 的二维数组存储所有格子
    • 每个格子维护一个可填数的 mask,采用 bitmap 方式存储,所以 0b111_111_111 可填 1-9,0b000_000_000 无可填数
    • 为其中一个格子填数时,更新其所对应的行、列、区块上的每个格子的 mask
    • 当有格子的 mask 为 0b000_000_000 时,数独不可解
    • 当行、列、区块上 mask 有冲突时,数独不可解
    • 因涉及递归,所以每个格子实际存储为一个 mask 栈
  2. 代码

    // 测试
    public class SudoKu {
        public static String sudokuString(BiFunction<Integer, Integer, Integer> supplier) {
            StringBuilder str = new StringBuilder("   _0_1_2_3_4_5_6_7_8\n");
            for (int r = 0; r < 9; r++) {
                str.append(r).append(" | ");
                for (int c = 0; c < 9; c++) {
                    int v = supplier.apply(r, c);
                    str.append(v).append(" ");
                }
                str.append('\n');
            }
            return str.toString();
        }
    
        public static void main(String[] args) {
    
            List<Integer> inputs = Arrays.asList(821007900, 7000000, 400003000, 908040000, 000000001, 374201000, 160040, 60000000,
                    709008600);
    
            int[][] cells = new int[9][9];
    
            for (int rowIndex = 0; rowIndex < inputs.size(); rowIndex++) {
                int input = inputs.get(rowIndex);
                for (int ci = 8; input > 0; input /= 10, ci--) {
                    cells[rowIndex][ci] = input % 10;
                }
            }
    
            SudoKuResolver rc = new SudoKuResolver(cells);
    
            System.out.printf("Input >>\n%s", sudokuString(rc::getCellValue));
            System.out.printf("resolving...\n");
            long s = System.currentTimeMillis();
            boolean resolved = rc.resolve();
            long e = System.currentTimeMillis();
            long cost = e - s;
            System.out.printf("finished, cost: %s\n" , cost);
    
            System.out.printf("Output >>\n%s", resolved ? sudokuString(rc::getCellValue) : "不可解");
        }
    }
    
    // 解析
    class SudoKuResolver {
        public final static int[] MIN_INDEX = new int[] { 0, 0, 0, 3, 3, 3, 6, 6, 6 };
        public final static int[] MAX_INDEX = new int[] { 2, 2, 2, 5, 5, 5, 8, 8, 8 };
    
        private static final Logger LOGGER = LoggerFactory.getLogger(SudoKuResolver.class);
    
        private final CellMaskBucket[][] cellBuckets;
    
        private int depth;
    
        public SudoKuResolver(int[][] cells) {
            this.cellBuckets = new CellMaskBucket[9][9];
            for (int i = 0; i < 9; i++) {
                for (int j = 0; j < 9; j++) {
                    cellBuckets[i][j] = new CellMaskBucket(i, j);
                    cellBuckets[i][j].setMask(0b111_111_111);
                }
            }
            for (int i = 0; i < 9; i++) {
                for (int j = 0; j < 9; j++) {
                    int v = cells[i][j];
                    if (v != 0) {
                        CellMaskBucket bucket = cellBuckets[i][j];
                        bucket.setMask(toMask(cells[i][j]));
                    }
                }
            }
        }
        public boolean resolve() {
    
            // 是否有效
            if (!isValid()) {
                return false;
            }
    
            // 选出格子
            CellMaskBucket cell = pickCell();
            if (cell == null) {
                return true;
            }
    
            // 使用调用栈保存拆解步骤
            int m = cell.getMask();
            if (LOGGER.isDebugEnabled()) {
                int trys[] = computeTrys(m);
                LOGGER.debug("{}: try [ {} ]", cell, Strings.join(",", trys));
            }
    
            while (Integer.bitCount(m) > 0) {
                int zc = Integer.numberOfTrailingZeros(m);
                int tv = 1 << zc;
                push();
    
                setCellMask(cell, tv);
    
                if (LOGGER.isDebugEnabled()) {
                    LOGGER.debug("{}: try {}\n{}", cell, zc + 1, SudoKu.sudokuString(this::getCellValue));
                }
    
                if (resolve()) {
                    return true;
                } else {
                    m &= ~tv;
                    pop();
    
                    if (LOGGER.isDebugEnabled()) {
                        LOGGER.debug("{}: not {}", cell, zc + 1);
                    }
                }
            }
            // 都不可解,返回 false
            return false;
        }
    
        void push() {
            if (++depth > 80) {
                throw new ArrayIndexOutOfBoundsException(depth--);
            }
            for (int i = 0; i < 9; i++) {
                for (int j = 0; j < 9; j++) {
                    cellBuckets[i][j].bucket[depth] = cellBuckets[i][j].bucket[depth - 1];
                }
            }
        }
    
        void pop() {
            if (--depth < 0) {
                throw new ArrayIndexOutOfBoundsException(depth++);
            }
        }
    
        void setCellMask(CellMaskBucket bucket, int mask) {
            bucket.setMask(mask);
            rcbTravel(bucket.row, bucket.col, (i, j) -> {
                int om = cellBuckets[i][j].getMask();
                if (Integer.bitCount(om) > 1) {
                    this.cellBuckets[i][j].setMask(om & (~mask));
                }
            });
        }
    
        void rcbTravel(int r, int c, BiConsumer<Integer, Integer> act) {
            // 行
            for (int i = 0; i < 9; i++) {
                if (i == c) {
                    continue;
                }
                act.accept(r, i);
            }
            // 列
            for (int i = 0; i < 9; i++) {
                if (i == r) {
                    continue;
                }
                act.accept(i, c);
            }
            // 区块
            for (int i = MIN_INDEX[r]; i <= MAX_INDEX[r]; i++) {
                for (int j = MIN_INDEX[c]; j <= MAX_INDEX[c]; j++) {
                    if (i == r || j == c) {
                        continue;
                    }
                    act.accept(i, j);
                }
            }
    
        }
    
        CellMaskBucket pickCell() {
            for (int i = 0; i < 9; i++) {
                for (int j = 0; j < 9; j++) {
                    CellMaskBucket cs = cellBuckets[i][j];
                    if (Integer.bitCount(cs.getMask()) > 1) {
                        return cs;
                    }
                }
            }
            return null;
        }
    
        boolean isValid() {
            int rms, cms, re, ce, bms, be;
            for (int i = 0; i < 9; i++) {
                rms = cms = re = ce = bms = be = 0;
                for (int j = 0; j < 9; j++) {
                    // 行
                    int rm = cellBuckets[i][j].getMask();
                    if (Integer.bitCount(rm) == 1) {
                        re += 1;
                        rms += rm;
                        if (Integer.bitCount(rms) != re) {
                            return false;
                        }
                    }
    
                    // 列
                    int cm = cellBuckets[j][i].getMask();
                    if (Integer.bitCount(cm) == 1) {
                        ce += 1;
                        cms += cm;
                        if (Integer.bitCount(cms) != ce) {
                            return false;
                        }
                    }
    
                    // block
                    int r = (i / 3) * 3 + (j / 3);
                    int c = (i % 3) * 3 + (j % 3);
                    int bm = cellBuckets[r][c].getMask();
                    if (Integer.bitCount(bm) == 1) {
                        be += 1;
                        bms += bm;
                        if (Integer.bitCount(bms) != be) {
                            return false;
                        }
                    }
                }
            }
            return true;
        }
    
        public int getCellValue(int r, int c) {
            int m = this.cellBuckets[r][c].getMask();
            return toDecimal(m);
        }
    
        class CellMaskBucket {
    
            private final int row, col;
            private final int[] bucket = new int[81];
    
            CellMaskBucket(int row, int col) {
                this.row = row;
                this.col = col;
            }
    
            void setMask(int v) {
                bucket[depth] = v;
            }
    
            int getMask() {
                return bucket[depth];
            }
    
            @Override
            public String toString() {
                return "(" + row + ", " + col + ")";
            }
        }
        static int toDecimal(int mask) {
            if (Integer.bitCount(mask) == 1) {
                return Integer.numberOfTrailingZeros(mask << 1);
            }
            return 0;
        }
    
        static int toMask(int dec) {
            return dec == 0 ? 0b111_111_111 : 1 << (dec - 1);
        }
    
        static int[] computeTrys(int m) {
            int[] r = new int[Integer.bitCount(m)];
            for (int i = 0; i < r.length; i++) {
                int z = Integer.numberOfTrailingZeros(m);
                r[i] = z + 1;
                m >>= r[i];
                m <<= r[i];
            }
            return r;
        }
    }

猜你喜欢

转载自www.cnblogs.com/realhyx/p/10354368.html