diff --git a/tools/cexenum/callbacks.md b/tools/cexenum/callbacks.md new file mode 100644 index 00000000..ca3a3ec7 --- /dev/null +++ b/tools/cexenum/callbacks.md @@ -0,0 +1,46 @@ +# Callback JSON Protocol + +The `--callback` option allows controlling the counter-example enumeration using a JSON based protocol. + +With `--callback '-'`, callback events are output to stdout and callback responses are read from stdin. + +With `--callback ''` the given command line is launched as subprocess which recieves callback events on stdin and produces callback responses on stdout. + +Each event and response consists of a single line containing a JSON object. + +## Callback Events + +A callback event is emitted whenever: + +* a step is entered for the first time (`"event": "step"`), +* a new counter-example trace was found (`"event": "trace"`) or +* no counter-examples remain for the current step (`"event": "unsat"`). + +Every callback event includes the current step (`"step": `) and a list of enabled named blocked patterns (`"enabled": [...]`, see below). + +A trace event also includes a path to the `.yw` file of the counter-example trace. The `.aiw` file can be found by replacing the file extension. + +After an event is emitted, the enumeration script waits for one or multiple callback responses, continuing when a callback response includes the `"action"` key. + +## Callback Responses + +### Actions + +An action hands control back to the enumeration script. The following three actions are available: + +* `{"action": "search"}`: Search for a further counter-example in the current time step. +* `{"action": "advance"}`: Advance to the next time step, abandoning any not-yet-enumerated counter-examples of the current time step. +* `{"action": "next"}`: Search for the next counter-example, automatically advancing to the next time-step when necessary. + +Note that an `{"action": "search"}` response to an `"unsat"` event remains at the same time step, which can be used to disable a named blocked pattern, which can make further counter-examples available. + +In the interactive callback mode (`--callback '-'`) it is possible to use single letter shortcuts `s`, `a` and `n` instead of the JSON syntax `{"action": "search"}`, `{"action": "advance"}` and `{"action": "next"}` respectively. + +### Blocking Patterns + +A specific input pattern given as an `.yw` or `.aiw` file can be blocked. +During enumeration, only counter-examples that differ from in at least one non-`x` bit from every blocked pattern are considered. + +A `.yw` pattern can be blocked using `{"block_yw": ""}` and a `.aiw` pattern using `{"block_aiw": ""}`. Additionally a pattern can be given a name by using `{"block_...": "", "name": ""}`, in which case it will be possible to disable and reenable the blocked pattern. By default a newly added named blocked pattern is enabled. + +A named pattern can be disable with the `{"disable": ""}` response and re-enabled with the `{"enable": ""}` response. diff --git a/tools/cexenum/cexenum.py b/tools/cexenum/cexenum.py index 3ac88916..55f11b19 100755 --- a/tools/cexenum/cexenum.py +++ b/tools/cexenum/cexenum.py @@ -3,16 +3,25 @@ import asyncio import json +import threading import traceback import argparse import shutil import shlex import os +import sys +import urllib.parse from pathlib import Path -from typing import Any, Awaitable, Literal +from typing import Any, Awaitable, Iterable, Literal + +try: + import readline # type: ignore # noqa +except ImportError: + pass import yosys_mau.task_loop.job_server as job from yosys_mau import task_loop as tl +from yosys_mau.stable_set import StableSet libexec = Path(__file__).parent.resolve() / "libexec" @@ -21,6 +30,10 @@ os.environb[b"PATH"] = bytes(libexec) + b":" + os.environb[b"PATH"] +def safe_smtlib_id(name: str) -> str: + return f"|{urllib.parse.quote_plus(name).replace('+', ' ')}|" + + def arg_parser(): parser = argparse.ArgumentParser( prog="cexenum", usage="%(prog)s [options] " @@ -66,6 +79,15 @@ def arg_parser(): default="-s yices --unroll", ) + parser.add_argument( + "--callback", + metavar='"..."', + type=shlex.split, + help="command that will receive enumerated traces on stdin and can control " + "the enumeration via stdout (pass '-' to handle callbacks interactively)", + default="", + ) + parser.add_argument("--debug", action="store_true", help="enable debug logging") parser.add_argument( "--debug-events", action="store_true", help="enable debug event logging" @@ -83,7 +105,7 @@ def arg_parser(): return parser -def lines(*args): +def lines(*args: Any): return "".join(f"{line}\n" for line in args) @@ -98,6 +120,8 @@ class App: enum_depth: int sim: bool + callback: list[str] + smtbmc_options: list[str] work_dir: Path @@ -148,7 +172,7 @@ def error_handler(err: BaseException): tl.current_task().set_error_handler(None, error_handler) -async def batch(*args): +async def batch(*args: Awaitable[Any]): result = None for arg in args: result = await arg @@ -200,7 +224,7 @@ def __init__(self): "hierarchy -simcheck", "flatten", "setundef -undriven -anyseq", - "setattr -set keep 1 w:\*", + "setattr -set keep 1 w:\\*", "delete -output", "opt -full", "techmap", @@ -266,6 +290,7 @@ def __init__(self, trace_name: str, aig_model: tl.Task): App.cache_dir / "design_aiger.ywa", min_yw, cwd=App.trace_dir_min, + options=("--skip-x", "--present-only"), ) aiw2yw[tl.LogContext].scope = f"aiw2yw[{stem}]" aiw2yw.depends_on(aigcexmin) @@ -286,6 +311,18 @@ def relative_to(target: Path, cwd: Path) -> Path: prefix = Path("") target = target.resolve() cwd = cwd.resolve() + + ok = False + for limit in (Path.cwd(), App.work_dir): + limit = Path.cwd().resolve() + try: + target.relative_to(limit) + ok = True + except ValueError: + pass + if not ok: + return target + while True: try: return prefix / (target.relative_to(cwd)) @@ -299,16 +336,18 @@ def relative_to(target: Path, cwd: Path) -> Path: class YosysWitness(tl.process.Process): def __init__( self, - mode: Literal["yw2aiw"] | Literal["aiw2yw"], + mode: Literal["yw2aiw", "aiw2yw"], input: Path, mapfile: Path, output: Path, cwd: Path, + options: Iterable[str] = (), ): super().__init__( [ "yosys-witness", mode, + *(options or []), str(relative_to(input, cwd)), str(relative_to(mapfile, cwd)), str(relative_to(output, cwd)), @@ -381,87 +420,467 @@ def __init__(self, design_il: Path, input_yw: Path, output_fst: Path, cwd: Path) self.log_output() +class Callback(tl.TaskGroup): + recover_from_errors = False + + def __init__(self, enumeration: Enumeration): + super().__init__() + self[tl.LogContext].scope = "callback" + self.search_next = False + self.enumeration = enumeration + self.tmp_counter = 0 + + async def step_callback(self, step: int) -> Literal["advance", "search"]: + with self.as_current_task(): + return await self._callback(step=step) + + async def unsat_callback(self, step: int) -> Literal["advance", "search"]: + with self.as_current_task(): + return await self._callback(step=step, unsat=True) + + async def trace_callback( + self, step: int, path: Path + ) -> Literal["advance", "search"]: + with self.as_current_task(): + return await self._callback(step=step, trace_path=path) + + async def _callback( + self, + step: int, + trace_path: Path | None = None, + unsat: bool = False, + ) -> Literal["advance", "search"]: + if not self.is_active(): + if unsat: + return "advance" + return "search" + if trace_path is None and self.search_next: + if unsat: + return "advance" + return "search" + self.search_next = False + + info = dict(step=step, enabled=sorted(self.enumeration.active_assumptions)) + + if trace_path: + self.callback_write( + {**info, "event": "trace", "trace_path": str(trace_path)} + ) + elif unsat: + self.callback_write({**info, "event": "unsat"}) + else: + self.callback_write({**info, "event": "step"}) + + while True: + try: + response: dict[Any, Any] | Any = await self.callback_read() + if not isinstance(response, (Exception, dict)): + raise ValueError( + f"expected JSON object, got: {json.dumps(response)}" + ) + if isinstance(response, Exception): + raise response + + did_something = False + + if "block_yw" in response: + did_something = True + block_yw: Any = response["block_yw"] + + if not isinstance(block_yw, str): + raise ValueError( + "'block_yw' must be a string containing a file path, " + f"got: {json.dumps(block_yw)}" + ) + name: Any = response.get("name") + if name is not None and not isinstance(name, str): + raise ValueError( + "optional 'name' must be a string when present, " + "got: {json.dumps(name)}" + ) + self.enumeration.block_trace(Path(block_yw), name=name) + + if "block_aiw" in response: + did_something = True + block_aiw: Any = response["block_aiw"] + if not isinstance(block_aiw, str): + raise ValueError( + "'block_yw' must be a string containing a file path, " + f"got: {json.dumps(block_aiw)}" + ) + name: Any = response.get("name") + if name is not None and not isinstance(name, str): + raise ValueError( + "optional 'name' must be a string when present, " + "got: {json.dumps(name)}" + ) + + tmpdir = App.work_subdir / "tmp" + tmpdir.mkdir(exist_ok=True) + self.tmp_counter += 1 + block_yw = tmpdir / f"callback_{self.tmp_counter}.yw" + + aiw2yw = YosysWitness( + "aiw2yw", + Path(block_aiw), + App.cache_dir / "design_aiger.ywa", + Path(block_yw), + cwd=tmpdir, + options=("--skip-x", "--present-only"), + ) + aiw2yw[tl.LogContext].scope = f"aiw2yw[callback_{self.tmp_counter}]" + + await aiw2yw.finished + self.enumeration.block_trace(Path(block_yw), name=name) + + if "disable" in response: + did_something = True + name = response["disable"] + if not isinstance(name, str): + raise ValueError( + "'disable' must be a string representing an assumption, " + f"got: {json.dumps(name)}" + ) + self.enumeration.disable_assumption(name) + if "enable" in response: + did_something = True + name = response["enable"] + if not isinstance(name, str): + raise ValueError( + "'disable' must be a string representing an assumption, " + f"got: {json.dumps(name)}" + ) + self.enumeration.enable_assumption(name) + action: Any = response.get("action") + if action == "next": + did_something = True + self.search_next = True + action = "search" + if action in ("search", "advance"): + return action + if not did_something: + raise ValueError( + f"could not interpret callback response: {response}" + ) + except Exception as e: + tl.log_exception(e, raise_error=False) + if not self.recover_from_errors: + raise + + def is_active(self) -> bool: + return False + + def callback_write(self, data: Any): + return + + async def callback_read(self) -> Any: + raise NotImplementedError("must be implemented in Callback subclass") + + +class InteractiveCallback(Callback): + recover_from_errors = True + + interactive_shortcuts = { + "n": '{"action": "next"}', + "s": '{"action": "search"}', + "a": '{"action": "advance"}', + } + + def __init__(self, enumeration: Enumeration): + super().__init__(enumeration) + self.__eof_reached = False + + def is_active(self) -> bool: + return not self.__eof_reached + + def callback_write(self, data: Any): + print(f"callback: {json.dumps(data)}") + + async def callback_read(self) -> Any: + future: asyncio.Future[Any] = asyncio.Future() + + loop = asyncio.get_event_loop() + + def blocking_read(): + try: + try: + result = "" + while not result: + result = input( + "callback> " if sys.stdout.isatty() else "" + ).strip() + result = self.interactive_shortcuts.get(result, result) + result_data = json.loads(result) + except EOFError: + print() + self.__eof_reached = True + result_data = dict(action="next") + loop.call_soon_threadsafe(lambda: future.set_result(result_data)) + except Exception as exc: + exception = exc + loop.call_soon_threadsafe(lambda: future.set_exception(exception)) + + thread = threading.Thread(target=blocking_read, daemon=True) + thread.start() + + return await future + + +class ProcessCallback(Callback): + def __init__(self, enumeration: Enumeration, command: list[str]): + super().__init__(enumeration) + self[tl.LogContext].scope = "callback" + self.__eof_reached = False + self._command = command + self._lines: list[asyncio.Future[str]] = [asyncio.Future()] + + async def on_prepare(self) -> None: + self.process = tl.Process(self._command, cwd=Path.cwd(), interact=True) + self.process.use_lease = False + + future_line: asyncio.Future[str] = self._lines[-1] + + def stdout_handler(event: tl.process.StdoutEvent): + nonlocal future_line + future_line.set_result(event.output) + future_line = asyncio.Future() + self._lines.append(future_line) + + self.process.sync_handle_events(tl.process.StdoutEvent, stdout_handler) + + def stderr_handler(event: tl.process.StderrEvent): + tl.log(event.output) + + self.process.sync_handle_events(tl.process.StderrEvent, stderr_handler) + + def exit_handler(event: tl.process.ExitEvent): + self._lines[-1].set_exception(EOFError("callback process exited")) + + self.process.sync_handle_events(tl.process.ExitEvent, exit_handler) + + def is_active(self) -> bool: + return not self.__eof_reached + + def callback_write(self, data: Any): + if not self.process.is_finished: + self.process.write(json.dumps(data) + "\n") + + async def callback_read(self) -> Any: + future_line = self._lines.pop(0) + try: + data = json.loads(await future_line) + tl.log_debug(f"callback action: {data}") + return data + except EOFError: + self._lines.insert(0, future_line) + self.__eof_reached = True + return dict(action="next") + + class Enumeration(tl.Task): + callback_mode: Literal["off", "interactive", "process"] + callback_auto_search: bool = False + def __init__(self, aig_model: tl.Task): self.aig_model = aig_model + self._pending_blocks: list[tuple[str | None, Path]] = [] + self.named_assumptions: StableSet[str] = StableSet() + self.active_assumptions: StableSet[str] = StableSet() + super().__init__() + async def on_prepare(self) -> None: + if App.callback: + if App.callback == ["-"]: + self.callback_task = InteractiveCallback(self) + else: + self.callback_task = ProcessCallback(self, App.callback) + else: + self.callback_task = Callback(self) + async def on_run(self) -> None: - smtbmc = Smtbmc(App.work_dir / "model" / "design_smt2.smt2") + self.smtbmc = smtbmc = Smtbmc(App.work_dir / "model" / "design_smt2.smt2") + self._push_level = 0 await smtbmc.ping() - pred = None - - i = 0 + i = -1 limit = App.depth first_failure = None - while i <= limit: - tl.log(f"Checking assumptions in step {i}..") - presat_checked = await batch( - smtbmc.bmc_step(i, initial=i == 0, assertions=None, pred=pred), - smtbmc.check(), - ) - if presat_checked != "sat": - if first_failure is None: - tl.log_error("Assumptions are not satisfiable") - else: - tl.log("No further counter-examples are reachable") - return - - tl.log(f"Checking assertions in step {i}..") - checked = await batch( - smtbmc.push(), - smtbmc.assertions(i, False), - smtbmc.check(), - ) - pred = i - if checked != "unsat": + checked = "skip" + + counter = 0 + + while i <= limit or limit < 0: + if checked != "skip": + checked = await self._search_counter_example(i) + + if checked == "unsat": + if i >= 0: + action = await self.callback_task.unsat_callback(i) + if action == "search": + continue + checked = "skip" + if checked == "skip": + checked = "unsat" + i += 1 + if i > limit and limit >= 0: + break + action = await self.callback_task.step_callback(i) + + pending = batch( + self._top(), + smtbmc.bmc_step( + i, initial=i == 0, assertions=None, pred=i - 1 if i else None + ), + ) + if action == "advance": + tl.log(f"Skipping step {i}") + await batch(pending, smtbmc.assertions(i)) + checked = "skip" + continue + assert action == "search" + + tl.log(f"Checking assumptions in step {i}..") + presat_checked = await batch( + pending, + smtbmc.check(), + ) + if presat_checked != "sat": + if first_failure is None: + tl.log_error("Assumptions are not satisfiable") + else: + tl.log("No further counter-examples are reachable") + smtbmc.close_stdin() + return + + tl.log(f"Checking assertions in step {i}..") + counter = 0 + continue + elif checked == "sat": if first_failure is None: first_failure = i - limit = i + App.enum_depth + if App.enum_depth < 0: + limit = -1 + else: + limit = i + App.enum_depth tl.log("BMC failed! Enumerating counter-examples..") - counter = 0 - assert checked == "sat" path = App.trace_dir_full / f"trace{i}_{counter}.yw" + await smtbmc.incremental_command(cmd="write_yw_trace", path=str(path)) + tl.log(f"Written counter-example to {path.name}") - while checked == "sat": - await smtbmc.incremental_command( - cmd="write_yw_trace", path=str(path) - ) - tl.log(f"Written counter-example to {path.name}") - - minimize = MinimizeTrace(path.name, self.aig_model) - minimize.depends_on(self.aig_model) - - await minimize.aiw2yw.finished - - min_path = App.trace_dir_min / f"trace{i}_{counter}.yw" - - checked = await batch( - smtbmc.incremental_command( - cmd="read_yw_trace", - name="last", - path=str(min_path), - skip_x=True, - ), - smtbmc.assert_( - ["not", ["and", *(["yw", "last", k] for k in range(i + 1))]] - ), - smtbmc.check(), - ) + minimize = MinimizeTrace(path.name, self.aig_model) + minimize.depends_on(self.aig_model) + + await minimize.aiw2yw.finished - counter += 1 - path = App.trace_dir_full / f"trace{i}_{counter}.yw" + min_path = App.trace_dir_min / f"trace{i}_{counter}.yw" - await batch(smtbmc.pop(), smtbmc.assertions(i)) + action = await self.callback_task.trace_callback(i, min_path) + if action == "advance": + tl.log("Skipping remaining counter-examples for this step") + checked = "skip" + continue + assert action == "search" - i += 1 + self.block_trace(min_path) + counter += 1 + else: + tl.log_error(f"Unexpected solver result: {checked!r}") smtbmc.close_stdin() + def block_trace(self, path: Path, name: str | None = None): + if name is not None: + if name in self.named_assumptions: + raise ValueError(f"an assumption with name {name} was already defined") + self.named_assumptions.add(name) + self.active_assumptions.add(name) + + self._pending_blocks.append((name, path.absolute())) + + def enable_assumption(self, name: str): + if name not in self.named_assumptions: + raise ValueError(f"unknown assumption {name!r}") + self.active_assumptions.add(name) + + def disable_assumption(self, name: str): + if name not in self.named_assumptions: + raise ValueError(f"unknown assumption {name!r}") + self.active_assumptions.discard(name) + + def _top(self) -> Awaitable[Any]: + return batch(*(self._pop() for _ in range(self._push_level))) + + def _pop(self) -> Awaitable[Any]: + self._push_level -= 1 + tl.log_debug(f"pop to {self._push_level}") + return self.smtbmc.pop() + + def _push(self) -> Awaitable[Any]: + self._push_level += 1 + tl.log_debug(f"push to {self._push_level}") + return self.smtbmc.push() + + def _search_counter_example(self, step: int) -> Awaitable[Any]: + smtbmc = self.smtbmc + pending_blocks, self._pending_blocks = self._pending_blocks, [] + + pending = self._top() + + for name, block_path in pending_blocks: + result = smtbmc.incremental_command( + cmd="read_yw_trace", + name="last", + path=str(block_path), + skip_x=True, + ) + + async def check_yw_trace_len(): + last_step = (await result).get("last_step", step) + if last_step > step: + tl.log_warning( + f"Ignoring future time steps " + f"{step + 1} to {last_step} of " + f"{relative_to(block_path, Path.cwd())}" + ) + return last_step + + expr = [ + "not", + ["and", *(["yw", "last", k] for k in range(step + 1))], + ] + + if name is not None: + name_id = safe_smtlib_id(f"cexenum trace {name}") + pending = batch( + pending, smtbmc.smtlib(f"(declare-const {name_id} Bool)") + ) + expr = ["=", expr, ["smtlib", name_id, "Bool"]] + + pending = batch( + pending, + check_yw_trace_len(), + smtbmc.assert_(expr), + ) + + pending = batch( + pending, + self._push(), + smtbmc.assertions(step, False), + ) + + for name in self.active_assumptions: + name_id = safe_smtlib_id(f"cexenum trace {name}") + pending = batch(pending, smtbmc.assert_(["smtlib", name_id, "Bool"])) + + return batch( + pending, + smtbmc.check(), + ) + class Smtbmc(tl.process.Process): def __init__(self, smt2_model: Path): @@ -470,6 +889,7 @@ def __init__(self, smt2_model: Path): [ "yosys-smtbmc", "--incremental", + "--noprogress", *App.smtbmc_options, str(smt2_model), ], @@ -477,11 +897,15 @@ def __init__(self, smt2_model: Path): ) self.name = "smtbmc" - self.expected_results = [] + self.expected_results: list[asyncio.Future[Any]] = [] async def on_run(self) -> None: - def output_handler(event: tl.process.StderrEvent): - result = json.loads(event.output) + def output_handler(event: tl.process.StdoutEvent): + line = event.output.strip() + if line.startswith("{"): + result = json.loads(event.output) + else: + result = dict(msg=line) tl.log_debug(f"smtbmc > {result!r}") if "err" in result: exception = tl.logging.LoggedError( @@ -501,11 +925,11 @@ def output_handler(event: tl.process.StderrEvent): def ping(self) -> Awaitable[None]: return self.incremental_command(cmd="ping") - def incremental_command(self, **command: dict[Any]) -> Awaitable[Any]: + def incremental_command(self, **command: Any) -> Awaitable[Any]: tl.log_debug(f"smtbmc < {command!r}") self.write(json.dumps(command)) self.write("\n") - result = asyncio.Future() + result: asyncio.Future[Any] = asyncio.Future() self.expected_results.append(result) return result @@ -522,6 +946,9 @@ def pop(self) -> Awaitable[None]: def check(self) -> Awaitable[str]: return self.incremental_command(cmd="check") + def smtlib(self, command: str) -> Awaitable[str]: + return self.incremental_command(cmd="smtlib", command=command) + def assert_antecedent(self, expr: Any) -> Awaitable[None]: return self.incremental_command(cmd="assert_antecedent", expr=expr) @@ -565,7 +992,7 @@ def bmc_step( assertions: bool | None = True, pred: int | None = None, ) -> Awaitable[None]: - futures = [] + futures: list[Awaitable[None]] = [] futures.append(self.new_step(step)) futures.append(self.hierarchy(step)) futures.append(self.assumptions(step))