Skip to content
Open
91 changes: 84 additions & 7 deletions cascadeflow/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@

# Phase 2B + v2.5: Telemetry module imports (with CostCalculator)
from .telemetry import CallbackManager, CostCalculator, MetricsCollector
from .integrations.litellm import LITELLM_AVAILABLE, LiteLLMCostProvider

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -1650,6 +1651,9 @@ async def _execute_cascade_with_timing(
timing["quality_verification"] = result.metadata.get("quality_check_ms", 0)
timing["verifier_generation"] = result.metadata.get("verifier_latency_ms", 0)
timing["cascade_overhead"] = result.metadata.get("cascade_overhead_ms", 0)
timing["tool_complexity_analysis_ms"] = result.metadata.get(
"tool_complexity_analysis_ms", 0
)
else:
timing["cascade_total"] = cascade_total

Expand Down Expand Up @@ -1714,8 +1718,33 @@ async def _execute_direct_with_timing(
)
direct_latency = (time.time() - direct_start) * 1000

tokens_used = response.tokens_used if hasattr(response, "tokens_used") else max_tokens
cost = best_model.cost * (tokens_used / 1000)
prompt_tokens = None
completion_tokens = None
total_tokens = None
if hasattr(response, "metadata") and response.metadata:
prompt_tokens = response.metadata.get("prompt_tokens")
completion_tokens = response.metadata.get("completion_tokens")
total_tokens = response.metadata.get("total_tokens")
if total_tokens is None and prompt_tokens is not None and completion_tokens is not None:
total_tokens = prompt_tokens + completion_tokens

cost = None
if LITELLM_AVAILABLE and prompt_tokens is not None and completion_tokens is not None:
try:
provider = LiteLLMCostProvider()
cost = provider.calculate_cost(
model=best_model.name,
input_tokens=prompt_tokens,
output_tokens=completion_tokens,
)
except Exception as exc:
logger.warning(f"LiteLLM direct cost failed for {best_model.name}: {exc}")

if cost is None:
tokens_used = response.tokens_used if hasattr(response, "tokens_used") else None
if not tokens_used:
tokens_used = total_tokens if total_tokens is not None else max_tokens
cost = best_model.cost * (tokens_used / 1000)

result = self._create_direct_result(
response.content,
Expand All @@ -1724,6 +1753,9 @@ async def _execute_direct_with_timing(
direct_latency,
reason,
tool_calls=getattr(response, "tool_calls", None),
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
)

timing = {
Expand All @@ -1732,6 +1764,7 @@ async def _execute_direct_with_timing(
"quality_verification": 0,
"verifier_generation": 0,
"cascade_overhead": 0,
"tool_complexity_analysis_ms": 0.0,
}

return result, timing
Expand Down Expand Up @@ -2132,7 +2165,18 @@ def _build_cascade_result(
# HELPER METHODS - WITH TOOL SUPPORT
# ========================================================================

def _create_direct_result(self, content, model, cost, latency, reason, tool_calls=None):
def _create_direct_result(
self,
content,
model,
cost,
latency,
reason,
tool_calls=None,
prompt_tokens: Optional[int] = None,
completion_tokens: Optional[int] = None,
total_tokens: Optional[int] = None,
):
"""
Create result object for direct routing with tool support.

Expand All @@ -2142,7 +2186,18 @@ def _create_direct_result(self, content, model, cost, latency, reason, tool_call
class DirectResult:
"""Mimics cascade results for telemetry compatibility."""

def __init__(self, content, model, cost, latency, reason, tool_calls=None):
def __init__(
self,
content,
model,
cost,
latency,
reason,
tool_calls=None,
prompt_tokens: Optional[int] = None,
completion_tokens: Optional[int] = None,
total_tokens: Optional[int] = None,
):
# Core attributes
self.content = content
self.model_used = model
Expand All @@ -2158,6 +2213,16 @@ def __init__(self, content, model, cost, latency, reason, tool_calls=None):
self.verifier_confidence = 0.95
self.speedup = 1.0

token_total = total_tokens
if (
token_total is None
and prompt_tokens is not None
and completion_tokens is not None
):
token_total = prompt_tokens + completion_tokens
if token_total is None:
token_total = int(len(content.split()) * 1.3)

# Complete metadata
self.metadata = {
"reason": reason,
Expand All @@ -2184,10 +2249,12 @@ def __init__(self, content, model, cost, latency, reason, tool_calls=None):
"quality_threshold": None,
"quality_check_passed": None,
"rejection_reason": None,
"prompt_tokens": prompt_tokens or 0,
"completion_tokens": completion_tokens or 0,
"total_tokens": token_total,
"tokens_generated": int(len(content.split()) * 1.3),
"total_tokens": int(len(content.split()) * 1.3),
"draft_tokens": 0,
"verifier_tokens": int(len(content.split()) * 1.3),
"verifier_tokens": token_total,
"draft_model": None,
"verifier_model": model,
"tool_calls": tool_calls,
Expand Down Expand Up @@ -2216,7 +2283,17 @@ def to_dict(self):
"metadata": self.metadata,
}

return DirectResult(content, model, cost, latency, reason, tool_calls)
return DirectResult(
content,
model,
cost,
latency,
reason,
tool_calls,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
)

def _dict_to_result(self, data):
"""Convert dict to result object."""
Expand Down
2 changes: 2 additions & 0 deletions tests/benchmarks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Professional benchmarks to validate CascadeFlow performance across real-world us
2. **Bitext Customer Support** - Customer service Q&A (27,000+ examples)
3. **Banking77** - Banking intent classification (13,000+ examples)
4. **GSM8K** - Grade school math reasoning (8,500+ problems)
5. **ToolCalls Real-World** - Tool routing with multi-turn context

#### Metrics

Expand All @@ -16,6 +17,7 @@ Each benchmark measures:
- **Quality maintenance** (accuracy/pass rate)
- **Latency** improvements
- **Escalation rates** (drafter acceptance %)
- **Direct routing** counts and **cascade overhead** latency

#### Running Benchmarks

Expand Down
11 changes: 9 additions & 2 deletions tests/benchmarks/banking77_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,12 +388,19 @@ async def run_cascade(self, query: str) -> dict[str, Any]:
"model_used": result.model_used,
"accepted": result.draft_accepted,
"quality_score": result.quality_score,
"routing_strategy": result.routing_strategy,
"drafter_cost": result.draft_cost,
"verifier_cost": result.verifier_cost,
"total_cost": result.total_cost,
"latency_ms": latency_ms,
"tokens_input": 0,
"tokens_output": 0,
"cascadeflow_latency_ms": (
(result.complexity_detection_ms or 0)
+ (result.metadata.get("domain_detection_ms", 0) if result.metadata else 0)
+ (result.metadata.get("tool_complexity_analysis_ms", 0) if result.metadata else 0)
+ (result.quality_verification_ms or 0)
),
"tokens_input": result.metadata.get("prompt_tokens", 0),
"tokens_output": result.metadata.get("completion_tokens", 0),
}


Expand Down
Loading