diff --git a/strix/llm/dedupe.py b/strix/llm/dedupe.py index 0ea608850..c505a7b39 100644 --- a/strix/llm/dedupe.py +++ b/strix/llm/dedupe.py @@ -182,6 +182,16 @@ def check_duplicate( if api_base: completion_kwargs["api_base"] = api_base + try: + from strix.telemetry.tracer import get_global_tracer + + tracer = get_global_tracer() + if tracer: + run_id = tracer.run_id + completion_kwargs["metadata"] = {"$ai_trace_id": run_id} + except Exception as e: + logger.error(f"Could not set trace metadata: {e}") + response = litellm.completion(**completion_kwargs) content = response.choices[0].message.content diff --git a/strix/llm/llm.py b/strix/llm/llm.py index d941361b2..654a879dd 100644 --- a/strix/llm/llm.py +++ b/strix/llm/llm.py @@ -1,4 +1,5 @@ import asyncio +import logging from collections.abc import AsyncIterator from dataclasses import dataclass from typing import Any @@ -18,10 +19,14 @@ parse_tool_invocations, ) from strix.skills import load_skills +from strix.telemetry import posthog from strix.tools import get_tools_prompt from strix.utils.resource_paths import get_strix_resource_path +logger = logging.getLogger(__name__) + + litellm.drop_params = True litellm.modify_params = True @@ -75,6 +80,11 @@ def __init__(self, config: LLMConfig, agent_name: str | None = None): else: self._reasoning_effort = "high" + try: + posthog.configure_litellm_posthog() + except Exception as e: + logger.error(f"Could not config posthog traces: {e}") + def _load_system_prompt(self, agent_name: str | None) -> str: if not agent_name: return "" @@ -129,18 +139,14 @@ async def generate( async def _stream(self, messages: list[dict[str, Any]]) -> AsyncIterator[LLMResponse]: accumulated = "" chunks: list[Any] = [] - done_streaming = 0 + found_function_end = False self._total_stats.requests += 1 - response = await acompletion(**self._build_completion_args(messages), stream=True) + completion_args = self._build_completion_args(messages) + response = await acompletion(**completion_args, stream=True) async for chunk in response: - chunks.append(chunk) - if done_streaming: - done_streaming += 1 - if getattr(chunk, "usage", None) or done_streaming > 5: - break - continue + chunks.append(chunk) delta = self._get_chunk_content(chunk) if delta: accumulated += delta @@ -149,12 +155,15 @@ async def _stream(self, messages: list[dict[str, Any]]) -> AsyncIterator[LLMResp pos = accumulated.find(end_tag) accumulated = accumulated[: pos + len(end_tag)] yield LLMResponse(content=accumulated) - done_streaming = 1 + found_function_end = True continue - yield LLMResponse(content=accumulated) + + if not found_function_end: + yield LLMResponse(content=accumulated) if chunks: - self._update_usage_stats(stream_chunk_builder(chunks)) + final_response = stream_chunk_builder(chunks) + self._update_usage_stats(final_response) accumulated = normalize_tool_format(accumulated) accumulated = fix_incomplete_tool_call(_truncate_to_first_function(accumulated)) @@ -209,6 +218,29 @@ def _build_completion_args(self, messages: list[dict[str, Any]]) -> dict[str, An args["api_key"] = self.config.api_key if self.config.api_base: args["api_base"] = self.config.api_base + metadata: dict[str, Any] = {} + + try: + from strix.telemetry.tracer import get_global_tracer + + tracer = get_global_tracer() + if tracer: + run_id = tracer.run_id + metadata["$ai_trace_id"] = run_id + except Exception as e: + logger.error(f"Could not set trace metadata: {e}") + if metadata: + args["metadata"] = metadata + + if api_key := Config.get("llm_api_key"): + args["api_key"] = api_key + if api_base := ( + Config.get("llm_api_base") + or Config.get("openai_api_base") + or Config.get("litellm_base_url") + or Config.get("ollama_api_base") + ): + args["api_base"] = api_base if self._supports_reasoning(): args["reasoning_effort"] = self._reasoning_effort diff --git a/strix/llm/memory_compressor.py b/strix/llm/memory_compressor.py index 8cad51078..f707451b8 100644 --- a/strix/llm/memory_compressor.py +++ b/strix/llm/memory_compressor.py @@ -117,6 +117,18 @@ def _summarize_messages( if api_base: completion_args["api_base"] = api_base + try: + from strix.telemetry.tracer import get_global_tracer + + tracer = get_global_tracer() + if tracer: + run_id = tracer.run_id + completion_args["metadata"] = { + "$ai_trace_id": run_id, + } + except Exception as e: + logger.error(f"Could not set trace metadata: {e}") + response = litellm.completion(**completion_args) summary = response.choices[0].message.content or "" if not summary.strip(): diff --git a/strix/telemetry/posthog.py b/strix/telemetry/posthog.py index fd66bcc06..00369ee89 100644 --- a/strix/telemetry/posthog.py +++ b/strix/telemetry/posthog.py @@ -1,4 +1,9 @@ +from litellm import CALLBACK_TYPES + + import json +import logging +import os import platform import sys import urllib.request @@ -6,20 +11,55 @@ from typing import TYPE_CHECKING, Any from uuid import uuid4 +import litellm + from strix.config import Config +logger = logging.getLogger(__name__) + if TYPE_CHECKING: from strix.telemetry.tracer import Tracer -_POSTHOG_PUBLIC_API_KEY = "phc_7rO3XRuNT5sgSKAl6HDIrWdSGh1COzxw0vxVIAR6vVZ" -_POSTHOG_HOST = "https://us.i.posthog.com" +_POSTHOG_PRIMARY_API_KEY = "phc_7rO3XRuNT5sgSKAl6HDIrWdSGh1COzxw0vxVIAR6vVZ" +_POSTHOG_PRIMARY_HOST = "https://us.i.posthog.com" + +_POSTHOG_LLM_API_KEY = os.environ.get("POSTHOG_LLM_API_KEY") +_POSTHOG_LLM_HOST = os.environ.get("POSTHOG_LLM_HOST") _SESSION_ID = uuid4().hex[:16] def _is_enabled() -> bool: - return (Config.get("strix_telemetry") or "1").lower() not in ("0", "false", "no", "off") + telemetry_value = Config.get("strix_telemetry") or "1" + return telemetry_value.lower() not in ("0", "false", "no", "off") + + +def configure_litellm_posthog() -> None: + """Configure LiteLLM to send LLM traces to env postHog account.""" + + should_send_trace_to_posthog = _POSTHOG_LLM_API_KEY is not None and _POSTHOG_LLM_HOST is not None + + if not _is_enabled(): + logger.info("PostHog telemetry (traces) is disabled") + return + + if not should_send_trace_to_posthog: + logger.info("PostHog telemetry (traces) is disabled") + return + + os.environ["POSTHOG_API_KEY"] = _POSTHOG_LLM_API_KEY + os.environ["POSTHOG_API_URL"] = _POSTHOG_LLM_HOST + + if "posthog" not in (litellm.success_callback or []): + callbacks = list[CALLBACK_TYPES](litellm.success_callback or []) + callbacks.append("posthog") + litellm.success_callback = callbacks + + if "posthog" not in (litellm.failure_callback or []): + callbacks = list[CALLBACK_TYPES](litellm.failure_callback or []) + callbacks.append("posthog") + litellm.failure_callback = callbacks def _is_first_run() -> bool: @@ -44,22 +84,24 @@ def _get_version() -> str: def _send(event: str, properties: dict[str, Any]) -> None: + """Send custom events to Instance A (Primary) for manual tracking.""" if not _is_enabled(): return try: payload = { - "api_key": _POSTHOG_PUBLIC_API_KEY, + "api_key": _POSTHOG_PRIMARY_API_KEY, "event": event, "distinct_id": _SESSION_ID, "properties": properties, } req = urllib.request.Request( # noqa: S310 - f"{_POSTHOG_HOST}/capture/", + f"{_POSTHOG_PRIMARY_HOST}/capture/", data=json.dumps(payload).encode(), headers={"Content-Type": "application/json"}, ) with urllib.request.urlopen(req, timeout=10): # noqa: S310 # nosec B310 pass + logger.error(f"Sent custom event '{event}' to hardcoded posthog account") except Exception: # noqa: BLE001, S110 pass # nosec B110