diff --git a/README.md b/README.md index a7df06c..f3f02a3 100644 --- a/README.md +++ b/README.md @@ -174,14 +174,14 @@ infermesh generate \ # Generate — from a JSONL file, results to another JSONL file # Each input line: {"prompt": "..."} or {"messages": [...]} or {"responses_input": "..."} -# Output includes an _index field so interrupted runs can be resumed. +# Output includes an _index field; a checkpoint file results.checkpoint.sqlite is kept. infermesh generate \ --model openai/gpt-4.1-mini \ --api-base https://api.openai.com/v1 \ --input-jsonl prompts.jsonl \ --output-jsonl results.jsonl -# Resume an interrupted run — skips already-completed rows and appends new ones +# Resume an interrupted run — reads results.checkpoint.sqlite, skips settled rows, appends the rest infermesh generate \ --model openai/gpt-4.1-mini \ --api-base https://api.openai.com/v1 \ @@ -189,6 +189,14 @@ infermesh generate \ --output-jsonl results.jsonl \ --resume +# Custom mapper — transform raw source records before sending to the model +# The mapper receives each record as a dict; must return {"input": ..., "metadata": ...} +infermesh generate \ + --model openai/gpt-4.1-mini \ + --input-jsonl dataset.jsonl \ + --output-jsonl results.jsonl \ + --mapper mypackage.prompts:build_prompt + # Create embeddings infermesh embed \ --model text-embedding-3-small \ diff --git a/docs/guide.md b/docs/guide.md index 21d6013..bc0ac3a 100644 --- a/docs/guide.md +++ b/docs/guide.md @@ -92,7 +92,7 @@ with open("results.jsonl", "w") as out, \ The callback receives: | Argument | Type | Notes | -|---|---|---| +| --- | --- | --- | | `index` | `int` | Position in `input_batch` (global item index, not micro-batch index) | | `result` | `GenerationResult \| EmbeddingResult \| TranscriptionResult \| None` | `None` on failure | | `error` | `BaseException \| None` | `None` on success | @@ -192,9 +192,20 @@ Input rows for `infermesh generate` may contain any of the following fields: ### Resuming an Interrupted Run -If a long batch is interrupted (Ctrl-C, OOM, network loss), the output file -contains all rows completed so far. Re-run with `--resume` to skip them and -append only the remaining rows: +Every file-backed run writes a checkpoint file alongside the output: + +``` +results.jsonl ← your output (human-readable) +results.checkpoint.sqlite ← checkpoint file (resume state) +``` + +By default the checkpoint stays beside the output for portability and +discoverability. If you want the checkpoint on local scratch instead, pass +`--checkpoint-dir DIR` or set `INFERMESH_CHECKPOINT_DIR=DIR` before the run. +When you resume later, reuse the same checkpoint-dir setting. + +If a long batch is interrupted (Ctrl-C, OOM, network loss), re-run with +`--resume` to skip settled items and append only the remaining rows: ```bash # First attempt — interrupted partway through @@ -204,7 +215,7 @@ infermesh generate \ --input-jsonl prompts.jsonl \ --output-jsonl results.jsonl -# Resume — reads results.jsonl, skips completed _index values, appends the rest +# Resume — reads results.checkpoint.sqlite, skips settled items, appends the rest infermesh generate \ --model openai/gpt-4.1-mini \ --api-base https://api.openai.com/v1 \ @@ -213,9 +224,64 @@ infermesh generate \ --resume ``` -Results are written to disk one row at a time as each request completes, so -a crash only loses the requests that were in-flight at that moment. -`--resume` requires `--output-jsonl`. +Each source row is tracked by its content fingerprint plus its occurrence +count, so duplicate rows are resumed independently. Re-ordering the input file +before resuming is safe, and resumed rows keep the original `_index` values +from the first run. Removing rows, adding rows, or deduplicating the input +before resuming is not supported. Results are written to disk one row at a +time as each request completes, so a crash only loses the requests that were +in-flight at that moment. + +The workflow keeps a rolling in-flight window, so each settled row immediately +admits the next pending row until the source is exhausted. Output rows are +written in completion order, not input order. +Row-level generation failures become per-item `error` rows and do not abort +their siblings, but setup and workflow failures still stop the command. +Use the `_index` field to re-sort after the run if needed. + +`--resume` requires `--output-jsonl` and the matching checkpoint file from a +previous file-backed run. If the checkpoint is missing, +if the input and output paths are the same file, if the output file is missing +any settled `_index` rows recorded in the checkpoint, or if the current input +does not match the original row occurrences, infermesh fails fast instead of +guessing. + +### Custom Input Mapping with `--mapper` + +Use `--mapper` to transform raw source records before they are sent to the +model. This lets you drive generation from any record format without +preprocessing the source file. + +```bash +infermesh generate \ + --model openai/gpt-4.1-mini \ + --input-jsonl dataset.jsonl \ + --output-jsonl results.jsonl \ + --mapper mypackage.prompts:build_prompt +``` + +The mapper is imported as `package.module:function`. The function receives +each raw source record as a `dict` and must return a `dict` with at least an +`"input"` key: + +```python +# mypackage/prompts.py +def build_prompt(record: dict) -> dict: + return { + "input": f"Classify the following text:\n\n{record['body']}", + "metadata": {"doc_id": record["id"]}, + } +``` + +| Return key | Required | Notes | +|---|---|---| +| `"input"` | Yes | Passed directly to the generation endpoint | +| `"metadata"` | No | Copied into the output row under `"metadata"` when it is a JSON-serializable dict | + +Extra keys beyond `"input"` and `"metadata"` are ignored. Mapper failures +become per-item error rows — they do not abort the run. If you later resume a +file-backed run, infermesh requires the same mapper implementation that wrote +the original checkpoint file. ## Generate Text diff --git a/src/infermesh/_cli_support.py b/src/infermesh/_cli_support.py index b9a3e2a..0649460 100644 --- a/src/infermesh/_cli_support.py +++ b/src/infermesh/_cli_support.py @@ -228,6 +228,10 @@ def _validate_cli_deployments_toml( ) -> None: """Reject plaintext secrets in CLI-loaded router deployment config.""" + if "deployments" not in loaded: + raise ValueError( + f"TOML file {deployments_toml_path!r} is missing a [deployments] table." + ) for name, config in loaded["deployments"].items(): forbidden_path = _find_forbidden_secret_path( config, @@ -347,3 +351,33 @@ def _maybe_parse_json(value: str) -> Any: return json.loads(value) except json.JSONDecodeError: return None + + +def _build_generation_record( + orig_idx: int, + result: Any, + error: BaseException | None, + *, + parse_json: bool, +) -> dict[str, Any]: + """Convert one generation result into its JSONL output shape.""" + + if result is None: + return { + "_index": orig_idx, + "output_text": None, + "output_parsed": None, + "token_usage": None, + "request_id": None, + "finish_reason": None, + "error": str(error) if error else "unknown error", + } + return { + "_index": orig_idx, + "output_text": result.output_text, + "output_parsed": _maybe_parse_json(result.output_text) if parse_json else None, + "token_usage": _token_usage_to_dict(result), + "request_id": result.request_id, + "finish_reason": result.finish_reason, + "error": None, + } diff --git a/src/infermesh/_client_runtime.py b/src/infermesh/_client_runtime.py index e1053bd..dae73da 100644 --- a/src/infermesh/_client_runtime.py +++ b/src/infermesh/_client_runtime.py @@ -8,8 +8,9 @@ import threading import time import weakref +from collections.abc import Coroutine from dataclasses import dataclass -from typing import Any, Literal, cast +from typing import Any, Literal, TypeVar, cast from urllib.parse import urlparse from infermesh._utils import ( @@ -36,6 +37,7 @@ "cost-based-routing", "usage-based-routing-v2", ] +T = TypeVar("T") @dataclass(slots=True) @@ -197,6 +199,11 @@ def _initialize_runtime_state( self._deployments = self._coerce_deployments(deployments) self._closed = False + def _run_sync(self, coroutine: Coroutine[Any, Any, T]) -> T: + """Run a coroutine on the client-owned background event loop.""" + + return self._sync_runner.run(coroutine) + async def _dispatch_with_controls( self, *, diff --git a/src/infermesh/_workflow/__init__.py b/src/infermesh/_workflow/__init__.py new file mode 100644 index 0000000..89a23dd --- /dev/null +++ b/src/infermesh/_workflow/__init__.py @@ -0,0 +1,5 @@ +"""Internal workflow package.""" + +from .engine import run_generate_workflow + +__all__ = ["run_generate_workflow"] diff --git a/src/infermesh/_workflow/checkpoint.py b/src/infermesh/_workflow/checkpoint.py new file mode 100644 index 0000000..c26d8fc --- /dev/null +++ b/src/infermesh/_workflow/checkpoint.py @@ -0,0 +1,450 @@ +"""Checkpoint storage and persistence helpers for the workflow engine.""" + +from __future__ import annotations + +import hashlib +import json +import queue +import sqlite3 +import tempfile +import threading +from dataclasses import dataclass +from pathlib import Path +from typing import IO, Any + +from .models import CheckpointKey, _CheckpointItem +from .source import _iter_source_rows_with_keys + +_SCHEMA_VERSION = 1 +_RUN_METADATA_SINGLETON = 1 +_ITEM_INSERT_BATCH_SIZE = 1000 +_CHECKPOINT_PATH_HASH_LENGTH = 8 + +_PENDING_STATUS = 0 +_SUCCESS_STATUS = 1 +_ERROR_STATUS = 2 +_SETTLED_STATUSES = frozenset({_SUCCESS_STATUS, _ERROR_STATUS}) +_STATUS_NAMES = { + _PENDING_STATUS: "pending", + _SUCCESS_STATUS: "success", + _ERROR_STATUS: "error", +} +_STATUS_VALUES = {name: value for value, name in _STATUS_NAMES.items()} + + +def _checkpoint_path_for( + output_jsonl: str, *, checkpoint_dir: str | None = None +) -> Path: + """Derive the checkpoint file path from the output JSONL path.""" + + path = Path(output_jsonl) + checkpoint_stem = path.stem if path.suffix == ".jsonl" else path.name + if checkpoint_dir is None: + return path.with_name(checkpoint_stem + ".checkpoint.sqlite") + + override_dir = Path(checkpoint_dir).expanduser() + override_dir.mkdir(parents=True, exist_ok=True) + resolved_output_path = path.expanduser().resolve(strict=False) + path_hash = hashlib.sha256(str(resolved_output_path).encode("utf-8")).hexdigest()[ + :_CHECKPOINT_PATH_HASH_LENGTH + ] + return override_dir / f"{checkpoint_stem}.{path_hash}.checkpoint.sqlite" + + +def _configure_checkpoint_journal_mode(connection: sqlite3.Connection) -> str: + """Configure the checkpoint DB for portable rollback journaling.""" + + persist_mode = connection.execute("PRAGMA journal_mode=PERSIST").fetchone() + journal_mode = str(persist_mode[0]).lower() if persist_mode is not None else "" + if journal_mode == "persist": + return journal_mode + + delete_mode = connection.execute("PRAGMA journal_mode=DELETE").fetchone() + journal_mode = str(delete_mode[0]).lower() if delete_mode is not None else "" + if journal_mode == "delete": + return journal_mode + + raise RuntimeError( + "Checkpoint DB could not be configured for rollback journaling. " + f"SQLite reported journal_mode={journal_mode!r}." + ) + + +def _connect_checkpoint_db(checkpoint_path: Path) -> sqlite3.Connection: + """Open a read-write checkpoint database connection.""" + + connection = sqlite3.connect(checkpoint_path) + _configure_checkpoint_journal_mode(connection) + connection.execute("PRAGMA synchronous=FULL") + connection.execute("PRAGMA busy_timeout=5000") + return connection + + +def _connect_checkpoint_db_read_only(checkpoint_path: Path) -> sqlite3.Connection: + """Open a read-only checkpoint connection for resume validation.""" + + connection = sqlite3.connect( + checkpoint_path.expanduser().resolve(strict=False).as_uri() + "?mode=ro", + uri=True, + ) + connection.execute("PRAGMA query_only=ON") + return connection + + +def _initialize_checkpoint_db( + connection: sqlite3.Connection, mapping_fingerprint: str +) -> None: + """Create the checkpoint schema and write the run metadata row.""" + + connection.executescript( + """ + CREATE TABLE run_metadata ( + singleton INTEGER PRIMARY KEY CHECK (singleton = 1), + schema_version INTEGER NOT NULL, + mapping_fingerprint TEXT NOT NULL + ); + + CREATE TABLE items ( + record_fingerprint BLOB NOT NULL, + occurrence INTEGER NOT NULL, + output_index INTEGER NOT NULL, + status INTEGER NOT NULL, + error TEXT, + PRIMARY KEY (record_fingerprint, occurrence) + ); + + CREATE INDEX idx_items_status_output_index + ON items(status, output_index); + """ + ) + connection.execute( + """ + INSERT INTO run_metadata (singleton, schema_version, mapping_fingerprint) + VALUES (?, ?, ?) + """, + (_RUN_METADATA_SINGLETON, _SCHEMA_VERSION, mapping_fingerprint), + ) + + +def _load_run_metadata(connection: sqlite3.Connection) -> tuple[int, str]: + """Load the singleton run metadata row.""" + + row = connection.execute( + """ + SELECT schema_version, mapping_fingerprint + FROM run_metadata + WHERE singleton = ? + """, + (_RUN_METADATA_SINGLETON,), + ).fetchone() + if row is None: + raise ValueError("Checkpoint file is invalid: missing run metadata.") + return int(row[0]), str(row[1]) + + +def _insert_pending_checkpoint_items( + connection: sqlite3.Connection, + *, + prompt: str | None, + input_jsonl: str | None, +) -> None: + """Insert one pending checkpoint item per source row.""" + + batch: list[tuple[bytes, int, int, int, None]] = [] + insert_sql = """ + INSERT INTO items ( + record_fingerprint, + occurrence, + output_index, + status, + error + ) + VALUES (?, ?, ?, ?, ?) + """ + for source_row, checkpoint_key in _iter_source_rows_with_keys( + prompt=prompt, + input_jsonl=input_jsonl, + ): + batch.append( + ( + checkpoint_key.record_fingerprint, + checkpoint_key.occurrence, + source_row.source_index, + _PENDING_STATUS, + None, + ) + ) + if len(batch) >= _ITEM_INSERT_BATCH_SIZE: + connection.executemany(insert_sql, batch) + batch.clear() + if batch: + connection.executemany(insert_sql, batch) + + +def _bootstrap_checkpoint( + *, + prompt: str | None, + input_jsonl: str | None, + checkpoint_path: Path, + mapping_fingerprint: str, +) -> None: + """Create the checkpoint DB and bootstrap one pending row per source item.""" + + connection = _connect_checkpoint_db(checkpoint_path) + try: + _initialize_checkpoint_db(connection, mapping_fingerprint) + _insert_pending_checkpoint_items( + connection, + prompt=prompt, + input_jsonl=input_jsonl, + ) + connection.commit() + finally: + connection.close() + + +def _stage_fresh_workflow_files( + *, + prompt: str | None, + input_jsonl: str | None, + output_path: Path, + checkpoint_path: Path, + mapping_fingerprint: str, +) -> None: + """Stage fresh workflow artifacts and replace existing ones after bootstrap.""" + + staged_output_path: Path | None = None + staged_checkpoint_path: Path | None = None + try: + with tempfile.NamedTemporaryFile( + "wb", + dir=checkpoint_path.parent, + prefix=f".{checkpoint_path.name}.", + suffix=".tmp", + delete=False, + ) as file_handle: + staged_checkpoint_path = Path(file_handle.name) + _bootstrap_checkpoint( + prompt=prompt, + input_jsonl=input_jsonl, + checkpoint_path=staged_checkpoint_path, + mapping_fingerprint=mapping_fingerprint, + ) + + with tempfile.NamedTemporaryFile( + "w", + encoding="utf-8", + dir=output_path.parent, + prefix=f".{output_path.name}.", + suffix=".tmp", + delete=False, + ) as file_handle: + staged_output_path = Path(file_handle.name) + + # Replace the visible artifacts only after checkpoint bootstrap has + # succeeded, so old run files survive bootstrap failures intact. + staged_checkpoint_path.replace(checkpoint_path) + staged_output_path.replace(output_path) + finally: + if staged_checkpoint_path is not None and staged_checkpoint_path.exists(): + staged_checkpoint_path.unlink() + if staged_output_path is not None and staged_output_path.exists(): + staged_output_path.unlink() + + +def _load_checkpoint_item( + connection: sqlite3.Connection, checkpoint_key: CheckpointKey +) -> _CheckpointItem | None: + """Load one checkpoint item by its occurrence-aware key.""" + + row = connection.execute( + """ + SELECT output_index, status, error + FROM items + WHERE record_fingerprint = ? AND occurrence = ? + """, + checkpoint_key.sql_params(), + ).fetchone() + if row is None: + return None + error_value = row[2] + if error_value is not None and not isinstance(error_value, str): + raise ValueError("Checkpoint file is invalid: item error column must be text.") + return _CheckpointItem( + output_index=int(row[0]), + status=int(row[1]), + error=error_value, + ) + + +def _mark_checkpoint_item_settled( + connection: sqlite3.Connection, + checkpoint_key: CheckpointKey, + *, + status: int, + error: str | None, +) -> None: + """Update one checkpoint item from pending to a terminal state.""" + + cursor = connection.execute( + """ + UPDATE items + SET status = ?, error = ? + WHERE record_fingerprint = ? AND occurrence = ? + """, + (status, error, *checkpoint_key.sql_params()), + ) + if cursor.rowcount != 1: + raise RuntimeError("Checkpoint item update failed for settled workflow row.") + connection.commit() + + +@dataclass +class _PersistenceRequest: + """One settled row that must be durably written by the sink thread.""" + + record: dict[str, Any] + checkpoint_key: CheckpointKey + status: int + error: str | None + done: threading.Event + failure: BaseException | None = None + + +@dataclass +class _PersistenceShutdown: + """Signal the sink thread to flush and stop.""" + + done: threading.Event + + +class _FileBackedPersistenceSink: + """Serialize output/checkpoint writes onto one dedicated thread.""" + + def __init__(self, *, output_path: Path, checkpoint_path: Path) -> None: + self._output_path = output_path + self._checkpoint_path = checkpoint_path + self._queue: queue.Queue[_PersistenceRequest | _PersistenceShutdown] = ( + queue.Queue() + ) + self._started = threading.Event() + self._failure: BaseException | None = None + self._closed = False + self._thread = threading.Thread( + target=self._run, + name="infermesh-generate-persistence", + daemon=True, + ) + self._thread.start() + self._started.wait() + self._raise_if_failed() + + def write_record( + self, + record: dict[str, Any], + checkpoint_key: CheckpointKey, + *, + status: int, + error: str | None, + ) -> None: + """Persist one settled record and checkpoint update.""" + + if self._closed: + raise RuntimeError("Cannot write to a closed persistence sink.") + self._raise_if_failed() + request = _PersistenceRequest( + record=record, + checkpoint_key=checkpoint_key, + status=status, + error=error, + done=threading.Event(), + ) + self._queue.put(request) + self._wait_for_event(request.done) + if request.failure is not None: + raise request.failure + self._raise_if_failed() + + def close(self) -> None: + """Stop the sink thread and re-raise any background failure.""" + + if self._closed: + self._raise_if_failed() + return + self._closed = True + if self._thread.is_alive(): + shutdown = _PersistenceShutdown(done=threading.Event()) + self._queue.put(shutdown) + self._wait_for_event(shutdown.done) + self._thread.join() + self._raise_if_failed() + + def _run(self) -> None: + out_file: IO[str] | None = None + connection: sqlite3.Connection | None = None + try: + out_file = open(self._output_path, "a", encoding="utf-8") # noqa: SIM115 + connection = _connect_checkpoint_db(self._checkpoint_path) + self._started.set() + while True: + item = self._queue.get() + if isinstance(item, _PersistenceShutdown): + item.done.set() + return + try: + # Write the user-facing output first, then mark the + # checkpoint item settled. A crash between the two can cause + # duplicate work on resume, but not silent row loss. + out_file.write(json.dumps(item.record) + "\n") + out_file.flush() + _mark_checkpoint_item_settled( + connection, + item.checkpoint_key, + status=item.status, + error=item.error, + ) + except BaseException as exc: # noqa: BLE001 + self._set_failure(exc) + item.failure = exc + item.done.set() + self._fail_pending_items(exc) + return + item.done.set() + except BaseException as exc: # noqa: BLE001 + self._set_failure(exc) + self._fail_pending_items(exc) + finally: + self._started.set() + if connection is not None: + connection.close() + if out_file is not None: + out_file.close() + + def _wait_for_event(self, event: threading.Event) -> None: + while not event.wait(timeout=0.1): + if not self._thread.is_alive(): + break + if event.is_set(): + return + self._raise_if_failed() + raise RuntimeError("Persistence sink stopped before acknowledging a write.") + + def _set_failure(self, exc: BaseException) -> None: + if self._failure is None: + self._failure = exc + + def _raise_if_failed(self) -> None: + if self._failure is not None: + raise self._failure + + def _fail_pending_items(self, exc: BaseException) -> None: + while True: + try: + item = self._queue.get_nowait() + except queue.Empty: + return + if isinstance(item, _PersistenceShutdown): + item.done.set() + continue + item.failure = exc + item.done.set() diff --git a/src/infermesh/_workflow/engine.py b/src/infermesh/_workflow/engine.py new file mode 100644 index 0000000..955ab9b --- /dev/null +++ b/src/infermesh/_workflow/engine.py @@ -0,0 +1,276 @@ +"""Generate workflow engine orchestration.""" + +from __future__ import annotations + +import asyncio +import sys +from collections.abc import Callable +from concurrent.futures import ThreadPoolExecutor +from typing import TYPE_CHECKING, Any + +from infermesh._batch_utils import cancel_tasks +from infermesh._cli_support import _build_generation_record, _write_jsonl + +from .checkpoint import _ERROR_STATUS, _SUCCESS_STATUS, _FileBackedPersistenceSink +from .models import CheckpointKey, _PreparedWorkItem, _SourceExhausted, _WorkItem +from .prepare import Preparer +from .runtime import _prepare_generate_run_resources + +if TYPE_CHECKING: + from infermesh.client import LMClient + from infermesh.types import EndpointType + + +def _write_item_result( + persistence_sink: _FileBackedPersistenceSink | None, + output_index: int, + checkpoint_key: CheckpointKey, + result: Any, + error: BaseException | None, + *, + metadata: dict[str, Any] | None, + parse_json: bool, +) -> dict[str, Any]: + """Build one settled item and optionally persist it through the sink.""" + + record = _build_generation_record( + output_index, + result, + error, + parse_json=parse_json, + ) + if metadata is not None: + record["metadata"] = metadata + if persistence_sink is not None: + persistence_sink.write_record( + record, + checkpoint_key, + status=_ERROR_STATUS if error else _SUCCESS_STATUS, + error=str(error) if error else None, + ) + return record + + +async def _agenerate_work_item( + client: LMClient, + item: _WorkItem, + *, + endpoint: EndpointType, +) -> tuple[Any, Exception | None]: + """Run one workflow item and return its settled outcome.""" + + try: + result = await client.agenerate(item.mapped_input, endpoint=endpoint) + except asyncio.CancelledError: + raise + except Exception as exc: + return None, exc + return result, None + + +def _emit_settled_work_item( + persistence_sink: _FileBackedPersistenceSink | None, + item: _WorkItem, + result: Any, + error: BaseException | None, + *, + parse_json: bool, + on_progress: Callable[[], Any] | None, +) -> None: + """Persist or print one settled workflow item, then tick progress.""" + + record = _write_item_result( + persistence_sink, + item.output_index, + item.checkpoint_key, + result, + error, + metadata=item.metadata, + parse_json=parse_json, + ) + if persistence_sink is None: + _write_jsonl([record], None) + if on_progress is not None: + on_progress() + + +def _emit_immediate_error( + persistence_sink: _FileBackedPersistenceSink | None, + prepared: _PreparedWorkItem, + *, + parse_json: bool, + on_progress: Callable[[], Any] | None, +) -> None: + """Write one per-item error row without aborting the run, then invoke progress.""" + + record = _write_item_result( + persistence_sink, + prepared.output_index, + prepared.checkpoint_key, + None, + prepared.immediate_error, + metadata=None, + parse_json=parse_json, + ) + if persistence_sink is None: + _write_jsonl([record], None) + if on_progress is not None: + on_progress() + + +async def _arun_generate_source_rows( + client: LMClient, + *, + preparer: Preparer, + preparer_executor: ThreadPoolExecutor, + resume: bool, + persistence_sink: _FileBackedPersistenceSink | None, + window_size: int, + endpoint: EndpointType, + parse_json: bool, + on_progress: Callable[[], Any] | None, +) -> None: + """Stream mapped rows through a rolling in-flight generation window.""" + + if window_size < 1: + raise ValueError("window_size must be a positive integer.") + + loop = asyncio.get_running_loop() + active_tasks: dict[asyncio.Task[tuple[Any, Exception | None]], _WorkItem] = {} + source_exhausted = False + any_work = False + + async def fill_window() -> None: + nonlocal source_exhausted + nonlocal any_work + + while len(active_tasks) < window_size and not source_exhausted: + prepared = await loop.run_in_executor( + preparer_executor, preparer.next_prepared + ) + if isinstance(prepared, _SourceExhausted): + source_exhausted = True + return + any_work = True + if prepared.immediate_error is not None: + _emit_immediate_error( + persistence_sink, + prepared, + parse_json=parse_json, + on_progress=on_progress, + ) + continue + + work_item = prepared.work_item + assert work_item is not None + task = asyncio.create_task( + _agenerate_work_item(client, work_item, endpoint=endpoint) + ) + active_tasks[task] = work_item + + try: + await fill_window() + while active_tasks: + done, _ = await asyncio.wait( + active_tasks, + return_when=asyncio.FIRST_COMPLETED, + ) + for task in done: + item = active_tasks.pop(task) + result, error = task.result() + _emit_settled_work_item( + persistence_sink, + item, + result, + error, + parse_json=parse_json, + on_progress=on_progress, + ) + await fill_window() + except BaseException: + await cancel_tasks(list(active_tasks)) + raise + finally: + await loop.run_in_executor(preparer_executor, preparer.close) + + if resume and not any_work: + sys.stderr.write("Nothing to do — all rows already completed.\n") + + +def _run_generate_source_rows( + client: LMClient, + *, + preparer: Preparer, + resume: bool, + persistence_sink: _FileBackedPersistenceSink | None, + window_size: int, + endpoint: EndpointType, + parse_json: bool, + on_progress: Callable[[], Any] | None, +) -> None: + """Run the rolling generate scheduler on the client's background loop.""" + + with ThreadPoolExecutor( + max_workers=1, + thread_name_prefix="infermesh-generate-prep", + ) as preparer_executor: + client._run_sync( + _arun_generate_source_rows( + client, + preparer=preparer, + preparer_executor=preparer_executor, + resume=resume, + persistence_sink=persistence_sink, + window_size=window_size, + endpoint=endpoint, + parse_json=parse_json, + on_progress=on_progress, + ) + ) + + +def run_generate_workflow( + client: LMClient, + *, + prompt: str | None, + input_jsonl: str | None, + output_jsonl: str | None, + checkpoint_dir: str | None, + mapper_spec: str | None, + resume: bool, + endpoint: EndpointType, + window_size: int, + parse_json: bool, + on_progress: Callable[[], Any] | None = None, + on_status: Callable[[str], Any] | None = None, +) -> None: + """Run the generate workflow engine.""" + + if resume and output_jsonl is None: + raise ValueError( + "--resume requires --output-jsonl because resumed runs need a " + "checkpoint file." + ) + + run = _prepare_generate_run_resources( + prompt=prompt, + input_jsonl=input_jsonl, + output_jsonl=output_jsonl, + checkpoint_dir=checkpoint_dir, + mapper_spec=mapper_spec, + resume=resume, + on_status=on_status, + ) + try: + _run_generate_source_rows( + client, + preparer=run.preparer, + resume=resume, + persistence_sink=run.persistence_sink, + window_size=window_size, + endpoint=endpoint, + parse_json=parse_json, + on_progress=on_progress, + ) + finally: + run.close() diff --git a/src/infermesh/_workflow/mapping.py b/src/infermesh/_workflow/mapping.py new file mode 100644 index 0000000..8b38235 --- /dev/null +++ b/src/infermesh/_workflow/mapping.py @@ -0,0 +1,130 @@ +"""Mapper loading and mapping strategy helpers for the workflow engine.""" + +from __future__ import annotations + +import contextlib +import hashlib +import importlib +import inspect +import json +from collections.abc import Callable +from typing import Any, cast + +# Encodes built-in field-extraction semantics. Bump the literal to invalidate +# existing checkpoints whenever the built-in mapping logic changes. +_BUILTIN_MAPPING_FINGERPRINT = hashlib.sha256( + b"infermesh.generate.builtin_mapping.v1" +).hexdigest() + + +def _load_mapper(mapper_spec: str) -> Callable[[dict[str, Any]], Any]: + r"""Load a mapper function from a ``\"package.module:function\"`` spec.""" + + module_path, sep, func_name = mapper_spec.rpartition(":") + if not sep or not module_path or not func_name: + raise ValueError( + f"--mapper must be 'package.module:function', got {mapper_spec!r}" + ) + module = importlib.import_module(module_path) + func = getattr(module, func_name, None) + if func is None: + raise ValueError(f"--mapper: {module_path!r} has no attribute {func_name!r}") + if not callable(func): + raise ValueError( + f"--mapper: {mapper_spec!r} resolved to a non-callable {type(func).__name__!r}" + ) + return cast(Callable[[dict[str, Any]], Any], func) + + +def _apply_mapper_or_builtin( + raw_record: dict[str, Any], mapper: Callable[[dict[str, Any]], Any] | None +) -> tuple[Any, dict[str, Any] | None] | Exception: + """Apply the mapper (or built-in field extraction) to ``raw_record``.""" + + if mapper is not None: + try: + result = mapper(raw_record) + except Exception as exc: # noqa: BLE001 + return exc + if not isinstance(result, dict): + return ValueError( + f"Mapper must return a dict, got {type(result).__name__!r}" + ) + if "input" not in result: + return KeyError("Mapper return value is missing required key 'input'") + return result["input"], result.get("metadata") + + missing = object() + for key in ("responses_input", "messages", "prompt"): + input_data = raw_record.get(key, missing) + if input_data is not missing and input_data is not None: + return input_data, None + return ValueError( + "Generation rows require 'prompt', 'messages', or 'responses_input'." + ) + + +def _validate_metadata(metadata: Any) -> dict[str, Any] | None | Exception: + """Validate mapper metadata before it reaches the sink.""" + + if metadata is None: + return None + if not isinstance(metadata, dict): + return TypeError("Mapper 'metadata' must be a dict when provided.") + try: + json.dumps(metadata) + except TypeError as exc: + return TypeError(f"Mapper 'metadata' must be JSON serializable: {exc}") + return metadata + + +def _compute_mapper_implementation_fingerprint( + mapper: Callable[[dict[str, Any]], Any], +) -> str: + """Return a stable fingerprint for the mapper implementation.""" + + module = inspect.getmodule(mapper) + if module is not None: + with contextlib.suppress(OSError, TypeError): + return hashlib.sha256(inspect.getsource(module).encode("utf-8")).hexdigest() + + with contextlib.suppress(OSError, TypeError): + return hashlib.sha256(inspect.getsource(mapper).encode("utf-8")).hexdigest() + + code = getattr(mapper, "__code__", None) + fallback_payload = repr( + { + "co_code": getattr(code, "co_code", None), + "co_consts": getattr(code, "co_consts", None), + "co_names": getattr(code, "co_names", None), + "defaults": getattr(mapper, "__defaults__", None), + "kwdefaults": getattr(mapper, "__kwdefaults__", None), + } + ) + return hashlib.sha256(fallback_payload.encode("utf-8")).hexdigest() + + +def _compute_mapping_fingerprint( + *, mapper_spec: str | None, mapper: Callable[[dict[str, Any]], Any] | None +) -> str: + """Return the fingerprint that ties a run to its mapping strategy.""" + + if mapper is None: + return _BUILTIN_MAPPING_FINGERPRINT + + module_name = getattr(mapper, "__module__", type(mapper).__module__) + qualname = getattr(mapper, "__qualname__", type(mapper).__qualname__) + payload = json.dumps( + { + "mapper_spec": mapper_spec, + "module_name": module_name, + "qualname": qualname, + "implementation_fingerprint": _compute_mapper_implementation_fingerprint( + mapper + ), + }, + sort_keys=True, + separators=(",", ":"), + ensure_ascii=False, + ) + return hashlib.sha256(payload.encode("utf-8")).hexdigest() diff --git a/src/infermesh/_workflow/models.py b/src/infermesh/_workflow/models.py new file mode 100644 index 0000000..79ba38b --- /dev/null +++ b/src/infermesh/_workflow/models.py @@ -0,0 +1,74 @@ +"""Workflow-internal data models.""" + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Any + + +@dataclass(frozen=True, slots=True) +class CheckpointKey: + """One logical workflow item in checkpoint storage.""" + + record_fingerprint: bytes + occurrence: int + + def sql_params(self) -> tuple[bytes, int]: + """Return the key in the shape expected by SQLite parameter binding.""" + + return (self.record_fingerprint, self.occurrence) + + +@dataclass(slots=True) +class _WorkItem: + """One unit of work in the generate workflow.""" + + output_index: int + checkpoint_key: CheckpointKey + mapped_input: Any + metadata: dict[str, Any] | None + + +@dataclass(frozen=True, slots=True) +class _PreparedWorkItem: + """One source row after resume/mapping validation.""" + + output_index: int + checkpoint_key: CheckpointKey + work_item: _WorkItem | None + immediate_error: BaseException | None = None + + +@dataclass(frozen=True, slots=True) +class _SourceExhausted: + """Sentinel returned when a blocking preparer is out of rows.""" + + +_SOURCE_EXHAUSTED = _SourceExhausted() + + +@dataclass(slots=True) +class _SourceRow: + """One source row and its parse outcome.""" + + source_index: int + raw_line: str + raw_record: dict[str, Any] | None + error: Exception | None + + +@dataclass(frozen=True, slots=True) +class _CheckpointItem: + """One logical checkpoint item loaded from SQLite.""" + + output_index: int + status: int + error: str | None + + +@dataclass(frozen=True, slots=True) +class _ResumePlan: + """Ephemeral planner DB that drives resumed file-backed source reads.""" + + planner_path: Path diff --git a/src/infermesh/_workflow/prepare.py b/src/infermesh/_workflow/prepare.py new file mode 100644 index 0000000..8a656de --- /dev/null +++ b/src/infermesh/_workflow/prepare.py @@ -0,0 +1,259 @@ +"""Work-item preparation helpers for the generate workflow.""" + +from __future__ import annotations + +import sqlite3 +from collections.abc import Callable, Generator +from pathlib import Path +from typing import Any, Protocol, cast + +from .checkpoint import ( + _SETTLED_STATUSES, + _connect_checkpoint_db_read_only, + _load_checkpoint_item, +) +from .mapping import _apply_mapper_or_builtin, _validate_metadata +from .models import ( + _SOURCE_EXHAUSTED, + CheckpointKey, + _PreparedWorkItem, + _ResumePlan, + _SourceExhausted, + _SourceRow, + _WorkItem, +) +from .resume import ResumePlanner +from .source import _iter_source_rows_with_keys + + +class Preparer(Protocol): + """Small blocking interface used by the rolling scheduler.""" + + def next_prepared(self) -> _PreparedWorkItem | _SourceExhausted: + """Return the next schedulable item or the exhaustion sentinel.""" + + def close(self) -> None: + """Release preparer resources.""" + + +class SequentialPreparer: + """Prepare source rows sequentially on one blocking worker thread.""" + + def __init__( + self, + *, + prompt: str | None, + input_jsonl: str | None, + resume: bool, + checkpoint_path: Path | None, + mapper: Callable[[dict[str, Any]], Any] | None, + ) -> None: + self._prompt = prompt + self._input_jsonl = input_jsonl + self._resume = resume + self._checkpoint_path = checkpoint_path + self._mapper = mapper + self._source_rows: ( + Generator[tuple[_SourceRow, CheckpointKey], None, None] | None + ) = None + self._checkpoint_connection: sqlite3.Connection | None = None + + def next_prepared(self) -> _PreparedWorkItem | _SourceExhausted: + """Return the next schedulable work item or source exhaustion sentinel.""" + + self._ensure_open() + assert self._source_rows is not None + for source_row, checkpoint_key in self._source_rows: + prepared = _prepare_generate_work_item( + source_row=source_row, + checkpoint_key=checkpoint_key, + resume=self._resume, + checkpoint_connection=self._checkpoint_connection, + mapper=self._mapper, + ) + if prepared is not None: + return prepared + return _SOURCE_EXHAUSTED + + def close(self) -> None: + """Release any blocking-thread resources held by the preparer.""" + + if self._source_rows is not None: + self._source_rows.close() + self._source_rows = None + if self._checkpoint_connection is not None: + self._checkpoint_connection.close() + self._checkpoint_connection = None + + def _ensure_open(self) -> None: + if self._source_rows is None: + self._source_rows = cast( + Generator[tuple[_SourceRow, CheckpointKey], None, None], + _iter_source_rows_with_keys( + prompt=self._prompt, + input_jsonl=self._input_jsonl, + ), + ) + if not self._resume or self._checkpoint_connection is not None: + return + if self._checkpoint_path is None: + raise RuntimeError("Resume path requires a checkpoint file path.") + self._checkpoint_connection = _connect_checkpoint_db_read_only( + self._checkpoint_path + ) + + +class PlannedResumePreparer: + """Prepare only pending rows from a precomputed resume plan.""" + + def __init__( + self, + *, + input_jsonl: str, + resume_plan: _ResumePlan, + mapper: Callable[[dict[str, Any]], Any] | None, + ) -> None: + self._input_jsonl = input_jsonl + self._resume_plan = resume_plan + self._mapper = mapper + self._planned_rows: ( + Generator[tuple[_SourceRow, int, CheckpointKey], None, None] | None + ) = None + + def next_prepared(self) -> _PreparedWorkItem | _SourceExhausted: + """Return the next pending row from the precomputed resume plan.""" + + self._ensure_open() + assert self._planned_rows is not None + for source_row, output_index, checkpoint_key in self._planned_rows: + return _prepare_mapped_work_item( + source_row=source_row, + output_index=output_index, + checkpoint_key=checkpoint_key, + mapper=self._mapper, + ) + return _SOURCE_EXHAUSTED + + def close(self) -> None: + """Release any blocking-thread resources held by the preparer.""" + + if self._planned_rows is not None: + self._planned_rows.close() + self._planned_rows = None + + def _ensure_open(self) -> None: + if self._planned_rows is not None: + return + self._planned_rows = cast( + Generator[tuple[_SourceRow, int, CheckpointKey], None, None], + ResumePlanner.iter_rows( + self._resume_plan, + input_jsonl=self._input_jsonl, + ), + ) + + +def _prepare_generate_work_item( + *, + source_row: _SourceRow, + checkpoint_key: CheckpointKey, + resume: bool, + checkpoint_connection: sqlite3.Connection | None, + mapper: Callable[[dict[str, Any]], Any] | None, +) -> _PreparedWorkItem | None: + """Convert one source row into schedulable work or an immediate error.""" + + output_index = _generate_output_index_for_resume_row( + resume=resume, + checkpoint_connection=checkpoint_connection, + checkpoint_key=checkpoint_key, + source_index=source_row.source_index, + ) + if output_index is None: + return None + return _prepare_mapped_work_item( + source_row=source_row, + output_index=output_index, + checkpoint_key=checkpoint_key, + mapper=mapper, + ) + + +def _prepare_mapped_work_item( + *, + source_row: _SourceRow, + output_index: int, + checkpoint_key: CheckpointKey, + mapper: Callable[[dict[str, Any]], Any] | None, +) -> _PreparedWorkItem: + """Prepare one already-selected row for mapping/generation.""" + + if source_row.error is not None: + return _PreparedWorkItem( + output_index=output_index, + checkpoint_key=checkpoint_key, + work_item=None, + immediate_error=source_row.error, + ) + + raw_record = source_row.raw_record + if raw_record is None: + raise RuntimeError( + "Invariant violated: source_row.raw_record is None after error check." + ) + + mapping_result = _apply_mapper_or_builtin(raw_record, mapper) + if isinstance(mapping_result, Exception): + return _PreparedWorkItem( + output_index=output_index, + checkpoint_key=checkpoint_key, + work_item=None, + immediate_error=mapping_result, + ) + + mapped_input, metadata = mapping_result + metadata_result = _validate_metadata(metadata) + if isinstance(metadata_result, Exception): + return _PreparedWorkItem( + output_index=output_index, + checkpoint_key=checkpoint_key, + work_item=None, + immediate_error=metadata_result, + ) + + return _PreparedWorkItem( + output_index=output_index, + checkpoint_key=checkpoint_key, + work_item=_WorkItem( + output_index=output_index, + checkpoint_key=checkpoint_key, + mapped_input=mapped_input, + metadata=metadata_result, + ), + ) + + +def _generate_output_index_for_resume_row( + *, + resume: bool, + checkpoint_connection: sqlite3.Connection | None, + checkpoint_key: CheckpointKey, + source_index: int, +) -> int | None: + """Return the output index for this row, or ``None`` if it is settled.""" + + if not resume: + return source_index + if checkpoint_connection is None: + raise RuntimeError("Resume path requires an open checkpoint connection.") + checkpoint_item = _load_checkpoint_item(checkpoint_connection, checkpoint_key) + if checkpoint_item is None: + raise ValueError( + "Resume source does not match the checkpoint file. Added, removed, or " + "modified row occurrences are not supported." + ) + if checkpoint_item.status in _SETTLED_STATUSES: + # Returning ``None`` lets sequential resume skip settled rows without + # emitting duplicate output or consuming a scheduler slot for them. + return None + return checkpoint_item.output_index diff --git a/src/infermesh/_workflow/resume.py b/src/infermesh/_workflow/resume.py new file mode 100644 index 0000000..3d55a4e --- /dev/null +++ b/src/infermesh/_workflow/resume.py @@ -0,0 +1,533 @@ +"""Resume validation and planning helpers for the workflow engine.""" + +from __future__ import annotations + +import json +import os +import sqlite3 +import tempfile +from collections.abc import Callable, Iterator +from pathlib import Path +from typing import Any, cast + +from .checkpoint import ( + _ERROR_STATUS, + _ITEM_INSERT_BATCH_SIZE, + _PENDING_STATUS, + _SCHEMA_VERSION, + _SUCCESS_STATUS, + _connect_checkpoint_db_read_only, + _load_run_metadata, +) +from .models import CheckpointKey, _ResumePlan, _SourceRow +from .source import ( + _compute_source_row_fingerprint, + _iter_binary_source_rows_with_offsets, + _iter_source_rows, + _load_source_row_at_offset, +) + +_STATUS_LOG_INTERVAL = 100_000 +_RESUME_SOURCE_MISMATCH_ERROR = ( + "Resume source does not match the checkpoint file. Added, removed, or " + "modified row occurrences are not supported." +) + + +class OutputIndexBitmap: + """Compact presence bitmap for observed output rows.""" + + def __init__(self) -> None: + self._bits = bytearray() + + def add(self, output_index: int) -> None: + """Mark one observed output index.""" + + if output_index < 0: + raise ValueError("Output rows must not use negative _index values.") + byte_index = output_index // 8 + if byte_index >= len(self._bits): + self._bits.extend(b"\x00" * (byte_index + 1 - len(self._bits))) + self._bits[byte_index] |= 1 << (output_index % 8) + + def contains(self, output_index: int) -> bool: + """Return whether the bitmap contains ``output_index``.""" + + if output_index < 0: + return False + byte_index = output_index // 8 + if byte_index >= len(self._bits): + return False + return bool(self._bits[byte_index] & (1 << (output_index % 8))) + + @classmethod + def load( + cls, + output_path: Path, + *, + on_status: Callable[[str], Any] | None = None, + ) -> OutputIndexBitmap: + """Load a bitmap of observed output indices from the output artifact.""" + + bitmap = cls() + with output_path.open(encoding="utf-8") as file_handle: + for line_number, line in enumerate(file_handle, start=1): + stripped = line.strip() + if not stripped: + continue + try: + row = json.loads(stripped) + except json.JSONDecodeError: + continue + if not isinstance(row, dict): + continue + output_index = row.get("_index") + if isinstance(output_index, int): + bitmap.add(output_index) + if on_status is not None and line_number % _STATUS_LOG_INTERVAL == 0: + on_status(f"Resume: scanned {line_number:,} output rows...") + return bitmap + + +class ResumeValidator: + """Validate resume state and optionally build a file-backed resume plan.""" + + def __init__( + self, + *, + output_path: Path, + checkpoint_path: Path, + mapping_fingerprint: str, + prompt: str | None, + input_jsonl: str | None, + on_status: Callable[[str], Any] | None = None, + ) -> None: + self._output_path = output_path + self._checkpoint_path = checkpoint_path + self._mapping_fingerprint = mapping_fingerprint + self._prompt = prompt + self._input_jsonl = input_jsonl + self._on_status = on_status + + def validate(self) -> _ResumePlan | None: + """Validate the checkpoint and return a resume plan when needed.""" + + if not self._checkpoint_path.exists(): + raise ValueError( + f"--resume requires checkpoint file {self._checkpoint_path}. " + "Start a fresh file-backed run first." + ) + if not self._output_path.exists(): + raise ValueError( + f"--resume requires output file {self._output_path} because " + f"checkpoint file {self._checkpoint_path} already exists." + ) + + connection = _connect_checkpoint_db_read_only(self._checkpoint_path) + try: + if self._on_status is not None: + self._on_status("Resume: validating checkpoint file...") + self._validate_mapping_fingerprint(connection) + self._validate_output_rows(connection) + if self._prompt is None and self._input_jsonl is not None: + # Only file-backed resume needs the planner; prompt runs have no + # seekable source file and can validate inline. + return ResumePlanner( + checkpoint_connection=connection, + input_jsonl=self._input_jsonl, + on_status=self._on_status, + ).build() + self._validate_source(connection) + return None + finally: + connection.close() + + def _validate_mapping_fingerprint(self, connection: sqlite3.Connection) -> None: + schema_version, checkpoint_mapping_fingerprint = _load_run_metadata(connection) + if schema_version != _SCHEMA_VERSION: + raise ValueError( + "Checkpoint file uses an unsupported schema version. Restart the run " + "without --resume." + ) + if checkpoint_mapping_fingerprint != self._mapping_fingerprint: + raise ValueError( + "Resume mapping does not match the checkpoint file. Use the original " + "mapper implementation or restart without --resume." + ) + + def _validate_output_rows(self, connection: sqlite3.Connection) -> None: + if self._on_status is not None: + self._on_status("Resume: validating output artifact...") + # A bitmap keeps this validation compact even when _index values are + # sparse or the run has already settled a very large number of rows. + output_indices = OutputIndexBitmap.load( + self._output_path, + on_status=self._on_status, + ) + missing_indices: list[int] = [] + for row in connection.execute( + """ + SELECT output_index + FROM items + WHERE status IN (?, ?) + ORDER BY output_index + """, + (_SUCCESS_STATUS, _ERROR_STATUS), + ): + output_index = int(row[0]) + if not output_indices.contains(output_index): + missing_indices.append(output_index) + if len(missing_indices) >= 11: + break + if missing_indices: + missing_text = ", ".join(str(index) for index in missing_indices[:10]) + if len(missing_indices) > 10: + missing_text += ", ..." + raise ValueError( + "Output file is missing settled checkpoint rows for _index values " + f"{missing_text}. Restore the output artifact or restart the run " + "without --resume." + ) + + def _validate_source(self, connection: sqlite3.Connection) -> None: + if self._on_status is not None: + self._on_status("Resume: validating input source...") + + remaining_counts = self._load_checkpoint_fingerprint_counts(connection) + for seen_count, source_row in enumerate( + _iter_source_rows(prompt=self._prompt, input_jsonl=self._input_jsonl), + start=1, + ): + fingerprint = _compute_source_row_fingerprint(source_row) + remaining = remaining_counts.get(fingerprint) + if remaining is None: + raise ValueError(_RESUME_SOURCE_MISMATCH_ERROR) + if remaining == 1: + del remaining_counts[fingerprint] + else: + remaining_counts[fingerprint] = remaining - 1 + if self._on_status is not None and seen_count % _STATUS_LOG_INTERVAL == 0: + self._on_status(f"Resume: scanned {seen_count:,} source rows...") + + if remaining_counts: + raise ValueError(_RESUME_SOURCE_MISMATCH_ERROR) + + @staticmethod + def _load_checkpoint_fingerprint_counts( + connection: sqlite3.Connection, + ) -> dict[bytes, int]: + return { + cast(bytes, row[0]): int(row[1]) + for row in connection.execute( + """ + SELECT record_fingerprint, COUNT(*) + FROM items + GROUP BY record_fingerprint + """ + ) + } + + +class ResumePlanner: + """Own the temporary SQLite database used to plan resumed file-backed runs.""" + + def __init__( + self, + *, + checkpoint_connection: sqlite3.Connection, + input_jsonl: str, + on_status: Callable[[str], Any] | None = None, + ) -> None: + self._checkpoint_connection = checkpoint_connection + self._input_jsonl = input_jsonl + self._on_status = on_status + + def build(self) -> _ResumePlan: + """Build the ephemeral planner DB for a resumed file-backed workflow.""" + + planner_path = self._create_path() + planner_connection: sqlite3.Connection | None = None + cleanup_planner_path = False + try: + planner_connection = self._connect_planner_db(planner_path) + self._initialize_db(planner_connection) + self._copy_checkpoint_items(planner_connection) + if self._on_status is not None: + self._on_status("Resume: building resume plan...") + self._index_source_rows(planner_connection) + self._materialize_source_items(planner_connection) + if self._on_status is not None: + self._on_status("Resume: locating pending rows...") + self._validate_source_plan(planner_connection) + self._materialize_pending_work(planner_connection) + planner_connection.commit() + return _ResumePlan(planner_path=planner_path) + except BaseException: + cleanup_planner_path = True + raise + finally: + if planner_connection is not None: + planner_connection.close() + if cleanup_planner_path: + planner_path.unlink(missing_ok=True) + + @staticmethod + def iter_rows( + resume_plan: _ResumePlan, *, input_jsonl: str + ) -> Iterator[tuple[_SourceRow, int, CheckpointKey]]: + """Yield pending source rows in source order using the built plan.""" + + planner_connection = sqlite3.connect(resume_plan.planner_path) + source_file = open(input_jsonl, "rb") # noqa: SIM115 + try: + for row in planner_connection.execute( + """ + SELECT + source_order, + output_index, + byte_offset, + record_fingerprint, + occurrence + FROM pending_work + ORDER BY source_order + """ + ): + source_order, output_index, byte_offset, fingerprint, occurrence = row + yield ( + _load_source_row_at_offset( + source_file, + offset=int(byte_offset), + source_index=int(source_order), + ), + int(output_index), + CheckpointKey(bytes(fingerprint), int(occurrence)), + ) + finally: + source_file.close() + planner_connection.close() + + @staticmethod + def cleanup(resume_plan: _ResumePlan | None) -> None: + """Remove the ephemeral planner DB if one exists.""" + + if resume_plan is not None: + resume_plan.planner_path.unlink(missing_ok=True) + + @staticmethod + def _temp_dir() -> Path: + """Return the directory used for ephemeral resume planner databases.""" + + return Path(os.getenv("TMPDIR") or tempfile.gettempdir()) + + @classmethod + def _create_path(cls) -> Path: + planner_dir = cls._temp_dir() + planner_dir.mkdir(parents=True, exist_ok=True) + with tempfile.NamedTemporaryFile( + dir=planner_dir, + prefix=".infermesh-resume-plan.", + suffix=".sqlite", + delete=False, + ) as file_handle: + return Path(file_handle.name) + + @staticmethod + def _connect_planner_db(planner_path: Path) -> sqlite3.Connection: + connection = sqlite3.connect(planner_path) + # The planner DB is disposable scratch rebuilt on every resume, so we + # optimize for planning throughput rather than planner durability. + connection.execute("PRAGMA journal_mode=MEMORY") + connection.execute("PRAGMA synchronous=OFF") + connection.execute("PRAGMA temp_store=MEMORY") + return connection + + @staticmethod + def _initialize_db(connection: sqlite3.Connection) -> None: + connection.executescript( + """ + CREATE TABLE checkpoint_items ( + record_fingerprint BLOB NOT NULL, + occurrence INTEGER NOT NULL, + output_index INTEGER NOT NULL, + status INTEGER NOT NULL, + PRIMARY KEY (record_fingerprint, occurrence) + ); + + CREATE TABLE source_rows ( + source_order INTEGER PRIMARY KEY, + byte_offset INTEGER NOT NULL, + record_fingerprint BLOB NOT NULL + ); + """ + ) + + def _copy_checkpoint_items(self, planner_connection: sqlite3.Connection) -> None: + batch: list[tuple[bytes, int, int, int]] = [] + for row in self._checkpoint_connection.execute( + """ + SELECT record_fingerprint, occurrence, output_index, status + FROM items + """ + ): + batch.append((bytes(row[0]), int(row[1]), int(row[2]), int(row[3]))) + if len(batch) >= _ITEM_INSERT_BATCH_SIZE: + planner_connection.executemany( + """ + INSERT INTO checkpoint_items ( + record_fingerprint, + occurrence, + output_index, + status + ) + VALUES (?, ?, ?, ?) + """, + batch, + ) + batch.clear() + if batch: + planner_connection.executemany( + """ + INSERT INTO checkpoint_items ( + record_fingerprint, + occurrence, + output_index, + status + ) + VALUES (?, ?, ?, ?) + """, + batch, + ) + + def _index_source_rows(self, planner_connection: sqlite3.Connection) -> None: + batch: list[tuple[int, int, bytes]] = [] + for seen_count, (source_row, byte_offset) in enumerate( + _iter_binary_source_rows_with_offsets(self._input_jsonl), + start=1, + ): + fingerprint = _compute_source_row_fingerprint(source_row) + batch.append((source_row.source_index, byte_offset, fingerprint)) + if len(batch) >= _ITEM_INSERT_BATCH_SIZE: + planner_connection.executemany( + """ + INSERT INTO source_rows ( + source_order, + byte_offset, + record_fingerprint + ) + VALUES (?, ?, ?) + """, + batch, + ) + batch.clear() + if self._on_status is not None and seen_count % _STATUS_LOG_INTERVAL == 0: + self._on_status(f"Resume: indexed {seen_count:,} source rows...") + if batch: + planner_connection.executemany( + """ + INSERT INTO source_rows (source_order, byte_offset, record_fingerprint) + VALUES (?, ?, ?) + """, + batch, + ) + + @staticmethod + def _materialize_source_items(planner_connection: sqlite3.Connection) -> None: + # Derive duplicate occurrences on disk so million-row resumes do not + # need a Python fingerprint->count map just to align with checkpoint + # occurrence keys. + planner_connection.executescript( + """ + CREATE TABLE source_items AS + SELECT + source_order, + byte_offset, + record_fingerprint, + row_number() OVER ( + PARTITION BY record_fingerprint + ORDER BY source_order + ) - 1 AS occurrence + FROM source_rows; + + DROP TABLE source_rows; + + CREATE INDEX idx_source_items_key + ON source_items(record_fingerprint, occurrence); + + CREATE INDEX idx_checkpoint_items_status_output_index + ON checkpoint_items(status, output_index); + """ + ) + + @staticmethod + def _validate_source_plan(planner_connection: sqlite3.Connection) -> None: + # These two anti-joins enforce exact source/checkpoint equivalence + # without re-running row-by-row checkpoint lookups in Python. + source_extra = planner_connection.execute( + """ + SELECT 1 + FROM source_items AS source + LEFT JOIN checkpoint_items AS checkpoint + USING (record_fingerprint, occurrence) + WHERE checkpoint.output_index IS NULL + LIMIT 1 + """ + ).fetchone() + if source_extra is not None: + raise ValueError(_RESUME_SOURCE_MISMATCH_ERROR) + + checkpoint_extra = planner_connection.execute( + """ + SELECT 1 + FROM checkpoint_items AS checkpoint + LEFT JOIN source_items AS source + USING (record_fingerprint, occurrence) + WHERE source.source_order IS NULL + LIMIT 1 + """ + ).fetchone() + if checkpoint_extra is not None: + raise ValueError(_RESUME_SOURCE_MISMATCH_ERROR) + + @staticmethod + def _materialize_pending_work(planner_connection: sqlite3.Connection) -> None: + # Materializing pending rows once lets the scheduler jump straight to + # unfinished work instead of rewalking the settled prefix on resume. + planner_connection.executescript( + f""" + CREATE TABLE pending_work AS + SELECT + source.source_order, + source.byte_offset, + checkpoint.output_index, + checkpoint.record_fingerprint, + checkpoint.occurrence + FROM source_items AS source + INNER JOIN checkpoint_items AS checkpoint + USING (record_fingerprint, occurrence) + WHERE checkpoint.status = {_PENDING_STATUS}; + + CREATE INDEX idx_pending_work_source_order + ON pending_work(source_order); + """ + ) + + +def validate_resume_checkpoint( + output_path: Path, + checkpoint_path: Path, + *, + mapping_fingerprint: str, + prompt: str | None, + input_jsonl: str | None, + on_status: Callable[[str], Any] | None = None, +) -> _ResumePlan | None: + """Validate the resume checkpoint and optionally build a resume plan.""" + + return ResumeValidator( + output_path=output_path, + checkpoint_path=checkpoint_path, + mapping_fingerprint=mapping_fingerprint, + prompt=prompt, + input_jsonl=input_jsonl, + on_status=on_status, + ).validate() diff --git a/src/infermesh/_workflow/runtime.py b/src/infermesh/_workflow/runtime.py new file mode 100644 index 0000000..d11f88a --- /dev/null +++ b/src/infermesh/_workflow/runtime.py @@ -0,0 +1,181 @@ +"""Run setup and cleanup helpers for the generate workflow.""" + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING, Any + +from .checkpoint import ( + _checkpoint_path_for, + _FileBackedPersistenceSink, + _stage_fresh_workflow_files, +) +from .mapping import _compute_mapping_fingerprint, _load_mapper +from .models import _ResumePlan +from .prepare import PlannedResumePreparer, Preparer, SequentialPreparer +from .resume import ResumePlanner, validate_resume_checkpoint +from .source import _materialize_stdin_source, _validate_distinct_input_output_paths + +if TYPE_CHECKING: + from collections.abc import Callable + + +@dataclass(slots=True) +class _GenerateRunResources: + """Normalized state for one generate-workflow invocation.""" + + staged_stdin_path: Path | None + persistence_sink: _FileBackedPersistenceSink | None + resume_plan: _ResumePlan | None + preparer: Preparer + + def close(self) -> None: + """Release non-scheduler resources and re-raise the first cleanup error.""" + + _cleanup_generate_run_resources( + persistence_sink=self.persistence_sink, + resume_plan=self.resume_plan, + staged_stdin_path=self.staged_stdin_path, + ) + + +def _cleanup_generate_run_resources( + *, + persistence_sink: _FileBackedPersistenceSink | None, + resume_plan: _ResumePlan | None, + staged_stdin_path: Path | None, +) -> None: + """Release setup-owned resources and re-raise the first cleanup error.""" + + cleanup_error: BaseException | None = None + + try: + if persistence_sink is not None: + persistence_sink.close() + except BaseException as exc: # noqa: BLE001 + cleanup_error = exc + + try: + ResumePlanner.cleanup(resume_plan) + except BaseException as exc: # noqa: BLE001 + if cleanup_error is None: + cleanup_error = exc + + try: + if staged_stdin_path is not None: + staged_stdin_path.unlink(missing_ok=True) + except BaseException as exc: # noqa: BLE001 + if cleanup_error is None: + cleanup_error = exc + + if cleanup_error is not None: + raise cleanup_error + + +def _prepare_generate_run_resources( + *, + prompt: str | None, + input_jsonl: str | None, + output_jsonl: str | None, + checkpoint_dir: str | None, + mapper_spec: str | None, + resume: bool, + on_status: Callable[[str], Any] | None = None, +) -> _GenerateRunResources: + """Build the normalized runtime resources for one generate run.""" + + mapper = _load_mapper(mapper_spec) if mapper_spec else None + mapping_fingerprint = _compute_mapping_fingerprint( + mapper_spec=mapper_spec, + mapper=mapper, + ) + + effective_input_jsonl = input_jsonl + staged_stdin_path: Path | None = None + output_path = Path(output_jsonl) if output_jsonl else None + checkpoint_path = ( + _checkpoint_path_for(output_jsonl, checkpoint_dir=checkpoint_dir) + if output_jsonl + else None + ) + persistence_sink: _FileBackedPersistenceSink | None = None + resume_plan: _ResumePlan | None = None + + try: + if output_jsonl and prompt is None and input_jsonl is None: + # File-backed runs need a replayable source so bootstrap/resume can + # scan it again later; stdout-only runs can keep streaming stdin. + staged_stdin_path = _materialize_stdin_source() + effective_input_jsonl = str(staged_stdin_path) + + _validate_distinct_input_output_paths( + input_jsonl=effective_input_jsonl, + output_jsonl=output_jsonl, + ) + + if output_path is not None and not resume: + assert checkpoint_path is not None + if on_status is not None: + on_status("Preparing fresh workflow artifacts...") + # Bootstrap into temp artifacts first so source/bootstrap failures + # never clobber an existing output/checkpoint pair. + _stage_fresh_workflow_files( + prompt=prompt, + input_jsonl=effective_input_jsonl, + output_path=output_path, + checkpoint_path=checkpoint_path, + mapping_fingerprint=mapping_fingerprint, + ) + + if resume and output_path is not None: + assert checkpoint_path is not None + resume_plan = validate_resume_checkpoint( + output_path, + checkpoint_path, + mapping_fingerprint=mapping_fingerprint, + prompt=prompt, + input_jsonl=effective_input_jsonl, + on_status=on_status, + ) + + if output_path is not None: + assert checkpoint_path is not None + if on_status is not None: + on_status("Opening output and checkpoint files...") + persistence_sink = _FileBackedPersistenceSink( + output_path=output_path, + checkpoint_path=checkpoint_path, + ) + + if resume_plan is not None: + assert effective_input_jsonl is not None + preparer: Preparer = PlannedResumePreparer( + input_jsonl=effective_input_jsonl, + resume_plan=resume_plan, + mapper=mapper, + ) + else: + preparer = SequentialPreparer( + prompt=prompt, + input_jsonl=effective_input_jsonl, + resume=resume, + checkpoint_path=checkpoint_path, + mapper=mapper, + ) + + return _GenerateRunResources( + staged_stdin_path=staged_stdin_path, + persistence_sink=persistence_sink, + resume_plan=resume_plan, + preparer=preparer, + ) + except BaseException: + # Setup failed before the scheduler took ownership, so reuse the shared + # cleanup helper here instead of fabricating a dummy resource wrapper. + _cleanup_generate_run_resources( + persistence_sink=persistence_sink, + resume_plan=resume_plan, + staged_stdin_path=staged_stdin_path, + ) + raise diff --git a/src/infermesh/_workflow/source.py b/src/infermesh/_workflow/source.py new file mode 100644 index 0000000..eca7d0f --- /dev/null +++ b/src/infermesh/_workflow/source.py @@ -0,0 +1,193 @@ +"""Source parsing and fingerprinting helpers for the workflow engine.""" + +from __future__ import annotations + +import contextlib +import hashlib +import json +import sys +import tempfile +from collections.abc import Iterator +from pathlib import Path +from typing import IO + +from .models import CheckpointKey, _SourceRow + + +def _compute_record_fingerprint(raw_record: dict[str, object]) -> bytes: + """Return a stable SHA-256 digest of the canonical JSON representation.""" + + canonical = json.dumps( + raw_record, sort_keys=True, separators=(",", ":"), ensure_ascii=False + ) + return hashlib.sha256(canonical.encode("utf-8")).digest() + + +def _compute_parse_error_fingerprint(raw_line: str) -> bytes: + """Return a stable fingerprint for one malformed JSONL source line.""" + + return hashlib.sha256(f"__parse_error__{raw_line}".encode()).digest() + + +def _parse_source_line(*, source_index: int, stripped: str) -> _SourceRow: + """Parse one non-empty source line into a workflow source row.""" + + try: + record = json.loads(stripped) + except json.JSONDecodeError as exc: + return _SourceRow( + source_index=source_index, + raw_line=stripped, + raw_record=None, + error=exc, + ) + if not isinstance(record, dict): + return _SourceRow( + source_index=source_index, + raw_line=stripped, + raw_record=None, + error=ValueError("Generation rows must be JSON objects."), + ) + return _SourceRow( + source_index=source_index, + raw_line=stripped, + raw_record=record, + error=None, + ) + + +def _iter_source_rows( + *, prompt: str | None, input_jsonl: str | None +) -> Iterator[_SourceRow]: + """Yield source rows one line at a time.""" + + if prompt is not None: + yield _SourceRow( + source_index=0, + raw_line=prompt, + raw_record={"prompt": prompt}, + error=None, + ) + return + + ctx = ( + open(input_jsonl, encoding="utf-8") # noqa: SIM115 + if input_jsonl is not None + else contextlib.nullcontext(sys.stdin) + ) + with ctx as source: + index = 0 + for raw_line in source: + stripped = raw_line.strip() + if not stripped: + continue + yield _parse_source_line(source_index=index, stripped=stripped) + index += 1 + + +def _iter_binary_source_rows_with_offsets( + input_jsonl: str, +) -> Iterator[tuple[_SourceRow, int]]: + """Yield file-backed source rows alongside their byte offsets.""" + + with open(input_jsonl, "rb") as source: + index = 0 + while True: + offset = source.tell() + raw_line = source.readline() + if not raw_line: + return + stripped_bytes = raw_line.strip() + if not stripped_bytes: + continue + yield ( + _parse_source_line( + source_index=index, + stripped=stripped_bytes.decode("utf-8"), + ), + offset, + ) + index += 1 + + +def _load_source_row_at_offset( + source_file: IO[bytes] | None, *, offset: int, source_index: int +) -> _SourceRow: + """Seek to ``offset`` and parse one source row from a binary JSONL file.""" + + if source_file is None: + raise RuntimeError("Resume planner requires an open source file.") + source_file.seek(offset) + raw_line = source_file.readline() + if not raw_line: + raise RuntimeError("Resume planner source offset points past EOF.") + stripped = raw_line.strip().decode("utf-8") + if not stripped: + raise RuntimeError("Resume planner source offset points to a blank line.") + return _parse_source_line(source_index=source_index, stripped=stripped) + + +def _compute_source_row_fingerprint(source_row: _SourceRow) -> bytes: + """Return the checkpoint fingerprint for one parsed source row.""" + + if source_row.raw_record is not None: + return _compute_record_fingerprint(source_row.raw_record) + return _compute_parse_error_fingerprint(source_row.raw_line) + + +def _materialize_stdin_source() -> Path: + """Copy stdin to a temporary JSONL file so file-backed runs can replay it.""" + + with tempfile.NamedTemporaryFile( + "w", + encoding="utf-8", + suffix=".jsonl", + delete=False, + ) as file_handle: + for raw_line in sys.stdin: + file_handle.write(raw_line) + return Path(file_handle.name) + + +def _resume_key_for_source_row( + source_row: _SourceRow, fingerprint_counts: dict[bytes, int] +) -> CheckpointKey: + """Return the occurrence-aware resume key for one source row.""" + + fingerprint = _compute_source_row_fingerprint(source_row) + occurrence = fingerprint_counts.get(fingerprint, 0) + fingerprint_counts[fingerprint] = occurrence + 1 + return CheckpointKey(record_fingerprint=fingerprint, occurrence=occurrence) + + +def _iter_source_rows_with_keys( + *, prompt: str | None, input_jsonl: str | None +) -> Iterator[tuple[_SourceRow, CheckpointKey]]: + """Yield ``(source_row, checkpoint_key)`` pairs with occurrence-aware keys.""" + + fingerprint_counts: dict[bytes, int] = {} + for source_row in _iter_source_rows(prompt=prompt, input_jsonl=input_jsonl): + yield source_row, _resume_key_for_source_row(source_row, fingerprint_counts) + + +def _paths_reference_same_file(input_path: Path, output_path: Path) -> bool: + """Return whether two paths resolve to the same file target.""" + + if input_path.exists() and output_path.exists(): + try: + if input_path.samefile(output_path): + return True + except OSError: + pass + return input_path.resolve(strict=False) == output_path.resolve(strict=False) + + +def _validate_distinct_input_output_paths( + *, input_jsonl: str | None, output_jsonl: str | None +) -> None: + """Reject file-backed runs that reuse the same path for input and output.""" + + if input_jsonl is None or output_jsonl is None: + return + if _paths_reference_same_file(Path(input_jsonl), Path(output_jsonl)): + raise ValueError("--input-jsonl and --output-jsonl must be different files.") diff --git a/src/infermesh/cli.py b/src/infermesh/cli.py index 79bfc7c..5c1ed9b 100644 --- a/src/infermesh/cli.py +++ b/src/infermesh/cli.py @@ -3,9 +3,10 @@ from __future__ import annotations import argparse -import json +import os import sys -from pathlib import Path +from collections.abc import Iterator +from contextlib import contextmanager from typing import Any from dotenv import load_dotenv @@ -25,24 +26,17 @@ from infermesh._cli_support import ( ClientConfig, _add_connection_args, + _client_config_from_args, _load_embed_texts, _load_generation_rows, _load_transcription_paths, - _maybe_parse_json, _token_usage_to_dict, _write_jsonl, ) -from infermesh._cli_support import ( - _build_client as _support_build_client, -) -from infermesh._cli_support import ( - _client_config_from_args as _support_client_config_from_args, -) +from infermesh._cli_support import _build_client as _support_build_client from infermesh._utils import batched_cycle +from infermesh._workflow import run_generate_workflow from infermesh.client import LMClient -from infermesh.types import EndpointType - -_client_config_from_args = _support_client_config_from_args def main(argv: list[str] | None = None) -> int: @@ -60,19 +54,27 @@ def main(argv: list[str] | None = None) -> int: def _build_client( - config: ClientConfig, - *, - max_parallel_requests: int | None = None, + config: ClientConfig, *, max_parallel_requests: int | None = None ) -> LMClient: """Build an ``LMClient`` instance for CLI commands.""" return _support_build_client( - config, - max_parallel_requests=max_parallel_requests, - client_cls=LMClient, + config, max_parallel_requests=max_parallel_requests, client_cls=LMClient ) +@contextmanager +def _managed_client( + config: ClientConfig, *, max_parallel_requests: int | None = None +) -> Iterator[LMClient]: + """Build a client and guarantee ``close()`` on exit regardless of errors.""" + client = _build_client(config, max_parallel_requests=max_parallel_requests) + try: + yield client + finally: + client.close() + + def _build_parser() -> argparse.ArgumentParser: """Build the CLI parser.""" @@ -112,8 +114,14 @@ def _add_generate_parser(subparsers: Any) -> None: ), ) generate_parser.add_argument( - "--output-jsonl", - help="Write one result object per input row.", + "--output-jsonl", help="Write one result object per input row." + ) + generate_parser.add_argument( + "--checkpoint-dir", + help=( + "Optional directory for the checkpoint file. Defaults to the output " + "directory, or INFERMESH_CHECKPOINT_DIR when that env var is set." + ), ) generate_parser.add_argument( "--endpoint", @@ -129,8 +137,17 @@ def _add_generate_parser(subparsers: Any) -> None: "--resume", action="store_true", help=( - "Resume a previous run by reading completed _index values from " - "--output-jsonl and appending the remaining rows." + "Resume a previous run by reading the checkpoint file " + "*.checkpoint.sqlite and appending the unsettled rows." + ), + ) + generate_parser.add_argument( + "--mapper", + metavar="MODULE:FUNC", + help=( + "Import path of a mapper function: 'package.module:function'. " + "The function receives a raw source record (dict) and must return " + "a dict with at least an 'input' key." ), ) generate_parser.set_defaults(handler=_handle_generate) @@ -324,350 +341,195 @@ def _add_bench_embed_parser(bench_subparsers: Any) -> None: bench_embed_parser.set_defaults(handler=_handle_bench_embed) -def _load_completed_generation_indices(output_jsonl: str) -> set[int]: - """Read completed ``_index`` values from an existing output JSONL file.""" - - completed: set[int] = set() - if not Path(output_jsonl).exists(): - return completed - - with open(output_jsonl, encoding="utf-8") as file_handle: - for raw_line in file_handle: - stripped_line = raw_line.strip() - if not stripped_line: - continue - try: - index = json.loads(stripped_line).get("_index") - except json.JSONDecodeError: - continue - if isinstance(index, int): - completed.add(index) - return completed - - -def _build_pending_generation_inputs( - all_rows: list[dict[str, Any]], - done: set[int], -) -> list[tuple[int, Any]]: - """Return the generation inputs that still need to run.""" - - pending: list[tuple[int, Any]] = [] - for orig_idx, row in enumerate(all_rows): - if orig_idx in done: - continue - input_data = ( - row.get("responses_input") or row.get("messages") or row.get("prompt") - ) - if input_data is None: - raise ValueError( - "Generation rows require 'prompt', 'messages', or 'responses_input'." - ) - pending.append((orig_idx, input_data)) - return pending - - -def _build_generation_record( - orig_idx: int, - result: Any, - error: BaseException | None, - *, - parse_json: bool, -) -> dict[str, Any]: - """Convert one generation result into its JSONL output shape.""" - - if result is None: - return { - "_index": orig_idx, - "output_text": None, - "output_parsed": None, - "token_usage": None, - "request_id": None, - "finish_reason": None, - "error": str(error) if error else "unknown error", - } - return { - "_index": orig_idx, - "output_text": result.output_text, - "output_parsed": _maybe_parse_json(result.output_text) if parse_json else None, - "token_usage": _token_usage_to_dict(result), - "request_id": result.request_id, - "finish_reason": result.finish_reason, - "error": None, - } - - -def _write_generation_batch_to_file( - client: LMClient, - *, - inputs: list[Any], - pending: list[tuple[int, Any]], - output_jsonl: str, - resume: bool, - endpoint: EndpointType, - parse_json: bool, - on_progress: Any, -) -> None: - """Stream generation results to disk as each request completes.""" - - file_mode = "a" if resume else "w" - with open(output_jsonl, file_mode, encoding="utf-8") as out_file: - - def on_result( - batch_idx: int, - result: Any, - error: BaseException | None, - ) -> None: - orig_idx = pending[batch_idx][0] - record = _build_generation_record( - orig_idx, - result, - error, - parse_json=parse_json, - ) - out_file.write(json.dumps(record) + "\n") - out_file.flush() - - client.generate_batch( - inputs, - endpoint=endpoint, - on_progress=on_progress, - on_result=on_result, - ) +def _handle_generate(args: argparse.Namespace) -> int: + """Handle the ``generate`` subcommand.""" + resume = bool(getattr(args, "resume", False)) + if resume and not args.output_jsonl: + sys.stderr.write("error: --resume requires --output-jsonl\n") + return 1 -def _write_generation_batch_to_stdout( - client: LMClient, - *, - inputs: list[Any], - pending: list[tuple[int, Any]], - endpoint: EndpointType, - parse_json: bool, - on_progress: Any, -) -> None: - """Collect generation results and write them to stdout together.""" - - batch = client.generate_batch( - inputs, - endpoint=endpoint, - on_progress=on_progress, - ) - records = [] - for batch_idx, result in enumerate(batch): - error = batch.errors[batch_idx] if batch.errors else None - records.append( - _build_generation_record( - pending[batch_idx][0], - result, - error, - parse_json=parse_json, - ) - ) - _write_jsonl(records, None) + config = _client_config_from_args(args) + window_size = config.max_parallel_requests or 128 + checkpoint_dir = args.checkpoint_dir or os.getenv("INFERMESH_CHECKPOINT_DIR") + # Use an open-ended progress bar for file-backed runs (total unknown up + # front). For a single --prompt to stdout, suppress it entirely. + disable_bar = args.output_jsonl is None and args.prompt is not None + try: + with ( + _managed_client(config) as client, + tqdm( + desc="Generating", unit="req", disable=disable_bar, file=sys.stderr + ) as bar, + ): -def _run_generate_command( - client: LMClient, - args: argparse.Namespace, - *, - endpoint: EndpointType, -) -> int: - """Execute the core generation workflow after client construction.""" + def report_status(message: str) -> None: + if disable_bar: + sys.stderr.write(message + "\n") + else: + bar.write(message) - resume = bool(getattr(args, "resume", False)) - all_rows = _load_generation_rows(prompt=args.prompt, input_jsonl=args.input_jsonl) - done = ( - _load_completed_generation_indices(args.output_jsonl) - if resume and args.output_jsonl - else set() - ) - if done: - sys.stderr.write(f"Resuming: skipping {len(done)} already-completed row(s).\n") - - pending = _build_pending_generation_inputs(all_rows, done) - if not pending: - sys.stderr.write("Nothing to do — all rows already completed.\n") - return 0 - - inputs = [input_data for _, input_data in pending] - parse_json = bool(getattr(args, "parse_json", False)) - with tqdm( - total=len(pending), - desc="Generating", - unit="req", - disable=(len(pending) <= 1), - file=sys.stderr, - ) as progress_bar: - - def on_progress(_done: int, _total: int) -> None: - progress_bar.update(1) - - if args.output_jsonl: - _write_generation_batch_to_file( + run_generate_workflow( client, - inputs=inputs, - pending=pending, + prompt=args.prompt, + input_jsonl=args.input_jsonl, output_jsonl=args.output_jsonl, + checkpoint_dir=checkpoint_dir, + mapper_spec=getattr(args, "mapper", None), resume=resume, - endpoint=endpoint, - parse_json=parse_json, - on_progress=on_progress, - ) - else: - _write_generation_batch_to_stdout( - client, - inputs=inputs, - pending=pending, - endpoint=endpoint, - parse_json=parse_json, - on_progress=on_progress, + endpoint=config.endpoint, + window_size=window_size, + parse_json=bool(getattr(args, "parse_json", False)), + on_progress=lambda: bar.update(), # noqa: PLW0108 + on_status=report_status, ) - return 0 - - -def _handle_generate(args: argparse.Namespace) -> int: - """Handle the ``generate`` subcommand.""" - - resume = getattr(args, "resume", False) - if resume and not args.output_jsonl: - sys.stderr.write("error: --resume requires --output-jsonl\n") + except Exception as exc: + sys.stderr.write(f"error: {exc}\n") return 1 - - config = _client_config_from_args(args) - client = _build_client(config) - try: - return _run_generate_command(client, args, endpoint=config.endpoint) - finally: - client.close() + return 0 def _handle_embed(args: argparse.Namespace) -> int: """Handle the ``embed`` subcommand.""" - client = _build_client(_client_config_from_args(args)) try: - texts = _load_embed_texts(text=args.text, input_jsonl=args.input_jsonl) - batch = client.embed_batch(texts) - rows: list[dict[str, Any]] = [] - for index, result in enumerate(batch): - if result is None: - error = batch.errors[index] if batch.errors else None - rows.append( - { - "embedding": None, - "dimensions": None, - "request_id": None, - "token_usage": None, - "error": str(error) if error else "unknown error", - } - ) - else: - rows.append( - { - "embedding": None if args.no_vectors else result.embedding, - "dimensions": len(result.embedding), - "request_id": result.request_id, - "token_usage": _token_usage_to_dict(result), - "error": None, - } - ) - _write_jsonl(rows, args.output_jsonl) - finally: - client.close() + with _managed_client(_client_config_from_args(args)) as client: + texts = _load_embed_texts(text=args.text, input_jsonl=args.input_jsonl) + batch = client.embed_batch(texts) + rows: list[dict[str, Any]] = [] + for index, result in enumerate(batch): + if result is None: + error = batch.errors[index] if batch.errors else None + rows.append( + { + "embedding": None, + "dimensions": None, + "request_id": None, + "token_usage": None, + "error": str(error) if error else "unknown error", + } + ) + else: + rows.append( + { + "embedding": None if args.no_vectors else result.embedding, + "dimensions": len(result.embedding), + "request_id": result.request_id, + "token_usage": _token_usage_to_dict(result), + "error": None, + } + ) + _write_jsonl(rows, args.output_jsonl) + except (ImportError, ValueError) as exc: + sys.stderr.write(f"error: {exc}\n") + return 1 return 0 def _handle_transcribe(args: argparse.Namespace) -> int: """Handle the ``transcribe`` subcommand.""" - client = _build_client(_client_config_from_args(args)) try: - paths = _load_transcription_paths(path=args.path, input_jsonl=args.input_jsonl) - rows: list[dict[str, Any]] = [] - for path in tqdm( - paths, - desc="Transcribing", - unit="file", - disable=(len(paths) <= 1), - file=sys.stderr, - ): - result = client.transcribe(path) - rows.append( - { - "text": result.text, - "duration_s": result.duration_s, - "language": result.language, - "request_id": result.request_id, - "error": None, - } + with _managed_client(_client_config_from_args(args)) as client: + paths = _load_transcription_paths( + path=args.path, input_jsonl=args.input_jsonl ) - _write_jsonl(rows, args.output_jsonl) - finally: - client.close() + rows: list[dict[str, Any]] = [] + for path in tqdm( + paths, + desc="Transcribing", + unit="file", + disable=(len(paths) <= 1), + file=sys.stderr, + ): + result = client.transcribe(path) + rows.append( + { + "text": result.text, + "duration_s": result.duration_s, + "language": result.language, + "request_id": result.request_id, + "error": None, + } + ) + _write_jsonl(rows, args.output_jsonl) + except (ImportError, ValueError) as exc: + sys.stderr.write(f"error: {exc}\n") + return 1 return 0 def _handle_bench_generate(args: argparse.Namespace) -> int: """Handle ``bench generate``.""" - config = _client_config_from_args(args) - rows = _load_generation_rows(prompt=args.prompt, input_jsonl=args.input_jsonl) - input_items = [ - row.get("responses_input") or row.get("messages") or row.get("prompt") - for row in rows - ] - if not input_items: - raise ValueError("Generation benchmark requires input rows or --prompt.") - - concurrency_levels, recommend = _resolve_sweep_levels( - single=getattr(args, "concurrency", None), - maximum=getattr(args, "max_concurrency", None), - default=DEFAULT_SWEEP, - sweep_fn=_build_concurrency_sweep, - ) - summary = _run_benchmark( - task_name="generate", - warmup=args.warmup, - requests=args.requests, - duration_s=getattr(args, "duration", None), - concurrency_sweep=concurrency_levels, - recommend=recommend, - workload_factory=lambda concurrency: _build_client( - config, - max_parallel_requests=concurrency, - ), - runner=lambda client, batch, *, on_progress=None: client.generate_batch( - batch, - endpoint=config.endpoint, - on_progress=on_progress, - ), - workload=batched_cycle(input_items, max(args.requests, args.warmup)), - ) - _write_generate_summary(summary, args.output_json) + try: + config = _client_config_from_args(args) + rows = _load_generation_rows(prompt=args.prompt, input_jsonl=args.input_jsonl) + input_items = [ + row.get("responses_input") or row.get("messages") or row.get("prompt") + for row in rows + ] + if not input_items: + raise ValueError("Generation benchmark requires input rows or --prompt.") + + concurrency_levels, recommend = _resolve_sweep_levels( + single=getattr(args, "concurrency", None), + maximum=getattr(args, "max_concurrency", None), + default=DEFAULT_SWEEP, + sweep_fn=_build_concurrency_sweep, + ) + summary = _run_benchmark( + task_name="generate", + warmup=args.warmup, + requests=args.requests, + duration_s=getattr(args, "duration", None), + concurrency_sweep=concurrency_levels, + recommend=recommend, + workload_factory=lambda concurrency: _build_client( + config, + max_parallel_requests=concurrency, + ), + runner=lambda client, batch, *, on_progress=None: client.generate_batch( + batch, + endpoint=config.endpoint, + on_progress=on_progress, + ), + workload=batched_cycle(input_items, max(args.requests, args.warmup)), + ) + _write_generate_summary(summary, args.output_json) + except (ImportError, ValueError) as exc: + sys.stderr.write(f"error: {exc}\n") + return 1 return 0 def _handle_bench_embed(args: argparse.Namespace) -> int: """Handle ``bench embed``.""" - config = _client_config_from_args(args) - texts = _load_embed_texts(text=args.text, input_jsonl=args.input_jsonl) - if not texts: - raise ValueError("Embedding benchmark requires input rows or --text.") - - batch_sizes, recommend = _resolve_sweep_levels( - single=getattr(args, "batch_size", None), - maximum=getattr(args, "max_batch_size", None), - default=DEFAULT_EMBED_BATCH_SIZES, - sweep_fn=_build_embed_batch_sweep, - ) - summary = _run_embed_batch_benchmark( - texts=texts, - config=config, - warmup=args.warmup, - requests=args.requests, - batch_size_sweep=batch_sizes, - recommend=recommend, - build_client=_build_client, - ) - _write_embed_summary(summary, args.output_json) + try: + config = _client_config_from_args(args) + texts = _load_embed_texts(text=args.text, input_jsonl=args.input_jsonl) + if not texts: + raise ValueError("Embedding benchmark requires input rows or --text.") + + batch_sizes, recommend = _resolve_sweep_levels( + single=getattr(args, "batch_size", None), + maximum=getattr(args, "max_batch_size", None), + default=DEFAULT_EMBED_BATCH_SIZES, + sweep_fn=_build_embed_batch_sweep, + ) + summary = _run_embed_batch_benchmark( + texts=texts, + config=config, + warmup=args.warmup, + requests=args.requests, + batch_size_sweep=batch_sizes, + recommend=recommend, + build_client=_build_client, + ) + _write_embed_summary(summary, args.output_json) + except (ImportError, ValueError) as exc: + sys.stderr.write(f"error: {exc}\n") + return 1 return 0 diff --git a/src/infermesh/client.py b/src/infermesh/client.py index 31d1ed0..1d353bf 100644 --- a/src/infermesh/client.py +++ b/src/infermesh/client.py @@ -384,7 +384,7 @@ def generate( >>> result.output_text """ - return self._sync_runner.run( + return self._run_sync( self.agenerate( input_data, endpoint=endpoint, @@ -454,7 +454,7 @@ def generate_batch( >>> [item.output_text if item else None for item in batch.results] """ - return self._sync_runner.run( + return self._run_sync( self.agenerate_batch( input_batch, endpoint=endpoint, @@ -604,7 +604,7 @@ def embed(self, input_data: str, **kwargs: Any) -> EmbeddingResult: >>> len(result.embedding) """ - return self._sync_runner.run(self.aembed(input_data, **kwargs)) + return self._run_sync(self.aembed(input_data, **kwargs)) def embed_batch( self, @@ -657,7 +657,7 @@ def embed_batch( >>> [len(item.embedding) if item else None for item in batch.results] """ - return self._sync_runner.run( + return self._run_sync( self.aembed_batch( input_batch, micro_batch_size=micro_batch_size, @@ -804,7 +804,7 @@ def transcribe( >>> result.text """ - return self._sync_runner.run( + return self._run_sync( self.atranscribe( input_data, max_transcription_bytes=max_transcription_bytes, @@ -866,7 +866,7 @@ def transcribe_batch( >>> [item.text if item else None for item in batch.results] """ - return self._sync_runner.run( + return self._run_sync( self.atranscribe_batch( input_batch, max_transcription_bytes=max_transcription_bytes, diff --git a/src/infermesh/sync_runner.py b/src/infermesh/sync_runner.py index 0e1a832..5f30b51 100644 --- a/src/infermesh/sync_runner.py +++ b/src/infermesh/sync_runner.py @@ -14,7 +14,13 @@ import contextlib import threading from collections.abc import Coroutine -from concurrent.futures import Future +from concurrent.futures import ( + CancelledError, + Future, +) +from concurrent.futures import ( + TimeoutError as FutureTimeoutError, +) from typing import TypeVar T = TypeVar("T") @@ -96,8 +102,89 @@ def run(self, coroutine: Coroutine[object, object, T]) -> T: >>> runner.run(add(1, 2)) 3 """ - future: Future[T] = asyncio.run_coroutine_threadsafe(coroutine, self._loop) - return future.result() + future: Future[T] = Future() + task_future: Future[asyncio.Task[T]] = Future() + + def start_task() -> None: + try: + task = self._loop.create_task(coroutine) + except BaseException as exc: # noqa: BLE001 + coroutine.close() + if not task_future.done(): + task_future.set_exception(exc) + if not future.done(): + future.set_exception(exc) + return + + if not task_future.done(): + task_future.set_result(task) + + def copy_result(completed_task: asyncio.Task[T]) -> None: + try: + result = completed_task.result() + except asyncio.CancelledError: + if not future.done(): + future.cancel() + except BaseException as exc: # noqa: BLE001 + if not future.done(): + future.set_exception(exc) + else: + if not future.done(): + future.set_result(result) + + task.add_done_callback(copy_result) + + try: + self._loop.call_soon_threadsafe(start_task) + except BaseException: + coroutine.close() + raise + + try: + self._wait_for_future(task_future) + return self._wait_for_future(future) + except KeyboardInterrupt: + # The loop owns the task lifecycle, so cancellation must also + # happen on that thread. We wait for the task to finish unwinding + # before re-raising so callers do not tear down shared resources + # underneath an in-flight coroutine. + self._loop.call_soon_threadsafe( + self._cancel_task_after_handoff, + task_future, + ) + self._wait_for_cancellation_cleanup(future) + raise + + @staticmethod + def _cancel_task_after_handoff(task_future: Future[asyncio.Task[T]]) -> None: + """Cancel the loop-owned task after the handoff future completes.""" + + if not task_future.done(): + return + with contextlib.suppress(BaseException): + task_future.result().cancel() + + def _wait_for_cancellation_cleanup(self, future: Future[T]) -> None: + """Block until the cancelled task finishes unwinding on the loop.""" + + while True: + try: + self._wait_for_future(future, timeout=0.1) + return + except FutureTimeoutError: + continue + except KeyboardInterrupt: + continue + except CancelledError: + return + except Exception: + return + + @staticmethod + def _wait_for_future(future: Future[T], timeout: float | None = None) -> T: + """Wait for a cross-thread future result.""" + + return future.result(timeout=timeout) def close(self) -> None: """Stop the background event loop and join the worker thread. diff --git a/tests/fakes.py b/tests/fakes.py index b8fa962..e6c4c1c 100644 --- a/tests/fakes.py +++ b/tests/fakes.py @@ -7,7 +7,22 @@ from pydantic import BaseModel from infermesh import cli +from infermesh._workflow.checkpoint import ( + _STATUS_NAMES, + _STATUS_VALUES, + _checkpoint_path_for, + _connect_checkpoint_db, + _connect_checkpoint_db_read_only, + _initialize_checkpoint_db, +) +from infermesh._workflow.mapping import _compute_mapping_fingerprint +from infermesh._workflow.models import CheckpointKey +from infermesh._workflow.source import ( + _compute_parse_error_fingerprint, + _compute_record_fingerprint, +) from infermesh.client import LMClient +from infermesh.sync_runner import SyncRunner from infermesh.types import ( BatchResult, EmbeddingResult, @@ -186,18 +201,44 @@ async def acompletion(self, **kwargs: Any) -> dict[str, Any]: class FakeCLIClient: def __init__(self, **kwargs: Any) -> None: self.kwargs = kwargs + self._sync_runner = SyncRunner() self.closed = False self.embed_batch_sizes: list[int] = [] self.embed_micro_batch_sizes: list[int | None] = [] + self.generate_inputs: list[Any] = [] + + def _run_sync(self, coroutine: Any) -> Any: + return self._sync_runner.run(coroutine) def close(self) -> None: + self._sync_runner.close() self.closed = True def generate(self, input_data: Any, **kwargs: Any) -> GenerationResult: + self.generate_inputs.append(input_data) + return GenerationResult( + model_id="test-model", + output_text=f"generated:{input_data}", + request_id="req-1", + ) + + async def agenerate(self, input_data: Any, **kwargs: Any) -> GenerationResult: + self.generate_inputs.append(input_data) return GenerationResult( model_id="test-model", output_text=f"generated:{input_data}", request_id="req-1", + token_usage=TokenUsage( + prompt_tokens=10, + completion_tokens=5, + total_tokens=15, + ), + metrics=RequestMetrics( + queue_wait_s=0.01, + service_time_s=0.02, + end_to_end_s=0.03, + deployment="replica-1", + ), ) def generate_batch( @@ -266,6 +307,122 @@ def transcribe(self, path: str, **kwargs: Any) -> TranscriptionResult: ) +def load_resume_state( + checkpoint_path: Path, +) -> dict[CheckpointKey, dict[str, Any]]: + """Load checkpoint items for test assertions.""" + + if not checkpoint_path.exists(): + return {} + + connection = _connect_checkpoint_db_read_only(checkpoint_path) + try: + state: dict[CheckpointKey, dict[str, Any]] = {} + for row in connection.execute( + """ + SELECT record_fingerprint, occurrence, output_index, status, error + FROM items + """ + ): + status = int(row[3]) + state[CheckpointKey(bytes(row[0]), int(row[1]))] = { + "_index": int(row[2]), + "status": _STATUS_NAMES[status], + "error": row[4], + } + return state + finally: + connection.close() + + +def checkpoint_item_for_record( + record: dict[str, Any], + *, + occurrence: int, + index: int, + status: str, + error: str | None = None, +) -> dict[str, Any]: + """Build a checkpoint item payload for a well-formed source record.""" + + return { + "record_fingerprint": _compute_record_fingerprint(record), + "occurrence": occurrence, + "_index": index, + "status": status, + "error": error, + } + + +def checkpoint_item_for_parse_error( + raw_line: str, + *, + occurrence: int, + index: int, + status: str, + error: str | None = None, +) -> dict[str, Any]: + """Build a checkpoint item payload for a malformed source line.""" + + return { + "record_fingerprint": _compute_parse_error_fingerprint(raw_line), + "occurrence": occurrence, + "_index": index, + "status": status, + "error": error, + } + + +def write_checkpoint_db( + output_path: Path, + items: list[dict[str, Any]], + *, + mapping_fingerprint: str | None = None, + checkpoint_dir: str | None = None, +) -> Path: + """Create a checkpoint DB for tests and populate it with explicit items.""" + + checkpoint_path = _checkpoint_path_for( + str(output_path), checkpoint_dir=checkpoint_dir + ) + connection = _connect_checkpoint_db(checkpoint_path) + try: + resolved_mapping_fingerprint = ( + mapping_fingerprint + or _compute_mapping_fingerprint( + mapper_spec=None, + mapper=None, + ) + ) + _initialize_checkpoint_db(connection, resolved_mapping_fingerprint) + connection.executemany( + """ + INSERT INTO items ( + record_fingerprint, + occurrence, + output_index, + status, + error + ) + VALUES (?, ?, ?, ?, ?) + """, + [ + ( + item["record_fingerprint"], + item["occurrence"], + item["_index"], + _STATUS_VALUES[item["status"]], + item["error"], + ) + for item in items + ], + ) + connection.commit() + finally: + connection.close() + return checkpoint_path + + @pytest.fixture def fake_client(monkeypatch: pytest.MonkeyPatch) -> LMClient: monkeypatch.setattr(LMClient, "_create_litellm_module", lambda self: FakeLiteLLM()) diff --git a/tests/test_cli.py b/tests/test_cli.py index b47227c..aac22cd 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -8,7 +8,9 @@ import pytest from infermesh import cli -from tests.fakes import FakeCLIClient +from infermesh._cli_support import _client_config_from_args +from infermesh._workflow.checkpoint import _checkpoint_path_for +from tests.fakes import FakeCLIClient, checkpoint_item_for_record, write_checkpoint_db _BASE_ARGS = [ "generate", @@ -19,10 +21,8 @@ ] -def test_embed_with_text( - capsys: pytest.CaptureFixture[str], - fake_client_builder: list[FakeCLIClient], -) -> None: +@pytest.mark.usefixtures("fake_client_builder") +def test_embed_with_text(capsys: pytest.CaptureFixture[str]) -> None: exit_code = cli.main( [ "embed", @@ -65,46 +65,75 @@ def fake_client_ctor(**kwargs: Any) -> FakeCLIClient: "hello", ] ) - cli._build_client(cli._client_config_from_args(args)) + cli._build_client(_client_config_from_args(args)) assert captured["deployments"] is not None assert len(captured["deployments"]) == 2 +@pytest.mark.usefixtures("fake_client_builder") def test_generate_rejects_malformed_rows( tmp_path: Path, - fake_client_builder: list[FakeCLIClient], + capsys: pytest.CaptureFixture[str], ) -> None: input_path = tmp_path / "bad.jsonl" input_path.write_text(json.dumps({"text": "wrong"}) + "\n", encoding="utf-8") - with pytest.raises(ValueError, match="Generation rows require"): - cli.main( - [ - "generate", - "--model", - "openai/test", - "--api-base", - "http://localhost:8000/v1", - "--input-jsonl", - str(input_path), - ] - ) - - -def test_generate_propagates_provider_failure( + exit_code = cli.main( + [ + "generate", + "--model", + "openai/test", + "--api-base", + "http://localhost:8000/v1", + "--input-jsonl", + str(input_path), + ] + ) + assert exit_code == 0 + out = capsys.readouterr().out + rows = [json.loads(line) for line in out.splitlines() if line.strip()] + assert len(rows) == 1 + assert rows[0]["_index"] == 0 + assert rows[0]["output_text"] is None + assert "Generation rows require" in rows[0]["error"] + + +def test_generate_surfaces_provider_failure_as_error_row( monkeypatch: pytest.MonkeyPatch, + capsys: pytest.CaptureFixture[str], ) -> None: class FailingClient(FakeCLIClient): - def generate_batch(self, input_batch: Any, **kwargs: Any) -> Any: + async def agenerate(self, input_data: Any, **kwargs: Any) -> Any: raise RuntimeError("boom") monkeypatch.setattr(cli, "_build_client", lambda *a, **k: FailingClient()) - with pytest.raises(RuntimeError, match="boom"): - cli.main([*_BASE_ARGS, "--prompt", "hello"]) + exit_code = cli.main([*_BASE_ARGS, "--prompt", "hello"]) + + assert exit_code == 0 + row = json.loads(capsys.readouterr().out.strip()) + assert row["error"] == "boom" + assert row["_index"] == 0 +def test_generate_surfaces_workflow_failure_cleanly( + monkeypatch: pytest.MonkeyPatch, + capsys: pytest.CaptureFixture[str], +) -> None: + monkeypatch.setattr(cli, "_build_client", lambda *a, **k: FakeCLIClient()) + + def blow_up(*args: Any, **kwargs: Any) -> None: + raise RuntimeError("disk full") + + monkeypatch.setattr(cli, "run_generate_workflow", blow_up) + + exit_code = cli.main([*_BASE_ARGS, "--prompt", "hello"]) + + assert exit_code == 1 + assert "error: disk full" in capsys.readouterr().err + + +@pytest.mark.usefixtures("fake_client_builder") def test_generate_index_field_always_present( tmp_path: Path, - fake_client_builder: list[FakeCLIClient], ) -> None: input_path = tmp_path / "in.jsonl" output_path = tmp_path / "out.jsonl" @@ -122,29 +151,24 @@ def test_generate_index_field_always_present( ] ) assert exit_code == 0 - rows = [ - json.loads(line) - for line in output_path.read_text(encoding="utf-8").splitlines() - ] + rows = sorted( + ( + json.loads(line) + for line in output_path.read_text(encoding="utf-8").splitlines() + ), + key=lambda row: row["_index"], + ) assert rows[0]["_index"] == 0 assert rows[1]["_index"] == 1 -def test_generate_resume_skips_completed_rows( - tmp_path: Path, - fake_client_builder: list[FakeCLIClient], -) -> None: +@pytest.mark.usefixtures("fake_client_builder") +def test_generate_no_resume_overwrites_existing_file(tmp_path: Path) -> None: input_path = tmp_path / "in.jsonl" output_path = tmp_path / "out.jsonl" - input_path.write_text( - "\n".join(json.dumps({"prompt": prompt}) for prompt in ["a", "b", "c"]) + "\n", - encoding="utf-8", - ) + input_path.write_text(json.dumps({"prompt": "hello"}) + "\n", encoding="utf-8") output_path.write_text( - json.dumps({"_index": 0, "output_text": "cached-a"}) - + "\n" - + json.dumps({"_index": 2, "output_text": "cached-c"}) - + "\n", + json.dumps({"_index": 99, "output_text": "stale"}) + "\n", encoding="utf-8", ) exit_code = cli.main( @@ -154,7 +178,6 @@ def test_generate_resume_skips_completed_rows( str(input_path), "--output-jsonl", str(output_path), - "--resume", ] ) assert exit_code == 0 @@ -162,25 +185,22 @@ def test_generate_resume_skips_completed_rows( json.loads(line) for line in output_path.read_text(encoding="utf-8").splitlines() ] - new_rows = [row for row in rows if row.get("_index") == 1] - assert len(new_rows) == 1 - assert new_rows[0]["output_text"] == "generated:b" + assert len(rows) == 1 + assert rows[0]["_index"] == 0 + assert rows[0]["output_text"] == "generated:hello" -def test_generate_resume_appends_to_existing_file( +@pytest.mark.usefixtures("fake_client_builder") +def test_generate_uses_checkpoint_dir_env_var( tmp_path: Path, - fake_client_builder: list[FakeCLIClient], + monkeypatch: pytest.MonkeyPatch, ) -> None: input_path = tmp_path / "in.jsonl" output_path = tmp_path / "out.jsonl" - input_path.write_text( - json.dumps({"prompt": "a"}) + "\n" + json.dumps({"prompt": "b"}) + "\n", - encoding="utf-8", - ) - output_path.write_text( - json.dumps({"_index": 0, "output_text": "cached-a"}) + "\n", - encoding="utf-8", - ) + checkpoint_dir = tmp_path / "env-checkpoints" + input_path.write_text(json.dumps({"prompt": "hello"}) + "\n", encoding="utf-8") + monkeypatch.setenv("INFERMESH_CHECKPOINT_DIR", str(checkpoint_dir)) + exit_code = cli.main( [ *_BASE_ARGS, @@ -188,29 +208,29 @@ def test_generate_resume_appends_to_existing_file( str(input_path), "--output-jsonl", str(output_path), - "--resume", ] ) + + checkpoint_path = _checkpoint_path_for( + str(output_path), checkpoint_dir=str(checkpoint_dir) + ) assert exit_code == 0 - rows = [ - json.loads(line) - for line in output_path.read_text(encoding="utf-8").splitlines() - ] - assert {row["_index"] for row in rows} == {0, 1} - assert any(row["_index"] == 0 and row["output_text"] == "cached-a" for row in rows) + assert checkpoint_path.exists() + assert not _checkpoint_path_for(str(output_path)).exists() -def test_generate_no_resume_overwrites_existing_file( +@pytest.mark.usefixtures("fake_client_builder") +def test_generate_checkpoint_dir_flag_overrides_env_var( tmp_path: Path, - fake_client_builder: list[FakeCLIClient], + monkeypatch: pytest.MonkeyPatch, ) -> None: input_path = tmp_path / "in.jsonl" output_path = tmp_path / "out.jsonl" + env_checkpoint_dir = tmp_path / "env-checkpoints" + arg_checkpoint_dir = tmp_path / "arg-checkpoints" input_path.write_text(json.dumps({"prompt": "hello"}) + "\n", encoding="utf-8") - output_path.write_text( - json.dumps({"_index": 99, "output_text": "stale"}) + "\n", - encoding="utf-8", - ) + monkeypatch.setenv("INFERMESH_CHECKPOINT_DIR", str(env_checkpoint_dir)) + exit_code = cli.main( [ *_BASE_ARGS, @@ -218,41 +238,61 @@ def test_generate_no_resume_overwrites_existing_file( str(input_path), "--output-jsonl", str(output_path), + "--checkpoint-dir", + str(arg_checkpoint_dir), ] ) + + env_checkpoint_path = _checkpoint_path_for( + str(output_path), checkpoint_dir=str(env_checkpoint_dir) + ) + arg_checkpoint_path = _checkpoint_path_for( + str(output_path), checkpoint_dir=str(arg_checkpoint_dir) + ) assert exit_code == 0 - rows = [ - json.loads(line) - for line in output_path.read_text(encoding="utf-8").splitlines() - ] - assert len(rows) == 1 - assert rows[0]["_index"] == 0 - assert rows[0]["output_text"] == "generated:hello" + assert arg_checkpoint_path.exists() + assert not env_checkpoint_path.exists() + + +@pytest.mark.usefixtures("fake_client_builder") +def test_generate_rejects_identical_input_and_output_paths( + tmp_path: Path, + capsys: pytest.CaptureFixture[str], +) -> None: + same_path = tmp_path / "same.jsonl" + same_path.write_text(json.dumps({"prompt": "hello"}) + "\n", encoding="utf-8") + + exit_code = cli.main( + [ + *_BASE_ARGS, + "--input-jsonl", + str(same_path), + "--output-jsonl", + str(same_path), + ] + ) + + assert exit_code == 1 + assert "must be different files" in capsys.readouterr().err +@pytest.mark.usefixtures("fake_client_builder") def test_generate_resume_nothing_to_do( tmp_path: Path, capsys: pytest.CaptureFixture[str], - monkeypatch: pytest.MonkeyPatch, ) -> None: input_path = tmp_path / "in.jsonl" output_path = tmp_path / "out.jsonl" - input_path.write_text(json.dumps({"prompt": "a"}) + "\n", encoding="utf-8") + row = {"prompt": "a"} + input_path.write_text(json.dumps(row) + "\n", encoding="utf-8") output_path.write_text( json.dumps({"_index": 0, "output_text": "done"}) + "\n", encoding="utf-8" ) - - called: list[bool] = [] - - class TrackingClient(FakeCLIClient): - def generate_batch(self, input_batch: Any, **kwargs: Any) -> Any: - called.append(True) - return super().generate_batch(input_batch, **kwargs) - - def patched_build(*a: Any, **kw: Any) -> TrackingClient: - return TrackingClient() - - monkeypatch.setattr(cli, "_build_client", patched_build) + checkpoint_path = write_checkpoint_db( + output_path, + [checkpoint_item_for_record(row, occurrence=0, index=0, status="success")], + ) + assert checkpoint_path == _checkpoint_path_for(str(output_path)) exit_code = cli.main( [ *_BASE_ARGS, @@ -265,20 +305,48 @@ def patched_build(*a: Any, **kw: Any) -> TrackingClient: ) assert exit_code == 0 - assert not called assert "Nothing to do" in capsys.readouterr().err -def test_generate_resume_requires_output_jsonl( - fake_client_builder: list[FakeCLIClient], +@pytest.mark.usefixtures("fake_client_builder") +def test_generate_resume_requires_state_file( + tmp_path: Path, + capsys: pytest.CaptureFixture[str], ) -> None: + """CLI should surface strict resume validation failures cleanly.""" + input_path = tmp_path / "in.jsonl" + output_path = tmp_path / "out.jsonl" + input_path.write_text( + json.dumps({"prompt": "a"}) + "\n" + json.dumps({"prompt": "b"}) + "\n", + encoding="utf-8", + ) + output_path.write_text( + json.dumps({"_index": 0, "output_text": "cached-a"}) + "\n", + encoding="utf-8", + ) + exit_code = cli.main( + [ + *_BASE_ARGS, + "--input-jsonl", + str(input_path), + "--output-jsonl", + str(output_path), + "--resume", + ] + ) + assert exit_code == 1 + stderr = capsys.readouterr().err + assert "checkpoint file" in stderr.lower() + + +def test_generate_resume_requires_output_jsonl() -> None: exit_code = cli.main([*_BASE_ARGS, "--prompt", "hello", "--resume"]) assert exit_code == 1 +@pytest.mark.usefixtures("fake_client_builder") def test_env_file_loads_secrets( tmp_path: Path, - fake_client_builder: list[FakeCLIClient], monkeypatch: pytest.MonkeyPatch, ) -> None: env_file = tmp_path / ".env" @@ -324,7 +392,7 @@ def test_deployments_toml_rejects_top_level_api_key(tmp_path: Path) -> None: ] ) with pytest.raises(ValueError, match="plaintext secret") as excinfo: - cli._build_client(cli._client_config_from_args(args)) + cli._build_client(_client_config_from_args(args)) assert "deployments.replica.api_key" in str(excinfo.value) assert "--env-file" in str(excinfo.value) @@ -355,6 +423,55 @@ def test_deployments_toml_rejects_nested_extra_kwargs_api_key(tmp_path: Path) -> ] ) with pytest.raises(ValueError, match="plaintext secret") as excinfo: - cli._build_client(cli._client_config_from_args(args)) + cli._build_client(_client_config_from_args(args)) assert "deployments.replica.extra_kwargs.api_key" in str(excinfo.value) assert "--env-file" in str(excinfo.value) + + +def test_deployments_toml_rejects_missing_deployments_table(tmp_path: Path) -> None: + deployments_toml = tmp_path / "deployments.toml" + deployments_toml.write_text('title = "no deployments table"\n', encoding="utf-8") + parser = cli._build_parser() + args = parser.parse_args( + [ + "generate", + "--model", + "gpt-4o", + "--deployments-toml", + str(deployments_toml), + "--prompt", + "hello", + ] + ) + with pytest.raises(ValueError, match="missing a \\[deployments\\] table"): + cli._build_client(_client_config_from_args(args)) + + +def test_handle_generate_surfaces_build_client_error( + tmp_path: Path, + capsys: pytest.CaptureFixture[str], +) -> None: + """_build_client ValueError must produce a clean error message, not a traceback.""" + deployments_toml = tmp_path / "deployments.toml" + deployments_toml.write_text( + ( + "[deployments.replica]\n" + 'model = "openai/gpt-4o"\n' + 'api_base = "https://api.openai.com/v1"\n' + 'api_key = "plaintext-secret"\n' + ), + encoding="utf-8", + ) + exit_code = cli.main( + [ + "generate", + "--model", + "gpt-4o", + "--deployments-toml", + str(deployments_toml), + "--prompt", + "hello", + ] + ) + assert exit_code == 1 + assert "plaintext secret" in capsys.readouterr().err diff --git a/tests/test_sync_runner.py b/tests/test_sync_runner.py new file mode 100644 index 0000000..ebb2031 --- /dev/null +++ b/tests/test_sync_runner.py @@ -0,0 +1,111 @@ +from __future__ import annotations + +import asyncio +import threading +from concurrent.futures import Future +from typing import Any + +import pytest + +from infermesh.sync_runner import SyncRunner + + +def test_sync_runner_waits_for_cancel_cleanup_on_keyboard_interrupt( + monkeypatch: pytest.MonkeyPatch, +) -> None: + runner = SyncRunner() + started = threading.Event() + cleanup_started = threading.Event() + allow_cleanup_finish = threading.Event() + cleaned = threading.Event() + returned = threading.Event() + result: dict[str, BaseException | None] = {"error": None} + + async def blocked() -> None: + started.set() + try: + while True: + await asyncio.sleep(0.01) + finally: + cleanup_started.set() + await asyncio.to_thread(allow_cleanup_finish.wait) + cleaned.set() + + original_wait_for_future = runner._wait_for_future + blocking_waits = 0 + + def interrupting_wait_for_future( + future: Future[Any], timeout: float | None = None + ) -> Any: + nonlocal blocking_waits + if timeout is None: + blocking_waits += 1 + if blocking_waits == 2 and timeout is None: + started.wait(timeout=1.0) + blocking_waits += 1 + raise KeyboardInterrupt + return original_wait_for_future(future, timeout) + + monkeypatch.setattr(runner, "_wait_for_future", interrupting_wait_for_future) + + def run_and_capture() -> None: + try: + runner.run(blocked()) + except BaseException as exc: # noqa: BLE001 + result["error"] = exc + finally: + returned.set() + + try: + worker = threading.Thread(target=run_and_capture, daemon=True) + worker.start() + assert cleanup_started.wait(timeout=1.0) + assert not returned.wait(timeout=0.05) + allow_cleanup_finish.set() + worker.join(timeout=1.0) + assert not worker.is_alive() + assert isinstance(result["error"], KeyboardInterrupt) + assert cleaned.is_set() + finally: + allow_cleanup_finish.set() + runner.close() + + +def test_sync_runner_waits_for_cancel_cleanup_when_interrupt_during_handoff( + monkeypatch: pytest.MonkeyPatch, +) -> None: + runner = SyncRunner() + started = threading.Event() + cleaned = threading.Event() + + async def blocked() -> None: + started.set() + try: + while True: + await asyncio.sleep(0.01) + finally: + cleaned.set() + + original_wait_for_future = runner._wait_for_future + blocking_waits = 0 + + def interrupting_wait_for_future( + future: Future[Any], timeout: float | None = None + ) -> Any: + nonlocal blocking_waits + if timeout is None: + blocking_waits += 1 + if blocking_waits == 1 and timeout is None: + started.wait(timeout=1.0) + blocking_waits += 1 + raise KeyboardInterrupt + return original_wait_for_future(future, timeout) + + monkeypatch.setattr(runner, "_wait_for_future", interrupting_wait_for_future) + + try: + with pytest.raises(KeyboardInterrupt): + runner.run(blocked()) + assert cleaned.is_set() + finally: + runner.close() diff --git a/tests/test_workflow.py b/tests/test_workflow.py new file mode 100644 index 0000000..f125489 --- /dev/null +++ b/tests/test_workflow.py @@ -0,0 +1,1324 @@ +"""Unit tests for the workflow engine (infermesh._workflow).""" + +from __future__ import annotations + +import asyncio +import io +import json +import sqlite3 +import sys +import threading +import types +from dataclasses import dataclass +from datetime import datetime +from pathlib import Path +from typing import Any + +import pytest + +import infermesh._workflow.checkpoint as checkpoint_module +from infermesh._workflow import run_generate_workflow +from infermesh._workflow.checkpoint import ( + _checkpoint_path_for, + _connect_checkpoint_db, + _load_run_metadata, +) +from infermesh._workflow.mapping import _compute_mapping_fingerprint +from infermesh._workflow.resume import ResumePlanner +from infermesh.sync_runner import SyncRunner +from tests.fakes import ( + checkpoint_item_for_parse_error, + checkpoint_item_for_record, + load_resume_state, + write_checkpoint_db, +) + +# --------------------------------------------------------------------------- +# Fake client +# --------------------------------------------------------------------------- + + +@dataclass +class _FakeResult: + output_text: str + request_id: str = "req-0" + finish_reason: str = "stop" + token_usage: None = None + + +class _FakeClient: + """Async fake that records workflow admissions one item at a time.""" + + def __init__(self) -> None: + self._sync_runner = SyncRunner() + self.inputs: list[Any] = [] + self.active = 0 + self.peak_active = 0 + + def _run_sync(self, coroutine: Any) -> Any: + return self._sync_runner.run(coroutine) + + async def agenerate(self, input_data: Any, **kwargs: Any) -> _FakeResult: + self.inputs.append(input_data) + self.active += 1 + self.peak_active = max(self.peak_active, self.active) + try: + await asyncio.sleep(0) + return _FakeResult(output_text=f"out:{input_data}") + finally: + self.active -= 1 + + def close(self) -> None: + self._sync_runner.close() + + +class _RollingWindowFakeClient(_FakeClient): + """Fake client that exposes whether the workflow refilled before a slow item ended.""" + + def __init__(self) -> None: + super().__init__() + self.tail_started_before_slow_finished = False + self._allow_slow_finish: asyncio.Event | None = None + + async def agenerate(self, input_data: Any, **kwargs: Any) -> _FakeResult: + self.inputs.append(input_data) + self.active += 1 + self.peak_active = max(self.peak_active, self.active) + try: + if self._allow_slow_finish is None: + self._allow_slow_finish = asyncio.Event() + if input_data == "slow": + await asyncio.wait_for(self._allow_slow_finish.wait(), timeout=1.0) + else: + await asyncio.sleep(0) + if input_data == "tail-1": + self.tail_started_before_slow_finished = True + self._allow_slow_finish.set() + return _FakeResult(output_text=f"out:{input_data}") + finally: + self.active -= 1 + + +class _SelectiveFailingFakeClient(_FakeClient): + """Fake client that fails selected prompts while letting siblings continue.""" + + def __init__(self, *, failing_inputs: set[Any]) -> None: + super().__init__() + self._failing_inputs = set(failing_inputs) + + async def agenerate(self, input_data: Any, **kwargs: Any) -> _FakeResult: + self.inputs.append(input_data) + self.active += 1 + self.peak_active = max(self.peak_active, self.active) + try: + await asyncio.sleep(0) + if input_data in self._failing_inputs: + raise RuntimeError(f"boom:{input_data}") + return _FakeResult(output_text=f"out:{input_data}") + finally: + self.active -= 1 + + +class _MapperSignalFakeClient(_FakeClient): + """Fake client that signals when a specific mapped input has completed.""" + + def __init__(self, *, release_event: threading.Event) -> None: + super().__init__() + self._release_event = release_event + + async def agenerate(self, input_data: Any, **kwargs: Any) -> _FakeResult: + self.inputs.append(input_data) + self.active += 1 + self.peak_active = max(self.peak_active, self.active) + try: + await asyncio.sleep(0) + if input_data == "first": + self._release_event.set() + return _FakeResult(output_text=f"out:{input_data}") + finally: + self.active -= 1 + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _write_input(tmp_path: Path, rows: list[dict[str, Any]]) -> Path: + p = tmp_path / "input.jsonl" + p.write_text("\n".join(json.dumps(r) for r in rows) + "\n", encoding="utf-8") + return p + + +def _read_output(path: Path) -> list[dict[str, Any]]: + return [ + json.loads(line) + for line in path.read_text(encoding="utf-8").splitlines() + if line.strip() + ] + + +def _run( + client: _FakeClient, + *, + input_path: Path | None = None, + output_path: Path | None = None, + checkpoint_dir: str | None = None, + prompt: str | None = None, + mapper_spec: str | None = None, + resume: bool = False, + window_size: int = 128, + parse_json: bool = False, + on_status: Any = None, +) -> None: + try: + run_generate_workflow( + client, # type: ignore[arg-type] + prompt=prompt, + input_jsonl=str(input_path) if input_path is not None else None, + output_jsonl=str(output_path) if output_path is not None else None, + checkpoint_dir=checkpoint_dir, + mapper_spec=mapper_spec, + resume=resume, + endpoint="chat_completion", + window_size=window_size, + parse_json=parse_json, + on_status=on_status, + ) + finally: + client.close() + + +def _load_mapping_fingerprint(checkpoint_path: Path) -> str: + connection = _connect_checkpoint_db(checkpoint_path) + try: + _, mapping_fingerprint = _load_run_metadata(connection) + return mapping_fingerprint + finally: + connection.close() + + +def _load_checkpoint_journal_mode(checkpoint_path: Path) -> str: + connection = sqlite3.connect(checkpoint_path) + try: + row = connection.execute("PRAGMA journal_mode").fetchone() + return str(row[0]).lower() if row is not None else "" + finally: + connection.close() + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +def test_file_backed_generate_streams_input( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + rows = [{"prompt": f"p{i}"} for i in range(7)] + input_path = _write_input(tmp_path, rows) + output_path = tmp_path / "out.jsonl" + client = _FakeClient() + monkeypatch.setattr(ResumePlanner, "_temp_dir", lambda: tmp_path) + + _run(client, input_path=input_path, output_path=output_path, window_size=3) + + assert client.inputs == [f"p{i}" for i in range(7)] + assert client.peak_active <= 3 + assert not list(tmp_path.glob(".infermesh-resume-plan.*.sqlite")) + + out_rows = _read_output(output_path) + assert len(out_rows) == 7 + + +def test_default_checkpoint_uses_portable_rollback_journaling(tmp_path: Path) -> None: + rows = [{"prompt": "a"}, {"prompt": "b"}] + input_path = _write_input(tmp_path, rows) + output_path = tmp_path / "out.jsonl" + checkpoint_path = _checkpoint_path_for(str(output_path)) + client = _FakeClient() + + _run(client, input_path=input_path, output_path=output_path) + + assert checkpoint_path.exists() + assert _load_checkpoint_journal_mode(checkpoint_path) in {"persist", "delete"} + assert not Path(f"{checkpoint_path}-wal").exists() + assert not Path(f"{checkpoint_path}-shm").exists() + + +def test_checkpoint_override_disambiguates_same_output_basename(tmp_path: Path) -> None: + checkpoint_dir = tmp_path / "checkpoints" + output_a = tmp_path / "run-a" / "out.jsonl" + output_b = tmp_path / "run-b" / "out.jsonl" + + checkpoint_a = _checkpoint_path_for( + str(output_a), checkpoint_dir=str(checkpoint_dir) + ) + checkpoint_b = _checkpoint_path_for( + str(output_b), checkpoint_dir=str(checkpoint_dir) + ) + + assert checkpoint_a.parent == checkpoint_dir + assert checkpoint_b.parent == checkpoint_dir + assert checkpoint_a != checkpoint_b + assert checkpoint_a.name.startswith("out.") + assert checkpoint_b.name.startswith("out.") + + +def test_generate_workflow_supports_running_event_loop(tmp_path: Path) -> None: + rows = [{"prompt": "hello"}] + input_path = _write_input(tmp_path, rows) + output_path = tmp_path / "out.jsonl" + client = _FakeClient() + + async def invoke_workflow() -> None: + _run(client, input_path=input_path, output_path=output_path) + + asyncio.run(invoke_workflow()) + + assert _read_output(output_path)[0]["output_text"] == "out:hello" + + +def test_file_backed_generate_refills_window_as_items_finish(tmp_path: Path) -> None: + rows = [ + {"prompt": "slow"}, + {"prompt": "fast-1"}, + {"prompt": "fast-2"}, + {"prompt": "tail-1"}, + {"prompt": "tail-2"}, + ] + input_path = _write_input(tmp_path, rows) + output_path = tmp_path / "out.jsonl" + checkpoint_path = _checkpoint_path_for(str(output_path)) + client = _RollingWindowFakeClient() + + _run(client, input_path=input_path, output_path=output_path, window_size=3) + + assert client.peak_active <= 3 + assert client.tail_started_before_slow_finished + + out_rows = _read_output(output_path) + assert len(out_rows) == 5 + + state = load_resume_state(checkpoint_path) + assert len(state) == 5 + assert {row["status"] for row in state.values()} == {"success"} + + +def test_resume_skips_settled_rows(tmp_path: Path) -> None: + rows = [{"prompt": "a"}, {"prompt": "b"}, {"prompt": "c"}] + input_path = _write_input(tmp_path, rows) + output_path = tmp_path / "out.jsonl" + checkpoint_path = _checkpoint_path_for(str(output_path)) + + output_path.write_text( + json.dumps({"_index": 0, "output_text": "cached-a"}) + "\n", + encoding="utf-8", + ) + write_checkpoint_db( + output_path, + [ + checkpoint_item_for_record( + rows[0], occurrence=0, index=0, status="success" + ), + checkpoint_item_for_record( + rows[1], occurrence=0, index=1, status="pending" + ), + checkpoint_item_for_record( + rows[2], occurrence=0, index=2, status="pending" + ), + ], + ) + + client = _FakeClient() + _run(client, input_path=input_path, output_path=output_path, resume=True) + + assert client.inputs == ["b", "c"] + + rows_by_index = {row["_index"]: row for row in _read_output(output_path)} + assert rows_by_index[1]["output_text"] == "out:b" + assert rows_by_index[2]["output_text"] == "out:c" + + state = load_resume_state(checkpoint_path) + assert {row["status"] for row in state.values()} == {"success"} + + +def test_resume_reports_planner_status_and_cleans_temp_db( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + rows = [{"prompt": "a"}, {"prompt": "b"}] + input_path = _write_input(tmp_path, rows) + output_path = tmp_path / "out.jsonl" + output_path.write_text( + json.dumps({"_index": 0, "output_text": "cached-a"}) + "\n", + encoding="utf-8", + ) + write_checkpoint_db( + output_path, + [ + checkpoint_item_for_record( + rows[0], occurrence=0, index=0, status="success" + ), + checkpoint_item_for_record( + rows[1], occurrence=0, index=1, status="pending" + ), + ], + ) + monkeypatch.setattr(ResumePlanner, "_temp_dir", lambda: tmp_path) + statuses: list[str] = [] + + client = _FakeClient() + _run( + client, + input_path=input_path, + output_path=output_path, + resume=True, + on_status=statuses.append, + ) + + assert "Resume: validating checkpoint file..." in statuses + assert "Resume: validating output artifact..." in statuses + assert "Resume: building resume plan..." in statuses + assert "Resume: locating pending rows..." in statuses + assert "Opening output and checkpoint files..." in statuses + assert not list(tmp_path.glob(".infermesh-resume-plan.*.sqlite")) + + +def test_resume_reuses_checkpoint_dir_override(tmp_path: Path) -> None: + rows = [{"prompt": "a"}, {"prompt": "b"}, {"prompt": "c"}] + input_path = _write_input(tmp_path, rows) + output_path = tmp_path / "out.jsonl" + checkpoint_dir = tmp_path / "scratch-checkpoints" + checkpoint_path = _checkpoint_path_for( + str(output_path), checkpoint_dir=str(checkpoint_dir) + ) + + output_path.write_text( + json.dumps({"_index": 0, "output_text": "cached-a"}) + "\n", + encoding="utf-8", + ) + write_checkpoint_db( + output_path, + [ + checkpoint_item_for_record( + rows[0], occurrence=0, index=0, status="success" + ), + checkpoint_item_for_record( + rows[1], occurrence=0, index=1, status="pending" + ), + checkpoint_item_for_record( + rows[2], occurrence=0, index=2, status="pending" + ), + ], + checkpoint_dir=str(checkpoint_dir), + ) + + client = _FakeClient() + _run( + client, + input_path=input_path, + output_path=output_path, + checkpoint_dir=str(checkpoint_dir), + resume=True, + ) + + rows_by_index = {row["_index"]: row for row in _read_output(output_path)} + assert rows_by_index[0]["output_text"] == "cached-a" + assert rows_by_index[1]["output_text"] == "out:b" + assert rows_by_index[2]["output_text"] == "out:c" + + state = load_resume_state(checkpoint_path) + assert {row["status"] for row in state.values()} == {"success"} + + +def test_resume_requires_same_checkpoint_dir_override(tmp_path: Path) -> None: + row = {"prompt": "a"} + input_path = _write_input(tmp_path, [row]) + output_path = tmp_path / "out.jsonl" + checkpoint_dir = tmp_path / "scratch-checkpoints" + + output_path.write_text( + json.dumps({"_index": 0, "output_text": "cached-a"}) + "\n", + encoding="utf-8", + ) + write_checkpoint_db( + output_path, + [checkpoint_item_for_record(row, occurrence=0, index=0, status="success")], + checkpoint_dir=str(checkpoint_dir), + ) + + client = _FakeClient() + with pytest.raises(ValueError, match="requires checkpoint file"): + _run(client, input_path=input_path, output_path=output_path, resume=True) + + +def test_resume_requires_file_backed_run() -> None: + client = _FakeClient() + + with pytest.raises(ValueError, match="--resume requires --output-jsonl"): + _run(client, prompt="hello", resume=True) + + +def test_builtin_row_conventions(tmp_path: Path) -> None: + rows: list[dict[str, Any]] = [ + {"prompt": "hello"}, + {"messages": [{"role": "user", "content": "hi"}]}, + {"responses_input": [{"role": "user", "content": "hey"}]}, + {"prompt": ""}, + {"messages": []}, + {"responses_input": []}, + ] + input_path = _write_input(tmp_path, rows) + output_path = tmp_path / "out.jsonl" + client = _FakeClient() + + _run(client, input_path=input_path, output_path=output_path) + + assert client.inputs == [ + "hello", + [{"role": "user", "content": "hi"}], + [{"role": "user", "content": "hey"}], + "", + [], + [], + ] + out_rows = _read_output(output_path) + assert len(out_rows) == 6 + assert all(r["error"] is None for r in out_rows) + + +def test_non_object_source_rows_become_error_rows_without_aborting_siblings( + tmp_path: Path, +) -> None: + input_path = tmp_path / "input.jsonl" + input_path.write_text( + "\n".join( + [ + json.dumps({"prompt": "good"}), + json.dumps([]), + json.dumps({"prompt": "later"}), + ] + ) + + "\n", + encoding="utf-8", + ) + output_path = tmp_path / "out.jsonl" + client = _FakeClient() + + _run(client, input_path=input_path, output_path=output_path) + + assert client.inputs == ["good", "later"] + rows_by_index = {row["_index"]: row for row in _read_output(output_path)} + assert len(rows_by_index) == 3 + assert rows_by_index[0]["error"] is None + assert rows_by_index[1]["output_text"] is None + assert "JSON objects" in rows_by_index[1]["error"] + assert rows_by_index[2]["error"] is None + + +def test_mapper_import_and_metadata(tmp_path: Path) -> None: + fake_mod = types.ModuleType("_test_wf_mapper_mod") + + def my_mapper(record: dict[str, Any]) -> dict[str, Any]: + return { + "input": record["prompt"].upper(), + "metadata": {"original": record["prompt"]}, + } + + fake_mod.my_mapper = my_mapper # type: ignore[attr-defined] + sys.modules["_test_wf_mapper_mod"] = fake_mod + try: + rows = [{"prompt": "hello"}] + input_path = _write_input(tmp_path, rows) + output_path = tmp_path / "out.jsonl" + client = _FakeClient() + + _run( + client, + input_path=input_path, + output_path=output_path, + mapper_spec="_test_wf_mapper_mod:my_mapper", + ) + + assert client.inputs == ["HELLO"] + out_rows = _read_output(output_path) + assert out_rows[0]["metadata"] == {"original": "hello"} + assert out_rows[0]["error"] is None + finally: + del sys.modules["_test_wf_mapper_mod"] + + +def test_mapper_waiting_on_generation_progress_does_not_block_loop( + tmp_path: Path, +) -> None: + fake_mod = types.ModuleType("_test_wf_waiting_mapper_mod") + release_event = threading.Event() + client = _MapperSignalFakeClient(release_event=release_event) + + def waiting_mapper(record: dict[str, Any]) -> dict[str, Any]: + if record["prompt"] == "second" and not release_event.wait(timeout=1.0): + raise RuntimeError("mapper never observed first completion") + return {"input": record["prompt"]} + + fake_mod.waiting_mapper = waiting_mapper # type: ignore[attr-defined] + sys.modules["_test_wf_waiting_mapper_mod"] = fake_mod + try: + rows = [{"prompt": "first"}, {"prompt": "second"}] + input_path = _write_input(tmp_path, rows) + output_path = tmp_path / "out.jsonl" + + _run( + client, + input_path=input_path, + output_path=output_path, + mapper_spec="_test_wf_waiting_mapper_mod:waiting_mapper", + window_size=2, + ) + + rows_by_index = {row["_index"]: row for row in _read_output(output_path)} + assert rows_by_index[0]["output_text"] == "out:first" + assert rows_by_index[0]["error"] is None + assert rows_by_index[1]["output_text"] == "out:second" + assert rows_by_index[1]["error"] is None + finally: + del sys.modules["_test_wf_waiting_mapper_mod"] + + +def test_mapper_ignores_extra_keys(tmp_path: Path) -> None: + fake_mod = types.ModuleType("_test_wf_extra_mod") + + def my_mapper(record: dict[str, Any]) -> dict[str, Any]: + return {"input": record["prompt"], "metadata": None, "extra_ignored_key": 42} + + fake_mod.my_mapper = my_mapper # type: ignore[attr-defined] + sys.modules["_test_wf_extra_mod"] = fake_mod + try: + rows = [{"prompt": "hello"}] + input_path = _write_input(tmp_path, rows) + output_path = tmp_path / "out.jsonl" + client = _FakeClient() + + _run( + client, + input_path=input_path, + output_path=output_path, + mapper_spec="_test_wf_extra_mod:my_mapper", + ) + + out_rows = _read_output(output_path) + assert len(out_rows) == 1 + assert out_rows[0]["error"] is None + finally: + del sys.modules["_test_wf_extra_mod"] + + +def test_mapper_validation_failure_becomes_error_row(tmp_path: Path) -> None: + fake_mod = types.ModuleType("_test_wf_bad_mapper_mod") + + def bad_mapper(record: dict[str, Any]) -> dict[str, Any]: + return {"no_input_key": "oops"} # missing required "input" + + fake_mod.bad_mapper = bad_mapper # type: ignore[attr-defined] + sys.modules["_test_wf_bad_mapper_mod"] = fake_mod + try: + rows = [{"prompt": "first"}, {"prompt": "second"}] + input_path = _write_input(tmp_path, rows) + output_path = tmp_path / "out.jsonl" + client = _FakeClient() + + _run( + client, + input_path=input_path, + output_path=output_path, + mapper_spec="_test_wf_bad_mapper_mod:bad_mapper", + ) + + # Both rows become error rows; no generation request is started. + assert client.inputs == [] + out_rows = _read_output(output_path) + assert len(out_rows) == 2 + assert all(r["output_text"] is None for r in out_rows) + assert all(r["error"] is not None for r in out_rows) + finally: + del sys.modules["_test_wf_bad_mapper_mod"] + + +def test_builtin_mapping_failure_becomes_error_row(tmp_path: Path) -> None: + rows = [{"text": "wrong_field"}] + input_path = _write_input(tmp_path, rows) + output_path = tmp_path / "out.jsonl" + client = _FakeClient() + + _run(client, input_path=input_path, output_path=output_path) + + assert client.inputs == [] + out_rows = _read_output(output_path) + assert len(out_rows) == 1 + assert out_rows[0]["output_text"] is None + assert out_rows[0]["error"] is not None + assert ( + "require" in out_rows[0]["error"].lower() + or "prompt" in out_rows[0]["error"].lower() + ) + + +def test_malformed_json_line_becomes_error_row(tmp_path: Path) -> None: + input_path = tmp_path / "input.jsonl" + input_path.write_text( + "not valid json\n" + json.dumps({"prompt": "good"}) + "\n", + encoding="utf-8", + ) + output_path = tmp_path / "out.jsonl" + client = _FakeClient() + + _run(client, input_path=input_path, output_path=output_path) + + out_rows = _read_output(output_path) + assert len(out_rows) == 2 + + error_rows = [r for r in out_rows if r["error"] is not None] + success_rows = [r for r in out_rows if r["error"] is None] + assert len(error_rows) == 1 + assert len(success_rows) == 1 + assert success_rows[0]["output_text"] is not None + + +def test_provider_failure_becomes_error_row_and_settles_checkpoint( + tmp_path: Path, +) -> None: + rows = [{"prompt": "good"}, {"prompt": "bad"}, {"prompt": "also-good"}] + input_path = _write_input(tmp_path, rows) + output_path = tmp_path / "out.jsonl" + checkpoint_path = _checkpoint_path_for(str(output_path)) + client = _SelectiveFailingFakeClient(failing_inputs={"bad"}) + + _run(client, input_path=input_path, output_path=output_path, window_size=2) + + rows_by_index = {row["_index"]: row for row in _read_output(output_path)} + assert rows_by_index[0]["output_text"] == "out:good" + assert rows_by_index[0]["error"] is None + assert rows_by_index[1]["output_text"] is None + assert rows_by_index[1]["error"] == "boom:bad" + assert rows_by_index[2]["output_text"] == "out:also-good" + assert rows_by_index[2]["error"] is None + + state = load_resume_state(checkpoint_path) + assert {row["status"] for row in state.values()} == {"success", "error"} + failed_item = next(row for row in state.values() if row["_index"] == 1) + assert failed_item["status"] == "error" + assert failed_item["error"] == "boom:bad" + + +def test_fresh_run_bootstrap_failure_preserves_existing_artifacts( + tmp_path: Path, +) -> None: + input_path = tmp_path / "missing.jsonl" + output_path = tmp_path / "out.jsonl" + checkpoint_path = _checkpoint_path_for(str(output_path)) + old_output = json.dumps({"_index": 99, "output_text": "keep-me"}) + "\n" + output_path.write_text(old_output, encoding="utf-8") + write_checkpoint_db( + output_path, + [ + checkpoint_item_for_record( + {"prompt": "old"}, + occurrence=0, + index=99, + status="success", + ) + ], + ) + old_checkpoint = checkpoint_path.read_bytes() + client = _FakeClient() + + with pytest.raises(FileNotFoundError): + _run(client, input_path=input_path, output_path=output_path) + + assert output_path.read_text(encoding="utf-8") == old_output + assert checkpoint_path.read_bytes() == old_checkpoint + assert client.inputs == [] + + +def test_resume_skips_items_from_checkpoint_file(tmp_path: Path) -> None: + rows = [{"prompt": "a"}, {"prompt": "b"}, {"prompt": "c"}] + input_path = _write_input(tmp_path, rows) + output_path = tmp_path / "out.jsonl" + output_path.write_text( + "\n".join( + json.dumps( + {"_index": index, "output_text": f"cached-{rows[index]['prompt']}"} + ) + for index in (0, 2) + ) + + "\n", + encoding="utf-8", + ) + write_checkpoint_db( + output_path, + [ + checkpoint_item_for_record( + row, + occurrence=0, + index=source_index, + status=status, + ) + for source_index, row, status in [ + (0, rows[0], "success"), + (1, rows[1], "pending"), + (2, rows[2], "success"), + ] + ], + ) + + client = _FakeClient() + _run(client, input_path=input_path, output_path=output_path, resume=True) + + assert client.inputs == ["b"] + + +def test_resume_requires_state_file(tmp_path: Path) -> None: + rows = [{"prompt": "a"}] + input_path = _write_input(tmp_path, rows) + output_path = tmp_path / "out.jsonl" + output_path.write_text( + json.dumps({"_index": 0, "output_text": "cached-a"}) + "\n", + encoding="utf-8", + ) + client = _FakeClient() + + with pytest.raises(ValueError, match="checkpoint file"): + _run(client, input_path=input_path, output_path=output_path, resume=True) + + +def test_resume_rejects_missing_output_file(tmp_path: Path) -> None: + rows = [{"prompt": "a"}] + input_path = _write_input(tmp_path, rows) + output_path = tmp_path / "out.jsonl" + write_checkpoint_db( + output_path, + [checkpoint_item_for_record(rows[0], occurrence=0, index=0, status="success")], + ) + client = _FakeClient() + + with pytest.raises(ValueError, match="requires output file"): + _run(client, input_path=input_path, output_path=output_path, resume=True) + + +def test_resume_rejects_truncated_output_file(tmp_path: Path) -> None: + rows = [{"prompt": "a"}, {"prompt": "b"}] + input_path = _write_input(tmp_path, rows) + output_path = tmp_path / "out.jsonl" + output_path.write_text( + json.dumps({"_index": 0, "output_text": "cached-a"}) + "\n", + encoding="utf-8", + ) + write_checkpoint_db( + output_path, + [ + checkpoint_item_for_record( + row, + occurrence=0, + index=index, + status="success", + ) + for index, row in enumerate(rows) + ], + ) + client = _FakeClient() + + with pytest.raises(ValueError, match="missing settled checkpoint rows"): + _run(client, input_path=input_path, output_path=output_path, resume=True) + + +def test_resume_output_bitmap_handles_sparse_high_indexes(tmp_path: Path) -> None: + rows = [{"prompt": "a"}, {"prompt": "b"}] + input_path = _write_input(tmp_path, rows) + output_path = tmp_path / "out.jsonl" + output_path.write_text( + "\n".join( + [ + json.dumps({"_index": 0, "output_text": "cached-a"}), + json.dumps({"_index": 1000, "output_text": "cached-b"}), + ] + ) + + "\n", + encoding="utf-8", + ) + write_checkpoint_db( + output_path, + [ + checkpoint_item_for_record( + rows[0], occurrence=0, index=0, status="success" + ), + checkpoint_item_for_record( + rows[1], occurrence=0, index=1000, status="success" + ), + ], + ) + + client = _FakeClient() + _run(client, input_path=input_path, output_path=output_path, resume=True) + + assert client.inputs == [] + + +def test_completed_run_updates_checkpoint_rows_in_place(tmp_path: Path) -> None: + rows = [{"prompt": "a"}, {"prompt": "b"}, {"prompt": "c"}] + input_path = _write_input(tmp_path, rows) + output_path = tmp_path / "out.jsonl" + checkpoint_path = _checkpoint_path_for(str(output_path)) + client = _FakeClient() + + _run(client, input_path=input_path, output_path=output_path) + + state = load_resume_state(checkpoint_path) + assert len(state) == 3 + assert {row["status"] for row in state.values()} == {"success"} + assert _load_mapping_fingerprint(checkpoint_path) == _compute_mapping_fingerprint( + mapper_spec=None, + mapper=None, + ) + + +def test_resume_tracks_duplicate_records_independently(tmp_path: Path) -> None: + rows = [{"prompt": "dup"}, {"prompt": "dup"}, {"prompt": "tail"}] + input_path = _write_input(tmp_path, rows) + output_path = tmp_path / "out.jsonl" + output_path.write_text( + json.dumps({"_index": 0, "output_text": "cached-dup"}) + "\n", + encoding="utf-8", + ) + write_checkpoint_db( + output_path, + [ + checkpoint_item_for_record( + rows[0], + occurrence=0, + index=0, + status="success", + ), + checkpoint_item_for_record( + rows[1], + occurrence=1, + index=1, + status="pending", + ), + checkpoint_item_for_record( + rows[2], + occurrence=0, + index=2, + status="pending", + ), + ], + ) + + client = _FakeClient() + _run(client, input_path=input_path, output_path=output_path, resume=True) + + assert client.inputs == ["dup", "tail"] + rows_by_index = {row["_index"]: row for row in _read_output(output_path)} + assert rows_by_index[1]["output_text"] == "out:dup" + assert rows_by_index[2]["output_text"] == "out:tail" + + +def test_resume_rejects_changed_mapper_implementation(tmp_path: Path) -> None: + rows = [{"prompt": "first"}, {"prompt": "second"}] + input_path = _write_input(tmp_path, rows) + output_path = tmp_path / "out.jsonl" + module_name = "_test_wf_changed_mapper_mod" + module_path = tmp_path / f"{module_name}.py" + + def write_mapper_module(prefix: str) -> None: + module_path.write_text( + "\n".join( + [ + f'HELPER_PREFIX = "{prefix}"', + "", + "def helper(text):", + ' return f"{HELPER_PREFIX}:{text}"', + "", + "def my_mapper(record):", + ' return {"input": helper(record["prompt"])}', + "", + ] + ), + encoding="utf-8", + ) + + sys.path.insert(0, str(tmp_path)) + try: + write_mapper_module("v1") + sys.modules.pop(module_name, None) + initial_client = _FakeClient() + _run( + initial_client, + input_path=input_path, + output_path=output_path, + mapper_spec=f"{module_name}:my_mapper", + ) + + checkpoint_path = _checkpoint_path_for(str(output_path)) + assert _load_mapping_fingerprint(checkpoint_path) + + write_mapper_module("v2") + sys.modules.pop(module_name, None) + + resume_client = _FakeClient() + with pytest.raises(ValueError, match="Resume mapping does not match"): + _run( + resume_client, + input_path=input_path, + output_path=output_path, + mapper_spec=f"{module_name}:my_mapper", + resume=True, + ) + finally: + sys.modules.pop(module_name, None) + sys.path.remove(str(tmp_path)) + + +def test_resume_allows_empty_custom_mapper_checkpoint(tmp_path: Path) -> None: + input_path = tmp_path / "input.jsonl" + input_path.write_text("", encoding="utf-8") + output_path = tmp_path / "out.jsonl" + + fake_mod = types.ModuleType("_test_wf_empty_mapper_mod") + + def my_mapper(record: dict[str, Any]) -> dict[str, Any]: + return {"input": record["prompt"].upper()} + + my_mapper.__module__ = "_test_wf_empty_mapper_mod" + fake_mod.my_mapper = my_mapper # type: ignore[attr-defined] + sys.modules["_test_wf_empty_mapper_mod"] = fake_mod + try: + initial_client = _FakeClient() + _run( + initial_client, + input_path=input_path, + output_path=output_path, + mapper_spec="_test_wf_empty_mapper_mod:my_mapper", + ) + + resume_client = _FakeClient() + _run( + resume_client, + input_path=input_path, + output_path=output_path, + mapper_spec="_test_wf_empty_mapper_mod:my_mapper", + resume=True, + ) + + checkpoint_path = _checkpoint_path_for(str(output_path)) + assert output_path.read_text(encoding="utf-8") == "" + assert load_resume_state(checkpoint_path) == {} + assert resume_client.inputs == [] + finally: + del sys.modules["_test_wf_empty_mapper_mod"] + + +def test_resume_preserves_original_indexes_after_reorder(tmp_path: Path) -> None: + original_rows = [{"prompt": "a"}, {"prompt": "b"}, {"prompt": "c"}] + reordered_rows = [original_rows[2], original_rows[0], original_rows[1]] + input_path = _write_input(tmp_path, reordered_rows) + output_path = tmp_path / "out.jsonl" + output_path.write_text( + json.dumps({"_index": 0, "output_text": "cached-a"}) + "\n", + encoding="utf-8", + ) + write_checkpoint_db( + output_path, + [ + checkpoint_item_for_record( + original_rows[0], + occurrence=0, + index=0, + status="success", + ), + checkpoint_item_for_record( + original_rows[1], + occurrence=0, + index=1, + status="pending", + ), + checkpoint_item_for_record( + original_rows[2], + occurrence=0, + index=2, + status="pending", + ), + ], + ) + + client = _FakeClient() + _run(client, input_path=input_path, output_path=output_path, resume=True) + + assert client.inputs == ["c", "b"] + rows_by_index = {row["_index"]: row for row in _read_output(output_path)} + assert rows_by_index[1]["output_text"] == "out:b" + assert rows_by_index[2]["output_text"] == "out:c" + + +@pytest.mark.parametrize( + ("checkpoint_rows", "input_rows"), + [ + ([{"prompt": "a"}, {"prompt": "a"}], [{"prompt": "a"}]), + ([{"prompt": "a"}], [{"prompt": "a"}, {"prompt": "a"}]), + ], +) +def test_resume_rejects_mismatched_occurrences( + tmp_path: Path, + checkpoint_rows: list[dict[str, Any]], + input_rows: list[dict[str, Any]], +) -> None: + input_path = _write_input(tmp_path, input_rows) + output_path = tmp_path / "out.jsonl" + output_path.write_text("", encoding="utf-8") + write_checkpoint_db( + output_path, + [ + checkpoint_item_for_record( + row, + occurrence=occurrence, + index=occurrence, + status="pending", + ) + for occurrence, row in enumerate(checkpoint_rows) + ], + ) + + client = _FakeClient() + with pytest.raises(ValueError, match="does not match the checkpoint file"): + _run(client, input_path=input_path, output_path=output_path, resume=True) + + +def test_resume_ignores_pending_rows_when_validating_output_rows( + tmp_path: Path, +) -> None: + rows = [{"prompt": "a"}, {"prompt": "b"}] + input_path = _write_input(tmp_path, rows) + output_path = tmp_path / "out.jsonl" + output_path.write_text( + json.dumps({"_index": 0, "output_text": "cached-a"}) + "\n", + encoding="utf-8", + ) + write_checkpoint_db( + output_path, + [ + checkpoint_item_for_record( + rows[0], + occurrence=0, + index=0, + status="success", + ), + checkpoint_item_for_record( + rows[1], + occurrence=0, + index=1, + status="pending", + ), + ], + ) + + client = _FakeClient() + _run(client, input_path=input_path, output_path=output_path, resume=True) + + rows_by_index = {row["_index"]: row for row in _read_output(output_path)} + assert rows_by_index[1]["output_text"] == "out:b" + + +def test_resume_tracks_duplicate_parse_errors_by_occurrence(tmp_path: Path) -> None: + input_path = tmp_path / "input.jsonl" + bad_line = "not valid json" + input_path.write_text( + json.dumps({"prompt": "good"}) + "\n" + bad_line + "\n" + bad_line + "\n", + encoding="utf-8", + ) + output_path = tmp_path / "out.jsonl" + output_path.write_text( + json.dumps({"_index": 0, "output_text": None, "error": "cached-parse-error"}) + + "\n", + encoding="utf-8", + ) + write_checkpoint_db( + output_path, + [ + checkpoint_item_for_parse_error( + bad_line, + occurrence=0, + index=0, + status="error", + error="cached-parse-error", + ), + checkpoint_item_for_record( + {"prompt": "good"}, + occurrence=0, + index=1, + status="pending", + ), + checkpoint_item_for_parse_error( + bad_line, + occurrence=1, + index=2, + status="pending", + ), + ], + ) + + client = _FakeClient() + _run(client, input_path=input_path, output_path=output_path, resume=True) + + assert client.inputs == ["good"] + rows_by_index = {row["_index"]: row for row in _read_output(output_path)} + assert rows_by_index[1]["output_text"] == "out:good" + assert rows_by_index[2]["output_text"] is None + assert rows_by_index[2]["error"] is not None + + +def test_stdout_path_creates_no_checkpoint_file( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, + capsys: pytest.CaptureFixture[str], +) -> None: + rows = [{"prompt": "hello"}] + input_path = _write_input(tmp_path, rows) + client = _FakeClient() + monkeypatch.setattr(ResumePlanner, "_temp_dir", lambda: tmp_path) + + _run(client, input_path=input_path, output_path=None) + + assert list(tmp_path.glob("*.checkpoint.sqlite")) == [] + assert not list(tmp_path.glob(".infermesh-resume-plan.*.sqlite")) + + out = capsys.readouterr().out + out_rows = [json.loads(line) for line in out.splitlines() if line.strip()] + assert len(out_rows) == 1 + assert out_rows[0]["error"] is None + + +def test_invalid_metadata_becomes_error_row_without_aborting_siblings( + tmp_path: Path, +) -> None: + fake_mod = types.ModuleType("_test_wf_metadata_mod") + + def my_mapper(record: dict[str, Any]) -> dict[str, Any]: + if record["prompt"] == "bad": + return { + "input": record["prompt"], + "metadata": {"created_at": datetime.now()}, + } + return { + "input": record["prompt"], + "metadata": {"kind": "ok"}, + } + + fake_mod.my_mapper = my_mapper # type: ignore[attr-defined] + sys.modules["_test_wf_metadata_mod"] = fake_mod + try: + rows = [{"prompt": "bad"}, {"prompt": "good"}] + input_path = _write_input(tmp_path, rows) + output_path = tmp_path / "out.jsonl" + client = _FakeClient() + + _run( + client, + input_path=input_path, + output_path=output_path, + mapper_spec="_test_wf_metadata_mod:my_mapper", + ) + + assert client.inputs == ["good"] + out_rows = _read_output(output_path) + assert len(out_rows) == 2 + rows_by_index = {row["_index"]: row for row in out_rows} + assert "json serializable" in rows_by_index[0]["error"].lower() + assert rows_by_index[1]["metadata"] == {"kind": "ok"} + assert rows_by_index[1]["error"] is None + finally: + del sys.modules["_test_wf_metadata_mod"] + + +def test_file_backed_generate_materialises_stdin( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """When output_jsonl is set but no input_jsonl/prompt, stdin is spooled to a temp file.""" + stdin_data = json.dumps({"prompt": "from-stdin"}) + "\n" + monkeypatch.setattr(sys, "stdin", io.StringIO(stdin_data)) + output_path = tmp_path / "out.jsonl" + client = _FakeClient() + + _run(client, input_path=None, output_path=output_path) + + out_rows = _read_output(output_path) + assert len(out_rows) == 1 + assert out_rows[0]["output_text"] == "out:from-stdin" + assert out_rows[0]["error"] is None + # No stray temp artefacts should remain in the working directory + assert not list(tmp_path.glob("*.tmp")) + + +def test_persistence_sink_failure_propagates( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + original = checkpoint_module._mark_checkpoint_item_settled + call_count = 0 + + def failing_mark(connection, checkpoint_key, *, status, error): + nonlocal call_count + call_count += 1 + if call_count == 2: + raise RuntimeError("injected sink failure") + original(connection, checkpoint_key, status=status, error=error) + + monkeypatch.setattr( + checkpoint_module, "_mark_checkpoint_item_settled", failing_mark + ) + + rows = [{"prompt": "a"}, {"prompt": "b"}, {"prompt": "c"}] + input_path = _write_input(tmp_path, rows) + output_path = tmp_path / "out.jsonl" + client = _FakeClient() + + with pytest.raises(RuntimeError, match="injected sink failure"): + _run(client, input_path=input_path, output_path=output_path) + + +def test_resume_planner_temp_db_is_removed_after_failure( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + rows = [{"prompt": "a"}, {"prompt": "b"}] + input_path = _write_input(tmp_path, rows) + output_path = tmp_path / "out.jsonl" + output_path.write_text( + json.dumps({"_index": 0, "output_text": "cached-a"}) + "\n", + encoding="utf-8", + ) + write_checkpoint_db( + output_path, + [ + checkpoint_item_for_record( + rows[0], occurrence=0, index=0, status="success" + ), + checkpoint_item_for_record( + rows[1], occurrence=0, index=1, status="pending" + ), + ], + ) + + def failing_mark(connection, checkpoint_key, *, status, error): + raise RuntimeError("injected sink failure") + + monkeypatch.setattr( + checkpoint_module, "_mark_checkpoint_item_settled", failing_mark + ) + monkeypatch.setattr(ResumePlanner, "_temp_dir", lambda: tmp_path) + + client = _FakeClient() + with pytest.raises(RuntimeError, match="injected sink failure"): + _run(client, input_path=input_path, output_path=output_path, resume=True) + + assert not list(tmp_path.glob(".infermesh-resume-plan.*.sqlite"))