Skip to content

Commit ba7283f

Browse files
refactor(prover): from oop to procedural
1 parent 2d73407 commit ba7283f

File tree

2 files changed

+74
-59
lines changed

2 files changed

+74
-59
lines changed

src/propositional_logic_prover/modules/prover.py

Lines changed: 68 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -5,65 +5,80 @@
55
from modules.parser import WFF, Negation
66

77

8-
class Prover:
9-
def __init__(self) -> None:
10-
self.initial_clauses: set[Clause] = set()
11-
self.derivation_map: dict[Clause, tuple[Clause, Clause]] = dict()
12-
self.resolved: set[tuple[Clause, Clause]] = set()
8+
def prove_by_refutation(
9+
kb: WFF, query: WFF
10+
) -> tuple[bool, dict[Clause, tuple[Clause, Clause]]]:
11+
clauses = convert_to_clauses(kb)
12+
query_negated = Negation(None, query)
13+
clauses_query = convert_to_clauses(query_negated)
14+
initial_clauses = clauses.union(clauses_query)
15+
derivation = search_refutation(initial_clauses)
16+
return frozenset() in derivation.keys(), derivation
1317

14-
def add_to_derivation(
15-
self, resolvent: Clause, clause_tuple: tuple[Clause, Clause]
16-
) -> None:
17-
# Check to avoid self derivation loops
18-
if resolvent in self.derivation_map.keys():
19-
return
2018

21-
# Check to avoid derivation of initial clauses
22-
if resolvent in self.initial_clauses:
23-
return
19+
def verify_knowledge_base_consistency(kb: WFF) -> bool:
20+
clauses = convert_to_clauses(kb)
21+
derivation = search_refutation(clauses)
22+
return not frozenset() in derivation.keys()
2423

25-
self.derivation_map[resolvent] = clause_tuple
2624

27-
def resolve(self, clause1: Clause, clause2: Clause) -> set[Clause]:
28-
resolvents: set[Clause] = set()
29-
for literal in clause1:
30-
neg_literal = negate_literal(literal)
31-
if neg_literal in clause2:
32-
clause_diff1 = clause1.difference({literal})
33-
clause_diff2 = clause2.difference({neg_literal})
34-
resolvent = clause_diff1.union(clause_diff2)
35-
resolvents.add(resolvent)
36-
self.add_to_derivation(resolvent, (clause1, clause2))
37-
return resolvents
25+
def search_refutation(
26+
initial_clauses: set[Clause],
27+
) -> dict[Clause, tuple[Clause, Clause]]:
28+
clauses = initial_clauses
29+
resolved: set[tuple[Clause, Clause]] = set()
30+
derivation: dict[Clause, tuple[Clause, Clause]] = dict()
31+
clauses_to_filter = initial_clauses
32+
length = 0
33+
while length != len(clauses):
34+
if frozenset() in clauses:
35+
break
3836

39-
def resolve_all(self, clauses: set[Clause]) -> set[Clause]:
40-
resolvents: set[Clause] = set()
41-
for clause_tuple in combinations(clauses, 2):
42-
if clause_tuple in self.resolved:
43-
continue
44-
self.resolved.add(clause_tuple)
45-
resolvents = resolvents.union(self.resolve(*clause_tuple))
46-
if frozenset() in resolvents:
47-
break
48-
return resolvents
37+
length = len(clauses)
38+
pairs = [
39+
pair for pair in combinations(clauses, 2) if pair not in resolved
40+
]
41+
pairs_resolvents = resolve_clause_pairs(pairs)
42+
clauses = clauses.union(*pairs_resolvents)
43+
derivation.update(
44+
map_resolvent_to_parent(pairs_resolvents, pairs, clauses_to_filter)
45+
)
46+
clauses_to_filter = clauses_to_filter.union(*pairs_resolvents)
47+
return derivation
4948

50-
def search_refutation(self, clauses: set[Clause]) -> bool:
51-
length = 0
52-
while length != len(clauses):
53-
if frozenset() in clauses:
54-
return True
5549

56-
length = len(clauses)
57-
clauses = clauses.union(self.resolve_all(clauses))
58-
return False
50+
def map_resolvent_to_parent(
51+
pairs_resolvents: list[set[Clause]],
52+
pairs: list[tuple[Clause, Clause]],
53+
clauses_to_filter: set[Clause],
54+
) -> dict[Clause, tuple[Clause, Clause]]:
55+
return {
56+
resolvent: pair
57+
for resolvents, pair in zip(pairs_resolvents, pairs)
58+
for resolvent in resolvents
59+
if not resolvent in clauses_to_filter
60+
}
5961

60-
def prove_by_refutation(self, kb: WFF, query: WFF) -> bool:
61-
clauses = convert_to_clauses(kb)
62-
is_kb_unsatisfiable = self.search_refutation(clauses)
63-
if is_kb_unsatisfiable:
64-
raise ValueError("Knowledge base must be consistent")
6562

66-
query_negated = Negation(None, query)
67-
clauses_query = convert_to_clauses(query_negated)
68-
self.initial_clauses = clauses.union(clauses_query)
69-
return self.search_refutation(self.initial_clauses)
63+
def resolve_clause_pairs(
64+
pairs: list[tuple[Clause, Clause]]
65+
) -> list[set[Clause]]:
66+
resolvents_list: list[set[Clause]] = []
67+
for pair in pairs:
68+
resolvents = resolve(*pair)
69+
resolvents_list.append(resolvents)
70+
if frozenset() in resolvents:
71+
break
72+
return resolvents_list
73+
74+
75+
def resolve(clause1: Clause, clause2: Clause) -> set[Clause]:
76+
resolvents: set[Clause] = set()
77+
for literal in clause1:
78+
neg_literal = negate_literal(literal)
79+
if neg_literal in clause2:
80+
clause_diff1 = clause1.difference({literal})
81+
clause_diff2 = clause2.difference({neg_literal})
82+
resolvent = clause_diff1.union(clause_diff2)
83+
resolvents.add(resolvent)
84+
return resolvents

tests/prover_test.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
from modules import parser
2-
from modules.prover import Prover
2+
from modules.prover import prove_by_refutation
33

44

55
def test_query_is_valid() -> None:
66
kb = parser.wff_from_str("A")
77
query = parser.wff_from_str("A")
88

9-
is_valid = Prover().prove_by_refutation(kb, query)
9+
is_valid, _ = prove_by_refutation(kb, query)
1010

1111
assert is_valid
1212

@@ -15,7 +15,7 @@ def test_query_not_valid() -> None:
1515
kb = parser.wff_from_str("A")
1616
query = parser.wff_from_str("B")
1717

18-
is_valid = Prover().prove_by_refutation(kb, query)
18+
is_valid, _ = prove_by_refutation(kb, query)
1919

2020
assert not is_valid
2121

@@ -24,7 +24,7 @@ def test_a_does_not_entails_not_a() -> None:
2424
kb = parser.wff_from_str("A")
2525
query = parser.wff_from_str("~A")
2626

27-
is_valid = Prover().prove_by_refutation(kb, query)
27+
is_valid, _ = prove_by_refutation(kb, query)
2828

2929
assert not is_valid
3030

@@ -33,7 +33,7 @@ def test_a_or_b_does_not_entails_a() -> None:
3333
kb = parser.wff_from_str("(A|B)")
3434
query = parser.wff_from_str("A")
3535

36-
is_valid = Prover().prove_by_refutation(kb, query)
36+
is_valid, _ = prove_by_refutation(kb, query)
3737

3838
assert not is_valid
3939

@@ -42,6 +42,6 @@ def test_study_and_practice_implies_graduate_is_valid() -> None:
4242
kb = parser.wff_from_str("(((STUDY&PRACTICE)=>PASS)&(PASS=>GRADUATE))")
4343
query = parser.wff_from_str("((STUDY&PRACTICE)=>GRADUATE)")
4444

45-
is_valid = Prover().prove_by_refutation(kb, query)
45+
is_valid, _ = prove_by_refutation(kb, query)
4646

4747
assert is_valid

0 commit comments

Comments
 (0)