Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,16 @@ def __init__(self, params: Params):
self._lock = threading.Lock()
self._adapter_start_time = time.time() # Record adapter initialization time
self._stats = {
# Average statistics
# Average statistics (computed from sums at save time)
"avg_prompt_tokens": None,
"avg_total_tokens": None,
"avg_completion_tokens": None,
"avg_latency_ms": None,
# Exact sum statistics (used to compute precise averages)
"sum_prompt_tokens": 0,
"sum_total_tokens": 0,
"sum_completion_tokens": 0,
"sum_latency_ms": 0.0,
# Maximum statistics
"max_prompt_tokens": None,
"max_total_tokens": None,
Expand All @@ -123,13 +128,21 @@ def __init__(self, params: Params):
"finish_reason": {},
"stop_reason": {},
"status_codes": {},
# Retry deduplication
"retry_count": 0,
# Time tracking
"inference_time": 0.0,
"run_id": 0,
"last_request_time": None,
"inference_run_times": {}, # {run_id: {"start": time, "end": time, "inference_time": time}}
}

# Set of cache_keys (request content hashes) already counted in stats.
# Used to detect retries across chain-job allocations: when a job
# requeues, previously in-flight requests are re-sent but should not
# be double-counted in aggregated statistics.
self._seen_cache_keys: set[str] = set()

# Always initialize cache database
cache_path = Path(self.cache_dir)
cache_path.mkdir(parents=True, exist_ok=True)
Expand Down Expand Up @@ -209,10 +222,47 @@ def _load_aggregated_cached_stats(self) -> None:
status_codes[key] = value
aggregated_stats["status_codes"] = status_codes

# Backward compatibility: if cached stats lack sum_* fields (from
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Power Review] ⚠️ Backward-compatibility migration bolted inline into _load_aggregated_cached_stats (Geological Layers) · 78% confidence

The architecture document warns against 'Geological Layers: new code bolted on top of old code without integrating — duplicate data paths, parallel config mechanisms, wrapper-on-wrapper stacking.' Lines 225–257 add ~35 lines of multi-branch migration logic (back-computing sum_* from avg_*, handling total_*, ensuring retry_count) inline at the top of the existing load method rather than extracting it into a focused helper. This creates a layered accumulation of unrelated responsibilities in one method.

💡 Suggestion: Extract the migration block into a private method such as _migrate_legacy_cached_stats(aggregated_stats: dict) -> dict. _load_aggregated_cached_stats then calls it as a single delegation step, keeping the load method's intent readable:

aggregated_stats = self._migrate_legacy_cached_stats(aggregated_stats)

# older versions that used running averages), back-compute sums
# from avg * successful_count.
successful_count = aggregated_stats.get("successful_count", 0)
for token_type in ["prompt_tokens", "total_tokens", "completion_tokens"]:
sum_key = f"sum_{token_type}"
avg_key = f"avg_{token_type}"
total_key = f"total_{token_type}"
if sum_key not in aggregated_stats:
# Try total_* first (from new save format), then back-compute from avg
if total_key in aggregated_stats:
aggregated_stats[sum_key] = aggregated_stats.pop(total_key)
elif (
aggregated_stats.get(avg_key) is not None
and successful_count > 0
):
aggregated_stats[sum_key] = (
aggregated_stats[avg_key] * successful_count
)
else:
aggregated_stats[sum_key] = 0

if "sum_latency_ms" not in aggregated_stats:
avg_latency = aggregated_stats.get("avg_latency_ms")
if avg_latency is not None and successful_count > 0:
aggregated_stats["sum_latency_ms"] = avg_latency * successful_count
else:
aggregated_stats["sum_latency_ms"] = 0.0

# Backward compatibility: ensure retry_count exists
if "retry_count" not in aggregated_stats:
aggregated_stats["retry_count"] = 0

# Set current stats to cached data (cached stats already contain accumulated data)
self._stats = aggregated_stats
# Note: run_id increment is handled in _save_run_ids_info()

# Restore seen cache keys for retry deduplication
if "seen_cache_keys" in interceptor_state:
self._seen_cache_keys = set(interceptor_state["seen_cache_keys"])

self.logger.info(
f"Loaded interceptor state with run_id {aggregated_stats.get('run_id', 0)}, count={aggregated_stats.get('count', 0)}"
)
Expand Down Expand Up @@ -256,26 +306,31 @@ def _update_basic_stats(self, resp: AdapterResponse, current_time: float) -> Non
self._stats["inference_time"] += delta

def _update_running_stats(self, stat_name: str, value: float) -> None:
"""Update running average and max for a given statistic."""
"""Update sum, average, and max for a given statistic.

Tracks exact sums to avoid floating-point error accumulation from
running averages. The average is recomputed from sum / count at each
step for live monitoring, and rounded to 2 decimal places only at
file-save time.
"""
# Skip if value is not a valid number
if not isinstance(value, (int, float)):
self.logger.warning(
f"Invalid value for {stat_name}: {value} (expected number)"
)
return

# Calculate running average using current successful count
sum_key = f"sum_{stat_name}"
avg_key = f"avg_{stat_name}"
if self._stats[avg_key] is None:
self._stats[avg_key] = value
else:
self._stats[avg_key] = round(
(self._stats[avg_key] * self._stats["successful_count"] + value)
/ (self._stats["successful_count"] + 1),
2,
)

# Update max valuename
# Accumulate exact sum
self._stats[sum_key] += value

# Compute average from exact sum (successful_count not yet incremented)
new_count = self._stats["successful_count"] + 1
self._stats[avg_key] = self._stats[sum_key] / new_count

# Update max value
max_key = f"max_{stat_name}"
if self._stats[max_key] is None or value > self._stats[max_key]:
self._stats[max_key] = value
Expand Down Expand Up @@ -404,6 +459,14 @@ def intercept_response(
return resp
status_code = resp.r.status_code

# Detect retries: if we've already counted a successful response for
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Power Review] ⚠️ intercept_response grows significantly beyond 50 lines with retry logic added inline · 72% confidence

The architecture document warns: 'Monolithic Functions (>50 lines): AI generates long, flat functions instead of composing smaller units.' The diff adds ~55 lines of new code to intercept_response (retry detection at lines 462–468, retry branch at 480–488, restructured non-retry block at 490–519). Even if the function was borderline before, these additions push it well past the ~50-line guideline and mix three distinct concerns: retry detection, retry accounting, and normal stats collection.

💡 Suggestion: Extract the retry detection and retry-path handling into a dedicated private method, e.g. _handle_retry_response(resp, cache_key, context). The main intercept_response orchestrates at a higher level:

is_retry = self._is_retry(resp)
if is_retry:
    self._handle_retry(resp, cache_key)
else:
    self._handle_successful_response(resp, context, status_code)

# the same request content (identified by cache_key), this is a retry
# from a chain-job requeue. We still track it in count/status_codes
# (total API calls) but skip updating token stats and successful_count
# to avoid inflating averages.
cache_key = getattr(resp.rctx, "cache_key", None)
is_retry = cache_key is not None and cache_key in self._seen_cache_keys

# Update time tracking with current timestamp
current_time = time.time()
self._update_time_tracking(current_time)
Expand All @@ -414,29 +477,46 @@ def intercept_response(
# Always add basic response stats (count, status_code)
self._add_basic_response_stats(resp, context)

if is_retry:
with self._lock:
self._stats["retry_count"] += 1
self.logger.debug(
"Detected retry request, skipping token stats update",
request_id=resp.rctx.request_id,
cache_key=cache_key[:8] + "..." if cache_key else None,
)
# Extract detailed stats once and reuse them
detailed_stats = None
try:
# Try to parse response as JSON
response_data = resp.r.json()
if not is_retry:
try:
# Try to parse response as JSON
response_data = resp.r.json()

if status_code == 200:
detailed_stats = self._extract_detailed_response_stats(
response_data
)

if status_code == 200:
detailed_stats = self._extract_detailed_response_stats(response_data)
# Add detailed stats for aggregation
self._update_response_stats(detailed_stats)

# Add detailed stats for aggregation
self._update_response_stats(detailed_stats)
# Mark this cache_key as seen
if cache_key is not None:
self._seen_cache_keys.add(cache_key)

self.logger.debug(
"Collected detailed response stats",
request_id=resp.rctx.request_id,
response_count=self._stats["count"],
status_code=status_code,
)
self.logger.debug(
"Collected detailed response stats",
request_id=resp.rctx.request_id,
response_count=self._stats["count"],
status_code=status_code,
)

except (json.JSONDecodeError, Exception) as e:
# Handle both JSON parsing errors and other exceptions
# In case of any error, only basic stats are collected
self.logger.warning(f"Error parsing response body for token counting: {e}")
except (json.JSONDecodeError, Exception) as e:
# Handle both JSON parsing errors and other exceptions
# In case of any error, only basic stats are collected
self.logger.warning(
f"Error parsing response body for token counting: {e}"
)

# Save stats to file if interval reached
if (
Expand Down Expand Up @@ -478,6 +558,22 @@ def _save_stats_to_file(self, context: AdapterGlobalContext) -> None:
self.logger.debug("No response statistics collected, skipping file write")
return

# Round averages to 2 decimal places for display
for avg_key in [
"avg_prompt_tokens",
"avg_total_tokens",
"avg_completion_tokens",
"avg_latency_ms",
]:
if stats[avg_key] is not None:
stats[avg_key] = round(stats[avg_key], 2)

# Expose exact totals as total_* fields and remove internal sum_* fields
for token_type in ["prompt_tokens", "total_tokens", "completion_tokens"]:
stats[f"total_{token_type}"] = stats.pop(f"sum_{token_type}", 0)
# Remove internal sum_latency_ms (not useful as an output field)
stats.pop("sum_latency_ms", None)

# Convert timestamps to readable dates in inference_run_times and add time_to_first_request
if "inference_run_times" in stats:
for run_id, run_data in stats["inference_run_times"].items():
Expand Down Expand Up @@ -591,6 +687,9 @@ def _save_aggregated_stats_to_cache(self) -> None:
# Update aggregated stats in interceptor state
interceptor_state["aggregated_stats"] = stats_to_cache

# Persist seen cache keys for retry deduplication across allocations
interceptor_state["seen_cache_keys"] = list(self._seen_cache_keys)

# Save updated interceptor state
self._save_interceptor_state(interceptor_state)

Expand Down
Loading