-
Notifications
You must be signed in to change notification settings - Fork 28
fix: eliminate floating-point error in response stats averages #783
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -105,11 +105,16 @@ def __init__(self, params: Params): | |
| self._lock = threading.Lock() | ||
| self._adapter_start_time = time.time() # Record adapter initialization time | ||
| self._stats = { | ||
| # Average statistics | ||
| # Average statistics (computed from sums at save time) | ||
| "avg_prompt_tokens": None, | ||
| "avg_total_tokens": None, | ||
| "avg_completion_tokens": None, | ||
| "avg_latency_ms": None, | ||
| # Exact sum statistics (used to compute precise averages) | ||
| "sum_prompt_tokens": 0, | ||
| "sum_total_tokens": 0, | ||
| "sum_completion_tokens": 0, | ||
| "sum_latency_ms": 0.0, | ||
| # Maximum statistics | ||
| "max_prompt_tokens": None, | ||
| "max_total_tokens": None, | ||
|
|
@@ -123,13 +128,21 @@ def __init__(self, params: Params): | |
| "finish_reason": {}, | ||
| "stop_reason": {}, | ||
| "status_codes": {}, | ||
| # Retry deduplication | ||
| "retry_count": 0, | ||
| # Time tracking | ||
| "inference_time": 0.0, | ||
| "run_id": 0, | ||
| "last_request_time": None, | ||
| "inference_run_times": {}, # {run_id: {"start": time, "end": time, "inference_time": time}} | ||
| } | ||
|
|
||
| # Set of cache_keys (request content hashes) already counted in stats. | ||
| # Used to detect retries across chain-job allocations: when a job | ||
| # requeues, previously in-flight requests are re-sent but should not | ||
| # be double-counted in aggregated statistics. | ||
| self._seen_cache_keys: set[str] = set() | ||
|
|
||
| # Always initialize cache database | ||
| cache_path = Path(self.cache_dir) | ||
| cache_path.mkdir(parents=True, exist_ok=True) | ||
|
|
@@ -209,10 +222,47 @@ def _load_aggregated_cached_stats(self) -> None: | |
| status_codes[key] = value | ||
| aggregated_stats["status_codes"] = status_codes | ||
|
|
||
| # Backward compatibility: if cached stats lack sum_* fields (from | ||
| # older versions that used running averages), back-compute sums | ||
| # from avg * successful_count. | ||
| successful_count = aggregated_stats.get("successful_count", 0) | ||
| for token_type in ["prompt_tokens", "total_tokens", "completion_tokens"]: | ||
| sum_key = f"sum_{token_type}" | ||
| avg_key = f"avg_{token_type}" | ||
| total_key = f"total_{token_type}" | ||
| if sum_key not in aggregated_stats: | ||
| # Try total_* first (from new save format), then back-compute from avg | ||
| if total_key in aggregated_stats: | ||
| aggregated_stats[sum_key] = aggregated_stats.pop(total_key) | ||
| elif ( | ||
| aggregated_stats.get(avg_key) is not None | ||
| and successful_count > 0 | ||
| ): | ||
| aggregated_stats[sum_key] = ( | ||
| aggregated_stats[avg_key] * successful_count | ||
| ) | ||
| else: | ||
| aggregated_stats[sum_key] = 0 | ||
|
|
||
| if "sum_latency_ms" not in aggregated_stats: | ||
| avg_latency = aggregated_stats.get("avg_latency_ms") | ||
| if avg_latency is not None and successful_count > 0: | ||
| aggregated_stats["sum_latency_ms"] = avg_latency * successful_count | ||
| else: | ||
| aggregated_stats["sum_latency_ms"] = 0.0 | ||
|
|
||
| # Backward compatibility: ensure retry_count exists | ||
| if "retry_count" not in aggregated_stats: | ||
| aggregated_stats["retry_count"] = 0 | ||
|
|
||
| # Set current stats to cached data (cached stats already contain accumulated data) | ||
| self._stats = aggregated_stats | ||
| # Note: run_id increment is handled in _save_run_ids_info() | ||
|
|
||
| # Restore seen cache keys for retry deduplication | ||
| if "seen_cache_keys" in interceptor_state: | ||
| self._seen_cache_keys = set(interceptor_state["seen_cache_keys"]) | ||
|
|
||
| self.logger.info( | ||
| f"Loaded interceptor state with run_id {aggregated_stats.get('run_id', 0)}, count={aggregated_stats.get('count', 0)}" | ||
| ) | ||
|
|
@@ -256,26 +306,31 @@ def _update_basic_stats(self, resp: AdapterResponse, current_time: float) -> Non | |
| self._stats["inference_time"] += delta | ||
|
|
||
| def _update_running_stats(self, stat_name: str, value: float) -> None: | ||
| """Update running average and max for a given statistic.""" | ||
| """Update sum, average, and max for a given statistic. | ||
|
|
||
| Tracks exact sums to avoid floating-point error accumulation from | ||
| running averages. The average is recomputed from sum / count at each | ||
| step for live monitoring, and rounded to 2 decimal places only at | ||
| file-save time. | ||
| """ | ||
| # Skip if value is not a valid number | ||
| if not isinstance(value, (int, float)): | ||
| self.logger.warning( | ||
| f"Invalid value for {stat_name}: {value} (expected number)" | ||
| ) | ||
| return | ||
|
|
||
| # Calculate running average using current successful count | ||
| sum_key = f"sum_{stat_name}" | ||
| avg_key = f"avg_{stat_name}" | ||
| if self._stats[avg_key] is None: | ||
| self._stats[avg_key] = value | ||
| else: | ||
| self._stats[avg_key] = round( | ||
| (self._stats[avg_key] * self._stats["successful_count"] + value) | ||
| / (self._stats["successful_count"] + 1), | ||
| 2, | ||
| ) | ||
|
|
||
| # Update max valuename | ||
| # Accumulate exact sum | ||
| self._stats[sum_key] += value | ||
|
|
||
| # Compute average from exact sum (successful_count not yet incremented) | ||
| new_count = self._stats["successful_count"] + 1 | ||
| self._stats[avg_key] = self._stats[sum_key] / new_count | ||
|
|
||
| # Update max value | ||
| max_key = f"max_{stat_name}" | ||
| if self._stats[max_key] is None or value > self._stats[max_key]: | ||
| self._stats[max_key] = value | ||
|
|
@@ -404,6 +459,14 @@ def intercept_response( | |
| return resp | ||
| status_code = resp.r.status_code | ||
|
|
||
| # Detect retries: if we've already counted a successful response for | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [Power Review] The architecture document warns: 'Monolithic Functions (>50 lines): AI generates long, flat functions instead of composing smaller units.' The diff adds ~55 lines of new code to 💡 Suggestion: Extract the retry detection and retry-path handling into a dedicated private method, e.g. is_retry = self._is_retry(resp)
if is_retry:
self._handle_retry(resp, cache_key)
else:
self._handle_successful_response(resp, context, status_code) |
||
| # the same request content (identified by cache_key), this is a retry | ||
| # from a chain-job requeue. We still track it in count/status_codes | ||
| # (total API calls) but skip updating token stats and successful_count | ||
| # to avoid inflating averages. | ||
| cache_key = getattr(resp.rctx, "cache_key", None) | ||
| is_retry = cache_key is not None and cache_key in self._seen_cache_keys | ||
|
|
||
| # Update time tracking with current timestamp | ||
| current_time = time.time() | ||
| self._update_time_tracking(current_time) | ||
|
|
@@ -414,29 +477,46 @@ def intercept_response( | |
| # Always add basic response stats (count, status_code) | ||
| self._add_basic_response_stats(resp, context) | ||
|
|
||
| if is_retry: | ||
| with self._lock: | ||
| self._stats["retry_count"] += 1 | ||
| self.logger.debug( | ||
| "Detected retry request, skipping token stats update", | ||
| request_id=resp.rctx.request_id, | ||
| cache_key=cache_key[:8] + "..." if cache_key else None, | ||
| ) | ||
| # Extract detailed stats once and reuse them | ||
| detailed_stats = None | ||
| try: | ||
| # Try to parse response as JSON | ||
| response_data = resp.r.json() | ||
| if not is_retry: | ||
| try: | ||
| # Try to parse response as JSON | ||
| response_data = resp.r.json() | ||
|
|
||
| if status_code == 200: | ||
| detailed_stats = self._extract_detailed_response_stats( | ||
| response_data | ||
| ) | ||
|
|
||
| if status_code == 200: | ||
| detailed_stats = self._extract_detailed_response_stats(response_data) | ||
| # Add detailed stats for aggregation | ||
| self._update_response_stats(detailed_stats) | ||
|
|
||
| # Add detailed stats for aggregation | ||
| self._update_response_stats(detailed_stats) | ||
| # Mark this cache_key as seen | ||
| if cache_key is not None: | ||
| self._seen_cache_keys.add(cache_key) | ||
|
|
||
| self.logger.debug( | ||
| "Collected detailed response stats", | ||
| request_id=resp.rctx.request_id, | ||
| response_count=self._stats["count"], | ||
| status_code=status_code, | ||
| ) | ||
| self.logger.debug( | ||
| "Collected detailed response stats", | ||
| request_id=resp.rctx.request_id, | ||
| response_count=self._stats["count"], | ||
| status_code=status_code, | ||
| ) | ||
|
|
||
| except (json.JSONDecodeError, Exception) as e: | ||
| # Handle both JSON parsing errors and other exceptions | ||
| # In case of any error, only basic stats are collected | ||
| self.logger.warning(f"Error parsing response body for token counting: {e}") | ||
| except (json.JSONDecodeError, Exception) as e: | ||
| # Handle both JSON parsing errors and other exceptions | ||
| # In case of any error, only basic stats are collected | ||
| self.logger.warning( | ||
| f"Error parsing response body for token counting: {e}" | ||
| ) | ||
|
|
||
| # Save stats to file if interval reached | ||
| if ( | ||
|
|
@@ -478,6 +558,22 @@ def _save_stats_to_file(self, context: AdapterGlobalContext) -> None: | |
| self.logger.debug("No response statistics collected, skipping file write") | ||
| return | ||
|
|
||
| # Round averages to 2 decimal places for display | ||
| for avg_key in [ | ||
| "avg_prompt_tokens", | ||
| "avg_total_tokens", | ||
| "avg_completion_tokens", | ||
| "avg_latency_ms", | ||
| ]: | ||
| if stats[avg_key] is not None: | ||
| stats[avg_key] = round(stats[avg_key], 2) | ||
|
|
||
| # Expose exact totals as total_* fields and remove internal sum_* fields | ||
| for token_type in ["prompt_tokens", "total_tokens", "completion_tokens"]: | ||
| stats[f"total_{token_type}"] = stats.pop(f"sum_{token_type}", 0) | ||
| # Remove internal sum_latency_ms (not useful as an output field) | ||
| stats.pop("sum_latency_ms", None) | ||
|
|
||
| # Convert timestamps to readable dates in inference_run_times and add time_to_first_request | ||
| if "inference_run_times" in stats: | ||
| for run_id, run_data in stats["inference_run_times"].items(): | ||
|
|
@@ -591,6 +687,9 @@ def _save_aggregated_stats_to_cache(self) -> None: | |
| # Update aggregated stats in interceptor state | ||
| interceptor_state["aggregated_stats"] = stats_to_cache | ||
|
|
||
| # Persist seen cache keys for retry deduplication across allocations | ||
| interceptor_state["seen_cache_keys"] = list(self._seen_cache_keys) | ||
|
|
||
| # Save updated interceptor state | ||
| self._save_interceptor_state(interceptor_state) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[Power Review]⚠️ Backward-compatibility migration bolted inline into
_load_aggregated_cached_stats(Geological Layers) · 78% confidenceThe architecture document warns against 'Geological Layers: new code bolted on top of old code without integrating — duplicate data paths, parallel config mechanisms, wrapper-on-wrapper stacking.' Lines 225–257 add ~35 lines of multi-branch migration logic (back-computing
sum_*fromavg_*, handlingtotal_*, ensuringretry_count) inline at the top of the existing load method rather than extracting it into a focused helper. This creates a layered accumulation of unrelated responsibilities in one method.💡 Suggestion: Extract the migration block into a private method such as
_migrate_legacy_cached_stats(aggregated_stats: dict) -> dict._load_aggregated_cached_statsthen calls it as a single delegation step, keeping the load method's intent readable: