diff --git a/cpmpy/solvers/pysat.py b/cpmpy/solvers/pysat.py index def9e6804..6781abdd7 100644 --- a/cpmpy/solvers/pysat.py +++ b/cpmpy/solvers/pysat.py @@ -427,10 +427,8 @@ def _post_constraint(self, cpm_expr): self.pysat_solver.append_formula(cnf) elif isinstance(a1, Comparison) and a1.args[0].name == "wsum": # implied pseudo-boolean comparison (a0->wsum(ws,bvs)<>val) # implied sum comparison (a0->wsum([w,bvs])<>val or a0->(w*bv<>val)) - cnf = self._pysat_pseudoboolean(a1) - # implication of conjunction is conjunction of individual implications - antecedent = [self.solver_var(~a0)] - self.pysat_solver.append_formula([antecedent+c for c in cnf]) + cnf = self._pysat_pseudoboolean(a1, conditional=a0) + self.pysat_solver.append_formula(cnf) else: raise NotSupportedError(f"Implication: {cpm_expr} not supported by CPM_pysat") @@ -548,7 +546,7 @@ def _pysat_cardinality(self, cpm_expr, reified=False): else: raise ValueError(f"PySAT: Expected Comparison to be either <=, ==, or >=, but was {cpm_expr.name}") - def _pysat_pseudoboolean(self, cpm_expr): + def _pysat_pseudoboolean(self, cpm_expr, conditional=None): """Convert CPMpy comparison of `wsum` (over Boolean variables) into PySAT list of clauses.""" if self._pb is None: raise ImportError("The model contains a PB constraint, for which PySAT needs an additional dependency (PBLib). To install it, run `pip install pypblib`.") @@ -563,6 +561,8 @@ def _pysat_pseudoboolean(self, cpm_expr): lits = self.solver_vars(lhs.args[1]) pysat_args = {"weights": lhs.args[0], "lits": lits, "bound": rhs, "vpool":self.pysat_vpool } + if conditional is not None: + pysat_args["conditionals"] = [self.solver_var(conditional)] if cpm_expr.name == "<=": return self._pb.PBEnc.atmost(**pysat_args).clauses diff --git a/tests/test_pysat_wsum.py b/tests/test_pysat_wsum.py index 1de62dde3..f0f634684 100644 --- a/tests/test_pysat_wsum.py +++ b/tests/test_pysat_wsum.py @@ -93,6 +93,41 @@ def test_encode_pb_oob(self): ## check all types of linear constraints are handled for expression in expressions: Model(expression).solve("pysat") + + def test_encode_pb_reified(self): + # Instantiate the solver with an empty model to get access to vpool and solver_var + solver = CPM_pysat(cp.Model()) + + sel_var = cp.boolvar(name="s") + # Define the pseudo-Boolean constraint (the consequent) + cons1 = sum(self.bv*[2,1,1]) >= 2 + + + # Convert the conditional constraint using _pysat_pseudoboolean + pb_clauses = solver._pysat_pseudoboolean(cons1, conditional=sel_var) + + # Check the result + self.assertEqual(str(pb_clauses), "[[5], [-5, 6], [3, 6], [-5, 3, 7], [2, 6], [-5, 2, 7], [3, 2, 7], [-5, 3, 2, 8], [1, 9], [-7, 9], [1, -7, 10], [-10, -4]]") + + # trivially unsat constraint + cons2 = sum(self.bv*[2,1,1]) >= 10 + + pb_clauses = solver._pysat_pseudoboolean(cons2, conditional=sel_var) + + # this testcase depends on the PR to PBLIB: https://github.com/rjungbeck/pypblib/pull/6, otherwise will be [[3, -4], [1, -4], [2, -4], [-4]] + self.assertEqual(str(pb_clauses), "[[-4]]") + + # trivially sat constraint + cons3 = sum(self.bv*[2,1,1]) >= 0 + + pb_clauses = solver._pysat_pseudoboolean(cons3, conditional=sel_var) + + # no clauses expected + self.assertEqual(str(pb_clauses), "[]") + + + + if __name__ == '__main__': unittest.main()