diff --git a/litellm/proxy/common_utils/openai_endpoint_utils.py b/litellm/proxy/common_utils/openai_endpoint_utils.py index 7b1a2945ba6b..4fdd51b51eaf 100644 --- a/litellm/proxy/common_utils/openai_endpoint_utils.py +++ b/litellm/proxy/common_utils/openai_endpoint_utils.py @@ -2,7 +2,7 @@ Contains utils used by OpenAI compatible endpoints """ -from typing import Optional +from typing import Optional, List from fastapi import Request @@ -33,6 +33,36 @@ def remove_sensitive_info_from_deployment(deployment_dict: dict) -> dict: return deployment_dict +def process_model_info_fields_from_deployment( + deployment_dict: dict, model_info_fields: Optional[List[str]] +) -> dict: + """ + Keeps only the specified fields in a deployment dictionary (whitelist approach). + + This function filters the deployment dictionary to include only the fields + specified in model_info_fields. All other fields are removed. + + Args: + deployment_dict (dict): The deployment dictionary to filter fields from. + model_info_fields (Optional[List[str]]): List of field names to keep in + the deployment dictionary. If None, all fields are kept. + + Returns: + dict: The modified deployment dictionary with only specified fields kept. + """ + if model_info_fields is None: + return deployment_dict + + # Keep only fields that are in the model_info_fields list + filtered_dict = { + key: value + for key, value in deployment_dict.items() + if key in model_info_fields + } + + return filtered_dict + + async def get_custom_llm_provider_from_request_body(request: Request) -> Optional[str]: """ Get the `custom_llm_provider` from the request body diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 4040b0aa707a..a3daa81b928e 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -225,7 +225,7 @@ def generate_feedback_box(): get_file_contents_from_s3, ) from litellm.proxy.common_utils.openai_endpoint_utils import ( - remove_sensitive_info_from_deployment, + remove_sensitive_info_from_deployment, process_model_info_fields_from_deployment, ) from litellm.proxy.common_utils.proxy_state import ProxyState from litellm.proxy.common_utils.reset_budget_job import ResetBudgetJob @@ -7497,7 +7497,7 @@ async def model_metrics_exceptions( return {"data": response, "exception_types": list(exception_types)} -def _get_proxy_model_info(model: dict) -> dict: +def _get_proxy_model_info(model: dict, model_info_fields: Optional[List[str]], return_litellm_params: bool) -> dict: # provided model_info in config.yaml model_info = model.get("model_info", {}) @@ -7534,6 +7534,9 @@ def _get_proxy_model_info(model: dict) -> dict: model["model_info"] = model_info # don't return the llm credentials model = remove_sensitive_info_from_deployment(deployment_dict=model) + model = process_model_info_fields_from_deployment(deployment_dict=model, model_info_fields=model_info_fields) + if not return_litellm_params and "litellm_params" in model: + del model["litellm_params"] return model @@ -7551,6 +7554,8 @@ def _get_proxy_model_info(model: dict) -> dict: async def model_info_v1( # noqa: PLR0915 user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), litellm_model_id: Optional[str] = None, + model_info_fields: Optional[List[str]] = None, + return_litellm_params: bool = True ): """ Provides more info about each model in /models, including config.yaml descriptions (except api key and api base) @@ -7561,6 +7566,16 @@ async def model_info_v1( # noqa: PLR0915 - When litellm_model_id is passed, it will return the info for that specific model - When litellm_model_id is not passed, it will return the info for all models + model_info_fields: Optional[List[str]] = None (list of field names to include in the response) + + - When model_info_fields is passed, only the specified fields will be included in each model's data + - When model_info_fields is not passed, all fields are included in the response + + return_litellm_params: bool = True (controls whether to include litellm_params in the response) + + - When return_litellm_params is True (default), litellm_params are included with sensitive data masked + - When return_litellm_params is False, litellm_params are completely removed from the response + Returns: Returns a dictionary containing information about each model. @@ -7603,6 +7618,14 @@ async def model_info_v1( # noqa: PLR0915 _deployment_info_dict = remove_sensitive_info_from_deployment( deployment_dict=_deployment_info_dict ) + _deployment_info_dict = process_model_info_fields_from_deployment( + deployment_dict=_deployment_info_dict, + model_info_fields=model_info_fields, + ) + + if not return_litellm_params and "litellm_params" in _deployment_info_dict: + del _deployment_info_dict["litellm_params"] + return {"data": _deployment_info_dict} if llm_model_list is None: @@ -7632,7 +7655,9 @@ async def model_info_v1( # noqa: PLR0915 }, ) _deployment_info_dict = _get_proxy_model_info( - model=deployment_info.model_dump(exclude_none=True) + model=deployment_info.model_dump(exclude_none=True), + model_info_fields=model_info_fields, + return_litellm_params=return_litellm_params, ) return {"data": [_deployment_info_dict]} @@ -7675,7 +7700,9 @@ async def model_info_v1( # noqa: PLR0915 all_models = [] for in_place_model in all_models: - in_place_model = _get_proxy_model_info(model=in_place_model) + in_place_model = _get_proxy_model_info( + model=in_place_model, model_info_fields=model_info_fields, return_litellm_params=return_litellm_params + ) verbose_proxy_logger.debug("all_models: %s", all_models) return {"data": all_models} diff --git a/tests/test_litellm/proxy/test_proxy_server.py b/tests/test_litellm/proxy/test_proxy_server.py index 8e87b67933fc..12e69523f198 100644 --- a/tests/test_litellm/proxy/test_proxy_server.py +++ b/tests/test_litellm/proxy/test_proxy_server.py @@ -2048,6 +2048,146 @@ async def test_model_info_v1_oci_secrets_not_leaked(): assert "/path/to/oci_api_key.pem" not in result_str +@pytest.mark.asyncio +async def test_model_info_v1_with_model_info_fields_filter(): + """Test that model_info_fields parameter keeps only specified fields (whitelist approach)""" + from unittest.mock import MagicMock, patch + + from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth + from litellm.proxy.proxy_server import model_info_v1 + from litellm.types.router import Deployment as RouterDeployment + from litellm.types.router import LiteLLM_Params + from litellm.types.router import ModelInfo + + # Mock user + user_api_key_dict = UserAPIKeyAuth( + user_id="test_user", + api_key="test_key", + user_role=LitellmUserRoles.PROXY_ADMIN, + ) + + # Create mock model data with multiple fields + mock_model_data = { + "model_name": "gpt-4", + "litellm_params": { + "model": "gpt-4", + "api_key": "sk-***", + "api_base": "https://api.openai.com/v1", + }, + "model_info": { + "id": "test-model-id", + "db_model": False, + "description": "Test model description", + "tags": ["production", "test"], + }, + } + + # Mock router with deployment + mock_router = MagicMock() + mock_deployment = RouterDeployment( + model_name="gpt-4", + litellm_params=LiteLLM_Params(model="gpt-4"), + model_info=ModelInfo(id="test-model-id", db_model=False), + ) + mock_router.get_deployment.return_value = mock_deployment + + with ( + patch("litellm.proxy.proxy_server.llm_router", mock_router), + patch("litellm.proxy.proxy_server.llm_model_list", [mock_model_data]), + patch( + "litellm.proxy.proxy_server.general_settings", + {"infer_model_from_keys": False}, + ), + patch("litellm.proxy.proxy_server.user_model", None), + ): + + # Call model_info_v1 with model_info_fields to keep only "model_name" and "model_info" + result = await model_info_v1( + user_api_key_dict=user_api_key_dict, + litellm_model_id="test-model-id", + model_info_fields=["model_name", "model_info"], + ) + + # Verify the result structure + assert "data" in result + assert len(result["data"]) == 1 + + model_data = result["data"][0] + + assert set(model_data.keys()) == { + "model_name", + "model_info", + }, f"Expected only ['model_name', 'model_info'], but got {list(model_data.keys())}" + assert "model_name" in model_data, "model_name should be present" + assert "model_info" in model_data, "model_info should be present" + + +@pytest.mark.asyncio +async def test_model_info_v1_with_return_litellm_params_false(): + """Test that return_litellm_params=False removes litellm_params from /v1/model/info response""" + from unittest.mock import MagicMock, patch + + from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth + from litellm.proxy.proxy_server import model_info_v1 + from litellm.types.router import Deployment as RouterDeployment + from litellm.types.router import LiteLLM_Params + from litellm.types.router import ModelInfo + + # Mock user + user_api_key_dict = UserAPIKeyAuth( + user_id="test_user", + api_key="test_key", + user_role=LitellmUserRoles.PROXY_ADMIN, + ) + + # Create mock model data + mock_model_data = { + "model_name": "gpt-4", + "litellm_params": { + "model": "gpt-4", + "api_key": "sk-***", + "api_base": "https://api.openai.com/v1", + }, + "model_info": {"id": "test-model-id", "db_model": False}, + } + + # Mock router with deployment + mock_router = MagicMock() + mock_deployment = RouterDeployment( + model_name="gpt-4", + litellm_params=LiteLLM_Params(model="gpt-4"), + model_info=ModelInfo(id="test-model-id", db_model=False), + ) + mock_router.get_deployment.return_value = mock_deployment + + with ( + patch("litellm.proxy.proxy_server.llm_router", mock_router), + patch("litellm.proxy.proxy_server.llm_model_list", [mock_model_data]), + patch( + "litellm.proxy.proxy_server.general_settings", + {"infer_model_from_keys": False}, + ), + patch("litellm.proxy.proxy_server.user_model", None), + ): + + # Call model_info_v1 with return_litellm_params=False + result = await model_info_v1( + user_api_key_dict=user_api_key_dict, + litellm_model_id="test-model-id", + return_litellm_params=False, + ) + + # Verify the result structure + assert "data" in result + assert len(result["data"]) == 1 + + model_data = result["data"][0] + + # Verify litellm_params is NOT present + assert ( + "litellm_params" not in model_data + ), "litellm_params should be removed when return_litellm_params=False" + def test_add_callback_from_db_to_in_memory_litellm_callbacks(): """ Test that _add_callback_from_db_to_in_memory_litellm_callbacks correctly adds callbacks