Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion litellm/proxy/common_utils/openai_endpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Contains utils used by OpenAI compatible endpoints
"""

from typing import Optional
from typing import Optional, List

from fastapi import Request

Expand Down Expand Up @@ -33,6 +33,36 @@ def remove_sensitive_info_from_deployment(deployment_dict: dict) -> dict:
return deployment_dict


def process_model_info_fields_from_deployment(
deployment_dict: dict, model_info_fields: Optional[List[str]]
) -> dict:
"""
Keeps only the specified fields in a deployment dictionary (whitelist approach).

This function filters the deployment dictionary to include only the fields
specified in model_info_fields. All other fields are removed.

Args:
deployment_dict (dict): The deployment dictionary to filter fields from.
model_info_fields (Optional[List[str]]): List of field names to keep in
the deployment dictionary. If None, all fields are kept.

Returns:
dict: The modified deployment dictionary with only specified fields kept.
"""
if model_info_fields is None:
return deployment_dict

# Keep only fields that are in the model_info_fields list
filtered_dict = {
key: value
for key, value in deployment_dict.items()
if key in model_info_fields
}

return filtered_dict


async def get_custom_llm_provider_from_request_body(request: Request) -> Optional[str]:
"""
Get the `custom_llm_provider` from the request body
Expand Down
35 changes: 31 additions & 4 deletions litellm/proxy/proxy_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def generate_feedback_box():
get_file_contents_from_s3,
)
from litellm.proxy.common_utils.openai_endpoint_utils import (
remove_sensitive_info_from_deployment,
remove_sensitive_info_from_deployment, process_model_info_fields_from_deployment,
)
from litellm.proxy.common_utils.proxy_state import ProxyState
from litellm.proxy.common_utils.reset_budget_job import ResetBudgetJob
Expand Down Expand Up @@ -7497,7 +7497,7 @@ async def model_metrics_exceptions(
return {"data": response, "exception_types": list(exception_types)}


def _get_proxy_model_info(model: dict) -> dict:
def _get_proxy_model_info(model: dict, model_info_fields: Optional[List[str]], return_litellm_params: bool) -> dict:
# provided model_info in config.yaml
model_info = model.get("model_info", {})

Expand Down Expand Up @@ -7534,6 +7534,9 @@ def _get_proxy_model_info(model: dict) -> dict:
model["model_info"] = model_info
# don't return the llm credentials
model = remove_sensitive_info_from_deployment(deployment_dict=model)
model = process_model_info_fields_from_deployment(deployment_dict=model, model_info_fields=model_info_fields)
if not return_litellm_params and "litellm_params" in model:
del model["litellm_params"]

return model

Expand All @@ -7551,6 +7554,8 @@ def _get_proxy_model_info(model: dict) -> dict:
async def model_info_v1( # noqa: PLR0915
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
litellm_model_id: Optional[str] = None,
model_info_fields: Optional[List[str]] = None,
return_litellm_params: bool = True
):
"""
Provides more info about each model in /models, including config.yaml descriptions (except api key and api base)
Expand All @@ -7561,6 +7566,16 @@ async def model_info_v1( # noqa: PLR0915
- When litellm_model_id is passed, it will return the info for that specific model
- When litellm_model_id is not passed, it will return the info for all models

model_info_fields: Optional[List[str]] = None (list of field names to include in the response)

- When model_info_fields is passed, only the specified fields will be included in each model's data
- When model_info_fields is not passed, all fields are included in the response

return_litellm_params: bool = True (controls whether to include litellm_params in the response)

- When return_litellm_params is True (default), litellm_params are included with sensitive data masked
- When return_litellm_params is False, litellm_params are completely removed from the response

Returns:
Returns a dictionary containing information about each model.

Expand Down Expand Up @@ -7603,6 +7618,14 @@ async def model_info_v1( # noqa: PLR0915
_deployment_info_dict = remove_sensitive_info_from_deployment(
deployment_dict=_deployment_info_dict
)
_deployment_info_dict = process_model_info_fields_from_deployment(
deployment_dict=_deployment_info_dict,
model_info_fields=model_info_fields,
)

if not return_litellm_params and "litellm_params" in _deployment_info_dict:
del _deployment_info_dict["litellm_params"]

return {"data": _deployment_info_dict}

if llm_model_list is None:
Expand Down Expand Up @@ -7632,7 +7655,9 @@ async def model_info_v1( # noqa: PLR0915
},
)
_deployment_info_dict = _get_proxy_model_info(
model=deployment_info.model_dump(exclude_none=True)
model=deployment_info.model_dump(exclude_none=True),
model_info_fields=model_info_fields,
return_litellm_params=return_litellm_params,
)
return {"data": [_deployment_info_dict]}

Expand Down Expand Up @@ -7675,7 +7700,9 @@ async def model_info_v1( # noqa: PLR0915
all_models = []

for in_place_model in all_models:
in_place_model = _get_proxy_model_info(model=in_place_model)
in_place_model = _get_proxy_model_info(
model=in_place_model, model_info_fields=model_info_fields, return_litellm_params=return_litellm_params
)

verbose_proxy_logger.debug("all_models: %s", all_models)
return {"data": all_models}
Expand Down
140 changes: 140 additions & 0 deletions tests/test_litellm/proxy/test_proxy_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2048,6 +2048,146 @@ async def test_model_info_v1_oci_secrets_not_leaked():
assert "/path/to/oci_api_key.pem" not in result_str


@pytest.mark.asyncio
async def test_model_info_v1_with_model_info_fields_filter():
"""Test that model_info_fields parameter keeps only specified fields (whitelist approach)"""
from unittest.mock import MagicMock, patch

from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth
from litellm.proxy.proxy_server import model_info_v1
from litellm.types.router import Deployment as RouterDeployment
from litellm.types.router import LiteLLM_Params
from litellm.types.router import ModelInfo

# Mock user
user_api_key_dict = UserAPIKeyAuth(
user_id="test_user",
api_key="test_key",
user_role=LitellmUserRoles.PROXY_ADMIN,
)

# Create mock model data with multiple fields
mock_model_data = {
"model_name": "gpt-4",
"litellm_params": {
"model": "gpt-4",
"api_key": "sk-***",
"api_base": "https://api.openai.com/v1",
},
"model_info": {
"id": "test-model-id",
"db_model": False,
"description": "Test model description",
"tags": ["production", "test"],
},
}

# Mock router with deployment
mock_router = MagicMock()
mock_deployment = RouterDeployment(
model_name="gpt-4",
litellm_params=LiteLLM_Params(model="gpt-4"),
model_info=ModelInfo(id="test-model-id", db_model=False),
)
mock_router.get_deployment.return_value = mock_deployment

with (
patch("litellm.proxy.proxy_server.llm_router", mock_router),
patch("litellm.proxy.proxy_server.llm_model_list", [mock_model_data]),
patch(
"litellm.proxy.proxy_server.general_settings",
{"infer_model_from_keys": False},
),
patch("litellm.proxy.proxy_server.user_model", None),
):

# Call model_info_v1 with model_info_fields to keep only "model_name" and "model_info"
result = await model_info_v1(
user_api_key_dict=user_api_key_dict,
litellm_model_id="test-model-id",
model_info_fields=["model_name", "model_info"],
)

# Verify the result structure
assert "data" in result
assert len(result["data"]) == 1

model_data = result["data"][0]

assert set(model_data.keys()) == {
"model_name",
"model_info",
}, f"Expected only ['model_name', 'model_info'], but got {list(model_data.keys())}"
assert "model_name" in model_data, "model_name should be present"
assert "model_info" in model_data, "model_info should be present"


@pytest.mark.asyncio
async def test_model_info_v1_with_return_litellm_params_false():
"""Test that return_litellm_params=False removes litellm_params from /v1/model/info response"""
from unittest.mock import MagicMock, patch

from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth
from litellm.proxy.proxy_server import model_info_v1
from litellm.types.router import Deployment as RouterDeployment
from litellm.types.router import LiteLLM_Params
from litellm.types.router import ModelInfo

# Mock user
user_api_key_dict = UserAPIKeyAuth(
user_id="test_user",
api_key="test_key",
user_role=LitellmUserRoles.PROXY_ADMIN,
)

# Create mock model data
mock_model_data = {
"model_name": "gpt-4",
"litellm_params": {
"model": "gpt-4",
"api_key": "sk-***",
"api_base": "https://api.openai.com/v1",
},
"model_info": {"id": "test-model-id", "db_model": False},
}

# Mock router with deployment
mock_router = MagicMock()
mock_deployment = RouterDeployment(
model_name="gpt-4",
litellm_params=LiteLLM_Params(model="gpt-4"),
model_info=ModelInfo(id="test-model-id", db_model=False),
)
mock_router.get_deployment.return_value = mock_deployment

with (
patch("litellm.proxy.proxy_server.llm_router", mock_router),
patch("litellm.proxy.proxy_server.llm_model_list", [mock_model_data]),
patch(
"litellm.proxy.proxy_server.general_settings",
{"infer_model_from_keys": False},
),
patch("litellm.proxy.proxy_server.user_model", None),
):

# Call model_info_v1 with return_litellm_params=False
result = await model_info_v1(
user_api_key_dict=user_api_key_dict,
litellm_model_id="test-model-id",
return_litellm_params=False,
)

# Verify the result structure
assert "data" in result
assert len(result["data"]) == 1

model_data = result["data"][0]

# Verify litellm_params is NOT present
assert (
"litellm_params" not in model_data
), "litellm_params should be removed when return_litellm_params=False"

def test_add_callback_from_db_to_in_memory_litellm_callbacks():
"""
Test that _add_callback_from_db_to_in_memory_litellm_callbacks correctly adds callbacks
Expand Down
Loading