Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 38 additions & 6 deletions Domain/Puzzles/Lits/LitsSolver.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import collections
import collections
from typing import Iterable, Set

from ortools.sat.python import cp_model
Expand All @@ -9,9 +9,12 @@
from Domain.Puzzles.GameSolver import GameSolver
from Domain.Puzzles.Lits.LitsGridBuilder import LitsGridBuilder
from Domain.Puzzles.Lits.LitsType import LitsType
from Utils.ShapeGenerator import ShapeGenerator


class LitsSolver(GameSolver):
empty = 0

def __init__(self, grid: Grid):
self._grid = grid
self._regions = self._grid.get_regions()
Expand All @@ -29,13 +32,42 @@ def get_solution(self) -> Grid:
self._add_constraints()

solver = cp_model.CpSolver()
status = solver.Solve(self._model)

if status not in (cp_model.FEASIBLE, cp_model.OPTIMAL):
return Grid.empty()
while solver.Solve(self._model) in [cp_model.OPTIMAL, cp_model.FEASIBLE]:
current_solution = Grid([[solver.Value(self._grid_vars.value(i, j)) for j in range(self.columns_number)] for i in range(self.rows_number)])

bool_matrix = [[1 if cell != 0 else 0 for cell in row] for row in current_solution.matrix]
bool_grid = Grid(bool_matrix)

components_shapes = bool_grid.get_all_shapes(1)
components = [set(shape) for shape in components_shapes]

if len(components) <= 1:
self.previous_solution = current_solution
return self.previous_solution

self._add_connectivity_constraints(components)

return Grid.empty()

def _add_connectivity_constraints(self, components):
components.sort(key=len, reverse=True)
for component in components[1:]:
literals = []
for position in component:
is_unshaded = self._model.NewBoolVar(f"conn_unshade_{position.r}_{position.c}")
self._model.Add(self._grid_vars[position] == 0).OnlyEnforceIf(is_unshaded)
self._model.Add(self._grid_vars[position] != 0).OnlyEnforceIf(is_unshaded.Not())
literals.append(is_unshaded)

neighbors = [p for p in ShapeGenerator.around_shape(component) if p in self._grid]
for position in neighbors:
is_shaded = self._model.NewBoolVar(f"conn_shade_{position.r}_{position.c}")
self._model.Add(self._grid_vars[position] != 0).OnlyEnforceIf(is_shaded)
self._model.Add(self._grid_vars[position] == 0).OnlyEnforceIf(is_shaded.Not())
literals.append(is_shaded)

self.previous_solution = Grid([[solver.Value(self._grid_vars.value(i, j)) for j in range(self.columns_number)] for i in range(self.rows_number)])
return self.previous_solution
self._model.AddBoolOr(literals)

def get_other_solution(self):
if self.previous_solution is None:
Expand Down
30 changes: 22 additions & 8 deletions Domain/Puzzles/Lits/tests/LitsSolver_test.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import unittest
import unittest

from parameterized import parameterized

from Domain.Board.Grid import Grid
from Domain.Board.RegionsGrid import RegionsGrid

Check warning on line 6 in Domain/Puzzles/Lits/tests/LitsSolver_test.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Unused imports

Unused import statement `from Domain.Board.RegionsGrid import RegionsGrid`
from Domain.Puzzles.Lits.LitsSolver import LitsSolver
from Domain.Puzzles.Lits.LitsType import LitsType

Check warning on line 8 in Domain/Puzzles/Lits/tests/LitsSolver_test.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Unused imports

Unused import statement `from Domain.Puzzles.Lits.LitsType import LitsType`

_ = 0
_ = LitsSolver.empty

class LitsSolverTest(unittest.TestCase):
def test_get_solution_region_too_small(self):
Expand Down Expand Up @@ -38,6 +38,14 @@
solution = lits_solver.get_solution()
self.assertEqual(Grid.empty(), solution)

def check_connectivity(self, solution: Grid):
if solution.is_empty():
return
# Create boolean grid where shaded cells are 1, unshaded are 0
bool_matrix = [[1 if cell != 0 else 0 for cell in row] for row in solution.matrix]
bool_grid = Grid(bool_matrix)
self.assertTrue(bool_grid.are_cells_connected(1), "Shaded cells must be connected")

def test_solution_6x6_normal(self):
input_grid = Grid([
[1, 1, 2, 2, 3, 3],
Expand All @@ -58,6 +66,7 @@
lits_solver = LitsSolver(input_grid)
solution = lits_solver.get_solution()
self.assertEqual(expected_solution, solution)
self.check_connectivity(solution)
other_solution = lits_solver.get_other_solution()
self.assertEqual(Grid.empty(), other_solution)

Expand All @@ -81,6 +90,7 @@
lits_solver = LitsSolver(input_grid)
solution = lits_solver.get_solution()
self.assertEqual(expected_solution, solution)
self.check_connectivity(solution)
other_solution = lits_solver.get_other_solution()
self.assertEqual(Grid.empty(), other_solution)

Expand Down Expand Up @@ -108,6 +118,7 @@
lits_solver = LitsSolver(input_grid)
solution = lits_solver.get_solution()
self.assertEqual(expected_solution, solution)
self.check_connectivity(solution)
other_solution = lits_solver.get_other_solution()
self.assertEqual(Grid.empty(), other_solution)

Expand Down Expand Up @@ -139,6 +150,7 @@
lits_solver = LitsSolver(input_grid)
solution = lits_solver.get_solution()
self.assertEqual(expected_solution, solution)
self.check_connectivity(solution)
other_solution = lits_solver.get_other_solution()
self.assertEqual(Grid.empty(), other_solution)

Expand Down Expand Up @@ -180,6 +192,7 @@
lits_solver = LitsSolver(input_grid)
solution = lits_solver.get_solution()
self.assertEqual(expected_solution, solution)
self.check_connectivity(solution)
other_solution = lits_solver.get_other_solution()
self.assertEqual(Grid.empty(), other_solution)

Expand Down Expand Up @@ -231,10 +244,10 @@
lits_solver = LitsSolver(input_grid)
solution = lits_solver.get_solution()
self.assertEqual(expected_solution, solution)
self.check_connectivity(solution)
other_solution = lits_solver.get_other_solution()
self.assertEqual(Grid.empty(), other_solution)

@unittest.skip("temporarily disabled - fails intermittently") # todo reactive test
def test_solution_20x20_hard_path(self):
input_grid = Grid([
[1, 1, 1, 2, 3, 4, 4, 4, 5, 5, 6, 7, 8, 8, 8, 8, 8, 9, 9, 9],
Expand Down Expand Up @@ -266,10 +279,10 @@
[2, _, _, 3, 3, _, _, 2, 1, 1, _, 4, 4, 1, 1, 1, _, 1, 1, 1],
[2, 1, _, 3, _, _, 4, 2, _, _, _, 2, _, _, _, 4, _, _, _, _],
[_, 1, 1, 1, _, 4, 4, _, 4, _, _, 2, _, 1, _, 4, 4, _, _, 2],
[3, _, _, _, _, 4, _, 4, 4, 1, _, 2, _, 1, _, _, 4, 3, _, 2],
[3, 3, 4, _, 1, 1, 1, 4, _, 1, _, 2, 1, 1, _, _, _, 3, 3, 2],
[3, _, 4, 4, _, _, 1, _, 1, 1, _, _, _, _, _, _, _, 3, _, 2],
[1, 1, _, 4, 3, _, 3, 3, 3, _, 3, 3, 3, 2, 2, 2, 2, 1, 1, 1],
[3, _, 4, _, _, 4, _, 4, 4, 1, _, 2, _, 1, _, _, 4, 3, _, 2],
[3, 3, 4, 4, 1, 1, 1, 4, _, 1, _, 2, 1, 1, _, _, _, 3, 3, 2],
[3, _, _, 4, _, _, 1, _, 1, 1, _, _, _, _, _, _, _, 3, _, 2],
[1, 1, _, _, 3, _, 3, 3, 3, _, 3, 3, 3, 2, 2, 2, 2, 1, 1, 1],
[_, 1, _, _, 3, 3, _, 3, _, _, _, 3, _, _, _, _, _, 1, _, _],
[_, 1, _, _, 3, _, _, 1, 1, _, _, 4, 4, 1, 1, 4, 4, _, _, _],
[_, 4, 4, _, 4, 4, _, 1, _, 1, 4, 4, _, 1, _, _, 4, 4, _, 1],
Expand All @@ -278,11 +291,12 @@
[1, _, _, _, _, 1, 3, 3, 3, _, _, 3, _, 1, _, _, 1, 1, _, 3],
[1, 1, _, _, 3, _, _, 3, _, 3, _, 1, _, 1, 1, 1, _, 1, 3, 3],
[_, 4, 4, 3, 3, _, _, 1, 3, 3, _, 1, _, 3, _, _, _, 1, _, 3],
[4, 4, _, _, 3, 1, 1, 1, _, 3, 1, 1, 3, 3, 3, 2, 2, 2, 2, _]
[4, 4, _, _, 3, 1, 1, 1, _, 3, 1, 1, 3, 3, 3, 2, 2, 2, 2, _],
])
lits_solver = LitsSolver(input_grid)
solution = lits_solver.get_solution()
self.assertEqual(expected_solution, solution)
self.check_connectivity(solution)
other_solution = lits_solver.get_other_solution()
self.assertEqual(Grid.empty(), other_solution)

Expand Down
Loading