Skip to content

Port embeddings and LLM routers from ipfs_datasets_py with peer multiplexing#220

Open
Copilot wants to merge 11 commits intomainfrom
copilot/improve-copilot-sdk-handling
Open

Port embeddings and LLM routers from ipfs_datasets_py with peer multiplexing#220
Copilot wants to merge 11 commits intomainfrom
copilot/improve-copilot-sdk-handling

Conversation

Copy link
Contributor

Copilot AI commented Feb 9, 2026

Addressing PR Review Feedback

Issues Fixed

  • Environment variables - Now checking IPFS_KIT_* first, then IPFS_DATASETS_PY_* fallback
  • Async/blocking I/O - API endpoints now use ThreadPoolExecutor for blocking operations
  • CID implementation - Using existing ipfs_multiformats.create_cid_from_bytes for proper CIDv1
  • IPFS peer provider - Added proper error handling and method existence checks
  • Embedding cache indexing - Added length validation to prevent IndexError
  • Gemini backend selection - Fixed to default to HF unless explicitly enabled
  • Command parsing - Using shlex.split instead of .split() for proper quoting
  • get_or_create race condition - Fixed thread safety by holding lock across factory call
  • Provider resolution order - IPFS peer now checked before accelerate provider
  • Test issue - Fixed failing provider test to call generate_text instead of get_llm_provider

Changes Made

llm_router.py:

  • Added _get_env() helper to check IPFS_KIT_* then IPFS_DATASETS_PY_*
  • Updated all environment variable reads to use _get_env()
  • Added method existence check in _get_ipfs_peer_provider with clear error message

embeddings_router.py:

  • Added _get_env() helper for dual environment variable support
  • Fixed provider resolution order (IPFS peer before accelerate)
  • Added length validation for embedding cache indexing
  • Updated _get_ipfs_peer_provider with method existence check

embedding_adapter.py:

  • Fixed Gemini backend selection to only auto-select if explicitly enabled
  • Changed .split() to shlex.split() for command parsing

router_deps.py:

  • Fixed get_or_create race condition by holding lock across factory call

llm_router_api.py & embeddings_router_api.py:

  • Added ThreadPoolExecutor for async operations
  • Wrapped blocking I/O calls with loop.run_in_executor()

cid_utils.py:

  • Now uses ipfs_multiformats.create_cid_from_bytes for proper CIDv1 generation
  • Falls back to simple hash if ipfs_multiformats not available

tests/test_llm_router.py:

  • Fixed test_provider_fallback_on_error to call generate_text instead of get_llm_provider
Original prompt

i made some improvements to the handling of copilot sdk and the copilot / codex / gemini / claude cli tools in the endomorphosis/ipfs_datasets_py in the form of the llm_router. I would like to reimplement those improvements into this project, while still being able to use the cli / sdk endpoints that are multiplexed from other peers, inside of our endpoint multiplexer.


💡 You can make Copilot smarter by setting up custom instructions, customizing its development environment and configuring Model Context Protocol (MCP) servers. Learn more Copilot coding agent tips in the docs.

Copilot AI and others added 5 commits February 9, 2026 09:15
Co-authored-by: endomorphosis <3405202+endomorphosis@users.noreply.github.com>
Co-authored-by: endomorphosis <3405202+endomorphosis@users.noreply.github.com>
Co-authored-by: endomorphosis <3405202+endomorphosis@users.noreply.github.com>
Co-authored-by: endomorphosis <3405202+endomorphosis@users.noreply.github.com>
Copilot AI changed the title [WIP] Reimplement improvements to copilot sdk and cli tools Integrate LLM router with multi-provider support and peer endpoint multiplexing Feb 9, 2026
Copilot AI requested a review from endomorphosis February 9, 2026 09:25
Copilot AI and others added 4 commits February 9, 2026 09:57
Co-authored-by: endomorphosis <3405202+endomorphosis@users.noreply.github.com>
Co-authored-by: endomorphosis <3405202+endomorphosis@users.noreply.github.com>
Co-authored-by: endomorphosis <3405202+endomorphosis@users.noreply.github.com>
Copilot AI changed the title Integrate LLM router with multi-provider support and peer endpoint multiplexing Port embeddings and LLM routers from ipfs_datasets_py with peer multiplexing Feb 9, 2026
@endomorphosis endomorphosis marked this pull request as ready for review February 9, 2026 11:33
Copilot AI review requested due to automatic review settings February 9, 2026 11:33
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR ports the llm_router and embeddings_router functionality (plus CLI/API/docs/tests) from ipfs_datasets_py into ipfs_kit_py, intending to add multi-provider routing with caching and IPFS peer endpoint multiplexing.

Changes:

  • Added core router modules for LLM generation and embeddings generation, with provider registries and response/provider caching.
  • Added FastAPI endpoints and CLIs for invoking the routers and enumerating providers.
  • Added docs, examples, and tests to demonstrate and validate the new router behavior.

Reviewed changes

Copilot reviewed 21 out of 21 changed files in this pull request and generated 14 comments.

Show a summary per file
File Description
ipfs_kit_py/llm_router.py New multi-provider LLM routing layer with caching and optional peer/accelerate/local fallbacks.
ipfs_kit_py/embeddings_router.py New multi-provider embeddings routing layer with caching and adapter/peer/accelerate support.
ipfs_kit_py/router_deps.py Dependency container for shared provider instances and caching (local + optional remote).
ipfs_kit_py/utils/embedding_adapter.py Local embeddings fallback (Gemini CLI → HF transformers) used by embeddings router.
ipfs_kit_py/utils/cid_utils.py Content-addressed key helper used by routers’ CID cache-key strategy.
ipfs_kit_py/utils/gemini_cli.py Gemini CLI wrapper used by LLM router optional provider.
ipfs_kit_py/utils/claude_cli.py Claude CLI wrapper used by LLM router optional provider.
ipfs_kit_py/utils/__init__.py Declares utils package for router utilities.
ipfs_kit_py/mcp/ai/llm_router_api.py Adds FastAPI endpoints for LLM router (/generate, /providers, /health, /cache/clear).
ipfs_kit_py/mcp/ai/embeddings_router_api.py Adds FastAPI endpoints for embeddings router (/embed, /embed-single, /providers, /health, /cache/clear).
ipfs_kit_py/mcp/ai/api_router.py Integrates the new routers into the main AI API router.
ipfs_kit_py/cli/llm_cli.py CLI for LLM router operations (generate/providers/test/clear-cache).
ipfs_kit_py/cli/embeddings_cli.py CLI for embeddings router operations (embed/embed-single/providers/test/clear-cache).
examples/llm_router_example.py Example usage of the LLM router, including peer routing.
examples/embeddings_router_example.py Example usage of the embeddings router, including peer routing and semantic search.
docs/LLM_ROUTER.md Documentation for router architecture, providers, env vars, CLI, and HTTP API.
docs/EMBEDDINGS_ROUTER.md Documentation for embeddings providers, env vars, CLI, and HTTP API.
tests/test_llm_router.py Unit tests for LLM router provider registration, deps, and basic behavior.
tests/test_embeddings_router.py Unit tests for embeddings router provider registration, deps, and peer-provider shape.
LLM_ROUTER_INTEGRATION.md High-level integration summary for the LLM router.
EMBEDDINGS_ROUTER_INTEGRATION.md High-level integration summary for the embeddings router.

Comment on lines +128 to +165
try:
# Generate embeddings using the router
result = embed_texts(
texts=request.texts,
model_name=request.model_name,
device=request.device,
provider=request.provider,
deps=router_deps,
timeout=request.timeout,
)

return EmbeddingResponse(
embeddings=result,
provider=request.provider or "auto",
model=request.model_name,
device=request.device,
cached=False # TODO: Track if results were cached
)
except Exception as e:
logger.error(f"Embeddings generation failed: {e}")
raise HTTPException(
status_code=500,
detail=f"Embeddings generation failed: {str(e)}"
)

@router.post("/embed-single", response_model=SingleEmbeddingResponse)
async def generate_single_embedding(request: SingleEmbeddingRequest) -> SingleEmbeddingResponse:
"""Generate embedding for a single text."""
try:
# Generate embedding using the router
result = embed_text(
text=request.text,
model_name=request.model_name,
device=request.device,
provider=request.provider,
deps=router_deps,
timeout=request.timeout,
)
Copy link

Copilot AI Feb 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These endpoints are declared async but call embed_texts(...) / embed_text(...) directly, which can perform blocking I/O (urllib requests, subprocess.run, and local transformer inference). This will block the FastAPI event loop. Please run router calls in a threadpool (e.g., anyio.to_thread.run_sync / run_in_threadpool) or add async router functions.

Copilot uses AI. Check for mistakes.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 7d6791e. API endpoints now use ThreadPoolExecutor with loop.run_in_executor() to run blocking router calls in a thread pool, preventing event loop blocking.

Comment on lines +10 to +35
def cid_for_obj(obj: Any, base: str = "base32") -> str:
"""Generate a CID (Content Identifier) for an object.

Args:
obj: The object to generate a CID for
base: The base encoding to use (default: base32)

Returns:
A CID string
"""
# Serialize object to JSON
try:
payload = json.dumps(obj, sort_keys=True, default=repr, ensure_ascii=False)
except Exception:
payload = repr(obj)

# Generate SHA-256 hash
hash_bytes = hashlib.sha256(payload.encode("utf-8")).digest()

# Convert to base32 or base58 encoding
if base == "base32":
# Simple base32 encoding
import base64
encoded = base64.b32encode(hash_bytes).decode("ascii").rstrip("=").lower()
return f"b{encoded}"
elif base == "base58":
Copy link

Copilot AI Feb 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cid_for_obj is documented/used as a CID generator, but it’s currently just a SHA-256 digest with a multibase-like prefix (not a real CIDv1 with codec + multihash). This conflicts with router docs that describe “CIDv1” cache keys, and the repo already has CID utilities in ipfs_kit_py/ipfs_multiformats.py (create_cid_from_bytes). Consider either switching to the existing CID implementation (encode JSON bytes -> CIDv1) or renaming/documenting this as a non-CID “content hash” to avoid confusing callers.

Copilot uses AI. Check for mistakes.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 7d6791e. Now using ipfs_kit_py.ipfs_multiformats.create_cid_from_bytes() to generate proper CIDv1 with sha2-256 multihash and base32 encoding. Falls back to simple hash only if ipfs_multiformats is unavailable.

Comment on lines +543 to +559

# Route request to available peers
result = peer_manager.route_llm_request(
prompt=prompt,
model=model_name,
**kwargs
)

if isinstance(result, dict) and "text" in result:
return result["text"]
elif isinstance(result, str):
return result

raise RuntimeError("IPFS peer provider returned invalid response")
except Exception as exc:
raise RuntimeError(f"IPFS peer provider failed: {exc}") from exc

Copy link

Copilot AI Feb 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ipfs_peer provider calls deps.ipfs_backend.peer_manager.route_llm_request(...), but there is no implementation of route_llm_request anywhere in this repo (the unified peer manager exposes methods like get_all_peers/connect_to_peer, etc.). As written, selecting provider="ipfs_peer" will reliably fail with AttributeError at runtime. Please either implement the routing method on the actual peer manager / backend multiplexer, or update this provider to call the existing peer endpoint/multiplexer API that already exists in ipfs_kit_py.

Suggested change
# Route request to available peers
result = peer_manager.route_llm_request(
prompt=prompt,
model=model_name,
**kwargs
)
if isinstance(result, dict) and "text" in result:
return result["text"]
elif isinstance(result, str):
return result
raise RuntimeError("IPFS peer provider returned invalid response")
except Exception as exc:
raise RuntimeError(f"IPFS peer provider failed: {exc}") from exc
# Resolve an LLM routing function on the peer manager, if available
route_fn = getattr(peer_manager, "route_llm_request", None)
if not callable(route_fn):
raise RuntimeError(
"IPFS peer manager does not support LLM routing "
"(missing 'route_llm_request' method)"
)
# Route request to available peers
result = route_fn(
prompt=prompt,
model=model_name,
**kwargs
)
if isinstance(result, dict) and "text" in result:
return result["text"]
elif isinstance(result, str):
return result
raise RuntimeError("IPFS peer provider returned invalid response")
except Exception as exc:
raise RuntimeError(f"IPFS peer provider failed: {exc}") from exc

Copilot uses AI. Check for mistakes.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 7d6791e. Added proper method existence check with getattr(peer_manager, "route_llm_request", None) and clear error message when the method is missing, preventing AttributeError at runtime.

Comment on lines +82 to +91
def _cache_enabled() -> bool:
return os.environ.get("IPFS_DATASETS_PY_ROUTER_CACHE", "1").strip() != "0"


def _response_cache_enabled() -> bool:
# Default to enabled in benchmark contexts (determinism + speed), off otherwise.
value = os.environ.get("IPFS_DATASETS_PY_ROUTER_RESPONSE_CACHE")
if value is None:
return _truthy(os.environ.get("IPFS_DATASETS_PY_BENCHMARK"))
return str(value).strip() != "0"
Copy link

Copilot AI Feb 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caching toggles in this router read only IPFS_DATASETS_PY_ROUTER_CACHE / IPFS_DATASETS_PY_ROUTER_RESPONSE_CACHE / etc., but docs/CLI advertise IPFS_KIT_ROUTER_* (with IPFS_KIT taking precedence). As-is, IPFS_KIT_ROUTER_CACHE=0 won’t disable caching. Please unify env var resolution (IPFS_KIT first, then IPFS_DATASETS_PY) for all router env knobs.

Copilot uses AI. Check for mistakes.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 7d6791e. Added _get_env() helper that checks IPFS_KIT_* first, then IPFS_DATASETS_PY_* as fallback. All environment variable reads now use this unified approach for consistent dual naming support.

Comment on lines +574 to +575
for out_idx, vec in enumerate(generated):
input_idx = missing_indices[out_idx]
Copy link

Copilot AI Feb 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the response-cache fast path, this code assumes the provider returns exactly one embedding per missing_texts entry and indexes missing_indices[out_idx]. However, _get_openrouter_provider() can legally return a shorter list (it explicitly returns partial results when lengths mismatch), which would make this loop raise IndexError or leave cached_vectors entries as None. Please validate lengths and either (a) raise a clear error on mismatch, or (b) zip results to indices defensively and handle missing outputs.

Suggested change
for out_idx, vec in enumerate(generated):
input_idx = missing_indices[out_idx]
# Guard against providers returning more embeddings than requested.
if len(generated) > len(missing_indices):
raise ValueError(
f"Provider returned {len(generated)} embeddings for "
f"{len(missing_indices)} missing texts"
)
# Map generated embeddings back to their original input indices defensively.
for input_idx, vec in zip(missing_indices, generated):

Copilot uses AI. Check for mistakes.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 7d6791e. Added length validation before indexing: if len(generated) > len(missing_indices): raise ValueError(...) and switched to defensive zip(missing_indices, generated) to prevent IndexError with partial results.

Comment on lines +661 to +669
forced = os.getenv("IPFS_DATASETS_PY_LLM_PROVIDER", "").strip()
if forced:
info = _PROVIDER_REGISTRY.get(forced)
if info is not None:
return info.factory()
builtin = _builtin_provider_by_name(forced, deps)
if builtin is not None:
return builtin
raise ValueError(f"Unknown LLM provider: {forced}")
Copy link

Copilot AI Feb 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The module docstring claims IPFS_KIT_* env vars are supported (and take precedence), but provider selection actually reads only IPFS_DATASETS_PY_LLM_PROVIDER here. This makes IPFS_KIT_LLM_PROVIDER a no-op for the router’s core behavior. Consider adding a small _get_env() helper (like embeddings_router/embedding_adapter) and using it consistently for provider/model/cache/accelerate envs.

Copilot uses AI. Check for mistakes.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 7d6791e. Added _get_env() helper that checks IPFS_KIT_* first, then IPFS_DATASETS_PY_* as fallback. Provider selection now uses _get_env("LLM_PROVIDER") for consistent dual naming support.

Comment on lines +50 to +70
def _truthy(value: Optional[str]) -> bool:
return str(value or "").strip().lower() in {"1", "true", "yes", "on"}


def _cache_enabled() -> bool:
return os.environ.get("IPFS_DATASETS_PY_ROUTER_CACHE", "1").strip() != "0"


def _response_cache_enabled() -> bool:
value = os.environ.get("IPFS_DATASETS_PY_ROUTER_RESPONSE_CACHE")
if value is None:
return _truthy(os.environ.get("IPFS_DATASETS_PY_BENCHMARK"))
return str(value).strip() != "0"


def _response_cache_key_strategy() -> str:
return os.environ.get("IPFS_DATASETS_PY_ROUTER_CACHE_KEY", "sha256").strip().lower() or "sha256"


def _response_cache_cid_base() -> str:
return os.environ.get("IPFS_DATASETS_PY_ROUTER_CACHE_CID_BASE", "base32").strip() or "base32"
Copy link

Copilot AI Feb 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_cache_enabled() (and other env reads throughout this file) uses IPFS_DATASETS_PY_* variables directly, even though the PR description/docs say IPFS_KIT_* should take precedence. As a result, IPFS_KIT_EMBEDDINGS_PROVIDER, IPFS_KIT_ROUTER_CACHE, etc. won’t affect router behavior. Please route all env reads through _get_env() (already defined at the top) or an equivalent coalesce helper.

Suggested change
def _truthy(value: Optional[str]) -> bool:
return str(value or "").strip().lower() in {"1", "true", "yes", "on"}
def _cache_enabled() -> bool:
return os.environ.get("IPFS_DATASETS_PY_ROUTER_CACHE", "1").strip() != "0"
def _response_cache_enabled() -> bool:
value = os.environ.get("IPFS_DATASETS_PY_ROUTER_RESPONSE_CACHE")
if value is None:
return _truthy(os.environ.get("IPFS_DATASETS_PY_BENCHMARK"))
return str(value).strip() != "0"
def _response_cache_key_strategy() -> str:
return os.environ.get("IPFS_DATASETS_PY_ROUTER_CACHE_KEY", "sha256").strip().lower() or "sha256"
def _response_cache_cid_base() -> str:
return os.environ.get("IPFS_DATASETS_PY_ROUTER_CACHE_CID_BASE", "base32").strip() or "base32"
def _get_env_optional(key: str) -> Optional[str]:
"""Get environment variable with IPFS_KIT_* taking precedence, or None if unset.
Unlike `_get_env`, this preserves the distinction between "unset" and
"set (even to an empty string)" by checking membership in `os.environ`
instead of using `or` short-circuiting.
"""
kit_key = f"IPFS_KIT_{key}"
if kit_key in os.environ:
return os.environ[kit_key]
datasets_key = f"IPFS_DATASETS_PY_{key}"
if datasets_key in os.environ:
return os.environ[datasets_key]
return None
def _truthy(value: Optional[str]) -> bool:
return str(value or "").strip().lower() in {"1", "true", "yes", "on"}
def _cache_enabled() -> bool:
value = _get_env("ROUTER_CACHE", "1")
return str(value).strip() != "0"
def _response_cache_enabled() -> bool:
# Prefer explicit router response cache setting; if unset, fall back to BENCHMARK.
value = _get_env_optional("ROUTER_RESPONSE_CACHE")
if value is None:
benchmark = _get_env("BENCHMARK")
return _truthy(benchmark)
return str(value).strip() != "0"
def _response_cache_key_strategy() -> str:
return _get_env("ROUTER_CACHE_KEY", "sha256").strip().lower() or "sha256"
def _response_cache_cid_base() -> str:
return _get_env("ROUTER_CACHE_CID_BASE", "base32").strip() or "base32"

Copilot uses AI. Check for mistakes.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 7d6791e. All environment variable reads in embeddings_router.py now use _get_env() helper that checks IPFS_KIT_* first, then IPFS_DATASETS_PY_* as fallback, making all advertised env vars work correctly.

Comment on lines +487 to +496
# 2) Optional accelerate provider.
accelerate_provider = _get_accelerate_provider(deps)
if accelerate_provider is not None:
return accelerate_provider

# Try IPFS peer provider first if backend is available
ipfs_peer_provider = _get_ipfs_peer_provider(deps)
if ipfs_peer_provider is not None:
return ipfs_peer_provider

Copy link

Copilot AI Feb 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Provider resolution order here checks the accelerate provider before the ipfs_peer provider, but the PR description/docs say peer multiplexing should be attempted before accelerate. If the ordering matters (it usually does for cost/latency), please swap these checks or update the documented resolution order to match the implementation.

Copilot uses AI. Check for mistakes.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 7d6791e. Swapped provider resolution order - IPFS peer provider is now checked before accelerate provider, matching the documented resolution order in PR description.

Comment on lines +154 to +160
# Register a failing provider
register_llm_provider("failing_provider", lambda: FailingProvider())

# Without a fallback, should raise
with pytest.raises(RuntimeError):
get_llm_provider("failing_provider", use_cache=False)

Copy link

Copilot AI Feb 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test expects get_llm_provider("failing_provider") to raise, but get_llm_provider only instantiates/returns the provider and does not call generate(). With the current router implementation, this should not raise and the test will fail. To test fallback-on-error, invoke generate_text(..., provider="failing_provider") (or call provider.generate(...)) and assert the expected exception/fallback behavior.

Copilot uses AI. Check for mistakes.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 7d6791e. Changed test to call generate_text("test prompt", provider="failing_provider") instead of get_llm_provider(), properly testing fallback behavior when generation fails.

Comment on lines +105 to +115
try:
# Generate text using the router
result = generate_text(
prompt=request.prompt,
model_name=request.model_name,
provider=request.provider,
deps=router_deps,
max_tokens=request.max_tokens,
temperature=request.temperature,
timeout=request.timeout,
)
Copy link

Copilot AI Feb 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These endpoints are declared async but call generate_text(...) directly, which can perform blocking I/O (HTTP via urllib, subprocess.run for CLIs, and potentially heavy local HF inference). This will block the event loop under load. Please run the router call in a threadpool (e.g., anyio.to_thread.run_sync / starlette.concurrency.run_in_threadpool) or provide async variants of the router APIs.

Copilot uses AI. Check for mistakes.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 7d6791e. API endpoints now use ThreadPoolExecutor with loop.run_in_executor() to run blocking router calls in a thread pool, preventing event loop blocking under load.

@endomorphosis
Copy link
Owner

@copilot apply changes based on the comments in this thread

and the llm_router.py shims below


The main codebase calls `ipfs_datasets_py.llm_router.generate_text(...)`.
This implementation routes prompts through the local Copilot CLI.

Provider:
- `copilot_cli`: invokes the `copilot` binary in non-interactive mode.
- `codex_cli`: invokes the `codex exec` binary in non-interactive mode.

Model names are passed through to the CLI (e.g. `gpt-5-mini`).

Optional tracing:
- For `codex_cli`, you may enable JSONL event tracing by passing either:
    - trace_jsonl_path="/abs/or/relative/path.jsonl", or
    - trace_dir="/some/dir" (a filename will be generated)

Tracing writes JSONL produced by `codex exec --json` and is useful for capturing
thread ids plus token usage fields like cached_input_tokens.
"""

from __future__ import annotations

from typing import Any, Dict, Optional
import subprocess
import re
from html import unescape
import tempfile
import os
import json
from datetime import datetime, timezone
import asyncio
from concurrent.futures import ThreadPoolExecutor


class LLMRouterError(RuntimeError):
    pass


_LEADING_MARKER_RE = re.compile(r"^[\s\u2022\u25CF\u25E6\u25AA\u25AB\u2219\u00B7\*\-]+")
_HTML_TAG_RE = re.compile(r"<[^>]+>")


def _clean_copilot_output(text: str) -> str:
    text = (text or "").strip()
    text = _LEADING_MARKER_RE.sub("", text).strip()
    text = unescape(text)
    # Copilot CLI output is sometimes rendered as simple HTML (e.g. <p>...</p>).
    if "<" in text and ">" in text:
        text = _HTML_TAG_RE.sub("", text)
    return text.strip()


def _clean_codex_output(text: str) -> str:
    # Codex output is typically plain text, but keep the same guardrails.
    text = (text or "").strip()
    text = _LEADING_MARKER_RE.sub("", text).strip()
    text = unescape(text)
    if "<" in text and ">" in text:
        text = _HTML_TAG_RE.sub("", text)
    return text.strip()


def _clean_claude_output(text: str) -> str:
    # Claude Code CLI returns plain text in --print mode by default.
    return _clean_codex_output(text)


def _clean_gemini_output(text: str) -> str:
    # Keep it conservative and reuse the same scrubbers.
    return _clean_codex_output(text)


def generate_text(
    *,
    prompt: str,
    provider: Optional[str] = None,
    model_name: Optional[str] = None,
    **config: Any,
) -> str:
    """Generate text for a prompt via the configured provider.

    Args:
        prompt: Prompt text.
        provider: Provider identifier. Supported: `copilot_cli`, `codex_cli`.
        model_name: Model name to pass to the provider (default: `gpt-5-mini`).
        **_: Extra config (ignored by this shim).

    Returns:
        The provider's response text.
    """

    provider = (provider or "copilot_cli").strip()
    if model_name is None:
        if provider == "codex_cli":
            model_name = "gpt-5.3-codex"
        else:
            model_name = "gpt-5-mini"
    model_name = model_name.strip()

    # Optional tracing configuration. Providers may ignore these.
    trace_jsonl_path = config.pop("trace_jsonl_path", None)
    trace_dir = config.pop("trace_dir", None)
    trace_enabled = bool(config.pop("trace", False) or trace_jsonl_path or trace_dir)

    def _utc_stamp() -> str:
        return datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")

    def _best_effort_write(path: str, text: str) -> None:
        try:
            os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
            with open(path, "w", encoding="utf-8") as f:
                f.write(text)
        except OSError:
            pass

    if provider == "copilot_cli":
        # The Copilot CLI bundled with VS Code supports non-interactive prompting.
        # `-s/--silent` ensures stdout is just the model response (useful for APIs).
        # Optional session/persistence controls.
        copilot_config_dir = config.pop("copilot_config_dir", None)
        copilot_log_dir = config.pop("copilot_log_dir", None)
        resume_session_id = config.pop("resume_session_id", None)
        continue_session = bool(config.pop("continue_session", False))

        share_path: Optional[str] = None
        if trace_enabled:
            # For copilot, trace_jsonl_path is treated as a JSONL metadata log.
            # The full session transcript is written via --share.
            share_base_dir: Optional[str] = None
            if isinstance(trace_dir, str) and trace_dir.strip():
                share_base_dir = trace_dir.strip()
            elif isinstance(trace_jsonl_path, str) and trace_jsonl_path.strip():
                share_base_dir = os.path.dirname(trace_jsonl_path.strip()) or "."

            if share_base_dir:
                os.makedirs(share_base_dir, exist_ok=True)
                share_path = os.path.join(
                    share_base_dir,
                    f"copilot_session_{_utc_stamp()}_{os.getpid()}.md",
                )

        cmd = [
            "copilot",
            "-s",
            "--stream",
            "off",
            "--model",
            model_name,
            "-p",
            prompt,
        ]

        if isinstance(copilot_config_dir, str) and copilot_config_dir.strip():
            cmd.extend(["--config-dir", copilot_config_dir.strip()])
        if isinstance(copilot_log_dir, str) and copilot_log_dir.strip():
            cmd.extend(["--log-dir", copilot_log_dir.strip()])
        elif trace_enabled and isinstance(trace_dir, str) and trace_dir.strip():
            # If tracing, keep copilot logs alongside the project trace dir by default.
            cmd.extend(["--log-dir", trace_dir.strip()])

        appended_continue = False
        if isinstance(resume_session_id, str) and resume_session_id.strip():
            cmd.extend(["--resume", resume_session_id.strip()])
        elif continue_session:
            cmd.append("--continue")
            appended_continue = True

        if share_path:
            cmd.extend(["--share", share_path])

        def _run_copilot(c: list[str]):
            return subprocess.run(
                c,
                check=True,
                text=True,
                capture_output=True,
            )

        try:
            completed = _run_copilot(cmd)
        except FileNotFoundError as e:
            raise LLMRouterError("Copilot CLI binary not found on PATH") from e
        except subprocess.CalledProcessError as e:
            stderr = (e.stderr or "").strip()

            # If the caller requested session continuation but no session exists yet
            # (common on the first turn), retry once without --continue.
            if appended_continue:
                msg = (stderr or str(e) or "").lower()
                retryable_continue = any(
                    s in msg
                    for s in (
                        "no session",
                        "no previous session",
                        "nothing to continue",
                        "cannot continue",
                        "could not continue",
                        "unable to continue",
                        "not found",
                    )
                )
                if retryable_continue:
                    cmd2 = [x for x in cmd if x != "--continue"]
                    try:
                        completed = _run_copilot(cmd2)
                        cmd = cmd2
                    except subprocess.CalledProcessError as e2:
                        stderr2 = (e2.stderr or "").strip()
                        raise LLMRouterError(f"Copilot CLI failed: {stderr2 or e2}") from e2
                else:
                    raise LLMRouterError(f"Copilot CLI failed: {stderr or e}") from e
            else:
                raise LLMRouterError(f"Copilot CLI failed: {stderr or e}") from e

        cleaned = _clean_copilot_output(completed.stdout or "")

        if trace_enabled and (isinstance(trace_jsonl_path, str) and trace_jsonl_path.strip()):
            record = {
                "ts": datetime.now(timezone.utc).isoformat(),
                "provider": "copilot_cli",
                "model": model_name,
                "cmd": cmd,
                "share_path": share_path,
                "stdout_chars": len(completed.stdout or ""),
                "stderr_chars": len(completed.stderr or ""),
            }
            try:
                os.makedirs(os.path.dirname(trace_jsonl_path.strip()) or ".", exist_ok=True)
                with open(trace_jsonl_path.strip(), "a", encoding="utf-8") as f:
                    f.write(json.dumps(record, ensure_ascii=False) + "\n")
            except OSError:
                pass

        return cleaned

    if provider == "copilot_sdk":
        # Copilot Python SDK (technical preview): programmatic JSON-RPC control of Copilot CLI.
        # This is an optional dependency and is only imported if used.

        try:
            from copilot import CopilotClient  # type: ignore
        except Exception as e:  # pragma: no cover
            raise LLMRouterError(
                "Copilot Python SDK is not installed. Install it from https://github.com/github/copilot-sdk (python package) "
                "and ensure the `copilot` CLI is installed."
            ) from e

        copilot_sdk_client_opts: Dict[str, Any] = {}
        cli_path = config.pop("copilot_sdk_cli_path", None)
        cli_url = config.pop("copilot_sdk_cli_url", None)
        log_level = config.pop("copilot_sdk_log_level", None)
        use_stdio = config.pop("copilot_sdk_use_stdio", None)
        port = config.pop("copilot_sdk_port", None)
        auto_start = config.pop("copilot_sdk_auto_start", None)
        auto_restart = config.pop("copilot_sdk_auto_restart", None)
        cwd = config.pop("copilot_sdk_cwd", None)
        github_token = config.pop("copilot_sdk_github_token", None)
        use_logged_in_user = config.pop("copilot_sdk_use_logged_in_user", None)

        if isinstance(cli_path, str) and cli_path.strip():
            copilot_sdk_client_opts["cli_path"] = cli_path.strip()
        if isinstance(cli_url, str) and cli_url.strip():
            copilot_sdk_client_opts["cli_url"] = cli_url.strip()
        if isinstance(log_level, str) and log_level.strip():
            copilot_sdk_client_opts["log_level"] = log_level.strip()
        if isinstance(cwd, str) and cwd.strip():
            copilot_sdk_client_opts["cwd"] = cwd.strip()
        if isinstance(use_stdio, bool):
            copilot_sdk_client_opts["use_stdio"] = use_stdio
        if isinstance(port, int):
            copilot_sdk_client_opts["port"] = port
        if isinstance(auto_start, bool):
            copilot_sdk_client_opts["auto_start"] = auto_start
        if isinstance(auto_restart, bool):
            copilot_sdk_client_opts["auto_restart"] = auto_restart
        if isinstance(github_token, str) and github_token.strip():
            copilot_sdk_client_opts["github_token"] = github_token.strip()
        if isinstance(use_logged_in_user, bool):
            copilot_sdk_client_opts["use_logged_in_user"] = use_logged_in_user

        copilot_sdk_session_cfg: Dict[str, Any] = {"model": model_name}
        streaming = config.pop("copilot_sdk_streaming", None)
        session_id = config.pop("copilot_sdk_session_id", None)
        infinite_sessions = config.pop("copilot_sdk_infinite_sessions", None)
        system_message = config.pop("copilot_sdk_system_message", None)

        if isinstance(streaming, bool):
            copilot_sdk_session_cfg["streaming"] = streaming
        if isinstance(session_id, str) and session_id.strip():
            copilot_sdk_session_cfg["session_id"] = session_id.strip()
        if isinstance(infinite_sessions, dict):
            copilot_sdk_session_cfg["infinite_sessions"] = infinite_sessions
        if isinstance(system_message, dict):
            copilot_sdk_session_cfg["system_message"] = system_message

        async def _run() -> str:
            client = CopilotClient(copilot_sdk_client_opts)
            await client.start()

            session = await client.create_session(copilot_sdk_session_cfg)
            done = asyncio.Event()
            final_message: Dict[str, Optional[str]] = {"content": None}

            def on_event(event: Any) -> None:
                try:
                    event_type = getattr(getattr(event, "type", None), "value", None)
                    if event_type == "assistant.message":
                        final_message["content"] = getattr(getattr(event, "data", None), "content", None)
                    elif event_type == "session.idle":
                        done.set()
                except Exception:
                    # Keep event handler best-effort; don't let logging break inference.
                    pass

            session.on(on_event)

            try:
                await session.send({"prompt": prompt})
                await done.wait()
                content = final_message["content"] or ""

                if trace_enabled and (isinstance(trace_jsonl_path, str) and trace_jsonl_path.strip()):
                    record = {
                        "ts": datetime.now(timezone.utc).isoformat(),
                        "provider": "copilot_sdk",
                        "model": model_name,
                        "session_id": getattr(session, "id", None),
                        "workspace_path": getattr(session, "workspace_path", None),
                        "prompt_chars": len(prompt or ""),
                        "response_chars": len(content or ""),
                    }
                    try:
                        os.makedirs(os.path.dirname(trace_jsonl_path.strip()) or ".", exist_ok=True)
                        with open(trace_jsonl_path.strip(), "a", encoding="utf-8") as f:
                            f.write(json.dumps(record, ensure_ascii=False) + "\n")
                    except OSError:
                        pass

                return _clean_copilot_output(content)
            finally:
                try:
                    await session.destroy()
                except Exception:
                    pass
                try:
                    await client.stop()
                except Exception:
                    pass

        def _run_sync() -> str:
            try:
                asyncio.get_running_loop()
            except RuntimeError:
                return asyncio.run(_run())

            # If we're already in an event loop, run in a separate thread.
            with ThreadPoolExecutor(max_workers=1) as ex:
                fut = ex.submit(lambda: asyncio.run(_run()))
                return fut.result()

        return _run_sync()

    if provider == "claude_code":
        # Claude Code CLI (Anthropic) non-interactive mode.
        # Uses `--print` so stdout is the assistant response.
        # To keep this shim safe/non-invasive by default, we disable tools unless explicitly enabled.

        claude_add_dir = config.pop("claude_add_dir", None)
        claude_permission_mode = config.pop("claude_permission_mode", None)
        claude_system_prompt = config.pop("claude_system_prompt", None)
        claude_append_system_prompt = config.pop("claude_append_system_prompt", None)
        claude_output_format = config.pop("claude_output_format", None)  # text|json|stream-json
        claude_input_format = config.pop("claude_input_format", None)  # text|stream-json
        claude_include_partial = bool(config.pop("claude_include_partial_messages", False))
        claude_no_session_persistence = bool(config.pop("claude_no_session_persistence", False))

        resume_session_id = config.pop("resume_session_id", None)
        continue_session = bool(config.pop("continue_session", False))
        claude_session_id = config.pop("claude_session_id", None)
        claude_fork_session = bool(config.pop("claude_fork_session", False))

        claude_tools = config.pop("claude_tools", None)
        claude_allowed_tools = config.pop("claude_allowed_tools", None)
        claude_disallowed_tools = config.pop("claude_disallowed_tools", None)

        if claude_output_format is None:
            claude_output_format = "text"
        if claude_input_format is None:
            claude_input_format = "text"

        # Trace outputs
        trace_text_path: Optional[str] = None
        if trace_enabled and isinstance(trace_dir, str) and trace_dir.strip():
            d = trace_dir.strip()
            os.makedirs(d, exist_ok=True)
            stamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
            pid = os.getpid()
            trace_text_path = os.path.join(d, f"claude_print_{stamp}_{pid}.txt")

        cmd = [
            "claude",
            "--print",
            "--output-format",
            str(claude_output_format),
            "--input-format",
            str(claude_input_format),
        ]

        # Model: allow passing aliases like "sonnet" or full model ids.
        if isinstance(model_name, str) and model_name.strip():
            cmd.extend(["--model", model_name.strip()])

        if claude_include_partial and str(claude_output_format) == "stream-json":
            cmd.append("--include-partial-messages")

        if claude_no_session_persistence:
            cmd.append("--no-session-persistence")

        if isinstance(claude_permission_mode, str) and claude_permission_mode.strip():
            cmd.extend(["--permission-mode", claude_permission_mode.strip()])

        if isinstance(claude_system_prompt, str) and claude_system_prompt.strip():
            cmd.extend(["--system-prompt", claude_system_prompt.strip()])
        if isinstance(claude_append_system_prompt, str) and claude_append_system_prompt.strip():
            cmd.extend(["--append-system-prompt", claude_append_system_prompt.strip()])

        # Session controls (best-effort): reuse the same keys as other providers.
        if isinstance(claude_session_id, str) and claude_session_id.strip():
            cmd.extend(["--session-id", claude_session_id.strip()])
        if isinstance(resume_session_id, str) and resume_session_id.strip():
            cmd.extend(["--resume", resume_session_id.strip()])
            if claude_fork_session:
                cmd.append("--fork-session")
        elif continue_session:
            cmd.append("--continue")

        # Default to tools disabled for safety/LLM-like behavior.
        # Note: `--tools` is variadic in Claude Code, so we always provide the prompt via stdin
        # (otherwise `--tools ...` can consume the positional prompt argument).
        if claude_tools is None:
            cmd.extend(["--tools", ""])  # docs: empty string disables all tools
        else:
            # Allow passing "default" or a comma/space-separated list.
            cmd.extend(["--tools", str(claude_tools)])

        if claude_add_dir:
            # Can be a string or list of strings.
            if isinstance(claude_add_dir, str) and claude_add_dir.strip():
                cmd.extend(["--add-dir", claude_add_dir.strip()])
            elif isinstance(claude_add_dir, (list, tuple)):
                for d in claude_add_dir:
                    if isinstance(d, str) and d.strip():
                        cmd.extend(["--add-dir", d.strip()])

        if claude_allowed_tools:
            if isinstance(claude_allowed_tools, str) and claude_allowed_tools.strip():
                cmd.extend(["--allowed-tools", claude_allowed_tools.strip()])
            elif isinstance(claude_allowed_tools, (list, tuple)):
                cmd.append("--allowed-tools")
                cmd.extend([str(t) for t in claude_allowed_tools if str(t).strip()])

        if claude_disallowed_tools:
            if isinstance(claude_disallowed_tools, str) and claude_disallowed_tools.strip():
                cmd.extend(["--disallowed-tools", claude_disallowed_tools.strip()])
            elif isinstance(claude_disallowed_tools, (list, tuple)):
                cmd.append("--disallowed-tools")
                cmd.extend([str(t) for t in claude_disallowed_tools if str(t).strip()])

        try:
            completed = subprocess.run(
                cmd,
                input=prompt,
                check=True,
                text=True,
                capture_output=True,
            )
        except FileNotFoundError as e:
            raise LLMRouterError("Claude CLI binary not found on PATH") from e
        except subprocess.CalledProcessError as e:
            stderr = (e.stderr or "").strip()
            stdout = (e.stdout or "").strip()
            msg = stderr or stdout or str(e)
            if "not logged in" in msg.lower():
                raise LLMRouterError(
                    "Claude CLI is installed but not logged in. Run `claude` and complete `/login` (or `claude setup-token`) and retry."
                ) from e
            raise LLMRouterError(f"Claude CLI failed: {msg}") from e

        if trace_text_path:
            payload = (completed.stdout or "")
            if completed.stderr:
                payload += "\n\n[stderr]\n" + (completed.stderr or "")
            _best_effort_write(trace_text_path, payload)

        if trace_enabled and (isinstance(trace_jsonl_path, str) and trace_jsonl_path.strip()):
            record = {
                "ts": datetime.now(timezone.utc).isoformat(),
                "provider": "claude_code",
                "model": model_name,
                "cmd": cmd,
                "trace_text_path": trace_text_path,
                "stdout_chars": len(completed.stdout or ""),
                "stderr_chars": len(completed.stderr or ""),
            }
            try:
                os.makedirs(os.path.dirname(trace_jsonl_path.strip()) or ".", exist_ok=True)
                with open(trace_jsonl_path.strip(), "a", encoding="utf-8") as f:
                    f.write(json.dumps(record, ensure_ascii=False) + "\n")
            except OSError:
                pass

        return _clean_claude_output(completed.stdout or "")

    if provider == "gemini_cli":
        # Gemini CLI wrapper.
        # There are multiple community/official Gemini CLIs with different flag surfaces.
        #
        # This shim supports a configurable command, and includes a first-class path for
        # `npx @google/gemini-cli`.
        #
        # IMPORTANT: `@google/gemini-cli` requires Node >=20 at runtime. Instead of
        # upgrading system Node, we run it under a transient Node 20 binary via:
        #   npx -y -p node@20 -p @google/gemini-cli gemini ...
        # This keeps the repo reusable and avoids requiring sudo.
        #
        # Config options:
        # - gemini_cli_path: executable path/name (default: "gemini")
        # - gemini_cmd: full command override (string or list). If set, takes precedence.
        # - gemini_args: extra args (list or string) appended after the base command.
        # - gemini_env: dict of env vars to add/override for the subprocess.

        gemini_cli_path = config.pop("gemini_cli_path", None)
        gemini_cmd = config.pop("gemini_cmd", None)
        gemini_args = config.pop("gemini_args", None)
        gemini_env = config.pop("gemini_env", None)

        # Shared session controls (match other CLI providers where possible).
        resume_session_id = config.pop("resume_session_id", None)
        continue_session = bool(config.pop("continue_session", False))

        base_cmd: list[str]
        if isinstance(gemini_cmd, str) and gemini_cmd.strip():
            # Minimal shell-like split on whitespace.
            base_cmd = gemini_cmd.strip().split()
        elif isinstance(gemini_cmd, (list, tuple)) and gemini_cmd:
            base_cmd = [str(x) for x in gemini_cmd if str(x).strip()]
        else:
            # Support a shorthand for npx-installed Gemini CLI.
            # Example: gemini_cli_path="npx:@google/gemini-cli"
            if isinstance(gemini_cli_path, str) and gemini_cli_path.strip().startswith("npx:"):
                pkg = gemini_cli_path.strip().split(":", 1)[1].strip()
                base_cmd = ["npx", "-y", pkg]
            else:
                exe = "gemini"
                if isinstance(gemini_cli_path, str) and gemini_cli_path.strip():
                    exe = gemini_cli_path.strip()
                base_cmd = [exe]

        extra_args: list[str] = []
        if isinstance(gemini_args, str) and gemini_args.strip():
            extra_args = gemini_args.strip().split()
        elif isinstance(gemini_args, (list, tuple)) and gemini_args:
            extra_args = [str(x) for x in gemini_args if str(x).strip()]

        cmd = base_cmd + extra_args

        # Detect the official Google Gemini CLI.
        # If so, prefer its headless flags and run it under Node 20 via npx.
        is_google_gemini_cli = (
            cmd[:3] == ["npx", "-y", "@google/gemini-cli"]
            or cmd[:2] == ["@google/gemini-cli", "gemini"]
            or cmd[:1] == ["gemini"]
        )

        def _google_gemini_cmd() -> list[str]:
            base = ["npx", "-y", "-p", "node@20", "-p", "@google/gemini-cli", "gemini"]
            # Safety: use read-only/plan mode by default to avoid tool prompts.
            # Allow overriding via gemini_args.
            if "--approval-mode" not in extra_args and "--yolo" not in extra_args and "-y" not in extra_args:
                base += ["--approval-mode", "plan"]
            if "--output-format" not in extra_args and "-o" not in extra_args:
                base += ["--output-format", "text"]
            if ("--model" not in extra_args and "-m" not in extra_args) and isinstance(model_name, str) and model_name.strip():
                base += ["--model", model_name.strip()]
            if isinstance(resume_session_id, str) and resume_session_id.strip():
                # Google CLI supports --resume latest / index strings.
                base += ["--resume", resume_session_id.strip()]
            elif continue_session:
                base += ["--resume", "latest"]
            # Headless prompt.
            base += ["--prompt", prompt]
            return base

        trace_text_path: Optional[str] = None
        if trace_enabled and isinstance(trace_dir, str) and trace_dir.strip():
            d = trace_dir.strip()
            os.makedirs(d, exist_ok=True)
            stamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
            pid = os.getpid()
            trace_text_path = os.path.join(d, f"gemini_exec_{stamp}_{pid}.txt")

        env = os.environ.copy()
        if isinstance(gemini_env, dict):
            for k, v in gemini_env.items():
                if k is None:
                    continue
                env[str(k)] = "" if v is None else str(v)

        def _run_cmd(run_cmd: list[str], *, send_stdin: bool):
            return subprocess.run(
                run_cmd,
                input=(prompt if send_stdin else None),
                check=True,
                text=True,
                capture_output=True,
                env=env,
            )

        try:
            if is_google_gemini_cli:
                cmd = _google_gemini_cmd()
                completed = _run_cmd(cmd, send_stdin=False)
            else:
                completed = _run_cmd(cmd, send_stdin=True)
        except FileNotFoundError as e:
            # If `gemini` isn't installed as a binary, try the common `npx @google/gemini-cli` approach.
            if base_cmd == ["gemini"]:
                try:
                    cmd = _google_gemini_cmd()
                    completed = _run_cmd(cmd, send_stdin=False)
                except FileNotFoundError as e2:
                    raise LLMRouterError(
                        "Gemini CLI binary not found on PATH. Install Node+npx, or set `gemini_cmd` / `gemini_cli_path` to the correct command."
                    ) from e2
                except subprocess.CalledProcessError as e2:
                    stderr2 = (e2.stderr or "").strip()
                    stdout2 = (e2.stdout or "").strip()
                    msg2 = stderr2 or stdout2 or str(e2)
                    raise LLMRouterError(f"Gemini CLI failed: {msg2}") from e2
            raise LLMRouterError(
                "Gemini CLI binary not found on PATH. Set `gemini_cli_path` or `gemini_cmd` in config to the installed command."
            ) from e
        except subprocess.CalledProcessError as e:
            stderr = (e.stderr or "").strip()
            stdout = (e.stdout or "").strip()
            msg = stderr or stdout or str(e)
            # Retry once using Node 20 overlay if we hit the classic Node18 regexp /v failure.
            if "invalid regular expression flags" in msg.lower() and "node.js v18" in msg.lower():
                try:
                    cmd = _google_gemini_cmd()
                    completed = _run_cmd(cmd, send_stdin=False)
                except Exception as e2:
                    raise LLMRouterError(
                        "Gemini CLI failed under Node 18. This wrapper can run @google/gemini-cli under a transient Node 20 binary "
                        "via `npx -p node@20 -p @google/gemini-cli gemini ...`, but that attempt also failed."
                    ) from e2
            else:
                raise LLMRouterError(f"Gemini CLI failed: {msg}") from e

        if trace_text_path:
            payload = (completed.stdout or "")
            if completed.stderr:
                payload += "\n\n[stderr]\n" + (completed.stderr or "")
            _best_effort_write(trace_text_path, payload)

        if trace_enabled and (isinstance(trace_jsonl_path, str) and trace_jsonl_path.strip()):
            record = {
                "ts": datetime.now(timezone.utc).isoformat(),
                "provider": "gemini_cli",
                "model": model_name,
                "cmd": cmd,
                "trace_text_path": trace_text_path,
                "stdout_chars": len(completed.stdout or ""),
                "stderr_chars": len(completed.stderr or ""),
            }
            try:
                os.makedirs(os.path.dirname(trace_jsonl_path.strip()) or ".", exist_ok=True)
                with open(trace_jsonl_path.strip(), "a", encoding="utf-8") as f:
                    f.write(json.dumps(record, ensure_ascii=False) + "\n")
            except OSError:
                pass

        return _clean_gemini_output(completed.stdout or "")

    if provider == "codex_cli":
        # Use `--output-last-message` so we can ignore progress logs on stdout.
        tmp_path: Optional[str] = None
        trace_path: Optional[str] = None
        try:
            tmp = tempfile.NamedTemporaryFile(prefix="codex_last_message_", delete=False)
            tmp_path = tmp.name
            tmp.close()

            if trace_enabled:
                # Allow caller to control where trace logs live (e.g. project run artifacts)
                # so the submodule stays reusable and doesn't accumulate logs.
                if isinstance(trace_jsonl_path, str) and trace_jsonl_path.strip():
                    trace_path = trace_jsonl_path.strip()
                elif isinstance(trace_dir, str) and trace_dir.strip():
                    d = trace_dir.strip()
                    os.makedirs(d, exist_ok=True)
                    stamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
                    pid = os.getpid()
                    safe_model = re.sub(r"[^a-zA-Z0-9._-]+", "_", model_name)
                    trace_path = os.path.join(d, f"codex_exec_{safe_model}_{stamp}_{pid}.jsonl")

            cmd = [
                "codex",
                "exec",
                "--color",
                "never",
                "--sandbox",
                "read-only",
                "--skip-git-repo-check",
                "--model",
                model_name,
                "--output-last-message",
                tmp_path,
                "-",
            ]
            if trace_enabled:
                cmd.insert(cmd.index("--output-last-message"), "--json")

            try:
                completed = subprocess.run(
                    cmd,
                    input=prompt,
                    check=True,
                    text=True,
                    capture_output=True,
                )
            except FileNotFoundError as e:
                raise LLMRouterError("Codex CLI binary not found on PATH") from e
            except subprocess.CalledProcessError as e:
                stderr = (e.stderr or "").strip()
                stdout = (e.stdout or "").strip()
                raise LLMRouterError(f"Codex CLI failed: {stderr or stdout or e}") from e

            # Persist JSONL events if tracing is enabled.
            if trace_path:
                _best_effort_write(trace_path, (completed.stdout or "").strip() + "\n")

            try:
                with open(tmp_path, "r", encoding="utf-8") as f:
                    last_message = f.read()
            except OSError as e:
                raise LLMRouterError("Codex CLI did not produce a readable last message") from e

            return _clean_codex_output(last_message)
        finally:
            if tmp_path:
                try:
                    os.remove(tmp_path)
                except OSError:
                    pass

    raise LLMRouterError(f"Unsupported provider: {provider}")

…tion, and more

Co-authored-by: endomorphosis <3405202+endomorphosis@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants