Skip to content

Commit 1cd6243

Browse files
committed
Add extensible prover implementation
1 parent 803850d commit 1cd6243

File tree

7 files changed

+464
-0
lines changed

7 files changed

+464
-0
lines changed

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ kriscv-asm = "kriscv.devtools:kriscv_asm"
1717
[tool.poetry.plugins.kdist]
1818
riscv-semantics = "kriscv.kdist.plugin"
1919

20+
[tool.poetry.plugins.kprovex]
21+
riscv = "kriscv.symtools:KRiscVPlugin"
22+
2023
[tool.poetry.dependencies]
2124
python = "^3.10"
2225
kframework = "7.1.257"

src/kriscv/kprovex/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from ._kprovex import KProveX, create_prover

src/kriscv/kprovex/_default.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
5+
if TYPE_CHECKING:
6+
from pathlib import Path
7+
from typing import Final
8+
9+
from pyk.kast import KInner
10+
from pyk.proof.reachability import APRProof
11+
12+
from .api import Config, Init, Show
13+
14+
15+
def init_from_claims(config: Config, spec_file: Path, claim_id: str) -> APRProof:
16+
from pyk.ktool.claim_loader import ClaimLoader
17+
from pyk.proof.reachability import APRProof
18+
19+
spec_module, claim_label = claim_id.split('.', 1)
20+
include_dirs = config.dist.source_dirs
21+
22+
claims = ClaimLoader(config.kprove).load_claims(
23+
spec_file=spec_file,
24+
spec_module_name=spec_module,
25+
claim_labels=[claim_label],
26+
include_dirs=include_dirs,
27+
)
28+
(claim,) = claims
29+
30+
proof = APRProof.from_claim(
31+
config.kprove.definition,
32+
claim=claim,
33+
logs={},
34+
proof_dir=config.proof_dir,
35+
)
36+
return proof
37+
38+
39+
def show_pretty_term(config: Config, term: KInner) -> str:
40+
from pyk.konvert import kast_to_kore
41+
from pyk.kore.tools import kore_print
42+
43+
kore = kast_to_kore(config.definition, term)
44+
text = kore_print(kore, definition_dir=config.dist.haskell_dir)
45+
return text
46+
47+
48+
# Check signatures
49+
_default_init: Final[Init] = init_from_claims
50+
_default_show: Final[Show] = show_pretty_term

src/kriscv/kprovex/_kprovex.py

Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
from __future__ import annotations
2+
3+
from dataclasses import dataclass
4+
from functools import cached_property
5+
from pathlib import Path
6+
from typing import TYPE_CHECKING, final
7+
8+
from pyk.proof.reachability import APRProof, APRProver
9+
10+
from .api import Config
11+
12+
if TYPE_CHECKING:
13+
from pyk.kcfg import KCFG
14+
from pyk.proof import ProofStatus
15+
from pyk.proof.show import APRProofNodePrinter
16+
from pyk.utils import BugReport
17+
18+
from .api import Init, Plugin, Show
19+
20+
21+
def create_prover(plugin_id: str, proof_dir: str | Path, *, bug_report: BugReport | None = None) -> KProveX:
22+
from ._loader import PLUGINS
23+
24+
if plugin_id not in PLUGINS:
25+
raise ValueError(f'Unknown plugin: {plugin_id}')
26+
27+
plugin = PLUGINS[plugin_id]
28+
proof_dir = Path(proof_dir)
29+
30+
return KProveX(
31+
plugin=plugin,
32+
proof_dir=proof_dir,
33+
bug_report=bug_report,
34+
)
35+
36+
37+
@final
38+
@dataclass
39+
class KProveX:
40+
plugin: Plugin
41+
proof_dir: Path
42+
bug_report: BugReport | None
43+
44+
def __init__(
45+
self,
46+
plugin: Plugin,
47+
proof_dir: Path,
48+
*,
49+
bug_report: BugReport | None = None,
50+
):
51+
self.plugin = plugin
52+
self.proof_dir = proof_dir
53+
self.bug_report = bug_report
54+
55+
proof_dir.mkdir(parents=True, exist_ok=True)
56+
57+
@cached_property
58+
def config(self) -> Config:
59+
return Config(
60+
dist=self.plugin.dist(),
61+
proof_dir=self.proof_dir,
62+
bug_report=self.bug_report,
63+
)
64+
65+
def init_proof(
66+
self,
67+
spec_file: str | Path,
68+
claim_id: str,
69+
*,
70+
init_id: str | None = None,
71+
exist_ok: bool = False,
72+
) -> str:
73+
spec_file = Path(spec_file)
74+
init = self._load_init(init_id=init_id)
75+
proof = init(config=self.config, spec_file=spec_file, claim_id=claim_id)
76+
if not exist_ok and APRProof.proof_data_exists(proof.id, self.proof_dir):
77+
raise ValueError(f'Proof with id already exists: {proof.id}')
78+
79+
proof.write_proof_data()
80+
return proof.id
81+
82+
def list_proofs(self) -> list[str]:
83+
raise ValueError('TODO')
84+
85+
def list_nodes(self, proof_id: str) -> list[int]:
86+
proof = self._load_proof(proof_id)
87+
return [node.id for node in proof.kcfg.nodes]
88+
89+
def advance_proof(
90+
self,
91+
proof_id: str,
92+
*,
93+
max_depth: int | None = None,
94+
max_iterations: int | None = None,
95+
) -> ProofStatus:
96+
proof = self._load_proof(proof_id)
97+
98+
with self.config.explore(id=proof_id) as kcfg_explore:
99+
prover = APRProver(
100+
kcfg_explore=kcfg_explore,
101+
execute_depth=max_depth,
102+
)
103+
prover.advance_proof(proof, max_iterations=max_iterations)
104+
105+
return proof.status
106+
107+
def show_proof(
108+
self,
109+
proof_id: str,
110+
*,
111+
show_id: str | None = None,
112+
truncate: bool = False,
113+
) -> str:
114+
from pyk.proof.show import APRProofShow
115+
116+
proof = self._load_proof(proof_id)
117+
node_printer = self._proof_node_printer(proof, show_id=show_id, full_printer=True)
118+
proof_show = APRProofShow(self.config.kprove, node_printer=node_printer)
119+
lines = proof_show.show(proof)
120+
if truncate:
121+
lines = [_truncate(line, 120) for line in lines]
122+
return '\n'.join(lines)
123+
124+
def view_proof(
125+
self,
126+
proof_id: str,
127+
*,
128+
show_id: str | None = None,
129+
) -> None:
130+
from pyk.proof.tui import APRProofViewer
131+
132+
proof = self._load_proof(proof_id)
133+
node_printer = self._proof_node_printer(proof, show_id=show_id, full_printer=False)
134+
viewer = APRProofViewer(proof, self.config.kprove, node_printer=node_printer)
135+
viewer.run()
136+
137+
def prune_node(self, proof_id: str, node_id: str) -> list[int]:
138+
proof = self._load_proof(proof_id)
139+
res = proof.prune(node_id)
140+
proof.write_proof_data()
141+
return res
142+
143+
def show_node(
144+
self,
145+
proof_id: str,
146+
node_id: str,
147+
*,
148+
show_id: str | None = None,
149+
truncate: bool = False,
150+
) -> str:
151+
proof = self._load_proof(proof_id)
152+
node_printer = self._proof_node_printer(proof, show_id=show_id, full_printer=True)
153+
kcfg = proof.kcfg
154+
node = kcfg.node(node_id)
155+
lines = node_printer.print_node(kcfg, node)
156+
if truncate:
157+
lines = [_truncate(line, 120) for line in lines]
158+
return '\n'.join(lines)
159+
160+
# Private helpers
161+
162+
def _load_proof(self, proof_id: str) -> APRProof:
163+
return APRProof.read_proof_data(proof_dir=self.proof_dir, id=proof_id)
164+
165+
def _load_init(self, *, init_id: str | None) -> Init:
166+
if init_id is None:
167+
from . import _default
168+
169+
return _default.init_from_claims
170+
171+
inits = self.plugin.inits()
172+
if init_id not in inits:
173+
raise ValueError(f'Unknown init function: {init_id}')
174+
175+
return inits[init_id]
176+
177+
def _load_show(self, *, show_id: str | None) -> Show:
178+
if show_id is None:
179+
from . import _default
180+
181+
return _default.show_pretty_term
182+
183+
shows = self.plugin.shows()
184+
if show_id not in shows:
185+
raise ValueError(f'Unknown show function: {show_id}')
186+
187+
return shows[show_id]
188+
189+
def _proof_node_printer(
190+
self,
191+
proof: APRProof,
192+
*,
193+
show_id: str | None = None,
194+
full_printer: bool = False,
195+
minimize: bool = False,
196+
) -> APRProofNodePrinter:
197+
from pyk.kast.manip import minimize_term
198+
from pyk.proof.show import APRProofNodePrinter
199+
200+
show = self._load_show(show_id=show_id)
201+
config = self.config
202+
203+
class _NodePrinter(APRProofNodePrinter):
204+
def print_node(self, kcfg: KCFG, node: KCFG.Node) -> list[str]:
205+
attrs = self.node_attrs(kcfg, node)
206+
attr_str = ' (' + ', '.join(attrs) + ')' if attrs else ''
207+
node_strs = [f'{node.id}{attr_str}']
208+
if self.full_printer:
209+
kast = node.cterm.kast
210+
if self.minimize:
211+
kast = minimize_term(kast)
212+
show_res = show(config, kast)
213+
node_strs.extend(' ' + line for line in show_res.split('\n'))
214+
return node_strs
215+
216+
return _NodePrinter(
217+
proof=proof,
218+
kprint=None, # type: ignore [arg-type]
219+
full_printer=full_printer,
220+
minimize=minimize,
221+
)
222+
223+
224+
def _truncate(line: str, n: int) -> str:
225+
if len(line) <= n:
226+
return line
227+
return line[: n - 3] + '...'

src/kriscv/kprovex/_loader.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
from __future__ import annotations
2+
3+
import importlib
4+
import logging
5+
import re
6+
from typing import TYPE_CHECKING
7+
8+
from pyk.utils import FrozenDict
9+
10+
from .api import Plugin
11+
12+
if TYPE_CHECKING:
13+
from importlib.metadata import EntryPoint
14+
from typing import Final
15+
16+
17+
_LOGGER: Final = logging.getLogger(__name__)
18+
19+
20+
def _load_plugins() -> FrozenDict[str, Plugin]:
21+
entry_points = importlib.metadata.entry_points(group='kprovex')
22+
plugins: FrozenDict[str, Plugin] = FrozenDict(
23+
(entry_point.name, plugin) for entry_point in entry_points if (plugin := _load_plugin(entry_point)) is not None
24+
)
25+
return plugins
26+
27+
28+
def _load_plugin(entry_point: EntryPoint) -> Plugin | None:
29+
if not _valid_id(entry_point.name):
30+
_LOGGER.warning(f'Invalid entry point name, skipping: {entry_point.name}')
31+
return None
32+
33+
_LOGGER.info(f'Loading entry point: {entry_point.name}')
34+
try:
35+
module_name, class_name = entry_point.value.split(':')
36+
except ValueError:
37+
_LOGGER.error(f'Invalid entry point value: {entry_point.value}', exc_info=True)
38+
return None
39+
40+
try:
41+
_LOGGER.info(f'Importing module: {module_name}')
42+
module = importlib.import_module(module_name)
43+
except Exception:
44+
_LOGGER.error(f'Module {module_name} cannot be imported', exc_info=True)
45+
return None
46+
47+
try:
48+
_LOGGER.info(f'Loading plugin: {class_name}')
49+
cls = getattr(module, class_name)
50+
except AttributeError:
51+
_LOGGER.error(f'Class {class_name} not found in module {module_name}', exc_info=True)
52+
return None
53+
54+
if not issubclass(cls, Plugin):
55+
_LOGGER.error(f'Class {class_name} is not a Plugin', exc_info=True)
56+
return None
57+
58+
try:
59+
_LOGGER.info(f'Instantiating plugin: {class_name}')
60+
plugin = cls()
61+
except TypeError:
62+
_LOGGER.error(f'Cannot instantiate plugin {class_name}', exc_info=True)
63+
return None
64+
65+
return plugin
66+
67+
68+
_ID_PATTERN = re.compile('[a-z0-9]+(-[a-z0-9]+)*')
69+
70+
71+
def _valid_id(s: str) -> bool:
72+
return _ID_PATTERN.fullmatch(s) is not None
73+
74+
75+
PLUGINS: Final = _load_plugins()

0 commit comments

Comments
 (0)