-
Notifications
You must be signed in to change notification settings - Fork 0
/
smarter_sudoku.py
302 lines (233 loc) · 9.48 KB
/
smarter_sudoku.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
from collections import defaultdict
from itertools import combinations
from math import ceil
rows = range(1, 10)
cols = range(1, 10)
boxes = [(br, bc) for bc in range(1, 4) for br in range(1, 4)]
def box_position(row, col):
return ceil(row/3), ceil(col/3)
def box_cell_axes(box_idx):
return range((box_idx-1)*3+1, box_idx*3+1)
def box_cells(box_position):
box_row, box_col = box_position
return [(row, col) for row in box_cell_axes(box_row)
for col in box_cell_axes(box_col)]
def affected_positions(row, col):
for affected_row in rows:
if affected_row == row:
continue
yield (affected_row, col)
for affected_col in cols:
if affected_col == col:
continue
yield (row, affected_col)
for affected_position in box_cells(box_position(row, col)):
affected_row, affected_col = affected_position
if affected_row == row:
continue
if affected_col == col:
continue
yield affected_position
class Place(object):
def __init__(self, row, col, value):
self.row = row
self.col = col
self.value = value
def __repr__(self):
return f'Place{self.row, self.col, self.value}'
def perform(self, grid):
grid.place(self.row, self.col, self.value)
class Remove(object):
def __init__(self, row, col, value):
self.row = row
self.col = col
self.value = value
def __repr__(self):
return f'Remove{self.row, self.col, self.value}'
def perform(self, grid):
grid.remove(self.row, self.col, self.value)
class SudokuGrid(object):
def __init__(self):
self.placed_grid = {}
self.grid = { (row, col): set(range(1, 10))
for row in rows for col in cols }
self.strategies = [
self.find_only_one_digit_in_cell,
self.find_only_one_place_in_row,
self.find_only_one_place_in_col,
self.find_only_one_place_in_box,
self.find_digits_in_one_box_row,
self.find_digits_in_one_box_col,
self.find_subsets_in_row,
self.find_subsets_in_col,
self.find_subsets_in_box,
]
def place(self, row, col, value):
self.placed_grid[(row, col)] = value
self.grid[(row, col)] = set()
for aff_pos in affected_positions(row, col):
if value == self.placed_grid.get(aff_pos):
print(self.formatted_grid())
raise ValueError(f'want to put {value} at {row, col} but already at {aff_pos}')
if set([value]) == self.grid[aff_pos]:
print(self.formatted_grid())
raise ValueError(f'want to put {value} at {row, col} but its the only option for {aff_pos}')
self.grid[aff_pos].discard(value)
def remove(self, row, col, value):
self.grid[(row, col)].remove(value)
def row(self, row):
return [(row, col, self.grid[row, col]) for col in cols]
def col(self, col):
return [(row, col, self.grid[row, col]) for row in rows]
def box(self, box):
return [(row, col, self.grid[row, col]) for row, col in box_cells(box)]
def remaining_values(self, cells):
return set(value for _, _, values in cells for value in values)
def locations_for(self, value, cells):
return set((row, col) for (row, col, vs) in cells if value in vs)
def find_only_one_digit_in_cell(self):
for (row, col), values in self.grid.items():
if len(values) == 1:
(value,) = values
return Place(row, col, value)
def find_only_one_place_in_row(self):
for row in rows:
occurrences = defaultdict(set)
for _, col, values in self.row(row):
for value in values:
occurrences[value].add(col)
for value, positions in occurrences.items():
if len(positions) == 1:
(col,) = positions
return Place(row, col, value)
def find_only_one_place_in_col(self):
for col in cols:
occurrences = defaultdict(set)
for row, _, values in self.col(col):
for value in values:
occurrences[value].add(row)
for value, positions in occurrences.items():
if len(positions) == 1:
(row,) = positions
return Place(row, col, value)
def find_only_one_place_in_box(self):
for box in boxes:
for row, col, values in self.box(box):
if len(values) == 1:
(value,) = values
return Place(row, col, value)
def find_digits_in_one_box_row(self):
for box in boxes:
box_row, box_col = box
for digit in self.remaining_values(self.box(box)):
digit_rows = set(row for row, _ in self.locations_for(digit, self.box(box)))
if len(digit_rows) == 1:
(row,) = digit_rows
for _, col, values in self.row(row):
if box_position(row, col) == box:
continue
if digit in values:
return Remove(row, col, digit)
def find_digits_in_one_box_col(self):
for box in boxes:
box_row, box_col = box
for digit in self.remaining_values(self.box(box)):
digit_cols = set(col for _, col in self.locations_for(digit, self.box(box)))
if len(digit_cols) == 1:
(col,) = digit_cols
for row, _, values in self.col(col):
if box_position(row, col) == box:
continue
if digit in values:
return Remove(row, col, digit)
def find_subsets_in_cells(self, cells):
remaining_values = self.remaining_values(cells)
if len(remaining_values) < 2:
return
for fst_val, snd_val in combinations(remaining_values, 2):
fst_locations = self.locations_for(fst_val, cells)
snd_locations = self.locations_for(snd_val, cells)
common_locations = fst_locations & snd_locations
common_locations = set(loc for loc in common_locations
if self.grid[loc].issuperset(set([fst_val, snd_val])))
if len(common_locations) >= 2:
for (fst_loc, snd_loc) in combinations(common_locations, 2):
# If these are the only 2 digits in these cells, these digits cannot appear elsewhere
if self.grid[fst_loc] == self.grid[snd_loc] == set([fst_val, snd_val]):
if len(fst_locations) > 2 or len(snd_locations) > 2:
for other_row, other_col in fst_locations - common_locations:
return Remove(other_row, other_col, fst_val)
for other_row, other_col in snd_locations - common_locations:
return Remove(other_row, other_col, snd_val)
# If these are the only 2 places for these 2 digits, no other digits can appear here
if len(fst_locations) == len(snd_locations) == 2:
for val in self.grid[fst_loc] - set([fst_val, snd_val]):
fst_row, fst_col = fst_loc
return Remove(fst_row, fst_col, val)
for val in self.grid[snd_loc] - set([fst_val, snd_val]):
snd_row, snd_col = snd_loc
return Remove(snd_row, snd_col, val)
def find_subsets_in_row(self):
for row in rows:
if subsets := self.find_subsets_in_cells(self.row(row)):
return subsets
def find_subsets_in_col(self):
for col in cols:
if subsets := self.find_subsets_in_cells(self.col(col)):
return subsets
def find_subsets_in_box(self):
for box in boxes:
if subsets := self.find_subsets_in_cells(self.box(box)):
return subsets
def find(self):
while True:
for strategy in self.strategies:
if result := strategy():
result.perform(self)
break
else:
return
def load(self, grid):
for row_num, row in enumerate(grid, 1):
for col_num, cell in enumerate(row, 1):
if cell != ' ':
self.place(row_num, col_num, int(cell))
def formatted_grid(self):
grid_str = []
for row in range(1, 10):
if row in (4, 7):
grid_str.append('-'*11)
rowstr = ''
for col in range(1, 10):
if col in (4, 7):
rowstr += '|'
rowstr += str(self.placed_grid.get((row, col), ' '))
grid_str.append(rowstr)
return '\n'.join(grid_str)
HARD_TEST_GRID = [
' 8 1 ',
'7 9 5 ',
' 2 4 ',
'9 ',
'6 1 34 ',
' 5 31 ',
' 2 ',
' 1 6 ',
'53 64 9'
]
EXPERT_TEST_GRID = [
'9 7 5',
' 1 28 ',
' 6 ',
' 4 ',
' 7 9 ',
' 4 536 1',
' 8 7 ',
' 3 ',
' 25 1 9',
]
for puzzle in [HARD_TEST_GRID, EXPERT_TEST_GRID]:
s = SudokuGrid()
s.load(puzzle)
s.find()
print(s.formatted_grid())