|
| 1 | +""" |
| 2 | +Please do not modify this file! It is published at https://norvig.com/sudoku.html with |
| 3 | +only minimal changes to work with modern versions of Python. If you have improvements, |
| 4 | +please make them in a separate file. |
| 5 | +""" |
| 6 | +import random |
| 7 | +import time |
| 8 | + |
| 9 | + |
| 10 | +def cross(items_a, items_b): |
| 11 | + "Cross product of elements in A and elements in B." |
| 12 | + return [a + b for a in items_a for b in items_b] |
| 13 | + |
| 14 | + |
| 15 | +digits = "123456789" |
| 16 | +rows = "ABCDEFGHI" |
| 17 | +cols = digits |
| 18 | +squares = cross(rows, cols) |
| 19 | +unitlist = ( |
| 20 | + [cross(rows, c) for c in cols] |
| 21 | + + [cross(r, cols) for r in rows] |
| 22 | + + [cross(rs, cs) for rs in ("ABC", "DEF", "GHI") for cs in ("123", "456", "789")] |
| 23 | +) |
| 24 | +units = {s: [u for u in unitlist if s in u] for s in squares} |
| 25 | +peers = {s: set(sum(units[s], [])) - {s} for s in squares} |
| 26 | + |
| 27 | + |
| 28 | +def test(): |
| 29 | + "A set of unit tests." |
| 30 | + assert len(squares) == 81 |
| 31 | + assert len(unitlist) == 27 |
| 32 | + assert all(len(units[s]) == 3 for s in squares) |
| 33 | + assert all(len(peers[s]) == 20 for s in squares) |
| 34 | + assert units["C2"] == [ |
| 35 | + ["A2", "B2", "C2", "D2", "E2", "F2", "G2", "H2", "I2"], |
| 36 | + ["C1", "C2", "C3", "C4", "C5", "C6", "C7", "C8", "C9"], |
| 37 | + ["A1", "A2", "A3", "B1", "B2", "B3", "C1", "C2", "C3"], |
| 38 | + ] |
| 39 | + # fmt: off |
| 40 | + assert peers["C2"] == { |
| 41 | + "A2", "B2", "D2", "E2", "F2", "G2", "H2", "I2", "C1", "C3", |
| 42 | + "C4", "C5", "C6", "C7", "C8", "C9", "A1", "A3", "B1", "B3" |
| 43 | + } |
| 44 | + # fmt: on |
| 45 | + print("All tests pass.") |
| 46 | + |
| 47 | + |
| 48 | +def parse_grid(grid): |
| 49 | + """Convert grid to a dict of possible values, {square: digits}, or |
| 50 | + return False if a contradiction is detected.""" |
| 51 | + ## To start, every square can be any digit; then assign values from the grid. |
| 52 | + values = {s: digits for s in squares} |
| 53 | + for s, d in grid_values(grid).items(): |
| 54 | + if d in digits and not assign(values, s, d): |
| 55 | + return False ## (Fail if we can't assign d to square s.) |
| 56 | + return values |
| 57 | + |
| 58 | + |
| 59 | +def grid_values(grid): |
| 60 | + "Convert grid into a dict of {square: char} with '0' or '.' for empties." |
| 61 | + chars = [c for c in grid if c in digits or c in "0."] |
| 62 | + assert len(chars) == 81 |
| 63 | + return dict(zip(squares, chars)) |
| 64 | + |
| 65 | + |
| 66 | +def assign(values, s, d): |
| 67 | + """Eliminate all the other values (except d) from values[s] and propagate. |
| 68 | + Return values, except return False if a contradiction is detected.""" |
| 69 | + other_values = values[s].replace(d, "") |
| 70 | + if all(eliminate(values, s, d2) for d2 in other_values): |
| 71 | + return values |
| 72 | + else: |
| 73 | + return False |
| 74 | + |
| 75 | + |
| 76 | +def eliminate(values, s, d): |
| 77 | + """Eliminate d from values[s]; propagate when values or places <= 2. |
| 78 | + Return values, except return False if a contradiction is detected.""" |
| 79 | + if d not in values[s]: |
| 80 | + return values ## Already eliminated |
| 81 | + values[s] = values[s].replace(d, "") |
| 82 | + ## (1) If a square s is reduced to one value d2, then eliminate d2 from the peers. |
| 83 | + if len(values[s]) == 0: |
| 84 | + return False ## Contradiction: removed last value |
| 85 | + elif len(values[s]) == 1: |
| 86 | + d2 = values[s] |
| 87 | + if not all(eliminate(values, s2, d2) for s2 in peers[s]): |
| 88 | + return False |
| 89 | + ## (2) If a unit u is reduced to only one place for a value d, then put it there. |
| 90 | + for u in units[s]: |
| 91 | + dplaces = [s for s in u if d in values[s]] |
| 92 | + if len(dplaces) == 0: |
| 93 | + return False ## Contradiction: no place for this value |
| 94 | + elif len(dplaces) == 1: |
| 95 | + # d can only be in one place in unit; assign it there |
| 96 | + if not assign(values, dplaces[0], d): |
| 97 | + return False |
| 98 | + return values |
| 99 | + |
| 100 | + |
| 101 | +def display(values): |
| 102 | + "Display these values as a 2-D grid." |
| 103 | + width = 1 + max(len(values[s]) for s in squares) |
| 104 | + line = "+".join(["-" * (width * 3)] * 3) |
| 105 | + for r in rows: |
| 106 | + print( |
| 107 | + "".join( |
| 108 | + values[r + c].center(width) + ("|" if c in "36" else "") for c in cols |
| 109 | + ) |
| 110 | + ) |
| 111 | + if r in "CF": |
| 112 | + print(line) |
| 113 | + print() |
| 114 | + |
| 115 | + |
| 116 | +def solve(grid): |
| 117 | + return search(parse_grid(grid)) |
| 118 | + |
| 119 | + |
| 120 | +def some(seq): |
| 121 | + "Return some element of seq that is true." |
| 122 | + for e in seq: |
| 123 | + if e: |
| 124 | + return e |
| 125 | + return False |
| 126 | + |
| 127 | + |
| 128 | +def search(values): |
| 129 | + "Using depth-first search and propagation, try all possible values." |
| 130 | + if values is False: |
| 131 | + return False ## Failed earlier |
| 132 | + if all(len(values[s]) == 1 for s in squares): |
| 133 | + return values ## Solved! |
| 134 | + ## Chose the unfilled square s with the fewest possibilities |
| 135 | + n, s = min((len(values[s]), s) for s in squares if len(values[s]) > 1) |
| 136 | + return some(search(assign(values.copy(), s, d)) for d in values[s]) |
| 137 | + |
| 138 | + |
| 139 | +def solve_all(grids, name="", showif=0.0): |
| 140 | + """Attempt to solve a sequence of grids. Report results. |
| 141 | + When showif is a number of seconds, display puzzles that take longer. |
| 142 | + When showif is None, don't display any puzzles.""" |
| 143 | + |
| 144 | + def time_solve(grid): |
| 145 | + start = time.monotonic() |
| 146 | + values = solve(grid) |
| 147 | + t = time.monotonic() - start |
| 148 | + ## Display puzzles that take long enough |
| 149 | + if showif is not None and t > showif: |
| 150 | + display(grid_values(grid)) |
| 151 | + if values: |
| 152 | + display(values) |
| 153 | + print("(%.5f seconds)\n" % t) |
| 154 | + return (t, solved(values)) |
| 155 | + |
| 156 | + times, results = zip(*[time_solve(grid) for grid in grids]) |
| 157 | + if (n := len(grids)) > 1: |
| 158 | + print( |
| 159 | + "Solved %d of %d %s puzzles (avg %.2f secs (%d Hz), max %.2f secs)." |
| 160 | + % (sum(results), n, name, sum(times) / n, n / sum(times), max(times)) |
| 161 | + ) |
| 162 | + |
| 163 | + |
| 164 | +def solved(values): |
| 165 | + "A puzzle is solved if each unit is a permutation of the digits 1 to 9." |
| 166 | + |
| 167 | + def unitsolved(unit): |
| 168 | + return {values[s] for s in unit} == set(digits) |
| 169 | + |
| 170 | + return values is not False and all(unitsolved(unit) for unit in unitlist) |
| 171 | + |
| 172 | + |
| 173 | +def from_file(filename, sep="\n"): |
| 174 | + "Parse a file into a list of strings, separated by sep." |
| 175 | + return open(filename).read().strip().split(sep) # noqa: SIM115 |
| 176 | + |
| 177 | + |
| 178 | +def random_puzzle(assignments=17): |
| 179 | + """Make a random puzzle with N or more assignments. Restart on contradictions. |
| 180 | + Note the resulting puzzle is not guaranteed to be solvable, but empirically |
| 181 | + about 99.8% of them are solvable. Some have multiple solutions.""" |
| 182 | + values = {s: digits for s in squares} |
| 183 | + for s in shuffled(squares): |
| 184 | + if not assign(values, s, random.choice(values[s])): |
| 185 | + break |
| 186 | + ds = [values[s] for s in squares if len(values[s]) == 1] |
| 187 | + if len(ds) >= assignments and len(set(ds)) >= 8: |
| 188 | + return "".join(values[s] if len(values[s]) == 1 else "." for s in squares) |
| 189 | + return random_puzzle(assignments) ## Give up and make a new puzzle |
| 190 | + |
| 191 | + |
| 192 | +def shuffled(seq): |
| 193 | + "Return a randomly shuffled copy of the input sequence." |
| 194 | + seq = list(seq) |
| 195 | + random.shuffle(seq) |
| 196 | + return seq |
| 197 | + |
| 198 | + |
| 199 | +grid1 = ( |
| 200 | + "003020600900305001001806400008102900700000008006708200002609500800203009005010300" |
| 201 | +) |
| 202 | +grid2 = ( |
| 203 | + "4.....8.5.3..........7......2.....6.....8.4......1.......6.3.7.5..2.....1.4......" |
| 204 | +) |
| 205 | +hard1 = ( |
| 206 | + ".....6....59.....82....8....45........3........6..3.54...325..6.................." |
| 207 | +) |
| 208 | + |
| 209 | +if __name__ == "__main__": |
| 210 | + test() |
| 211 | + # solve_all(from_file("easy50.txt", '========'), "easy", None) |
| 212 | + # solve_all(from_file("top95.txt"), "hard", None) |
| 213 | + # solve_all(from_file("hardest.txt"), "hardest", None) |
| 214 | + solve_all([random_puzzle() for _ in range(99)], "random", 100.0) |
| 215 | + for puzzle in (grid1, grid2): # , hard1): # Takes 22 sec to solve on my M1 Mac. |
| 216 | + display(parse_grid(puzzle)) |
| 217 | + start = time.monotonic() |
| 218 | + solve(puzzle) |
| 219 | + t = time.monotonic() - start |
| 220 | + print("Solved: %.5f sec" % t) |
0 commit comments