Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion litellm/proxy/anthropic_endpoints/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,13 @@ async def anthropic_response( # noqa: PLR0915

response = responses[1]

# Extract model_id from request metadata (set by router during routing)
litellm_metadata = data.get("litellm_metadata", {}) or {}
model_info = litellm_metadata.get("model_info", {}) or {}
model_id = model_info.get("id", "") or ""

# Get other metadata from hidden_params
hidden_params = getattr(response, "_hidden_params", {}) or {}
model_id = hidden_params.get("model_id", None) or ""
cache_key = hidden_params.get("cache_key", None) or ""
api_base = hidden_params.get("api_base", None) or ""
response_cost = hidden_params.get("response_cost", None) or ""
Expand Down Expand Up @@ -216,12 +221,32 @@ async def anthropic_response( # noqa: PLR0915
str(e)
)
)

# Extract model_id from request metadata (same as success path)
litellm_metadata = data.get("litellm_metadata", {}) or {}
model_info = litellm_metadata.get("model_info", {}) or {}
model_id = model_info.get("id", "") or ""

# Get headers
headers = ProxyBaseLLMRequestProcessing.get_custom_headers(
user_api_key_dict=user_api_key_dict,
call_id=data.get("litellm_call_id", ""),
model_id=model_id,
version=version,
response_cost=0,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
request_data=data,
timeout=getattr(e, "timeout", None),
litellm_logging_obj=None,
)

error_msg = f"{str(e)}"
raise ProxyException(
message=getattr(e, "message", error_msg),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", 500),
headers=headers,
)


Expand Down
64 changes: 64 additions & 0 deletions litellm/proxy/common_request_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,7 @@ async def common_processing_pre_call_logic(
user_max_tokens: Optional[int] = None,
user_api_base: Optional[str] = None,
model: Optional[str] = None,
llm_router: Optional[Router] = None,
) -> Tuple[dict, LiteLLMLoggingObj]:
start_time = datetime.now() # start before calling guardrail hooks

Expand Down Expand Up @@ -490,6 +491,7 @@ async def base_process_llm_request(
user_api_base=user_api_base,
model=model,
route_type=route_type,
llm_router=llm_router,
)

tasks = []
Expand Down Expand Up @@ -528,6 +530,13 @@ async def base_process_llm_request(

hidden_params = getattr(response, "_hidden_params", {}) or {}
model_id = hidden_params.get("model_id", None) or ""

# Fallback: extract model_id from litellm_metadata if not in hidden_params
if not model_id:
litellm_metadata = self.data.get("litellm_metadata", {}) or {}
model_info = litellm_metadata.get("model_info", {}) or {}
model_id = model_info.get("id", "") or ""

cache_key = hidden_params.get("cache_key", None) or ""
api_base = hidden_params.get("api_base", None) or ""
response_cost = hidden_params.get("response_cost", None) or ""
Expand Down Expand Up @@ -748,11 +757,19 @@ async def _handle_llm_api_exception(
_litellm_logging_obj: Optional[LiteLLMLoggingObj] = self.data.get(
"litellm_logging_obj", None
)

# Attempt to get model_id from logging object
#
# Note: We check the direct model_info path first (not nested in metadata) because that's where the router sets it.
# The nested metadata path is only a fallback for cases where model_info wasn't set at the top level.
model_id = self.maybe_get_model_id(_litellm_logging_obj)

custom_headers = ProxyBaseLLMRequestProcessing.get_custom_headers(
user_api_key_dict=user_api_key_dict,
call_id=(
_litellm_logging_obj.litellm_call_id if _litellm_logging_obj else None
),
model_id=model_id,
version=version,
response_cost=0,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
Expand Down Expand Up @@ -1065,3 +1082,50 @@ def _inject_cost_into_usage_dict(obj: dict, model_name: str) -> Optional[dict]:
obj.setdefault("usage", {})["cost"] = cost_val
return obj
return None

def maybe_get_model_id(self, _logging_obj: Optional[LiteLLMLoggingObj]) -> Optional[str]:
"""
Get model_id from logging object or request metadata.

The router sets model_info.id when selecting a deployment. This tries multiple locations
where the ID might be stored depending on the request lifecycle stage.
"""
model_id = None
if _logging_obj:
# 1. Try getting from litellm_params (updated during call)
if (
hasattr(_logging_obj, "litellm_params")
and _logging_obj.litellm_params
):
# First check direct model_info path (set by router.py with selected deployment)
model_info = _logging_obj.litellm_params.get("model_info") or {}
model_id = model_info.get("id", None)

# Fallback to nested metadata path
if not model_id:
metadata = _logging_obj.litellm_params.get("metadata") or {}
model_info = metadata.get("model_info") or {}
model_id = model_info.get("id", None)

# 2. Fallback to kwargs (initial)
if not model_id:
_kwargs = getattr(_logging_obj, "kwargs", None)
if _kwargs:
litellm_params = _kwargs.get("litellm_params", {})
# First check direct model_info path
model_info = litellm_params.get("model_info") or {}
model_id = model_info.get("id", None)

# Fallback to nested metadata path
if not model_id:
metadata = litellm_params.get("metadata") or {}
model_info = metadata.get("model_info") or {}
model_id = model_info.get("id", None)

# 3. Final fallback to self.data["litellm_metadata"] (for routes like /v1/responses that populate data before error)
if not model_id:
litellm_metadata = self.data.get("litellm_metadata", {}) or {}
model_info = litellm_metadata.get("model_info", {}) or {}
model_id = model_info.get("id", None)

return model_id
19 changes: 19 additions & 0 deletions litellm/responses/streaming_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
import litellm
from litellm.constants import STREAM_SSE_DONE_STRING
from litellm.litellm_core_utils.asyncify import run_async_function
from litellm.litellm_core_utils.core_helpers import process_response_headers
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.litellm_core_utils.llm_response_utils.get_api_base import get_api_base
from litellm.litellm_core_utils.thread_pool_executor import executor
from litellm.llms.base_llm.responses.transformation import BaseResponsesAPIConfig
from litellm.responses.utils import ResponsesAPIRequestUtils
Expand Down Expand Up @@ -51,6 +53,23 @@ def __init__(
self.litellm_metadata = litellm_metadata
self.custom_llm_provider = custom_llm_provider

# set hidden params for response headers (e.g., x-litellm-model-id)
# This matches ths stream wrapper in litellm/litellm_core_utils/streaming_handler.py
_api_base = get_api_base(
model=model or "",
optional_params=self.logging_obj.model_call_details.get(
"litellm_params", {}
),
)
_model_info: Dict = litellm_metadata.get("model_info", {}) if litellm_metadata else {}
self._hidden_params = {
"model_id": _model_info.get("id", None),
"api_base": _api_base,
}
self._hidden_params["additional_headers"] = process_response_headers(
self.response.headers or {}
) # GUARANTEE OPENAI HEADERS IN RESPONSE

def _process_chunk(self, chunk) -> Optional[ResponsesAPIStreamingResponse]:
"""Process a single chunk of data from the stream"""
if not chunk:
Expand Down
Loading
Loading