diff --git a/AGENTS.md b/AGENTS.md index 316f904..c8d4836 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -230,7 +230,7 @@ erDiagram ## Data Contracts (the model layer) -`backend/database/models.py` is the **model layer**: domain data contracts (a `TypedDict` per table-group row, plus the `PhraseGroup` union — `list[str] | LiteralPhraseGroup | RegexPhraseGroup`, a discriminated union keyed on `kind`) that describe the *shape* of persisted data and depend on nothing else in the codebase. The dependency rule is one-way — every other layer points its dependencies **inward**, toward the data, and `backend/database/` must never import "up" into `passes/` or `orchestrator.py`. (The introducing commit moved `PhraseGroup` *down* into `models.py` from `slop_detector.py` to kill the last upward import; anything in `database/` that reaches up for a shared shape is an architectural inversion — put the shape here instead.) +`backend/database/models.py` is the **model layer**: domain data contracts (a `TypedDict` per table-group row, plus the `PhraseGroup` union — `list[str] | LiteralPhraseGroup | RegexPhraseGroup`, a discriminated union keyed on `kind`) that describe the *shape* of persisted data and depend on nothing else in the codebase. The dependency rule is one-way — every other layer points its dependencies **inward**, toward the data, and `backend/database/` must never import "up" into `passes/`, `orchestrator.py`, or `workflows/`. (The introducing commit moved `PhraseGroup` *down* into `models.py` from `slop_detector.py` to kill the last upward import; anything in `database/` that reaches up for a shared shape is an architectural inversion — put the shape here instead.) When the database layer genuinely needs higher-layer *behavior* at a fixed seam — `add_message` persisting workflow attachments inside its own write transaction — it declares the contract and the higher layer registers an implementation (dependency inversion): `database/queries/messages.py` owns `register_workflow_attachment_persister`, and `workflows/attachment_cache.py` registers `insert_workflow_attachments` into it at import. Don't reintroduce a lazy `import backend.workflows` inside a `database/` function to dodge the rule — that hides the inversion from the import graph without removing it. - **The TypedDicts label plain dicts, with zero runtime change.** The query layer still returns ordinary `dict(row)` objects; each query stamps the shape at its boundary with `cast(SomeRow, ...)` (a `TypedDict` is not assignable from a bare `dict`). So `row["col"]` access is checked against the schema without any wrapper object, validation, or runtime cost. Each `queries/*.py` module imports just the contract(s) for its tables (`SettingsRow`, `ConversationRow`/`ConversationListRow`, `MessageRow`/`MessageWithAttachments`, `EndpointRow`, `ModelConfigRow`, `WorldRow`, `LorebookEntryRow`, `CharacterCardRow`, `DirectorStateRow`, `DirectorFragmentRow`, `MoodFragmentRow`, `UserPersonaRow`, `ConversationLogRow`, `PhraseBankRow`, and the attachment rows). - **Every row-shaped query return is typed; only free-form blobs stay `dict`.** A query that returns table rows uses a contract. The lone exception is the per-workflow JSON state/config accessors (`get_workflow_state`, `get_workflow_message_state`, `get_workflow_character_state`, `get_workflow_config`) — these decode an arbitrary per-workflow slot with no fixed schema, so they correctly return bare `dict`/`dict | None`. Don't invent a contract for those. @@ -389,7 +389,7 @@ Because the writer's KV cache now lives on a different server than the agent pas Orb sends the **full active message path** (leaf to root) every turn — no automatic truncation or rolling window. Inactive sibling branches are not included. -- `updateContextCounter()` calls `GET /api/conversations/{cid}/context-size` which computes a per-component token breakdown (system prompt, persona, scenario, messages, director injection, lorebook, post-history) using `chars / 3.5` per component +- `updateContextCounter()` calls `GET /api/conversations/{cid}/context-size` which computes a per-component token breakdown (system prompt, persona, scenario, messages, director injection, lorebook, post-history) using `chars / 4` per component - **Manual compress flow**: `POST /summarize` → LLM writes narrative summary → user reviews → `POST /compress` → creates new conversation with summary + last N messages - No RAG, no background compaction, no automatic summarization @@ -463,4 +463,4 @@ See [docs/architecture/secondary-workflow.md](docs/architecture/secondary-workfl 9. **Lorebook scan depth** — Hard-coded to 6 messages (`LOREBOOK_SCAN_DEPTH` in `prompt_builder.py`). Only the last 6 messages are scanned for lorebook keyword matches. -10. **Macros resolve at different levels** — `resolve_message()` expands everything ({{user}}, {{char}}, inline macros like {{roll}}). `resolve_prompt()` only does {{user}}/{{char}} substitution. Use `resolve_prompt()` for historical messages where inline macros shouldn't fire. +10. **Macros resolve at different levels** — `resolve_message()` expands everything ({{user}}, {{char}}, inline macros like {{roll}}). `resolve_prompt()` only does {{user}}/{{char}} substitution. Use `resolve_prompt()` for historical messages where inline macros shouldn't fire. `macros.py` is a **dependency-free leaf** (it imports nothing else in the codebase — like `database/models.py` and `llm_types.py`): it transforms strings and message dicts, and knows nothing about the LLM client. The transport-boundary catch-all that scrubs `{{user}}`/`{{char}}` from *every* outgoing message (the director's tool prompt embeds user-authored fragment text that can carry `{{char}}`) is `Macros.resolve_prompt_messages`, wired in as the `CachedBase.resolve` hook in `kv_tracker.py` — applied to `[*prefix, *trailing]` right before the call, so the KV tracker snapshots the exact resolved bytes sent. There is **no** macro-resolving `LLMClient` subclass/wrapper; don't reintroduce one. diff --git a/backend/database/queries/messages.py b/backend/database/queries/messages.py index 447a7b4..fcb71ff 100644 --- a/backend/database/queries/messages.py +++ b/backend/database/queries/messages.py @@ -3,7 +3,7 @@ import json import sqlite3 from datetime import datetime, timezone -from typing import Any, List, Mapping, Optional, Sequence, cast +from typing import Any, List, Mapping, Optional, Protocol, Sequence, cast from ..connection import get_db from ..models import ( @@ -15,6 +15,36 @@ from .conversations import get_conversation +class _WorkflowAttachmentPersister(Protocol): + """Persists workflow attachments inside ``add_message``'s transaction. + + Implemented by ``backend.workflows.attachment_cache.insert_workflow_attachments`` + and registered at import time -- see the dependency-inversion note below. + """ + + async def __call__(self, message_id: int, attachments: list[dict], *, db: Any = None) -> tuple[list[int], list[dict]]: ... + + +# Dependency-inversion seam. ``add_message`` must insert workflow attachments +# inside its own write transaction (the cache layer's read->evict->insert runs +# under the same lock as the message INSERT), yet the database layer must never +# import "up" into ``backend.workflows``. So the workflow layer registers its +# persister here at import time and ``add_message`` calls through this slot. +# Left None in DB-only contexts that never produce workflow attachments; in +# that state a workflow attachment reaching ``add_message`` is a wiring bug, so +# we fail loudly rather than silently dropping bytes. +_workflow_attachment_persister: "_WorkflowAttachmentPersister | None" = None + + +def register_workflow_attachment_persister(fn: "_WorkflowAttachmentPersister") -> None: + """Wire the workflow-attachment persister into ``add_message``. + + Called once, at import of ``backend.workflows.attachment_cache``. + """ + global _workflow_attachment_persister + _workflow_attachment_persister = fn + + async def get_path_to_leaf(cid: str, leaf_id: int) -> list[MessageWithAttachments]: """Walk parent_id chain from leaf to root, return ordered root→leaf.""" async with get_db() as db: @@ -243,12 +273,16 @@ async def add_message( now, ), ) - # Lazy import: the database package must not depend on - # workflows at import time (would invert the layering). + # Persist workflow attachments through the registered persister + # (see register_workflow_attachment_persister) so the database layer + # never imports up into backend.workflows. if workflow_atts: - from backend.workflows.attachment_cache import insert_workflow_attachments - - _, rejected_workflow_atts = await insert_workflow_attachments(message_id, workflow_atts, db=db) + if _workflow_attachment_persister is None: + raise RuntimeError( + "workflow attachments supplied to add_message but no persister is " + "registered -- import backend.workflows before producing them" + ) + _, rejected_workflow_atts = await _workflow_attachment_persister(message_id, workflow_atts, db=db) await db.execute("UPDATE conversations SET updated_at = ? WHERE id = ?", (now, cid)) await db.commit() diff --git a/backend/database/queries/workflow_attachments.py b/backend/database/queries/workflow_attachments.py index 723930a..92fc79a 100644 --- a/backend/database/queries/workflow_attachments.py +++ b/backend/database/queries/workflow_attachments.py @@ -25,6 +25,13 @@ logger = logging.getLogger(__name__) +# Sentinel string written into ``data_b64`` when an artifact's bytes are +# evicted (the other columns stay intact so a later rehydrate can recover the +# bytes from stored parameters). Defined here -- in the database boundary -- +# because it describes the persisted shape of the column, not cache policy. +# ``backend.workflows.attachment_cache`` re-exports it for the eviction layer. +EVICTED_MARKER = "[evicted]" + def _encode_metadata_field(value: object, field_name: str, workflow_id: str, filename: str) -> str | None: """JSON-encode a dict-shaped metadata field, or return None for absent/bad shape. @@ -133,9 +140,6 @@ async def insert_workflow_attachment_row( raise ValueError("attachment data is empty") if insert_as_evicted: - # Lazy import keeps queries module free of attachment_cache cycle. - from backend.workflows.attachment_cache import EVICTED_MARKER - data_b64 = EVICTED_MARKER elif has_path: with open(attachment["path"], "rb") as f: diff --git a/backend/kv_tracker.py b/backend/kv_tracker.py index 87dc83f..bcb7c9e 100644 --- a/backend/kv_tracker.py +++ b/backend/kv_tracker.py @@ -32,7 +32,7 @@ import json import logging from dataclasses import dataclass -from typing import Any, AsyncIterator, Mapping, Sequence +from typing import Any, AsyncIterator, Callable, Mapping, Sequence logger = logging.getLogger(__name__) @@ -208,7 +208,11 @@ def log_summary(self) -> None: elif stats["source"] in ("unrecognized", "no_cache_fields"): provider_note = f"provider: prompt={stats['prompt_tokens']} tok cached=N/A [{stats['source']}]" else: - pt, ct, cw = stats["prompt_tokens"], stats["cached_tokens"], stats["cache_write_tokens"] + pt, ct, cw = ( + stats["prompt_tokens"], + stats["cached_tokens"], + stats["cache_write_tokens"], + ) total_cached += ct total_prompt += pt pct = (ct / pt * 100) if pt else 0.0 @@ -291,11 +295,20 @@ class CachedBase: tools when it runs on a different server than the agent" — is then just a property of how the writer's base is built (empty ``tools``), not a flag threaded through the writer pass. + + ``resolve`` is the last step of turning the assembled stack into the literal + bytes on the wire: an opaque ``messages -> messages`` transform applied to + ``[*prefix, *trailing]`` immediately before the call (in practice + ``Macros.resolve_prompt_messages``, scrubbing ``{{user}}``/``{{char}}`` from + whatever a pass appended). Keeping it on the base means the tracker snapshot + is taken from the *resolved* bytes — the same ones sent — so it cannot drift. + ``None`` means send the assembled stack unchanged. """ prefix: tuple[Mapping[str, Any], ...] tools: tuple[dict, ...] model: str + resolve: Callable[[Sequence[Mapping[str, Any]]], list[dict]] | None = None def complete( self, @@ -312,13 +325,17 @@ def complete( per-pass top of the stack). The cached bottom — prefix + tools + model — comes solely from ``self``; only *trailing* and *tool_choice* vary. - Delegates to :func:`cached_complete` so the tracker snapshot is taken - from the exact bytes sent. + The assembled stack is run through ``self.resolve`` (if set) to produce + the final wire bytes, then handed to :func:`cached_complete` so the + tracker snapshot is taken from the exact bytes sent. """ + messages: Sequence[Mapping[str, Any]] = [*self.prefix, *trailing] + if self.resolve is not None: + messages = self.resolve(messages) return cached_complete( client, label=label, - messages=[*self.prefix, *trailing], + messages=messages, model=self.model, tools=list(self.tools) or None, tool_choice=tool_choice, diff --git a/backend/macros.py b/backend/macros.py index d41276a..acbb545 100644 --- a/backend/macros.py +++ b/backend/macros.py @@ -1,6 +1,13 @@ """ macros.py — Macro resolution for prompts and messages. +A dependency-free leaf: it turns ``{{user}}``/``{{char}}`` and inline macros +like ``{{roll}}`` into literal text and imports nothing else in the codebase. +It knows about *strings and message dicts*, not about the LLM client — the +pipeline applies :meth:`Macros.resolve_prompt_messages` at the transport +boundary (the cached-base ``resolve`` hook in ``kv_tracker.py``) rather than +this module reaching up into the client layer. + Public API: resolve_message(text, user_name, char_name) — Full resolution ({{user}}/{{char}} + inline macros like {{roll}}). @@ -15,7 +22,8 @@ Macros.resolve_message(text) — instance method, full resolution Macros.resolve_prompt(text) — instance method, substitution only Macros.resolve_prompt_messages(msgs) — batch prompt-level res on message list - Macros.wrap_client(client) — wraps LLMClient for prompt-level resolution + (the transport-boundary catch-all that guarantees no placeholder + reaches the model, whatever a pass assembled) Macros.from_settings(...) — factory from app settings """ @@ -25,8 +33,6 @@ import re from typing import Any, Mapping, NamedTuple, Sequence -from .llm_client import LLMClient - # --------------------------------------------------------------------------- # Internal helpers @@ -120,55 +126,22 @@ def resolve_prompt(self, text: str) -> str: """Only {{user}}/{{char}} substitution (no inline macros).""" return resolve_prompt(text, self.user, self.char) - def _resolve_prompt_on_message(self, msg: dict) -> dict: + def _resolve_prompt_on_message(self, msg: Mapping[str, Any]) -> dict: """Apply prompt-level resolution (substitution only) to a single message dict.""" return { **msg, "content": _apply_content(msg.get("content"), lambda t: self.resolve_prompt(t)), } - def resolve_prompt_messages(self, messages: list[dict]) -> list[dict]: - """Apply prompt-level resolution to a list of message dicts.""" + def resolve_prompt_messages(self, messages: Sequence[Mapping[str, Any]]) -> list[dict]: + """Apply prompt-level resolution to every message in a list. + + This is the transport-boundary catch-all: passed to a cached base's + ``resolve`` hook so the fully-assembled wire messages are scrubbed of + ``{{user}}``/``{{char}}`` just before they are sent, no matter which + pass built them (e.g. the director's tool prompt embeds user-authored + fragment text that can carry ``{{char}}``). Inline macros like + ``{{roll}}`` are intentionally *not* fired here — those are resolved on + the latest user message and prefix content when it is built. + """ return [self._resolve_prompt_on_message(m) for m in messages] - - def wrap_client(self, client: LLMClient) -> "_PlaceholderClient": - return _PlaceholderClient(client, self.user, self.char) - - -class _PlaceholderClient(LLMClient): - """Wraps LLMClient to resolve {{user}}/{{char}} on all messages before completion. - - Only applies prompt-level resolution (no inline macros) — inline macros - must be resolved on the latest user message before it reaches this client. - """ - - def __init__(self, inner: LLMClient, user_name: str, char_name: str) -> None: - self._inner = inner - self._user_name = user_name - self._char_name = char_name - # Share the inner client's abort token so the inherited abort()/ - # is_aborted reflect the same turn-wide stop signal — no delegation - # overrides needed. Transport config (base_url/profile/…) is left unset - # since complete() delegates to the inner client rather than using it. - self.abort_token = inner.abort_token - - async def complete( - self, - messages: Sequence[Mapping[str, Any]], - model: str, - tools: list[dict] | None = None, - tool_choice: dict | str | None = None, - **params, - ): - msgs = [ - { - **msg, - "content": _apply_content( - msg.get("content"), - lambda t: resolve_prompt(t, self._user_name, self._char_name), - ), - } - for msg in messages - ] - async for item in self._inner.complete(msgs, model, tools=tools, tool_choice=tool_choice, **params): - yield item diff --git a/backend/main.py b/backend/main.py index e6074a6..2817605 100644 --- a/backend/main.py +++ b/backend/main.py @@ -134,6 +134,7 @@ from . import card_downloader from . import prompt_builder from .summarizer import ConversationSummarizer +from .utils import estimate_tokens logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -1674,9 +1675,6 @@ async def api_get_context_size(cid: str): recent_messages = messages[-scan_depth:] if len(messages) >= scan_depth else messages lorebook_block = prompt_builder.compute_lorebook_injection_block(recent_messages, lorebook_entries, macros) - def est(chars): - return max(1, round(chars / 3.5)) - breakdown = {} for label, chars in [ ("system_prompt", len(sys_text)), @@ -1689,12 +1687,12 @@ def est(chars): ("director_injection", len(inj_block)), ("lorebook", len(lorebook_block)), ]: - breakdown[label] = {"chars": chars, "tokens_est": est(chars)} + breakdown[label] = {"chars": chars, "tokens_est": estimate_tokens(chars)} total_chars = sum(v["chars"] for v in breakdown.values()) return { "total_chars": total_chars, - "total_tokens_est": est(total_chars), + "total_tokens_est": estimate_tokens(total_chars), "breakdown": breakdown, "message_count": len(messages), } diff --git a/backend/orchestrator.py b/backend/orchestrator.py index 44752b4..18179ef 100644 --- a/backend/orchestrator.py +++ b/backend/orchestrator.py @@ -14,7 +14,12 @@ from . import database as db from .llm_client import AbortToken, LLMClient, reasoning_cfg from .endpoint_profiles import profile_for -from .tool_defs import TOOLS, POST_WRITER_TOOLS, build_direct_scene_tool, enabled_schemas +from .tool_defs import ( + TOOLS, + POST_WRITER_TOOLS, + build_direct_scene_tool, + enabled_schemas, +) from .prompt_builder import ( build_prefix, compute_style_injection_block, @@ -34,8 +39,8 @@ from .workflows.attachment_cache import OVERSIZE_NO_METADATA_REASON from .llm_types import ChatMessage from .utils import LengthGuard, extract_hyperparams -from .passes.director import DirectorResult, _director_pass -from .passes.writer import _writer_pass, build_writer_content +from .passes.director import DirectorResult, director_pass +from .passes.writer import writer_pass, build_writer_content from .passes.editor import editor_pass from .database.models import ( CharacterCardRow, @@ -57,8 +62,9 @@ @dataclass(frozen=True) class ModelLane: - """One model's call surface for the turn: a macro-wrapped client paired with - its byte-identical cached bottom (prefix + tools + model). + """One model's call surface for the turn: a client paired with its + byte-identical cached bottom (prefix + tools + model + the macro ``resolve`` + hook that scrubs placeholders from the final wire bytes). A turn has two lanes — ``writer`` and ``agent`` (director + editor). In single-model mode they are the *same object* (the writer's lane is reused for @@ -98,8 +104,6 @@ class _PipelineConfig: writer_reasoning_on: bool editor_reasoning_on: bool audit_enabled: bool - length_guard_enabled: bool - length_guard_enforce: bool length_guard: LengthGuard | None do_edit: bool writer_enabled_tools: Mapping[str, bool] @@ -135,13 +139,13 @@ def _resolve_pipeline_config( # (not user-toggleable); this feature flag is its only enable path. if length_guard_enabled: enabled_tools = {**enabled_tools, "editor_rewrite": True} - length_guard_enforce = bool(settings.get("length_guard_enforce", 0)) if agent_on else False # length_guard_enabled already folds in agent_on (it is False whenever the - # agent is off), so no extra `and agent_on` guard is needed below. + # agent is off). The dict is built *only* when enabled, so its presence is the + # on/off state downstream — `cfg.length_guard is not None` means enabled. length_guard: LengthGuard | None = ( { - "enabled": length_guard_enabled, + "enforce": bool(settings.get("length_guard_enforce", 0)), "max_words": int(settings.get("length_guard_max_words", 240)), "max_paragraphs": int(settings.get("length_guard_max_paragraphs", 4)), } @@ -163,21 +167,23 @@ def _resolve_pipeline_config( # The tool blobs are built via enabled_schemas exactly once each so no pass # can rebuild them differently. writer_lane = ModelLane( - client=macros.wrap_client(client), + client=client, base=CachedBase( prefix=tuple(prefix), tools=tuple(enabled_schemas(writer_enabled_tools, schema_overrides)), model=settings["model_name"], + resolve=macros.resolve_prompt_messages, ), ) if dual_model: assert agent_client is not None # dual_model is True iff agent_client was resolved agent_lane = ModelLane( - client=macros.wrap_client(agent_client), + client=agent_client, base=CachedBase( prefix=tuple(agent_prefix or prefix), tools=tuple(enabled_schemas(enabled_tools, schema_overrides)), model=settings.get("agent_model_name", settings["model_name"]), + resolve=macros.resolve_prompt_messages, ), ) else: @@ -193,8 +199,6 @@ def _resolve_pipeline_config( writer_reasoning_on=bool(reasoning_passes.get("writer", False)), editor_reasoning_on=bool(reasoning_passes.get("editor", False)), audit_enabled=audit_enabled, - length_guard_enabled=length_guard_enabled, - length_guard_enforce=length_guard_enforce, length_guard=length_guard, do_edit=audit_enabled or length_guard_enabled, writer_enabled_tools=writer_enabled_tools, @@ -334,7 +338,7 @@ async def _run_pipeline( has_pre_writer_tools = any(cfg.enabled_tools.get(n, False) for n in TOOLS if n not in POST_WRITER_TOOLS) if cfg.agent_on and has_pre_writer_tools: yield {"event": "director_start"} - async for event in _director_pass( + async for event in director_pass( cfg.agent_lane.client, cfg.agent_lane.base, user_message, @@ -402,27 +406,22 @@ async def _run_pipeline( } # --- Writer pass --- + # Built once here and threaded into both the writer pass and (later) the + # editor, which replays it verbatim to extend the writer's KV-cached prefix. writer_content = build_writer_content( lorebook_block, inj_block, cfg.writer_enabled_tools, effective_msg, attachments, - cfg.length_guard_enforce, cfg.length_guard, ) resp_text = "" - async for item in _writer_pass( + async for item in writer_pass( cfg.writer_lane.client, cfg.writer_lane.base, settings, - cfg.writer_enabled_tools, - inj_block=inj_block, - lorebook_block=lorebook_block, - effective_msg=effective_msg, - attachments=attachments, - length_guard_enforce=cfg.length_guard_enforce, - length_guard=cfg.length_guard, + writer_content, kv_tracker=kv_tracker, reasoning_on=cfg.writer_reasoning_on, ): @@ -510,7 +509,11 @@ def _make_result(final_text: str, staged: list[dict], staged_state: dict | None except Exception as e: logger.error("editor pass failed, keeping original: %s", e, exc_info=True) else: - logger.info("Editor pass skipped (do_edit=%s, draft=%d chars)", cfg.do_edit, len(resp_text)) + logger.info( + "Editor pass skipped (do_edit=%s, draft=%d chars)", + cfg.do_edit, + len(resp_text), + ) # --- Post-pipeline workflow iteration --- # PostCtx.director_output is a read-only mapping (workflow contract), so this @@ -639,7 +642,10 @@ async def _run_post_pipeline( continue draft = new_draft replaced_this_hook = True - yield {"event": "writer_rewrite", "data": {"refined_text": draft}} + yield { + "event": "writer_rewrite", + "data": {"refined_text": draft}, + } continue if t == "attach_artifact": # Only workflows declared with produces_artifacts=True @@ -1067,7 +1073,12 @@ def _build_prefixes( prefix = _build_prefix_from_ctx(ctx, history, extra_system_blocks=extra_system_blocks) agent_sp = ctx.agent_system_prompt agent_prefix = ( - _build_prefix_from_ctx(ctx, history, system_prompt=agent_sp, extra_system_blocks=extra_system_blocks) + _build_prefix_from_ctx( + ctx, + history, + system_prompt=agent_sp, + extra_system_blocks=extra_system_blocks, + ) if agent_sp is not None else None ) @@ -1243,7 +1254,10 @@ async def _resolve_target_and_parent( async def _prepare_regen_context( - ctx: PipelineContext, conversation_id: str, target: Mapping[str, Any], user_msg: Mapping[str, Any] + ctx: PipelineContext, + conversation_id: str, + target: Mapping[str, Any], + user_msg: Mapping[str, Any], ) -> tuple[Sequence[Mapping[str, Any]], Sequence[Mapping[str, Any]]]: """Prepare history and attachments for a regeneration pass. @@ -1795,7 +1809,10 @@ async def handle_regenerate( history=history, pipeline_settings=settings, last_user_message=user_msg["content"], - lorebook_messages=[*history, {"role": "user", "content": user_msg["content"]}], + lorebook_messages=[ + *history, + {"role": "user", "content": user_msg["content"]}, + ], user_message=user_msg["content"], attachments=attachments, user_msg_id=user_msg_id, @@ -1926,7 +1943,10 @@ async def handle_magic_rewrite( # Stream reasoning deltas like the main pipeline does, labelled # "writer" since this is a writer-style rewrite with no # director/editor passes. - yield {"event": "reasoning", "data": {"pass": "writer", "delta": item["delta"]}} + yield { + "event": "reasoning", + "data": {"pass": "writer", "delta": item["delta"]}, + } elif item["type"] == "content": accumulated += item["delta"] yield {"event": "token", "data": item["delta"]} diff --git a/backend/passes/director.py b/backend/passes/director.py index 69ba2d8..52b5736 100644 --- a/backend/passes/director.py +++ b/backend/passes/director.py @@ -78,7 +78,7 @@ def apply_tool_calls( # ── Agent pass ──────────────────────────────────────────────────────────────── -async def _director_pass( +async def director_pass( client: LLMClient, base: CachedBase, user_message: str, diff --git a/backend/passes/editor/editor.py b/backend/passes/editor/editor.py index 9e500b9..30be9e1 100644 --- a/backend/passes/editor/editor.py +++ b/backend/passes/editor/editor.py @@ -15,8 +15,9 @@ if TYPE_CHECKING: from ...database.models import PhraseGroup -from .opening_monotony import FlaggedOpener, MonotonyResult, _split_sentences +from .opening_monotony import FlaggedOpener, MonotonyResult from .template_repetition import FlaggedTemplate, TemplateResult +from .text_segmentation import split_narration_sentences from ...llm_client import LLMClient, parse_tool_calls, reasoning_cfg from ...kv_tracker import CachedBase from ...tool_defs import ( @@ -36,7 +37,7 @@ def _split_target_sentences(target_text: str) -> set[str]: """Split *target_text* into a sentence set using the same heuristic as the detectors.""" - return set(_split_sentences(target_text)) + return set(split_narration_sentences(target_text)) def _filter_flagged_items(items, sentences: set[str], total: int, *, cls, label_field: str): @@ -385,7 +386,7 @@ async def editor_pass( # the list mid-loop would bust the KV cache every iteration. Which single tool # the model must call is steered entirely by tool_choice (see _pick_tool_choice, # recomputed each iteration) while base.tools stays byte-identical throughout. - if length_guard and length_guard["enabled"]: + if length_guard is not None: word_count = len(draft.split()) max_words = length_guard["max_words"] max_paragraphs = length_guard["max_paragraphs"] diff --git a/backend/passes/editor/opening_monotony.py b/backend/passes/editor/opening_monotony.py index cd6139e..36118a1 100644 --- a/backend/passes/editor/opening_monotony.py +++ b/backend/passes/editor/opening_monotony.py @@ -15,11 +15,10 @@ from __future__ import annotations import os -import re import sys from dataclasses import dataclass, field -from .text_segmentation import split_narration_sentences +from .text_segmentation import normalize_word, split_narration_sentences DEBUG = "DEBUG_OPENING_MONOTONY" in os.environ # ---------- public dataclasses (unchanged) ---------- @@ -44,31 +43,19 @@ class MonotonyResult: # ---------- narration extraction ---------- # Paragraph/sentence/dialogue segmentation lives in text_segmentation so every -# audit pass splits text identically. +# audit pass splits text identically. `_split_sentences` strips dialogue. - -def _split_sentences(text: str) -> list[str]: - """Paragraph-aware sentence splitter that strips dialogue (with DEBUG trace).""" - if DEBUG: - sys.stderr.write(f"[opening_monotony] splitting text: {repr(text)}\n") - sentences = split_narration_sentences(text) - if DEBUG: - sys.stderr.write(f"[opening_monotony] extracted sentences: {sentences}\n") - return sentences +_split_sentences = split_narration_sentences # ---------- opener analysis (unchanged logic) ---------- -def _normalize(word: str) -> str: - return re.sub(r"[^a-z0-9']", "", word.lower()) - - def _get_opener(sentence: str, n_words: int) -> str | None: words = sentence.split() if len(words) < n_words: return None - normalized = [_normalize(w) for w in words[:n_words]] + normalized = [normalize_word(w) for w in words[:n_words]] if any(w == "" for w in normalized): return None return " ".join(normalized) diff --git a/backend/passes/editor/template_repetition.py b/backend/passes/editor/template_repetition.py index 7931ff5..f4ae1ff 100644 --- a/backend/passes/editor/template_repetition.py +++ b/backend/passes/editor/template_repetition.py @@ -22,12 +22,11 @@ from __future__ import annotations import os -import re import sys from collections import defaultdict from dataclasses import dataclass, field -from .text_segmentation import split_narration_sentences +from .text_segmentation import normalize_word, split_narration_sentences DEBUG = "DEBUG_TEMPLATE_REPETITION" in os.environ @@ -59,27 +58,14 @@ class TemplateResult: # ---------- text processing ---------- # Paragraph/sentence/dialogue segmentation lives in text_segmentation so every -# audit pass splits text identically. +# audit pass splits text identically. `_split_sentences` strips dialogue. - -def _split_sentences(text: str) -> list[str]: - """Paragraph-aware sentence splitter that strips dialogue (with DEBUG trace).""" - if DEBUG: - sys.stderr.write(f"[template_repetition] splitting text: {repr(text)}\n") - sentences = split_narration_sentences(text) - if DEBUG: - sys.stderr.write(f"[template_repetition] extracted sentences: {sentences}\n") - return sentences +_split_sentences = split_narration_sentences # ---------- template analysis ---------- -def _normalize(word: str) -> str: - """Normalize a word for template matching.""" - return re.sub(r"[^a-z0-9']", "", word.lower()) - - def _get_template(sentence: str, max_words: int) -> str | None: """Extract the template (first N words) from a sentence.""" words = sentence.split() @@ -87,7 +73,7 @@ def _get_template(sentence: str, max_words: int) -> str | None: return None # Take up to max_words words template_words = words[:max_words] - normalized = [_normalize(w) for w in template_words] + normalized = [normalize_word(w) for w in template_words] # Filter out empty words after normalization normalized = [w for w in normalized if w] if len(normalized) < 3: diff --git a/backend/passes/editor/text_segmentation.py b/backend/passes/editor/text_segmentation.py index 5648de8..6060161 100644 --- a/backend/passes/editor/text_segmentation.py +++ b/backend/passes/editor/text_segmentation.py @@ -39,9 +39,22 @@ "split_narration_sentences", "find_quote_spans", "count_sentences", + "normalize_word", ] +# ---------- word normalization ---------- + + +def normalize_word(word: str) -> str: + """Lowercase a word and strip everything but ``a-z0-9'``. + + Shared by the opener/template scanners so they key on the same normalized + token form. + """ + return re.sub(r"[^a-z0-9']", "", word.lower()) + + # ---------- canonical patterns ---------- # Paragraph break: a blank line, optionally filled with whitespace. diff --git a/backend/passes/writer.py b/backend/passes/writer.py index b9cc4d0..4bac584 100644 --- a/backend/passes/writer.py +++ b/backend/passes/writer.py @@ -6,7 +6,7 @@ import json import logging -from typing import Any, AsyncIterator, Mapping, Optional, Sequence +from typing import Any, AsyncIterator, Mapping, Sequence from ..llm_client import LLMClient, reasoning_cfg from ..kv_tracker import CachedBase @@ -22,13 +22,15 @@ def build_writer_content( enabled_tools: Mapping[str, bool], effective_msg: str, attachments: Sequence[Mapping[str, Any]] | None, - length_guard_enforce: bool, length_guard: LengthGuard | None, ) -> "str | list[ContentPart]": """Build the writer's user-message content (string or multimodal list). - Extracted so the orchestrator can pass the exact value to the editor, - letting it replicate the writer's last user message for KV-cache reuse. + Built once by the orchestrator and threaded into both the writer pass and + the editor, which replays it verbatim to reuse the writer's KV-cached prefix. + The length-guard nudge is the *preventive* arm: it fires only in enforce mode + (``length_guard["enforce"]``); a non-None ``length_guard`` already means the + feature is enabled. """ tail = "" if lorebook_block: @@ -37,7 +39,7 @@ def build_writer_content( tail += "___\n\n" + inj_block + "\n\n" if enabled_tools: tail += "**Do not use tool or function calls this turn.**\n\n" - if length_guard_enforce and length_guard and length_guard["enabled"]: + if length_guard and length_guard["enforce"]: max_words = length_guard["max_words"] max_paragraphs = length_guard["max_paragraphs"] tail += f"**Keep your response under {max_words} words and {max_paragraphs} paragraphs.**\n\n" @@ -46,37 +48,21 @@ def build_writer_content( return build_multimodal_content(tail, attachments) -async def _writer_pass( +async def writer_pass( client: LLMClient, base: CachedBase, settings: Mapping[str, Any], - enabled_tools: Mapping[str, bool], + content: "str | list[ContentPart]", *, - inj_block: str = "", - lorebook_block: str = "", - effective_msg: str, - attachments: Optional[Sequence[Mapping[str, Any]]] = None, - length_guard_enforce: bool = False, - length_guard: LengthGuard | None = None, kv_tracker=None, reasoning_on: bool = True, ) -> AsyncIterator[dict]: """Yields {"type": "content"|"reasoning", "delta": str} dicts. - *enabled_tools* still drives the in-prompt "do not use tools" notice; the - tool *schema* blob comes from ``base`` (built from the same enabled-tool set) - so it stays byte-identical with the director/editor passes. + *content* is the writer's user-message body, prebuilt by the orchestrator via + ``build_writer_content`` (and shared with the editor). The tool *schema* blob + comes from ``base`` so it stays byte-identical with the director/editor passes. """ - content = build_writer_content( - lorebook_block, - inj_block, - enabled_tools, - effective_msg, - attachments, - length_guard_enforce, - length_guard, - ) - trailing: list[ChatMessage] = [{"role": "user", "content": content}] hyperparams = extract_hyperparams(settings) diff --git a/backend/utils.py b/backend/utils.py index 53d820b..021705f 100644 --- a/backend/utils.py +++ b/backend/utils.py @@ -12,16 +12,30 @@ class LengthGuard(TypedDict): """Resolved length-guard limits threaded through the pipeline. - Built by the orchestrator only when the length guard is enabled (``None`` - otherwise) and consumed by the writer and editor passes. ``enabled`` mirrors - that on/off state so a hook receiving the dict need not re-derive it. + Built by the orchestrator only when the length guard is enabled, so its mere + presence *is* the on/off state — ``None`` means disabled, and any non-None + value means enabled. Consumed by the writer (preventive nudge, only when + ``enforce``) and the editor (corrective rewrite). ``enforce`` carries the + enforce-mode flag so it travels with the limits instead of as a sidecar. """ - enabled: bool + enforce: bool max_words: int max_paragraphs: int +#: Heuristic characters-per-token ratio used for rough context-size estimates. +#: This is the one convention referenced throughout (see AGENTS.md → Context +#: Management); keep all chars→token estimation going through ``estimate_tokens`` +#: rather than re-spelling the constant. +CHARS_PER_TOKEN = 4 + + +def estimate_tokens(chars: int) -> int: + """Rough token estimate from a character count (min 1 for any non-empty text).""" + return max(1, round(chars / CHARS_PER_TOKEN)) + + def extract_hyperparams(settings: Mapping[str, Any], *, defaults: Mapping[str, Any] | None = None) -> dict: """Extract LLM hyperparameters from a settings dict. diff --git a/backend/workflows/_forced_call.py b/backend/workflows/_forced_call.py index daa7507..f6534ae 100644 --- a/backend/workflows/_forced_call.py +++ b/backend/workflows/_forced_call.py @@ -15,8 +15,8 @@ from types import MappingProxyType from typing import Any, AsyncIterator, Mapping, Sequence -from backend.llm_client import parse_tool_calls, reasoning_cfg -from backend.tool_defs import TOOLS, STANDALONE_TOOLS, enabled_schemas +from ..llm_client import parse_tool_calls, reasoning_cfg +from ..tool_defs import TOOLS, STANDALONE_TOOLS, enabled_schemas logger = logging.getLogger(__name__) diff --git a/backend/workflows/attachment_cache.py b/backend/workflows/attachment_cache.py index 2bd5944..dd60084 100644 --- a/backend/workflows/attachment_cache.py +++ b/backend/workflows/attachment_cache.py @@ -17,14 +17,21 @@ import os from typing import Any -from backend.database.connection import get_db -from backend.database.queries.workflow_attachments import _encode_metadata_field, insert_workflow_attachment_row +from ..database.connection import get_db +from ..database.queries.messages import register_workflow_attachment_persister +from ..database.queries.workflow_attachments import ( + EVICTED_MARKER, + _encode_metadata_field, + insert_workflow_attachment_row, +) from .registry import get_workflow logger = logging.getLogger(__name__) -EVICTED_MARKER = "[evicted]" +# EVICTED_MARKER is re-exported from the database boundary above (where it +# describes the persisted ``data_b64`` shape) so the eviction layer here and +# the route layer in main.py can keep importing it from this module. class RehydrateAlreadyDoneError(ValueError): @@ -949,3 +956,10 @@ async def delete_workflow_attachments( "root_id": new_root, "active_sibling_id": new_active, } + + +# Wire this module's batch persister into the database layer's add_message +# seam (dependency inversion -- the DB layer must not import up into +# backend.workflows). Registered at import; backend.workflows is always +# imported before any workflow attachment reaches add_message. +register_workflow_attachment_persister(insert_workflow_attachments) diff --git a/backend/workflows/registry.py b/backend/workflows/registry.py index 6fe0b24..638bad0 100644 --- a/backend/workflows/registry.py +++ b/backend/workflows/registry.py @@ -18,7 +18,7 @@ from dataclasses import dataclass, field from typing import Callable, Mapping, Optional -from backend.database import ( +from ..database import ( get_workflow_character_state as _db_get_workflow_character_state, get_workflow_config as _db_get_workflow_config, get_workflow_message_state as _db_get_workflow_message_state, @@ -28,7 +28,7 @@ set_workflow_message_state as _db_set_workflow_message_state, set_workflow_state as _db_set_workflow_state, ) -from backend.tool_defs import ( +from ..tool_defs import ( BUILTIN_TOOL_NAMES, STANDALONE_TOOLS, TOOLS, diff --git a/backend/workflows/toolkit.py b/backend/workflows/toolkit.py index 79fbb98..3d68c0c 100644 --- a/backend/workflows/toolkit.py +++ b/backend/workflows/toolkit.py @@ -20,7 +20,7 @@ from __future__ import annotations -from backend.database import ( +from ..database import ( get_character_card, get_conversation, get_director_fragments, @@ -31,18 +31,18 @@ get_phrase_bank, get_user_personas, ) -from backend.llm_client import LLMClient, parse_tool_calls, reasoning_cfg -from backend.macros import Macros -from backend.passes.editor.audit import format_report, run_audit -from backend.prompt_builder import ( +from ..llm_client import LLMClient, parse_tool_calls, reasoning_cfg +from ..macros import Macros +from ..passes.editor.audit import format_report, run_audit +from ..prompt_builder import ( build_prefix, compute_lorebook_injection_block, compute_style_injection_block, format_message_with_attachments, ) -from backend.tool_defs import STANDALONE_TOOLS, TOOLS, enabled_schemas +from ..tool_defs import STANDALONE_TOOLS, TOOLS, enabled_schemas -from backend.locks import ( +from ..locks import ( workflow_character_state_lock, workflow_config_lock, workflow_state_lock, diff --git a/backend/workflows/tts/__init__.py b/backend/workflows/tts/__init__.py index 65097a6..14a7877 100644 --- a/backend/workflows/tts/__init__.py +++ b/backend/workflows/tts/__init__.py @@ -10,7 +10,7 @@ from __future__ import annotations -from backend.workflows.registry import Workflow +from ..registry import Workflow _CONFIG_SCHEMA = { diff --git a/backend/workflows/tts/hooks.py b/backend/workflows/tts/hooks.py index b5704d6..f612460 100644 --- a/backend/workflows/tts/hooks.py +++ b/backend/workflows/tts/hooks.py @@ -20,7 +20,7 @@ import base64 import logging -from backend.workflows.toolkit import ( +from ..toolkit import ( get_message_by_id, get_workflow_character_state, get_workflow_config, diff --git a/docs/architecture/secondary-workflow.md b/docs/architecture/secondary-workflow.md index e2b4d8d..33295c9 100644 --- a/docs/architecture/secondary-workflow.md +++ b/docs/architecture/secondary-workflow.md @@ -404,7 +404,7 @@ Runs unconditionally (subject to each step's own guard): Then, only when `resp_text.strip()`: -3. `db.add_message(..., attachments=staged, ...)` -- single transaction. Internally lazy-imports `insert_workflow_attachments` from cache. Returns `(asst_id, rejected_workflow_atts)`. +3. `db.add_message(..., attachments=staged, ...)` -- single transaction. It persists workflow attachments by calling through a registered persister seam (the database layer must not import "up" into `backend.workflows`; `attachment_cache` registers `insert_workflow_attachments` via `register_workflow_attachment_persister` at import time). Returns `(asst_id, rejected_workflow_atts)`. 4. For each post-pipeline `set_message_state` entry, `db.set_workflow_message_state(asst_id, wid, payload)`. The assistant `mid` is first known here; unlocked because the row is not yet the active leaf and no other caller can name it. 5. `db.set_active_leaf(conversation_id, asst_id)`. diff --git a/tests/integration/_llm_mock.py b/tests/integration/_llm_mock.py index 7d63d22..f3a7c6c 100644 --- a/tests/integration/_llm_mock.py +++ b/tests/integration/_llm_mock.py @@ -114,10 +114,20 @@ class FakeLLMClient: """ def __init__(self) -> None: - self._queues: dict[str, list[dict]] = {"director": [], "writer": [], "editor": [], "workflow": []} - self._gates: dict[str, list[PassGate]] = {"director": [], "writer": [], "editor": [], "workflow": []} - # Mirror LLMClient: a wrapping _PlaceholderClient shares this token, so - # an abort signalled on either is visible to both. + self._queues: dict[str, list[dict]] = { + "director": [], + "writer": [], + "editor": [], + "workflow": [], + } + self._gates: dict[str, list[PassGate]] = { + "director": [], + "writer": [], + "editor": [], + "workflow": [], + } + # Mirror LLMClient: the turn's clients share one abort token, so an + # abort signalled on any of them is visible to all. self.abort_token = AbortToken() # Public assertion surface: tests inspect ``calls`` directly for # dispatch order and invocation counts, so its shape is part of @@ -228,7 +238,11 @@ async def complete( payload = self._queues["director"].pop(0) if self._queues["director"] else {"tool_calls": []} yield { "type": "done", - "message": {"role": "assistant", "content": "", "tool_calls": payload.get("tool_calls", [])}, + "message": { + "role": "assistant", + "content": "", + "tool_calls": payload.get("tool_calls", []), + }, } diff --git a/tests/integration/test_kv_cache_real_stack.py b/tests/integration/test_kv_cache_real_stack.py index 50bb7a1..e8788c7 100644 --- a/tests/integration/test_kv_cache_real_stack.py +++ b/tests/integration/test_kv_cache_real_stack.py @@ -68,8 +68,8 @@ async def _configure_all_features(client) -> None: async def _make_conversation(client) -> str: - # Macros in the card text exercise the per-pass macro-resolving client - # wrapper; they must resolve identically on every pass and every turn. + # Macros in the card text exercise the cached base's macro ``resolve`` hook; + # they must resolve identically on every pass and every turn. card = await client.post( "/api/characters", json={ @@ -130,6 +130,17 @@ async def test_within_turn_all_passes_share_prefix_and_tools_through_build_prefi "shared prefix — build_prefix or a pass rendered the system/history differently across passes." ) + # The base's macro ``resolve`` hook must scrub every {{char}}/{{user}} from + # the bytes each pass actually shipped — including the card text carried in + # the shared prefix. The recorded messages are post-resolution, so a raw + # placeholder surviving here means the hook was dropped. + for c in calls: + sent = _serialize_messages(c["messages"]) + assert ( + "{{char}}" not in sent and "{{user}}" not in sent + ), f"MACRO LEAK: pass {c['pass']!r} shipped an unresolved placeholder to the model." + assert "Aria" in prefix_bytes # {{char}} → the card name, resolved in the shared prefix + # Inv-3 — wire-faithful tools blob identical across every pass, non-empty. blobs = {_wire_tools(c["tools"]) for c in calls} assert len(blobs) == 1, f"CACHE BUST: tools blob differs across passes; distinct sizes {sorted(len(b) for b in blobs)}" @@ -178,7 +189,11 @@ async def test_cross_turn_prefix_is_append_only_through_persistence(client, llm_ # Sanity: the DB really did persist the turn-1 exchange. roles = [m["role"] for m in await get_messages(cid)] - assert roles[:3] == ["assistant", "user", "assistant"], f"unexpected persisted history: {roles}" + assert roles[:3] == [ + "assistant", + "user", + "assistant", + ], f"unexpected persisted history: {roles}" # Director's dynamic schema, rebuilt from get_director_fragments() each turn, # must be byte-identical across turns (this is the ONLY place a DB row-order diff --git a/tests/integration/workflows/test_pipeline_hooks.py b/tests/integration/workflows/test_pipeline_hooks.py index fc22728..d5d5d9e 100644 --- a/tests/integration/workflows/test_pipeline_hooks.py +++ b/tests/integration/workflows/test_pipeline_hooks.py @@ -487,7 +487,7 @@ async def post_hook(post_ctx): reroll_gen=lambda ctx, params, seed: b"", ) with register_for_test(w): - with patch("backend.orchestrator._writer_pass", new=mock_writer): + with patch("backend.orchestrator.writer_pass", new=mock_writer): events = await _drain( _run_pipeline( client, @@ -534,7 +534,7 @@ async def post_hook(post_ctx): reroll_gen=lambda ctx, params, seed: b"", ) with register_for_test(w): - with patch("backend.orchestrator._writer_pass", new=mock_writer): + with patch("backend.orchestrator.writer_pass", new=mock_writer): events = await _drain( _run_pipeline( client, @@ -564,7 +564,7 @@ async def post_hook(post_ctx): w = make_workflow("rewriter", post_pipeline=post_hook) with register_for_test(w): - with patch("backend.orchestrator._writer_pass", new=mock_writer): + with patch("backend.orchestrator.writer_pass", new=mock_writer): events = await _drain( _run_pipeline( client, @@ -623,7 +623,7 @@ async def post_hook(post_ctx): ): pass - with patch("backend.orchestrator._writer_pass", new=mock_writer): + with patch("backend.orchestrator.writer_pass", new=mock_writer): await _drain( _run_pipeline( client, @@ -658,7 +658,7 @@ async def post_hook(post_ctx): w = make_workflow("scratch_lifetime", post_pipeline=post_hook) with register_for_test(w): - with patch("backend.orchestrator._writer_pass", new=mock_writer): + with patch("backend.orchestrator.writer_pass", new=mock_writer): for _ in range(2): await _drain( _run_pipeline( @@ -687,7 +687,7 @@ async def test_run_pipeline_empty_registry_emits_single_result_no_staged(): async def mock_writer(c, *args, **kwargs): yield {"type": "content", "delta": "plain draft"} - with patch("backend.orchestrator._writer_pass", new=mock_writer): + with patch("backend.orchestrator.writer_pass", new=mock_writer): events = await _drain( _run_pipeline( client, @@ -720,7 +720,7 @@ async def crasher(post_ctx): w = make_workflow("crasher", post_pipeline=crasher) with register_for_test(w): - with patch("backend.orchestrator._writer_pass", new=mock_writer): + with patch("backend.orchestrator.writer_pass", new=mock_writer): events = await _drain( _run_pipeline( client, @@ -755,7 +755,7 @@ async def post_hook(post_ctx): w = make_workflow("never_runs", post_pipeline=post_hook) with register_for_test(w): - with patch("backend.orchestrator._writer_pass", new=mock_writer): + with patch("backend.orchestrator.writer_pass", new=mock_writer): events = await _drain( _run_pipeline( client, @@ -785,7 +785,7 @@ async def post_hook(post_ctx): w = make_workflow("ms", post_pipeline=post_hook) with register_for_test(w): - with patch("backend.orchestrator._writer_pass", new=mock_writer): + with patch("backend.orchestrator.writer_pass", new=mock_writer): events = await _drain( _run_pipeline( client, @@ -815,7 +815,7 @@ async def post_hook(post_ctx): w = make_workflow("ms", post_pipeline=post_hook) with register_for_test(w): - with patch("backend.orchestrator._writer_pass", new=mock_writer): + with patch("backend.orchestrator.writer_pass", new=mock_writer): events = await _drain( _run_pipeline( client, @@ -847,7 +847,7 @@ async def hook_b(post_ctx): wa = make_workflow("wf_a", post_pipeline=hook_a) wb = make_workflow("wf_b", post_pipeline=hook_b) with register_for_test(wa), register_for_test(wb): - with patch("backend.orchestrator._writer_pass", new=mock_writer): + with patch("backend.orchestrator.writer_pass", new=mock_writer): events = await _drain( _run_pipeline( client, @@ -877,7 +877,7 @@ async def post_hook(post_ctx): w = make_workflow("hist", post_pipeline=post_hook) with register_for_test(w): - with patch("backend.orchestrator._writer_pass", new=mock_writer): + with patch("backend.orchestrator.writer_pass", new=mock_writer): await _drain( _run_pipeline( client, @@ -913,7 +913,7 @@ async def post_hook(post_ctx): w = make_workflow("ms_persist", post_pipeline=post_hook) with register_for_test(w): - with patch("backend.orchestrator._writer_pass", new=mock_writer): + with patch("backend.orchestrator.writer_pass", new=mock_writer): pipeline = _run_pipeline( _make_client(), _SETTINGS, @@ -945,7 +945,7 @@ async def post_hook(post_ctx): w = make_workflow("ms_empty", post_pipeline=post_hook) with register_for_test(w): - with patch("backend.orchestrator._writer_pass", new=mock_writer): + with patch("backend.orchestrator.writer_pass", new=mock_writer): pipeline = _run_pipeline( _make_client(), _SETTINGS, diff --git a/tests/integration/workflows/test_tts_hooks.py b/tests/integration/workflows/test_tts_hooks.py index 6a45c40..352d4ea 100644 --- a/tests/integration/workflows/test_tts_hooks.py +++ b/tests/integration/workflows/test_tts_hooks.py @@ -124,7 +124,7 @@ async def test_run_pipeline_autogenerates_attachment_end_to_end(client, fake_ada async def mock_writer(c, *args, **kwargs): yield {"type": "content", "delta": '"Hello there."'} - with patch("backend.orchestrator._writer_pass", new=mock_writer): + with patch("backend.orchestrator.writer_pass", new=mock_writer): events = [ ev async for ev in _run_pipeline( diff --git a/tests/integration/workflows/test_workflow_character_fields.py b/tests/integration/workflows/test_workflow_character_fields.py index d0856d3..379d003 100644 --- a/tests/integration/workflows/test_workflow_character_fields.py +++ b/tests/integration/workflows/test_workflow_character_fields.py @@ -95,7 +95,7 @@ async def post_hook(post_ctx): w = make_workflow("cf_post", post_pipeline=post_hook) with register_for_test(w): - with patch("backend.orchestrator._writer_pass", new=mock_writer): + with patch("backend.orchestrator.writer_pass", new=mock_writer): await _drain( _run_pipeline( LLMClient("http://localhost:9999"), diff --git a/tests/unit/test_abort_pipeline.py b/tests/unit/test_abort_pipeline.py index 3cb392a..2e3ea5e 100644 --- a/tests/unit/test_abort_pipeline.py +++ b/tests/unit/test_abort_pipeline.py @@ -62,8 +62,8 @@ async def mock_writer(*args, **kwargs): "reasoning_enabled_passes": {}, } - with patch("backend.orchestrator._director_pass", new=mock_director), patch( - "backend.orchestrator._writer_pass", new=mock_writer + with patch("backend.orchestrator.director_pass", new=mock_director), patch( + "backend.orchestrator.writer_pass", new=mock_writer ): await _drain( _run_pipeline( @@ -101,7 +101,7 @@ async def mock_editor(*args, **kwargs): "reasoning_enabled_passes": {}, } - with patch("backend.orchestrator._writer_pass", new=mock_writer), patch( + with patch("backend.orchestrator.writer_pass", new=mock_writer), patch( "backend.orchestrator.editor_pass", new=mock_editor ): await _drain( diff --git a/tests/unit/test_kv_cache_invariants.py b/tests/unit/test_kv_cache_invariants.py index 54484e1..ff8fa58 100644 --- a/tests/unit/test_kv_cache_invariants.py +++ b/tests/unit/test_kv_cache_invariants.py @@ -114,7 +114,7 @@ class CapturingClient: def __init__(self, model: str) -> None: self.model = model self.calls: list[dict] = [] - # Shared with a wrapping _PlaceholderClient, mirroring LLMClient. + # The turn's clients share one abort token, mirroring LLMClient. self.abort_token = AbortToken() # FIFO of editor tool-call messages to return, one per ReAct iteration. # Empty → the editor returns no tool call and the loop stops. @@ -199,7 +199,10 @@ async def complete(self, messages, model, tools=None, tool_choice=None, **params if label == "writer": yield {"type": "content", "delta": _WRITER_DRAFT} - yield {"type": "done", "message": {"role": "assistant", "content": _WRITER_DRAFT}} + yield { + "type": "done", + "message": {"role": "assistant", "content": _WRITER_DRAFT}, + } return if label == "editor": @@ -216,7 +219,13 @@ async def complete(self, messages, model, tools=None, tool_choice=None, **params "message": { "role": "assistant", "content": "", - "tool_calls": [{"id": "c1", "type": "function", "function": {"name": name, "arguments": args}}], + "tool_calls": [ + { + "id": "c1", + "type": "function", + "function": {"name": name, "arguments": args}, + } + ], }, } @@ -397,7 +406,9 @@ async def test_single_model_prefix_and_tools_are_byte_identical_across_passes(): @pytest.mark.parametrize("system_prompt", ["You are a narrator.", "ANOTHER totally different system body."]) -async def test_dual_model_agent_passes_share_agent_prefix_and_writer_drops_tools(system_prompt): +async def test_dual_model_agent_passes_share_agent_prefix_and_writer_drops_tools( + system_prompt, +): """Dual-model: director+editor run on the agent server and must share the agent prefix + a byte-identical tools blob; the writer runs on its own server and must send NO tools (Inv-5).""" @@ -622,11 +633,20 @@ async def test_editor_tools_blob_constant_across_tool_switch(): client.enqueue_editor_rewrite(" ".join([f"A {banned}."] * 4)) settings = {"model_name": "editor-model", "editor_audit_toggles": None} - length_guard = {"enabled": True, "max_words": 5, "max_paragraphs": 1} + length_guard = {"enforce": False, "max_words": 5, "max_paragraphs": 1} # 3-tool enabled set so a narrow-to-one would be visible as a byte change. base = CachedBase( prefix=tuple(prefix), - tools=tuple(enabled_schemas({"direct_scene": True, "editor_apply_patch": True, "editor_rewrite": True}, {})), + tools=tuple( + enabled_schemas( + { + "direct_scene": True, + "editor_apply_patch": True, + "editor_rewrite": True, + }, + {}, + ) + ), model="editor-model", ) async for _ in editor_pass( @@ -656,5 +676,10 @@ async def test_editor_tools_blob_constant_across_tool_switch(): "schema list must stay byte-identical. Distinct blob sizes: " + json.dumps(sorted(len(b) for b in blobs)) ) # And it must genuinely be the full 3-tool set, not a coincidental match. - full_blob = _wire_tools(enabled_schemas({"direct_scene": True, "editor_apply_patch": True, "editor_rewrite": True}, {})) + full_blob = _wire_tools( + enabled_schemas( + {"direct_scene": True, "editor_apply_patch": True, "editor_rewrite": True}, + {}, + ) + ) assert next(iter(blobs)) == full_blob, "editor shipped a tools blob that is not the full enabled set"