diff --git a/CHANGELOG.md b/CHANGELOG.md index 41ad1c16..3191f99d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,54 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Changed + +- **Breaking (server mode only):** `amplifier-agent serve chat-completions` now requires `host_config.providers` to be a non-empty dict. Any provider declared there that cannot initialize (missing credentials, module not installed, `list_models()` raises, returns 0 models) causes the server to exit 2 with a structured error listing every problem. The previous behavior — iterating a hardcoded `KNOWN_PROVIDERS` list, silently skipping unreachable providers, and falling back to an unusable placeholder model — is gone. Single-turn mode (`amplifier-agent run`) is unaffected; the `provider` (singular) block continues to work for it. + +- **`POST /v1/chat/completions` now validates `model` against the served registry.** Requests with an unknown model return HTTP 400 `{"error": {"code": "unknown_model", ...}}` immediately, instead of being silently routed to whichever provider loaded first and failing 4 seconds later with an upstream `not_found_error` embedded in `delta.content`. + +- **`stream: false` is now honored.** Requests with that flag return a single JSON body; only `stream: true` (or absent) uses SSE. + +- **Upstream errors raised before any content chunks are emitted now surface as HTTP 502** with a structured OpenAI-shape error envelope, instead of being embedded inside `delta.content` of a 200 SSE response. + +- **`/v1/models` no longer falls back to a placeholder `{"id": "amplifier", ...}` entry.** The lifespan now guarantees `served_models_registry` is non-empty (or the server exits at boot), so the fallback was unreachable in practice. + +### Added + +- **`amplifier-agent serve status / stop / restart` subcommands** — operational lifecycle for the chat-completions HTTP server. Status reports whether the server is running, where it's reachable, how many models from which providers it's serving, and self-cleans stale state files when the PID no longer exists. Stop sends SIGTERM with a configurable graceful-exit window (`--timeout`), escalating to SIGKILL on expiry or on `--force`. Restart performs an identity-restart using the args stored at original launch (host, port, api-key, workspace, host_config). State is tracked in `~/.amplifier-agent/state/serve.json` (mode 0600, parent dir 0700; api_key is sensitive — never logged). + +- **`host_config.providers` (plural) registry** — declares which providers the server-mode lifespan loads and how to instantiate each. Schema: `providers: {: {module?: str, config?: dict}}`. The `module` defaults to the provider_id when omitted. Each provider's `config` is passed through as the `extra_config` arg to `list_provider_models()` and then to the provider module's constructor. + +### Internal + +- New `_validate_providers_registry()` in `amplifier_agent_lib/config/loader.py` enforces the closed schema for the new block. +- HTTP-face tests introduced from scratch under `tests/http/` covering lifespan boot scenarios and chat-completions validation. + +### Migration + +For server-mode users on `<= 0.8.0`: add a `providers` block to your `host_config.json`. Minimum to keep working with just Anthropic: + +```json +{ + "providers": { + "anthropic": {} + } +} +``` + +Multi-provider example: + +```json +{ + "providers": { + "anthropic": {}, + "openai": {"config": {"base_url": "https://api.openai.com/v1"}} + } +} +``` + +If you don't pass `host_config.providers`, the server will exit at boot with a clear error message rather than running in a broken half-state. + ## [0.8.0] — 2026-06-20 Adds an OpenAI-compatible chat-completions HTTP face for embedding amplifier-agent in third-party tools (opencode and similar), a persistent `auth` subcommand for provider credentials, and integrates the model-routing matrix for per-provider model selection. Existing JSON-RPC wire protocol unchanged — no wrapper bump required. diff --git a/src/amplifier_agent_cli/admin/serve.py b/src/amplifier_agent_cli/admin/serve.py index af877a35..94473619 100644 --- a/src/amplifier_agent_cli/admin/serve.py +++ b/src/amplifier_agent_cli/admin/serve.py @@ -16,17 +16,31 @@ import logging import os +import signal from pathlib import Path import click import uvicorn +from amplifier_agent_cli.admin.serve_lifecycle import ( + remove_state_file, + restart_command, + status_command, + stop_command, +) + @click.group(name="serve") def serve_group() -> None: """Start a wire face for amplifier-agent.""" +# Register lifecycle subcommands on the group. +serve_group.add_command(status_command, name="status") +serve_group.add_command(stop_command, name="stop") +serve_group.add_command(restart_command, name="restart") + + @serve_group.command(name="chat-completions") @click.option( "--bind", @@ -144,6 +158,11 @@ def chat_completions( raise click.UsageError(f"--config path does not exist or is not a file: {resolved_config_path}") os.environ["AMPLIFIER_AGENT_HTTP_CONFIG_PATH"] = str(resolved_config_path) + # Expose host and port via env so load_config() can stash them in the + # state file (which the lifecycle commands read to know the wire address). + os.environ["AMPLIFIER_AGENT_HTTP_BIND"] = host + os.environ["AMPLIFIER_AGENT_HTTP_PORT"] = str(port) + # Resolve the values that will actually be used, so we can echo them # to stderr (handy for opencode.json setup). resolved_api_key = os.environ.get("AMPLIFIER_AGENT_HTTP_API_KEY", "local-dev-secret") @@ -165,6 +184,17 @@ def chat_completions( click.echo(f" Workspace: {resolved_workspace}", err=True) click.echo(f" Config: {resolved_config}", err=True) + # Belt-and-suspenders: remove the state file on SIGTERM/SIGINT from the + # outer process context. uvicorn handles the actual shutdown sequence; + # the lifespan's finally block is the primary cleanup path. These + # handlers ensure cleanup even if the lifespan teardown is skipped (e.g. + # when the server is killed before lifespan has finished setting up). + def _cleanup_state(_signum: int, _frame: object) -> None: + remove_state_file() + + signal.signal(signal.SIGTERM, _cleanup_state) + signal.signal(signal.SIGINT, _cleanup_state) + uvicorn.run( "amplifier_agent_http.app:app", host=host, diff --git a/src/amplifier_agent_cli/admin/serve_lifecycle.py b/src/amplifier_agent_cli/admin/serve_lifecycle.py new file mode 100644 index 00000000..55601806 --- /dev/null +++ b/src/amplifier_agent_cli/admin/serve_lifecycle.py @@ -0,0 +1,429 @@ +"""State file management and lifecycle commands for ``amplifier-agent serve``. + +Single global state file at ``~/.amplifier-agent/state/serve.json``. Records +the running server's PID, wire endpoint, credentials, workspace, and a +summary of served providers — enough for ``serve status / stop / restart`` +to operate without needing the original invocation context. + +The file is atomic-write (tempfile + os.replace), mode 0600, parent dir +0700. ``api_key`` is sensitive; never log it, never include it in error +messages, never leak it via process listings beyond the original invocation. + +This module is the SINGLE owner of the file path and schema. Callers go +through ``read_state_file``, ``write_state_file``, and +``remove_state_file`` — never touch the path directly. +""" + +from __future__ import annotations + +import json +import os +import shutil +import signal +import stat +import subprocess +import sys +import tempfile +import time +from pathlib import Path +from typing import Any + +import click +import httpx + +from amplifier_agent_lib.persistence import amplifier_agent_home + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + +SCHEMA_VERSION = 1 +_STATE_FILE_MODE = 0o600 +_STATE_DIR_MODE = 0o700 + + +def _state_dir() -> Path: + """Return the directory that holds ``serve.json``. + + Resolves via :func:`amplifier_agent_home` so tests can redirect with + ``AMPLIFIER_AGENT_HOME``. + """ + return amplifier_agent_home() / "state" + + +def _state_file() -> Path: + """Return the canonical path of the state file.""" + return _state_dir() / "serve.json" + + +# --------------------------------------------------------------------------- +# File-mode helpers +# --------------------------------------------------------------------------- + + +def _ensure_state_dir() -> Path: + """Create the state directory with mode 0700, return its path. + + Raises ``PermissionError`` with a clear message if mode enforcement + fails (rare; non-Unix filesystems that ignore chmod). + """ + d = _state_dir() + d.mkdir(parents=True, exist_ok=True) + try: + d.chmod(_STATE_DIR_MODE) + except NotImplementedError as exc: + raise PermissionError( + f"Cannot set directory permissions on {d}. " + "Your filesystem may not support Unix mode bits. " + "The state file (which contains a sensitive api_key) cannot be " + "written safely without mode 0700 on the parent directory." + ) from exc + # Verify the mode was actually applied (some networked/virtual FSes ignore chmod). + actual = stat.S_IMODE(d.stat().st_mode) + if actual != _STATE_DIR_MODE: + raise PermissionError( + f"Failed to set mode 0700 on {d} (got {oct(actual)}). " + "The state file contains a sensitive api_key and cannot be written " + "safely without enforced directory permissions." + ) + return d + + +# --------------------------------------------------------------------------- +# Public IO helpers +# --------------------------------------------------------------------------- + + +def write_state_file(payload: dict[str, Any]) -> None: + """Write ``payload`` atomically to the state file (mode 0600, dir 0700). + + Uses a NamedTemporaryFile in the *same directory* as the target so + that ``os.replace`` is an atomic rename on POSIX (same filesystem). + The tempfile is chmod'd to 0600 *before* the rename so the sensitive + ``api_key`` field is never visible at a more-permissive mode. + + Raises ``PermissionError`` if mode enforcement fails (non-Unix FS). + """ + d = _ensure_state_dir() + payload = {**payload, "schema_version": SCHEMA_VERSION} + encoded = json.dumps(payload, indent=2).encode("utf-8") + + # Write into a tempfile in the same directory so os.replace is atomic. + fd, tmp_path = tempfile.mkstemp(dir=d, prefix=".serve-", suffix=".json.tmp") + try: + try: + os.chmod(tmp_path, _STATE_FILE_MODE) + except NotImplementedError as exc: + raise PermissionError( + f"Cannot set mode 0600 on {tmp_path}. " + "Your filesystem may not support Unix mode bits. " + "Refusing to write api_key in plaintext without permission enforcement." + ) from exc + # Verify enforcement before writing the sensitive payload. + actual = stat.S_IMODE(os.stat(tmp_path).st_mode) + if actual != _STATE_FILE_MODE: + raise PermissionError( + f"Failed to set mode 0600 on {tmp_path} (got {oct(actual)}). " + "Refusing to write api_key in plaintext without enforced file permissions." + ) + os.write(fd, encoded) + finally: + os.close(fd) + + os.replace(tmp_path, _state_file()) + + +def read_state_file() -> dict[str, Any] | None: + """Read and parse the state file. + + Returns ``None`` if the file does not exist. + Raises ``click.ClickException`` on an unknown schema version so the + user gets a clear error with a remediation path. + """ + sf = _state_file() + if not sf.exists(): + return None + try: + data: dict[str, Any] = json.loads(sf.read_text(encoding="utf-8")) + except (json.JSONDecodeError, OSError) as exc: + raise click.ClickException( + f"State file {sf} is unreadable or corrupt: {exc}. " + "Remove it manually and re-run 'amplifier-agent serve chat-completions' to start fresh." + ) from exc + version = data.get("schema_version") + if version != SCHEMA_VERSION: + raise click.ClickException( + f"State file {sf} has schema_version={version!r} but this version of " + f"amplifier-agent only understands schema_version={SCHEMA_VERSION}. " + "Remove the file manually: " + f"rm {sf}" + ) + return data + + +def remove_state_file() -> None: + """Remove the state file if it exists (idempotent).""" + _state_file().unlink(missing_ok=True) + + +# --------------------------------------------------------------------------- +# Process-liveness helpers +# --------------------------------------------------------------------------- + + +def is_pid_alive(pid: int) -> bool: + """Return True if process ``pid`` exists and is signalable. + + Uses ``os.kill(pid, 0)`` (signal 0 checks existence without delivering + a signal). ``PermissionError`` means the process exists but we don't + own it — still alive. ``ProcessLookupError`` means it is gone. + """ + try: + os.kill(pid, 0) + return True + except ProcessLookupError: + return False + except PermissionError: + # Process exists, owned by another user. + return True + + +def wait_for_exit(pid: int, timeout: float) -> bool: + """Poll until ``pid`` exits or ``timeout`` seconds elapse. + + Returns ``True`` if the process exited, ``False`` on timeout. + Polls every 100 ms. + """ + deadline = time.monotonic() + timeout + while time.monotonic() < deadline: + if not is_pid_alive(pid): + return True + time.sleep(0.1) + return not is_pid_alive(pid) + + +# --------------------------------------------------------------------------- +# ``serve status`` +# --------------------------------------------------------------------------- + + +@click.command(name="status") +def status_command() -> None: + """Report whether the chat-completions server is running. + + Checks the state file, validates the recorded PID is still alive, and + probes ``GET /v1/models`` over the wire to confirm the server is + responding. Exits 0 when healthy, 1 when the server cannot be reached. + """ + state = read_state_file() + + if state is None: + click.echo("amplifier-agent serve: not running") + raise SystemExit(0) + + pid: int = state["pid"] + + if not is_pid_alive(pid): + click.echo(f"amplifier-agent serve: stale state file (PID {pid} no longer exists) — cleaned") + remove_state_file() + raise SystemExit(0) + + host: str = state["host"] + port: int = state["port"] + api_key: str = state["api_key"] + workspace: str = state.get("workspace") or "(cwd-derived)" + providers_summary: dict[str, int] = state.get("providers_summary", {}) + + # Probe the wire endpoint. + try: + resp = httpx.get( + f"http://{host}:{port}/v1/models", + headers={"Authorization": f"Bearer {api_key}"}, + timeout=2.0, + ) + resp.raise_for_status() + except (httpx.ConnectError, httpx.TimeoutException, httpx.HTTPStatusError): + click.echo(f"amplifier-agent serve: running (PID {pid}) — but http://{host}:{port}/v1/models not responding") + raise SystemExit(1) from None + + total = sum(providers_summary.values()) + click.echo(f"amplifier-agent serve: running at http://{host}:{port}/v1/ (PID {pid}, workspace={workspace})") + click.echo(f" models: {total} total") + click.echo(" by provider:") + for provider_id, count in providers_summary.items(): + click.echo(f" {provider_id}: {count}") + + raise SystemExit(0) + + +# --------------------------------------------------------------------------- +# ``serve stop`` +# --------------------------------------------------------------------------- + + +@click.command(name="stop") +@click.option( + "--force", + is_flag=True, + default=False, + help="Skip the graceful wait and send SIGKILL immediately.", +) +@click.option( + "--timeout", + "timeout_s", + default=5.0, + show_default=True, + type=float, + metavar="SECONDS", + help="Graceful-exit window before escalating to SIGKILL.", +) +def stop_command(force: bool, timeout_s: float) -> None: + """Stop the running chat-completions server. + + Sends SIGTERM and waits up to ``--timeout`` seconds for a clean exit. + Escalates to SIGKILL on timeout or when ``--force`` is given. + Exits 0 on success, 1 when there is nothing to stop. + """ + state = read_state_file() + + if state is None: + click.echo("amplifier-agent serve: not running", err=True) + raise SystemExit(1) + + pid: int = state["pid"] + + if not is_pid_alive(pid): + click.echo( + f"amplifier-agent serve: stale state file (PID {pid} no longer exists) — cleaning", + err=True, + ) + remove_state_file() + raise SystemExit(0) + + if force: + os.kill(pid, signal.SIGKILL) + wait_for_exit(pid, timeout=2.0) + remove_state_file() + click.echo(f"amplifier-agent serve: stopped (SIGKILL, PID {pid})") + raise SystemExit(0) + + # Graceful path: SIGTERM → wait → escalate if needed. + os.kill(pid, signal.SIGTERM) + if wait_for_exit(pid, timeout=timeout_s): + # SIGTERM handler should have removed the state file; double-check. + if _state_file().exists(): + remove_state_file() + click.echo(f"amplifier-agent serve: stopped (SIGTERM, PID {pid})") + raise SystemExit(0) + + # Graceful window expired — escalate. + os.kill(pid, signal.SIGKILL) + wait_for_exit(pid, timeout=2.0) + remove_state_file() + click.echo( + f"amplifier-agent serve: stopped (SIGTERM timed out after {timeout_s}s, escalated to SIGKILL, PID {pid})", + err=True, + ) + raise SystemExit(0) + + +# --------------------------------------------------------------------------- +# ``serve restart`` +# --------------------------------------------------------------------------- + + +def _resolve_amplifier_agent_executable() -> list[str]: + """Resolve the ``amplifier-agent`` executable path for subprocess re-launch. + + Resolution order: + 1. ``shutil.which("amplifier-agent")`` — installed entry point on PATH. + 2. ``[sys.executable, "-m", "amplifier_agent_cli"]`` — editable / bare + checkout fallback. + """ + exe = shutil.which("amplifier-agent") + if exe: + return [exe] + return [sys.executable, "-m", "amplifier_agent_cli"] + + +@click.command(name="restart") +def restart_command() -> None: + """Restart the chat-completions server using the stored launch args. + + Reads the host, port, api-key, workspace, and host_config_path from the + existing state file, stops the running server, and re-launches it as a + detached background process. Waits up to 30 s for the new state file to + appear (confirming successful startup) before reporting success. + + Exits 1 when there is nothing to restart or the new server does not + become ready within 30 s. + """ + state = read_state_file() + + if state is None: + click.echo("amplifier-agent serve: not running — nothing to restart", err=True) + raise SystemExit(1) + + # Capture launch args before stopping (stop will remove the state file). + host: str = state["host"] + port: int = state["port"] + api_key: str = state["api_key"] + workspace: str | None = state.get("workspace") + host_config_path: str | None = state.get("host_config_path") + old_pid: int = state["pid"] + + # Stop the running server via the graceful stop path. + os.kill(old_pid, signal.SIGTERM) if is_pid_alive(old_pid) else None + if not wait_for_exit(old_pid, timeout=5.0): + if is_pid_alive(old_pid): + os.kill(old_pid, signal.SIGKILL) + wait_for_exit(old_pid, timeout=2.0) + remove_state_file() + + # Reconstruct launch command from stored args. + cmd = [ + *_resolve_amplifier_agent_executable(), + "serve", + "chat-completions", + "--bind", + host, + "--port", + str(port), + "--api-key", + api_key, + ] + if workspace: + cmd.extend(["--workspace", workspace]) + if host_config_path: + cmd.extend(["--config", host_config_path]) + + # Launch detached — stdout/stderr go to /dev/null; the server writes its + # own logs via uvicorn's log machinery. + devnull = open(os.devnull, "wb") + subprocess.Popen( + cmd, + start_new_session=True, + stdout=devnull, + stderr=devnull, + ) + + # Wait for the new state file to appear, indicating a successful lifespan. + deadline = time.monotonic() + 30.0 + sf = _state_file() + while time.monotonic() < deadline: + if sf.exists(): + try: + new_state = read_state_file() + except click.ClickException: + new_state = None + if new_state is not None and new_state.get("pid") != old_pid: + new_pid = new_state["pid"] + click.echo(f"amplifier-agent serve: restarted at http://{host}:{port}/v1/ (new PID {new_pid})") + raise SystemExit(0) + time.sleep(0.2) + + click.echo( + "amplifier-agent serve: restart launched but new server did not become ready within 30s — check logs", + err=True, + ) + raise SystemExit(1) diff --git a/src/amplifier_agent_http/_config.py b/src/amplifier_agent_http/_config.py index 49a50247..7d036ae2 100644 --- a/src/amplifier_agent_http/_config.py +++ b/src/amplifier_agent_http/_config.py @@ -24,6 +24,20 @@ class ServerConfig: """Human-readable model name (returned in `models.data[*].name` for hosts that read it; not part of strict OpenAI spec but harmless).""" + host: str + """Bind address for the HTTP server. + + Read from ``AMPLIFIER_AGENT_HTTP_BIND``. Defaults to ``127.0.0.1``. + Written into the state file at startup so lifecycle commands + (``serve status / stop / restart``) know the wire address without + re-parsing CLI flags.""" + + port: int + """Bind port for the HTTP server. + + Read from ``AMPLIFIER_AGENT_HTTP_PORT``. Defaults to ``9099``. + Written into the state file alongside ``host``.""" + workspace: str | None """Optional workspace override for this server's session bucket. @@ -68,10 +82,18 @@ class ServerConfig: def load_config() -> ServerConfig: """Load ServerConfig from environment.""" + raw_port = os.environ.get("AMPLIFIER_AGENT_HTTP_PORT", "9099") + try: + port = int(raw_port) + except ValueError: + port = 9099 + return ServerConfig( api_key=os.environ.get("AMPLIFIER_AGENT_HTTP_API_KEY", "local-dev-secret"), model_id=os.environ.get("AMPLIFIER_AGENT_HTTP_MODEL_ID", "amplifier"), model_display_name=os.environ.get("AMPLIFIER_AGENT_HTTP_MODEL_NAME", "Amplifier"), + host=os.environ.get("AMPLIFIER_AGENT_HTTP_BIND", "127.0.0.1"), + port=port, # Prefer the HTTP-face-specific env var when set; fall back to the # ecosystem-shared one (which the CLI also reads via # persistence.resolve_workspace). Empty / whitespace = unset. diff --git a/src/amplifier_agent_http/_wire.py b/src/amplifier_agent_http/_wire.py index 3734b28f..b801e90c 100644 --- a/src/amplifier_agent_http/_wire.py +++ b/src/amplifier_agent_http/_wire.py @@ -63,9 +63,10 @@ class ChatCompletionRequest(BaseModel): model: str messages: list[ChatMessage] - stream: bool = False - """Non-streaming requests are accepted but the POC always emits SSE - internally and buffers if needed. opencode always streams.""" + stream: bool | None = None + """Explicit streaming flag. ``True`` → SSE; ``False`` → single JSON body. + ``None`` (field absent) is treated as ``True`` for backward compatibility + so existing clients that omit the field continue to get SSE.""" tools: list[ToolDefinition] | None = None """Host-provided tools. Accepted but ignored in the POC -- amplifier never diff --git a/src/amplifier_agent_http/app.py b/src/amplifier_agent_http/app.py index 5f61ce3b..aeca5c01 100644 --- a/src/amplifier_agent_http/app.py +++ b/src/amplifier_agent_http/app.py @@ -13,8 +13,11 @@ import asyncio import logging +import os +import sys from collections.abc import AsyncIterator from contextlib import asynccontextmanager +from datetime import UTC, datetime from importlib.metadata import PackageNotFoundError from importlib.metadata import version as _pkg_version from pathlib import Path @@ -27,7 +30,6 @@ ProviderModuleNotInstalledError, list_provider_models, ) -from amplifier_agent_cli.provider_sources import KNOWN_PROVIDERS from amplifier_agent_http._config import load_config from amplifier_agent_http._session_runner import hydrate_agent_configs from amplifier_agent_http.routes import chat_completions, models @@ -142,76 +144,104 @@ async def lifespan(app: FastAPI) -> AsyncIterator[None]: app.state.resolved_workspace = resolved_workspace app.state.agent_configs = hydrate_agent_configs(prepared) - # Eager-load every reachable provider's model list. Iterates the CLI's - # ``KNOWN_PROVIDERS`` catalog and calls ``list_provider_models`` for each - # provider whose credentials are present in the environment; the typed - # exceptions ``ProviderCredentialsMissingError`` and - # ``ProviderModuleNotInstalledError`` tell us to skip silently rather - # than abort startup. Each model dict is tagged with ``_provider`` so - # ``chat_completions`` can route the per-request injection. + # Explicit per-provider model enumeration. ``host_config.providers`` is + # authoritative: we load exactly the providers declared there, fail loudly + # on any that cannot initialize. The previous behavior (iterate + # KNOWN_PROVIDERS, skip silently, fall back to a placeholder) is removed. # - # ``list_provider_models`` is a sync function that may call ``asyncio.run()`` - # internally for async providers; we wrap each call in ``to_thread`` so - # the lifespan event loop is not blocked. Failures are non-fatal -- - # ``/v1/models`` falls back to advertising the configured ``model_id`` - # alone when nothing could be enumerated. + # ``list_provider_models`` is sync and may call ``asyncio.run()`` internally; + # wrap each call in ``to_thread`` so the lifespan event loop is not blocked. # served_models_registry: maps wire model id -> provider id (e.g. "anthropic") # so chat_completions can route the per-request inject_provider() call. + providers_block = (app.state.host_config or {}).get("providers") + + if not isinstance(providers_block, dict) or not providers_block: + logger.error( + "amplifier-agent serve chat-completions requires `host_config.providers` " + "to be a non-empty dict. Declare at least one provider explicitly. " + "There is no implicit registry — KNOWN_PROVIDERS is no longer iterated." + ) + sys.exit(2) + app.state.available_models = [] app.state.served_models_registry = {} - for provider_id in KNOWN_PROVIDERS: + errors: list[str] = [] + + for provider_id, entry in providers_block.items(): + module_id = entry.get("module", provider_id) + provider_config = entry.get("config", {}) try: - models = await asyncio.to_thread( + provider_models = await asyncio.to_thread( list_provider_models, - provider_id, + module_id, 15.0, + provider_config, # extra_config — passes per-provider config through ) - except ProviderCredentialsMissingError: - logger.info( - "Skipping provider %r -- no credentials in env", - provider_id, - ) + except ProviderCredentialsMissingError as exc: + errors.append(f"provider {provider_id!r}: credentials missing — {exc}") continue - except ProviderModuleNotInstalledError: - logger.info( - "Skipping provider %r -- module not installed", - provider_id, - ) + except ProviderModuleNotInstalledError as exc: + errors.append(f"provider {provider_id!r}: module {module_id!r} not installed — {exc}") continue except Exception as exc: - logger.warning( - "Could not enumerate models from provider %r (%s: %s)", - provider_id, - type(exc).__name__, - exc, - ) + errors.append(f"provider {provider_id!r}: failed to enumerate models — {type(exc).__name__}: {exc}") continue - for m in models: + if not provider_models: + errors.append(f"provider {provider_id!r}: list_models() returned 0 models") + continue + for m in provider_models: d = m.model_dump() if hasattr(m, "model_dump") else dict(m) d["_provider"] = provider_id app.state.available_models.append(d) app.state.served_models_registry[d["id"]] = provider_id - logger.info( - "Loaded %d models from provider %r", - len(models), - provider_id, - ) + logger.info("Loaded %d models from provider %r", len(provider_models), provider_id) - if not app.state.available_models: - logger.warning( - "No providers could be enumerated. /v1/models will advertise only %r.", - config.model_id, + if errors: + logger.error( + "amplifier-agent serve failed to initialize %d of %d declared providers:\n %s", + len(errors), + len(providers_block), + "\n ".join(errors), ) + sys.exit(2) logger.info( - "Prepared bundle loaded with provider; %d agents hydrated. Ready to serve.", + "Prepared bundle loaded with providers; %d agents hydrated. Ready to serve.", len(app.state.agent_configs), ) + # Write the state file so lifecycle commands (serve status / stop / restart) + # can discover the running server without re-parsing CLI flags. + # providers_summary maps provider_id -> model count for the status display. + from amplifier_agent_cli.admin.serve_lifecycle import ( + remove_state_file, + write_state_file, + ) + + providers_summary: dict[str, int] = {} + for m in app.state.available_models: + pid_ = m.get("_provider", "unknown") + providers_summary[pid_] = providers_summary.get(pid_, 0) + 1 + + write_state_file( + { + "pid": os.getpid(), + "started_at": datetime.now(UTC).isoformat().replace("+00:00", "Z"), + "host": app.state.config.host, + "port": app.state.config.port, + "api_key": app.state.config.api_key, + "workspace": app.state.resolved_workspace, + "host_config_path": app.state.config.host_config_path or None, + "providers_summary": providers_summary, + } + ) + logger.info("State file written; server is discoverable via 'amplifier-agent serve status'.") + try: yield finally: logger.info("amplifier-agent HTTP face shutting down") + remove_state_file() def build_app() -> FastAPI: diff --git a/src/amplifier_agent_http/routes/chat_completions.py b/src/amplifier_agent_http/routes/chat_completions.py index e746131d..d74656b4 100644 --- a/src/amplifier_agent_http/routes/chat_completions.py +++ b/src/amplifier_agent_http/routes/chat_completions.py @@ -24,12 +24,13 @@ import asyncio import json import logging +import time from collections.abc import AsyncGenerator from decimal import Decimal, InvalidOperation from typing import Any from fastapi import APIRouter, Depends, HTTPException, Request, status -from fastapi.responses import StreamingResponse +from fastapi.responses import JSONResponse, StreamingResponse from amplifier_agent_http._auth import require_bearer from amplifier_agent_http._event_translator import extract_usage, translate_event @@ -282,39 +283,30 @@ def _extract_text(msg: ChatMessage) -> str: async def _stream_chat_completion( *, - prepared: Any, - agent_configs: dict[str, dict[str, Any]], - history: list[dict[str, Any]], - prompt: str, chunk_id: str, model_id: str, - tools: list[dict[str, Any]] | None = None, - workspace: str | None = None, - provider_id: str = "anthropic", - upstream_model: str | None = None, + turn_task: asyncio.Task[str], + signal_task: asyncio.Task[None], + event_queue: asyncio.Queue[Any], + display: HttpQueueDisplaySystem, + host_tool_yield_state: dict[str, Any], ) -> AsyncGenerator[str, None]: - # Shared mutable state for the host-tool hook to signal yields back to us - # WITHOUT depending on BaseException subclass preservation across the - # kernel's session.execute() bridge (which wraps everything as RuntimeError). - # The hook writes ``yielded=True`` plus the tool name/id when a host-tool's - # tool:pre fires. We read it after turn_task settles to pick finish_reason. - host_tool_yield_state: dict[str, Any] = {"yielded": False, "tool_name": "", "tool_call_id": ""} """Drive a single chat completion and yield SSE chunks. - This is the heart of Slice 2. It coordinates: - 1. Setting up the display queue. - 2. Spawning the turn task. - 3. Draining the queue -> translating events -> yielding SSE chunks. - 4. Joining the turn task and emitting the final chunk. - 5. Cleaning up on cancellation. + Accepts a pre-started ``turn_task`` and its associated infrastructure + (event_queue, display, signal_task, host_tool_yield_state) from the caller + so the caller can perform a pre-flight error check (Edit C) BEFORE returning + a StreamingResponse to FastAPI. If the caller detects an immediate failure + it raises HTTPException(502) itself -- by the time this generator is + iterated the HTTP 200 headers are already committed and no status change + is possible. + + The generator coordinates: + 1. Yielding the role chunk to open the SSE stream. + 2. Draining the event_queue -> translating events -> yielding SSE chunks. + 3. Joining the turn task and emitting the final stop/tool_calls chunk. + 4. Cleaning up on cancellation. """ - # Per-request queue: each emit-able event is one slot. ``maxsize=0`` = - # unbounded; for the POC this is fine. If we observe memory pressure - # under burst loads in Slice 3, we can bound this and shed events. - event_queue: asyncio.Queue[Any] = asyncio.Queue() - display = HttpQueueDisplaySystem(event_queue) - approval = HttpAutoApprovalSystem() - # Accumulate usage across multiple kernel ``usage`` events. A single turn # may make several internal LLM calls (e.g. subagent delegation, retry on # tool error) -- emitting only the last one understates total cost. @@ -343,37 +335,6 @@ async def _stream_chat_completion( # with no content, matching every other OpenAI-compatible provider. yield sse_data(role_chunk(chunk_id, model_id)) - # Spawn the turn task. It runs concurrently with our drain loop. - turn_task: asyncio.Task[str] = asyncio.create_task( - run_chat_turn( - prepared=prepared, - agent_configs=agent_configs, - history=history, - prompt=prompt, - display=display, - approval=approval, - tools=tools, - host_tool_yield_state=host_tool_yield_state, - workspace=workspace, - provider_id=provider_id, - upstream_model=upstream_model, - ) - ) - - # Watcher coroutine: when the turn task finishes (success or failure), - # post the sentinel to wake our drain loop. Avoids polling. - async def _signal_done() -> None: - try: - await asyncio.shield(turn_task) - except BaseException: - # Errors are handled in the main flow when we ``await turn_task``. - # Here we just need to wake the drain loop. - pass - finally: - display.close() - - signal_task = asyncio.create_task(_signal_done()) - try: # Drain loop: pump events until the sentinel arrives. ``asyncio.wait_for`` # bounds each ``queue.get()`` so we can emit SSE keepalive comments @@ -446,7 +407,8 @@ async def _signal_done() -> None: type(exc).__name__, ) else: - # Surface as inline text + log -- proper error envelope is v2. + # chunks_emitted is True here (role chunk already sent) so we + # cannot raise HTTP 502 -- embed the error in delta.content. logger.exception("turn task raised: %s", exc) err_chunk = content_delta_chunk( chunk_id, @@ -500,14 +462,79 @@ async def _signal_done() -> None: display.close() -@router.post("/v1/chat/completions", dependencies=[Depends(require_bearer)]) -async def chat_completions(payload: ChatCompletionRequest, request: Request) -> StreamingResponse: - """Streaming chat completion endpoint. - - Slice 2: real AmplifierSession execution. the request's ``stream`` field is - effectively ignored -- we always stream because that's the only path the - kernel exposes via display.emit. If a client requests stream=false in - the future we should buffer and return a JSON ChatCompletion -- Slice 3+. +async def _collect_completion( + gen: AsyncGenerator[str, None], + *, + chunk_id: str, + model: str, +) -> dict[str, Any]: + """Buffer a streaming generator into a single non-streaming ChatCompletion. + + Consumes all SSE strings from ``gen``, parses each data line, accumulates + assistant content from delta chunks, and extracts finish_reason and usage + from the terminal stop chunk. The returned dict matches the OpenAI + ``chat.completion`` (non-streaming) shape. + """ + content_parts: list[str] = [] + finish_reason: str = "stop" + usage_block: dict[str, Any] | None = None + created = int(time.time()) + + async for sse_str in gen: + for line in sse_str.splitlines(): + if not line.startswith("data: "): + continue + payload_str = line[6:] + if payload_str == "[DONE]": + continue + try: + chunk_obj = json.loads(payload_str) + except json.JSONDecodeError: + continue + for choice in chunk_obj.get("choices", []): + delta = choice.get("delta", {}) + if isinstance(delta.get("content"), str) and delta["content"]: + content_parts.append(delta["content"]) + fr = choice.get("finish_reason") + if fr: + finish_reason = fr + if "usage" in chunk_obj: + usage_block = chunk_obj["usage"] + + return { + "id": chunk_id, + "object": "chat.completion", + "created": created, + "model": model, + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "".join(content_parts), + }, + "finish_reason": finish_reason, + } + ], + "usage": usage_block or {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, + } + + +@router.post("/v1/chat/completions", dependencies=[Depends(require_bearer)], response_model=None) +async def chat_completions( + payload: ChatCompletionRequest, + request: Request, +) -> StreamingResponse | JSONResponse: + """Chat completion endpoint — streaming (SSE) or non-streaming (JSON). + + ``stream: true`` (or absent/null) → Server-Sent Events, Content-Type: + text/event-stream. ``stream: false`` → single JSON body, Content-Type: + application/json, matching the OpenAI non-streaming chat.completion shape. + + Upstream errors raised before any content has been produced surface as + HTTP 502 with a structured OpenAI-shape error envelope. Errors that occur + mid-stream (after the role chunk has been emitted) are embedded in + ``delta.content`` — there is no other option once SSE has started. """ config = request.app.state.config prepared = getattr(request.app.state, "prepared", None) @@ -526,21 +553,25 @@ async def chat_completions(payload: ChatCompletionRequest, request: Request) -> }, ) - # Look up which provider serves this model. The registry is built at - # lifespan from the CLI's KNOWN_PROVIDERS catalog (one entry per provider - # whose credentials are present in env). Fall back to "anthropic" when - # the requested model is not in the registry -- preserves the existing - # behavior (warn + serve) for clients that send a raw model name like - # "amplifier" rather than a model id from /v1/models. - served_registry: dict[str, str] = getattr(request.app.state, "served_models_registry", {}) - provider_id = served_registry.get(payload.model, "anthropic") - if payload.model not in served_registry: - logger.warning( - "Request model=%r not in served_models_registry (%d entries); " - "falling back to provider=%r with the bundle's default upstream model.", - payload.model, - len(served_registry), - provider_id, + # Look up which provider serves this model. The registry is built at + # lifespan from ``host_config.providers`` (one entry per provider that + # successfully enumerated models). An unknown model is a hard 400 -- + # there is no silent fallback to a hardcoded provider. + served_registry: dict[str, str] = getattr(request.app.state, "served_models_registry", {}) or {} + provider_id = served_registry.get(payload.model) + if provider_id is None: + raise HTTPException( + status_code=400, + detail={ + "error": { + "type": "invalid_request_error", + "code": "unknown_model", + "message": ( + f"model {payload.model!r} is not served by this instance. " + "Call GET /v1/models for the list of served models." + ), + } + }, ) history, prompt = _split_history_and_prompt(payload.messages) @@ -594,19 +625,91 @@ async def chat_completions(payload: ChatCompletionRequest, request: Request) -> client_session_id, ) + # Edit C: set up the turn infrastructure HERE (in the route handler, not in + # the async generator) so we can detect immediate initialization failures + # BEFORE returning a StreamingResponse. Once FastAPI returns a + # StreamingResponse object, Starlette commits the HTTP 200 status line + # before iterating the body generator, making it impossible to switch to 502. + # By doing the pre-flight check here we still have a clean slate. + event_queue: asyncio.Queue[Any] = asyncio.Queue() + display = HttpQueueDisplaySystem(event_queue) + approval = HttpAutoApprovalSystem() + host_tool_yield_state: dict[str, Any] = {"yielded": False, "tool_name": "", "tool_call_id": ""} + + turn_task: asyncio.Task[Any] = asyncio.create_task( + run_chat_turn( + prepared=prepared, + agent_configs=agent_configs, + history=history, + prompt=prompt, + display=display, + approval=approval, + tools=tools_payload, + host_tool_yield_state=host_tool_yield_state, + workspace=workspace, + provider_id=provider_id, + upstream_model=payload.model, + ) + ) + + # Pre-flight: give the task a brief window (50 ms) to fail immediately. + # An immediately-failing coroutine (mock with side_effect, or a provider + # that raises before its first IO await) completes well within 50 ms. + # Normal turns are waiting on an LLM response so they remain pending. + _PREFLIGHT_TIMEOUT_SECONDS: float = 0.05 + done, _ = await asyncio.wait([turn_task], timeout=_PREFLIGHT_TIMEOUT_SECONDS) + if turn_task in done: + try: + await turn_task + except asyncio.CancelledError: + raise + except Exception as exc: + raise HTTPException( + status_code=502, + detail={ + "error": { + "type": "upstream_error", + "code": "upstream_error", + "message": (f"Provider initialization failed: {type(exc).__name__}: {exc}"), + } + }, + ) from exc + + # Watcher: when turn_task finishes, post the sentinel to wake the drain loop. + async def _signal_done() -> None: + try: + await asyncio.shield(turn_task) + except BaseException: + pass + finally: + display.close() + + signal_task: asyncio.Task[None] = asyncio.create_task(_signal_done()) + generator = _stream_chat_completion( - prepared=prepared, - agent_configs=agent_configs, - history=history, - prompt=prompt, chunk_id=chunk_id, model_id=config.model_id, - tools=tools_payload, - workspace=workspace, - provider_id=provider_id, - upstream_model=payload.model, + turn_task=turn_task, + signal_task=signal_task, + event_queue=event_queue, + display=display, + host_tool_yield_state=host_tool_yield_state, ) + # Edit B: honor the ``stream`` flag. + # ``stream: false`` → buffer all SSE chunks and return a single JSON body. + # ``stream: true`` → SSE streaming (the original path). + # ``stream: null`` → treated as ``true`` for backward compatibility + # (clients that omit the field get SSE, matching the + # behavior before this flag was honored). + if payload.stream is False: + completion = await _collect_completion( + generator, + chunk_id=chunk_id, + model=payload.model, + ) + return JSONResponse(content=completion) + return StreamingResponse( generator, media_type="text/event-stream", diff --git a/src/amplifier_agent_http/routes/models.py b/src/amplifier_agent_http/routes/models.py index 86dda13e..e5a166c0 100644 --- a/src/amplifier_agent_http/routes/models.py +++ b/src/amplifier_agent_http/routes/models.py @@ -58,7 +58,7 @@ def _to_openai_entry(model_obj: Any, *, now: int) -> dict[str, Any]: "created": now, "owned_by": "amplifier-agent", } - # _provider is set by app.py lifespan when iterating KNOWN_PROVIDERS. + # _provider is set by app.py lifespan when iterating the providers registry. # Surface it so clients can see which provider serves each model. The # field is non-standard but additive (standard OpenAI clients ignore # unknown fields). @@ -89,33 +89,14 @@ async def list_models(request: Request) -> dict: backs ``amplifier-agent models list``. No drift between the two surfaces; if the CLI sees a model, /v1/models does too. - Falls back to a minimal single-model placeholder (the configured - ``--model-id``) when the lifespan probe couldn't load the list (no - credentials, provider module not installed, network error, etc.). - The CLI handles this case the same way -- empty list means "could - not enumerate" rather than "no models exist". + The lifespan guarantees ``available_models`` is non-empty at startup + (or exits 2). If ``available_models`` is somehow empty at request time + the route returns an empty list rather than synthesizing a placeholder. """ available = getattr(request.app.state, "available_models", None) or [] now = int(time.time()) - if available: - return { - "object": "list", - "data": [_to_openai_entry(m, now=now) for m in available], - } - - # Fallback: advertise just the configured model_id when list_models - # failed at lifespan. Preserves the minimum-viable shape so clients - # that hit /v1/models for smoke tests don't get an empty list. - config = request.app.state.config return { "object": "list", - "data": [ - { - "id": config.model_id, - "object": "model", - "created": now, - "owned_by": "amplifier-agent", - } - ], + "data": [_to_openai_entry(m, now=now) for m in available], } diff --git a/src/amplifier_agent_lib/config/loader.py b/src/amplifier_agent_lib/config/loader.py index c453e1d3..d12c412f 100644 --- a/src/amplifier_agent_lib/config/loader.py +++ b/src/amplifier_agent_lib/config/loader.py @@ -28,7 +28,7 @@ __all__ = ["VALID_APPROVAL_MODES", "ConfigError", "load_config"] -_VALID_TOP_LEVEL_KEYS = frozenset({"mcp", "approval", "provider", "allowProtocolSkew", "skills"}) +_VALID_TOP_LEVEL_KEYS = frozenset({"mcp", "approval", "provider", "providers", "allowProtocolSkew", "skills"}) _VALID_PROVIDER_MODULES = frozenset({"anthropic", "openai", "azure-openai", "ollama"}) # G3: explicit set of host-supplied approval modes. ``CliApprovalSystem`` accepts # exactly these three strings; any other value must be rejected at parse time @@ -202,6 +202,73 @@ def _validate_provider_module(provider_block: Any, path: Path) -> None: ) +def _validate_providers_registry(providers_block: Any, path: Path) -> None: + """Validate the ``providers`` (plural) registry for server mode. + + Schema:: + + providers: dict[str, dict] + : + module: str (optional; defaults to ) + config: dict (optional; defaults to {}) + + Empty dict is allowed at validation time; the HTTP lifespan rejects it + separately at boot (so single-turn mode never trips on an empty + ``providers`` block someone left in their host_config.json). + """ + if not isinstance(providers_block, dict): + raise ConfigError( + code="config_invalid_type", + message=(f"`providers` at {path} must be a JSON object, got {type(providers_block).__name__}."), + classification="protocol", + ) + for provider_id, entry in providers_block.items(): + if not isinstance(provider_id, str) or not provider_id: + raise ConfigError( + code="config_invalid_type", + message=(f"`providers` at {path}: keys must be non-empty strings; got {provider_id!r}."), + classification="protocol", + ) + if not isinstance(entry, dict): + raise ConfigError( + code="config_invalid_type", + message=(f"`providers.{provider_id}` at {path} must be a JSON object, got {type(entry).__name__}."), + classification="protocol", + ) + # `module` is optional; if present it must be a known provider module string. + module = entry.get("module", provider_id) + if not isinstance(module, str) or not module: + raise ConfigError( + code="config_invalid_type", + message=(f"`providers.{provider_id}.module` at {path} must be a non-empty string."), + classification="protocol", + ) + if module not in _VALID_PROVIDER_MODULES: + raise ConfigError( + code="config_invalid_provider_module", + message=( + f"`providers.{provider_id}.module` at {path} is {module!r}; " + f"must be one of: {sorted(_VALID_PROVIDER_MODULES)}." + ), + classification="protocol", + ) + # `config` is optional; if present it must be a dict. + if "config" in entry and not isinstance(entry["config"], dict): + raise ConfigError( + code="config_invalid_type", + message=(f"`providers.{provider_id}.config` at {path} must be a JSON object."), + classification="protocol", + ) + # Reject unknown keys inside the entry (closed schema). + unknown = set(entry.keys()) - {"module", "config"} + if unknown: + raise ConfigError( + code="config_unknown_key", + message=(f"`providers.{provider_id}` at {path} has unknown keys: {sorted(unknown)}."), + classification="protocol", + ) + + class ConfigError(AaaError): """Recoverable configuration error raised by loader/merger. @@ -298,4 +365,6 @@ def load_config(config_arg: str | None) -> dict[str, Any] | None: _validate_approval_mode(parsed.get("approval"), path) _validate_provider_module(parsed.get("provider"), path) _validate_skills_block(parsed.get("skills"), path) + if "providers" in parsed: + _validate_providers_registry(parsed["providers"], path) return parsed diff --git a/src/amplifier_agent_lib/config/merger.py b/src/amplifier_agent_lib/config/merger.py index 7da87014..540bb03c 100644 --- a/src/amplifier_agent_lib/config/merger.py +++ b/src/amplifier_agent_lib/config/merger.py @@ -202,6 +202,11 @@ def merge_config( if isinstance(skills_block, dict): _merge_skills(merged, skills_block) + # `providers` (plural) is server-mode only and read directly by the HTTP + # lifespan from host_config; no overlay into bundle_modules here. + # The loader's _validate_providers_registry enforces the schema; the merger + # intentionally ignores this key to keep single-turn mode unaffected. + # D4: ``allowProtocolSkew`` is engine-level, not a module pass-through. # Surface it as a separate return field so the engine boot path can read # it without re-parsing host_config. Defaults to False when absent or diff --git a/tests/cli/test_serve_lifecycle.py b/tests/cli/test_serve_lifecycle.py new file mode 100644 index 00000000..96168be6 --- /dev/null +++ b/tests/cli/test_serve_lifecycle.py @@ -0,0 +1,489 @@ +"""Tests for ``amplifier_agent_cli.admin.serve_lifecycle``. + +Covers state-file IO helpers (write/read/remove), process-liveness helpers, +and the three CLI subcommands (status, stop, restart). + +Isolation strategy: every test that touches the state file uses +``AMPLIFIER_AGENT_HOME`` env-var redirection so we never pollute the real +``~/.amplifier-agent/`` directory. Tests that exercise PID logic spawn +throwaway subprocesses via ``subprocess.Popen("python -c ...")`` and clean +up after themselves. +""" + +from __future__ import annotations + +import json +import os +import stat +import subprocess +import sys +import time +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest +from click.testing import CliRunner + +from amplifier_agent_cli.admin.serve_lifecycle import ( + SCHEMA_VERSION, + _state_dir, + _state_file, + is_pid_alive, + read_state_file, + remove_state_file, + restart_command, + status_command, + stop_command, + write_state_file, +) + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture() +def isolated_home(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Path: + """Redirect AMPLIFIER_AGENT_HOME to a temp dir for the duration of each test.""" + monkeypatch.setenv("AMPLIFIER_AGENT_HOME", str(tmp_path)) + return tmp_path + + +@pytest.fixture() +def sample_state(isolated_home: Path) -> dict: + """A minimal valid state payload (no file on disk; just the dict).""" + return { + "pid": os.getpid(), + "started_at": "2026-06-21T19:30:00Z", + "host": "127.0.0.1", + "port": 9099, + "api_key": "local-dev-secret", + "workspace": "test-workspace", + "host_config_path": None, + "providers_summary": {"anthropic": 3, "openai": 6}, + } + + +# --------------------------------------------------------------------------- +# write_state_file +# --------------------------------------------------------------------------- + + +def test_write_state_file_creates_with_correct_modes(isolated_home: Path, sample_state: dict) -> None: + """write_state_file creates serve.json at mode 0600; parent dir at mode 0700.""" + write_state_file(sample_state) + + sf = _state_file() + assert sf.exists(), "State file was not created" + + file_mode = stat.S_IMODE(sf.stat().st_mode) + assert file_mode == 0o600, f"Expected mode 0600 on file, got {oct(file_mode)}" + + dir_mode = stat.S_IMODE(_state_dir().stat().st_mode) + assert dir_mode == 0o700, f"Expected mode 0700 on directory, got {oct(dir_mode)}" + + +def test_write_state_file_content_round_trips(isolated_home: Path, sample_state: dict) -> None: + """Written payload round-trips through read_state_file (with schema_version injected).""" + write_state_file(sample_state) + result = read_state_file() + assert result is not None + assert result["schema_version"] == SCHEMA_VERSION + assert result["pid"] == sample_state["pid"] + assert result["host"] == "127.0.0.1" + assert result["port"] == 9099 + assert result["workspace"] == "test-workspace" + # api_key must round-trip but never appear in error output + assert result["api_key"] == "local-dev-secret" + assert result["providers_summary"] == {"anthropic": 3, "openai": 6} + + +def test_write_state_file_is_atomic(isolated_home: Path, sample_state: dict) -> None: + """State file does not appear until os.replace completes. + + We simulate an interrupted write by patching os.replace to a no-op AFTER + the tempfile is written, then verify the state file does not exist. + Subsequently we let a real write complete and verify the file appears. + """ + sf = _state_file() + + # First: intercept os.replace so the tempfile is created but never renamed. + with patch("amplifier_agent_cli.admin.serve_lifecycle.os.replace"): + write_state_file(sample_state) + + # The state file must NOT exist after the intercepted (no-op) rename. + assert not sf.exists(), "State file appeared despite os.replace being no-op'd" + + # Real write succeeds. + write_state_file(sample_state) + assert sf.exists() + + +# --------------------------------------------------------------------------- +# read_state_file +# --------------------------------------------------------------------------- + + +def test_read_state_file_missing_returns_none(isolated_home: Path) -> None: + """read_state_file returns None when the file does not exist.""" + assert read_state_file() is None + + +def test_read_state_file_unknown_schema_version_raises(isolated_home: Path) -> None: + """read_state_file raises ClickException for unknown schema_version.""" + import click + + _state_dir().mkdir(parents=True, exist_ok=True) + bad_payload = json.dumps({"schema_version": 999, "pid": 1}) + _state_file().write_text(bad_payload, encoding="utf-8") + + with pytest.raises(click.ClickException) as exc_info: + read_state_file() + + assert "schema_version" in str(exc_info.value.message) + + +# --------------------------------------------------------------------------- +# remove_state_file +# --------------------------------------------------------------------------- + + +def test_remove_state_file_idempotent(isolated_home: Path, sample_state: dict) -> None: + """remove_state_file is idempotent — calling it twice raises no exception.""" + write_state_file(sample_state) + assert _state_file().exists() + + remove_state_file() + assert not _state_file().exists() + + # Second call must not raise. + remove_state_file() + + +# --------------------------------------------------------------------------- +# is_pid_alive +# --------------------------------------------------------------------------- + + +def test_is_pid_alive_running_process(isolated_home: Path) -> None: + """is_pid_alive returns True for a live subprocess.""" + proc = subprocess.Popen( + [sys.executable, "-c", "import time; time.sleep(60)"], + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + try: + assert is_pid_alive(proc.pid) + finally: + proc.kill() + proc.wait() + + +def test_is_pid_alive_dead_process(isolated_home: Path) -> None: + """is_pid_alive returns False after the process exits.""" + proc = subprocess.Popen( + [sys.executable, "-c", "pass"], + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + proc.wait(timeout=5) + assert not is_pid_alive(proc.pid) + + +def test_is_pid_alive_init(isolated_home: Path) -> None: + """PID 1 (init / launchd) is always alive.""" + assert is_pid_alive(1) + + +# --------------------------------------------------------------------------- +# ``serve status`` +# --------------------------------------------------------------------------- + + +def test_status_no_state_file_reports_not_running(isolated_home: Path) -> None: + """``serve status`` exits 0 and reports 'not running' when no state file.""" + runner = CliRunner() + result = runner.invoke(status_command, []) + assert result.exit_code == 0, result.output + assert "not running" in result.output + + +def test_status_stale_pid_self_cleans(isolated_home: Path) -> None: + """``serve status`` removes a state file whose PID is gone and exits 0.""" + # Write a state file with a PID that definitely does not exist. + dead_proc = subprocess.Popen( + [sys.executable, "-c", "pass"], + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + dead_pid = dead_proc.pid + dead_proc.wait(timeout=5) + + write_state_file( + { + "pid": dead_pid, + "started_at": "2026-06-21T19:30:00Z", + "host": "127.0.0.1", + "port": 9099, + "api_key": "local-dev-secret", + "workspace": "test", + "host_config_path": None, + "providers_summary": {}, + } + ) + assert _state_file().exists() + + runner = CliRunner() + result = runner.invoke(status_command, []) + assert result.exit_code == 0, result.output + assert "stale" in result.output + assert not _state_file().exists() + + +def test_status_running_and_reachable(isolated_home: Path, sample_state: dict) -> None: + """``serve status`` reports 'running at http://...' with provider counts.""" + write_state_file(sample_state) + + models_response = MagicMock() + models_response.raise_for_status = MagicMock() + + with ( + patch( + "amplifier_agent_cli.admin.serve_lifecycle.is_pid_alive", + return_value=True, + ), + patch( + "amplifier_agent_cli.admin.serve_lifecycle.httpx.get", + return_value=models_response, + ), + ): + runner = CliRunner() + result = runner.invoke(status_command, []) + + assert result.exit_code == 0, result.output + assert "running at http://127.0.0.1:9099/v1/" in result.output + assert "workspace=test-workspace" in result.output + # Total model count from sample_state: 3 + 6 = 9 + assert "9 total" in result.output + assert "anthropic: 3" in result.output + assert "openai: 6" in result.output + + +# --------------------------------------------------------------------------- +# ``serve stop`` +# --------------------------------------------------------------------------- + + +def test_stop_no_state_file_exits_1(isolated_home: Path) -> None: + """``serve stop`` exits 1 when there is no state file.""" + runner = CliRunner() + result = runner.invoke(stop_command, []) + assert result.exit_code == 1 + + +def test_stop_graceful_sends_sigterm(isolated_home: Path) -> None: + """``serve stop`` sends SIGTERM and exits 0 when the process exits cleanly.""" + proc = subprocess.Popen( + [sys.executable, "-c", "import time; time.sleep(60)"], + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + write_state_file( + { + "pid": proc.pid, + "started_at": "2026-06-21T19:30:00Z", + "host": "127.0.0.1", + "port": 9099, + "api_key": "local-dev-secret", + "workspace": "test", + "host_config_path": None, + "providers_summary": {}, + } + ) + + try: + runner = CliRunner() + result = runner.invoke(stop_command, ["--timeout", "5"]) + assert result.exit_code == 0, result.output + assert "stopped" in result.output + # Reap the zombie so the PID is fully released before checking liveness. + proc.wait(timeout=5) + assert proc.returncode is not None, "Process should have exited" + assert not _state_file().exists() + finally: + # Clean up if stop didn't kill it. + if proc.poll() is None: + proc.kill() + proc.wait() + + +def test_stop_force_sends_sigkill(isolated_home: Path) -> None: + """``serve stop --force`` sends SIGKILL immediately.""" + proc = subprocess.Popen( + [sys.executable, "-c", "import time; time.sleep(60)"], + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + write_state_file( + { + "pid": proc.pid, + "started_at": "2026-06-21T19:30:00Z", + "host": "127.0.0.1", + "port": 9099, + "api_key": "local-dev-secret", + "workspace": "test", + "host_config_path": None, + "providers_summary": {}, + } + ) + + try: + runner = CliRunner() + result = runner.invoke(stop_command, ["--force"]) + assert result.exit_code == 0, result.output + assert "SIGKILL" in result.output + # Reap the zombie so we can confirm the exit. + proc.wait(timeout=5) + assert proc.returncode is not None, "Process should have exited" + assert not _state_file().exists() + finally: + if proc.poll() is None: + proc.kill() + proc.wait() + + +def test_stop_timeout_escalates(isolated_home: Path) -> None: + """``serve stop`` escalates to SIGKILL when the graceful window expires.""" + # A process that ignores SIGTERM. + proc = subprocess.Popen( + [ + sys.executable, + "-c", + ("import signal, time; signal.signal(signal.SIGTERM, lambda s, f: None); time.sleep(60)"), + ], + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + # Give the process time to register the SIGTERM handler. + time.sleep(0.3) + + write_state_file( + { + "pid": proc.pid, + "started_at": "2026-06-21T19:30:00Z", + "host": "127.0.0.1", + "port": 9099, + "api_key": "local-dev-secret", + "workspace": "test", + "host_config_path": None, + "providers_summary": {}, + } + ) + + try: + runner = CliRunner() + result = runner.invoke(stop_command, ["--timeout", "0.5"]) + assert result.exit_code == 0, result.output + assert "SIGKILL" in result.output + # Reap the zombie to confirm the exit. + proc.wait(timeout=5) + assert proc.returncode is not None, "Process should have exited" + assert not _state_file().exists() + finally: + if proc.poll() is None: + proc.kill() + proc.wait() + + +# --------------------------------------------------------------------------- +# ``serve restart`` +# --------------------------------------------------------------------------- + + +def test_restart_no_state_file_exits_1(isolated_home: Path) -> None: + """``serve restart`` exits 1 when there is no state file.""" + runner = CliRunner() + result = runner.invoke(restart_command, []) + assert result.exit_code == 1 + assert "nothing to restart" in result.output + + +def test_restart_invokes_stop_then_start(isolated_home: Path) -> None: + """``serve restart`` assembles the correct command from state file args.""" + # Spawn a dummy long-lived process to act as the "old server". + old_proc = subprocess.Popen( + [sys.executable, "-c", "import time; time.sleep(60)"], + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + old_pid = old_proc.pid + + write_state_file( + { + "pid": old_pid, + "started_at": "2026-06-21T19:30:00Z", + "host": "127.0.0.1", + "port": 9099, + "api_key": "local-dev-secret", + "workspace": "my-workspace", + "host_config_path": "/tmp/cfg.json", + "providers_summary": {"anthropic": 1}, + } + ) + + captured_cmd: list[list[str]] = [] + + # Immediately write a fresh state file so restart_command thinks the + # new server came up. + def _fake_popen(cmd: list[str], **kwargs: object) -> MagicMock: + captured_cmd.append(cmd) + # Write the new state file so the restart polling loop exits quickly. + write_state_file( + { + "pid": 99999, # different from old_pid + "started_at": "2026-06-21T20:00:00Z", + "host": "127.0.0.1", + "port": 9099, + "api_key": "local-dev-secret", + "workspace": "my-workspace", + "host_config_path": "/tmp/cfg.json", + "providers_summary": {"anthropic": 1}, + } + ) + m = MagicMock() + m.pid = 99999 + return m + + try: + with patch( + "amplifier_agent_cli.admin.serve_lifecycle.subprocess.Popen", + side_effect=_fake_popen, + ): + runner = CliRunner() + result = runner.invoke(restart_command, []) + + assert result.exit_code == 0, result.output + assert "restarted" in result.output + assert len(captured_cmd) == 1 + cmd = captured_cmd[0] + + # Must include the original launch args. + assert "serve" in cmd + assert "chat-completions" in cmd + assert "--bind" in cmd + assert "127.0.0.1" in cmd + assert "--port" in cmd + assert "9099" in cmd + assert "--workspace" in cmd + assert "my-workspace" in cmd + assert "--config" in cmd + assert "/tmp/cfg.json" in cmd + + # api_key must NOT appear in any readable error output or logs — + # but it IS in the cmd (passed as a flag to the sub-process, same as + # the original invocation). We just verify the key was forwarded. + assert "--api-key" in cmd + finally: + if is_pid_alive(old_pid): + old_proc.kill() + old_proc.wait() diff --git a/tests/http/__init__.py b/tests/http/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/http/test_chat_completions_validation.py b/tests/http/test_chat_completions_validation.py new file mode 100644 index 00000000..c3aa2681 --- /dev/null +++ b/tests/http/test_chat_completions_validation.py @@ -0,0 +1,234 @@ +"""Tests for chat-completions validation: model registry, streaming flag, upstream errors. + +Covers: +- Unknown model → HTTP 400 with structured error body. +- ``stream: false`` → single JSON body (chat.completion shape). +- ``stream: true`` → SSE (text/event-stream). +- Upstream error before first chunk → HTTP 502 with structured error body. +- Upstream error after first chunk (mid-stream) → embedded in delta.content. +""" + +from __future__ import annotations + +from contextlib import asynccontextmanager +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from amplifier_agent_http.routes import chat_completions as cc_module +from amplifier_agent_http.routes import models as models_module + +# --------------------------------------------------------------------------- +# Shared fixtures +# --------------------------------------------------------------------------- + + +def _make_test_app(*, registry: dict[str, str] | None = None) -> FastAPI: + """Build a minimal FastAPI app with no real lifespan for chat-completions tests. + + Uses a no-op lifespan so TestClient doesn't trigger any bundle loading, + provider enumeration, or sys.exit() calls. State is seeded directly. + """ + prepared_mock = MagicMock() + prepared_mock.mount_plan = {} + + state_registry = registry or {} + + @asynccontextmanager + async def _noop_lifespan(application: FastAPI): + # Seed required app.state attributes before yielding. + application.state.config = MagicMock() + application.state.config.model_id = "amplifier" + application.state.config.api_key = "test-key" + application.state.prepared = prepared_mock + application.state.agent_configs = {} + application.state.resolved_workspace = None + application.state.host_config = {} + application.state.available_models = [] + application.state.served_models_registry = state_registry + yield + + app = FastAPI(lifespan=_noop_lifespan) + app.include_router(cc_module.router) + app.include_router(models_module.router) + return app + + +AUTH = {"Authorization": "Bearer test-key"} + + +def _chat_payload(model: str = "claude-3-5-sonnet-20241022", **kwargs: Any) -> dict[str, Any]: + base: dict[str, Any] = { + "model": model, + "messages": [{"role": "user", "content": "hello"}], + } + base.update(kwargs) + return base + + +# --------------------------------------------------------------------------- +# Edit A — model validation +# --------------------------------------------------------------------------- + + +def test_unknown_model_returns_400() -> None: + """Request with an unregistered model → HTTP 400 + structured error body.""" + app = _make_test_app(registry={}) # empty registry → no models served + with TestClient(app, raise_server_exceptions=False) as client: + resp = client.post( + "/v1/chat/completions", + json=_chat_payload(model="nonexistent-model"), + headers=AUTH, + ) + + assert resp.status_code == 400 + body = resp.json() + assert "error" in body["detail"] + error = body["detail"]["error"] + assert error["code"] == "unknown_model" + assert error["type"] == "invalid_request_error" + assert "nonexistent-model" in error["message"] + + +# --------------------------------------------------------------------------- +# Edit B — stream flag +# --------------------------------------------------------------------------- + + +def test_stream_false_returns_single_json() -> None: + """``stream: false`` → single JSON body, Content-Type: application/json.""" + registry = {"claude-3-5-sonnet-20241022": "anthropic"} + app = _make_test_app(registry=registry) + + async def _fake_run_chat_turn(**kwargs: Any) -> str: + display = kwargs["display"] + # display.emit is async; must be awaited. + await display.emit({"type": "text:delta", "text": "Hello from non-streaming path"}) + return "done" + + with ( + patch( + "amplifier_agent_http.routes.chat_completions.run_chat_turn", + side_effect=_fake_run_chat_turn, + ), + TestClient(app, raise_server_exceptions=False) as client, + ): + resp = client.post( + "/v1/chat/completions", + json=_chat_payload(stream=False), + headers=AUTH, + ) + + assert resp.status_code == 200 + assert "application/json" in resp.headers["content-type"] + body = resp.json() + assert body["object"] == "chat.completion" + assert "choices" in body + assert body["choices"][0]["message"]["role"] == "assistant" + + +def test_stream_true_returns_sse() -> None: + """``stream: true`` → SSE stream, Content-Type: text/event-stream.""" + registry = {"claude-3-5-sonnet-20241022": "anthropic"} + app = _make_test_app(registry=registry) + + async def _fake_run_chat_turn(**kwargs: Any) -> str: + display = kwargs["display"] + # display.emit is async; must be awaited. + await display.emit({"type": "text:delta", "text": "hi"}) + return "done" + + with ( + patch( + "amplifier_agent_http.routes.chat_completions.run_chat_turn", + side_effect=_fake_run_chat_turn, + ), + TestClient(app, raise_server_exceptions=False) as client, + ): + resp = client.post( + "/v1/chat/completions", + json=_chat_payload(stream=True), + headers=AUTH, + ) + + assert resp.status_code == 200 + assert "text/event-stream" in resp.headers["content-type"] + # Response body should contain SSE data lines + assert "data:" in resp.text + + +# --------------------------------------------------------------------------- +# Edit C — upstream errors +# --------------------------------------------------------------------------- + + +def test_upstream_error_before_first_chunk_returns_502() -> None: + """run_chat_turn raises immediately → HTTP 502 with structured error body.""" + registry = {"claude-3-5-sonnet-20241022": "anthropic"} + app = _make_test_app(registry=registry) + + with ( + patch( + "amplifier_agent_http.routes.chat_completions.run_chat_turn", + new=AsyncMock(side_effect=RuntimeError("provider init failure")), + ), + TestClient(app, raise_server_exceptions=False) as client, + ): + resp = client.post( + "/v1/chat/completions", + json=_chat_payload(), + headers=AUTH, + ) + + assert resp.status_code == 502 + body = resp.json() + assert "error" in body["detail"] + error = body["detail"]["error"] + assert error["code"] == "upstream_error" + assert error["type"] == "upstream_error" + + +def test_upstream_error_after_first_chunk_embeds_in_delta() -> None: + """Error raised after the preflight window → embedded in delta.content (SSE 200). + + The preflight check in ``chat_completions`` waits 50 ms for immediate failures. + This mock emits a content event and then raises AFTER 100 ms, so the preflight + check does not see a failed task; the error occurs mid-stream and must be + embedded in delta.content of the in-progress SSE response. + """ + import asyncio as _asyncio + + registry = {"claude-3-5-sonnet-20241022": "anthropic"} + app = _make_test_app(registry=registry) + + async def _run_with_late_error(**kwargs: Any) -> str: + display = kwargs["display"] + # Emit a real content event so the task is known-running at preflight time. + # Must await display.emit because it is an async method. + await display.emit({"type": "text:delta", "text": "partial response"}) + # Sleep longer than the preflight window (50 ms) so the pre-flight check + # sees a still-running task and allows the SSE stream to start. + await _asyncio.sleep(0.1) + raise RuntimeError("mid-stream provider error") + + with ( + patch( + "amplifier_agent_http.routes.chat_completions.run_chat_turn", + side_effect=_run_with_late_error, + ), + TestClient(app, raise_server_exceptions=False) as client, + ): + resp = client.post( + "/v1/chat/completions", + json=_chat_payload(stream=True), + headers=AUTH, + ) + + # Response is still 200 SSE (headers already committed before error occurs) + assert resp.status_code == 200 + assert "text/event-stream" in resp.headers["content-type"] + # The error should be embedded somewhere in the SSE body + sse_body = resp.text + assert "amplifier-agent error" in sse_body or "mid-stream provider error" in sse_body diff --git a/tests/http/test_lifespan_providers.py b/tests/http/test_lifespan_providers.py new file mode 100644 index 00000000..1939a103 --- /dev/null +++ b/tests/http/test_lifespan_providers.py @@ -0,0 +1,277 @@ +"""Tests for the explicit-providers lifespan boot semantics. + +Verifies that: +- Server exits 2 when ``host_config.providers`` is absent, empty, or wrong type. +- Server exits 2 when any declared provider fails (missing credentials, + missing module, list_models() raises, list_models() returns 0 models). +- Server boots successfully when all providers initialise correctly. +- Each provider's ``config`` block is forwarded as ``extra_config`` to + ``list_provider_models``. +""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from fastapi import FastAPI + +from amplifier_agent_cli.admin.models import ( + ProviderCredentialsMissingError, + ProviderModuleNotInstalledError, +) +from amplifier_agent_http._config import ServerConfig +from amplifier_agent_http.app import lifespan + +# --------------------------------------------------------------------------- +# Shared mock helpers +# --------------------------------------------------------------------------- + + +def _server_config(host_config_path: str | None = "/tmp/test-host-config.json") -> ServerConfig: + return ServerConfig( + api_key="test-api-key", + model_id="test-model", + model_display_name="Test", + host="127.0.0.1", + port=9099, + workspace=None, + host_config_path=host_config_path, + ) + + +def _model_info(model_id: str, provider_id: str) -> dict[str, Any]: + """Minimal model-info dict that passes the lifespan's model_dump path.""" + m = MagicMock() + m.model_dump.return_value = {"id": model_id} + return m + + +@pytest.fixture() +def base_mocks(tmp_path): + """Patch every heavy-weight lifespan dependency except list_provider_models. + + Returns a dict of mock objects keyed by their symbolic name so individual + tests can adjust return values / side effects without re-declaring patches. + """ + prepared_mock = MagicMock() + prepared_mock.mount_plan = {} + + with ( + patch("amplifier_agent_http.app.load_config", return_value=_server_config()) as m_load_cfg, + patch( + "amplifier_agent_http.app.load_and_prepare_cached", + new_callable=AsyncMock, + return_value=prepared_mock, + ) as m_prep, + patch( + "amplifier_agent_http.app.load_host_config", + return_value={}, + ) as m_host, + patch( + "amplifier_agent_http.app.resolve_workspace", + return_value="test-workspace", + ) as m_ws, + patch("amplifier_agent_http.app.prepare_bundle_for_session") as m_pbs, + patch( + "amplifier_agent_http.app.hydrate_agent_configs", + return_value={}, + ) as m_hydrate, + patch("amplifier_agent_http.app._resolve_aaa_version", return_value="0.0.0+test"), + # Prevent lifespan from touching the real state file during tests. + patch("amplifier_agent_cli.admin.serve_lifecycle.write_state_file") as m_write_sf, + patch("amplifier_agent_cli.admin.serve_lifecycle.remove_state_file") as m_remove_sf, + ): + yield { + "load_config": m_load_cfg, + "load_and_prepare_cached": m_prep, + "load_host_config": m_host, + "resolve_workspace": m_ws, + "prepare_bundle_for_session": m_pbs, + "hydrate_agent_configs": m_hydrate, + "prepared": prepared_mock, + "write_state_file": m_write_sf, + "remove_state_file": m_remove_sf, + } + + +# --------------------------------------------------------------------------- +# Exit-2 scenarios +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_lifespan_exits_when_providers_block_missing(base_mocks) -> None: + """No ``providers`` key in host_config → exit 2.""" + base_mocks["load_host_config"].return_value = {} # no providers key + app = FastAPI() + + with pytest.raises(SystemExit) as exc_info: + async with lifespan(app): + pass # pragma: no cover + + assert exc_info.value.code == 2 + + +@pytest.mark.asyncio +async def test_lifespan_exits_when_providers_block_empty_dict(base_mocks) -> None: + """``providers: {}`` → exit 2 (must declare at least one provider).""" + base_mocks["load_host_config"].return_value = {"providers": {}} + app = FastAPI() + + with pytest.raises(SystemExit) as exc_info: + async with lifespan(app): + pass # pragma: no cover + + assert exc_info.value.code == 2 + + +@pytest.mark.asyncio +async def test_lifespan_exits_when_providers_block_wrong_type(base_mocks) -> None: + """``providers: "not-a-dict"`` → ConfigError raised by load_config before lifespan runs. + + The validator in loader.py raises ConfigError, which propagates up through + the lifespan's load_host_config call and is re-raised. + """ + from amplifier_agent_lib.config import ConfigError + + base_mocks["load_host_config"].side_effect = ConfigError( + code="config_invalid_type", + message="`providers` must be a JSON object", + classification="protocol", + ) + app = FastAPI() + + with pytest.raises(ConfigError) as exc_info: + async with lifespan(app): + pass # pragma: no cover + + assert exc_info.value.code == "config_invalid_type" + + +@pytest.mark.asyncio +async def test_lifespan_exits_when_provider_credentials_missing(base_mocks) -> None: + """Provider with missing credentials → collected error → exit 2.""" + base_mocks["load_host_config"].return_value = {"providers": {"anthropic": {}}} + app = FastAPI() + + with ( + patch( + "amplifier_agent_http.app.list_provider_models", + side_effect=ProviderCredentialsMissingError("ANTHROPIC_API_KEY not set"), + ), + pytest.raises(SystemExit) as exc_info, + ): + async with lifespan(app): + pass # pragma: no cover + + assert exc_info.value.code == 2 + + +@pytest.mark.asyncio +async def test_lifespan_exits_when_provider_module_not_installed(base_mocks) -> None: + """Provider module not installed → collected error → exit 2.""" + base_mocks["load_host_config"].return_value = {"providers": {"openai": {}}} + app = FastAPI() + + with ( + patch( + "amplifier_agent_http.app.list_provider_models", + side_effect=ProviderModuleNotInstalledError("openai provider not installed"), + ), + pytest.raises(SystemExit) as exc_info, + ): + async with lifespan(app): + pass # pragma: no cover + + assert exc_info.value.code == 2 + + +@pytest.mark.asyncio +async def test_lifespan_exits_when_provider_returns_no_models(base_mocks) -> None: + """list_models() returns [] → collected error → exit 2.""" + base_mocks["load_host_config"].return_value = {"providers": {"anthropic": {}}} + app = FastAPI() + + with ( + patch("amplifier_agent_http.app.list_provider_models", return_value=[]), + pytest.raises(SystemExit) as exc_info, + ): + async with lifespan(app): + pass # pragma: no cover + + assert exc_info.value.code == 2 + + +# --------------------------------------------------------------------------- +# Happy-path scenarios +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_lifespan_succeeds_with_two_providers(base_mocks) -> None: + """Two providers load successfully → registry contains both providers' models.""" + anthropic_model = MagicMock() + anthropic_model.model_dump.return_value = {"id": "claude-3-5-sonnet-20241022"} + + openai_model = MagicMock() + openai_model.model_dump.return_value = {"id": "gpt-4o"} + + base_mocks["load_host_config"].return_value = { + "providers": { + "anthropic": {}, + "openai": {}, + } + } + app = FastAPI() + + def _side_effect(provider_id: str, timeout: float, extra_config: dict | None = None): + if provider_id == "anthropic": + return [anthropic_model] + if provider_id == "openai": + return [openai_model] + return [] + + with patch("amplifier_agent_http.app.list_provider_models", side_effect=_side_effect): + async with lifespan(app): + # Verify registry mappings + assert app.state.served_models_registry["claude-3-5-sonnet-20241022"] == "anthropic" + assert app.state.served_models_registry["gpt-4o"] == "openai" + # Verify available_models has both, tagged with _provider + model_ids = {m["id"] for m in app.state.available_models} + assert "claude-3-5-sonnet-20241022" in model_ids + assert "gpt-4o" in model_ids + providers_in_models = {m["_provider"] for m in app.state.available_models} + assert "anthropic" in providers_in_models + assert "openai" in providers_in_models + + +@pytest.mark.asyncio +async def test_lifespan_passes_extra_config_to_list_provider_models(base_mocks) -> None: + """Provider ``config`` block is forwarded as ``extra_config`` to list_provider_models.""" + extra_cfg = {"base_url": "https://api.openai.com/v1", "filtered": False} + base_mocks["load_host_config"].return_value = { + "providers": { + "openai": {"config": extra_cfg}, + } + } + app = FastAPI() + + model_mock = MagicMock() + model_mock.model_dump.return_value = {"id": "gpt-4o"} + + calls: list[tuple] = [] + + def _capture(provider_id: str, timeout: float, extra_config: dict | None = None): + calls.append((provider_id, timeout, extra_config)) + return [model_mock] + + with patch("amplifier_agent_http.app.list_provider_models", side_effect=_capture): + async with lifespan(app): + pass + + assert len(calls) == 1 + called_provider, _timeout, called_extra = calls[0] + assert called_provider == "openai" + assert called_extra == extra_cfg