Skip to content

Commit 1f8fe00

Browse files
authored
[Feat] Prompt Management - Allow specifying just prompt_id in a request to a model (#16834)
* test_dotprompt_auto_detection_with_model_only * fix _auto_detect_prompt_management_logger * test_dotprompt_auto_detection_with_model_only
1 parent 5f94b37 commit 1f8fe00

File tree

2 files changed

+146
-4
lines changed

2 files changed

+146
-4
lines changed

litellm/litellm_core_utils/litellm_logging.py

Lines changed: 70 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -585,7 +585,10 @@ def get_chat_completion_prompt(
585585
custom_logger = (
586586
prompt_management_logger
587587
or self.get_custom_logger_for_prompt_management(
588-
model=model, non_default_params=non_default_params
588+
model=model,
589+
non_default_params=non_default_params,
590+
prompt_id=prompt_id,
591+
dynamic_callback_params=self.standard_callback_dynamic_params,
589592
)
590593
)
591594

@@ -622,7 +625,11 @@ async def async_get_chat_completion_prompt(
622625
custom_logger = (
623626
prompt_management_logger
624627
or self.get_custom_logger_for_prompt_management(
625-
model=model, tools=tools, non_default_params=non_default_params
628+
model=model,
629+
tools=tools,
630+
non_default_params=non_default_params,
631+
prompt_id=prompt_id,
632+
dynamic_callback_params=self.standard_callback_dynamic_params,
626633
)
627634
)
628635

@@ -646,19 +653,69 @@ async def async_get_chat_completion_prompt(
646653
self.messages = messages
647654
return model, messages, non_default_params
648655

656+
def _auto_detect_prompt_management_logger(
657+
self,
658+
prompt_id: str,
659+
dynamic_callback_params: StandardCallbackDynamicParams,
660+
) -> Optional[CustomLogger]:
661+
"""
662+
Auto-detect which prompt management system owns the given prompt_id.
663+
664+
This allows a user to just pass prompt_id in the completion call and it will be auto-detected which system owns this prompt.
665+
666+
Args:
667+
prompt_id: The prompt ID to check
668+
dynamic_callback_params: Dynamic callback parameters for should_run_prompt_management checks
669+
670+
Returns:
671+
A CustomLogger instance if a matching prompt management system is found, None otherwise
672+
"""
673+
prompt_management_loggers = (
674+
litellm.logging_callback_manager.get_custom_loggers_for_type(
675+
callback_type=CustomPromptManagement
676+
)
677+
)
678+
679+
for logger in prompt_management_loggers:
680+
if isinstance(logger, CustomPromptManagement):
681+
try:
682+
if logger.should_run_prompt_management(
683+
prompt_id=prompt_id,
684+
dynamic_callback_params=dynamic_callback_params,
685+
):
686+
self.model_call_details["prompt_integration"] = (
687+
logger.__class__.__name__
688+
)
689+
return logger
690+
except Exception:
691+
# If check fails, continue to next logger
692+
continue
693+
694+
return None
695+
649696
def get_custom_logger_for_prompt_management(
650-
self, model: str, non_default_params: Dict, tools: Optional[List[Dict]] = None
697+
self,
698+
model: str,
699+
non_default_params: Dict,
700+
tools: Optional[List[Dict]] = None,
701+
prompt_id: Optional[str] = None,
702+
dynamic_callback_params: Optional[StandardCallbackDynamicParams] = None,
651703
) -> Optional[CustomLogger]:
652704
"""
653705
Get a custom logger for prompt management based on model name or available callbacks.
654706
655707
Args:
656708
model: The model name to check for prompt management integration
709+
non_default_params: Non-default parameters passed to the completion call
710+
tools: Optional tools passed to the completion call
711+
prompt_id: Optional prompt ID to auto-detect which system owns this prompt
712+
dynamic_callback_params: Dynamic callback parameters for should_run_prompt_management checks
657713
658714
Returns:
659715
A CustomLogger instance if one is found, None otherwise
660716
"""
661717
# First check if model starts with a known custom logger compatible callback
718+
# This takes precedence for backward compatibility
662719
for callback_name in litellm._known_custom_logger_compatible_callbacks:
663720
if model.startswith(callback_name):
664721
custom_logger = _init_custom_logger_compatible_class(
@@ -670,7 +727,16 @@ def get_custom_logger_for_prompt_management(
670727
self.model_call_details["prompt_integration"] = model.split("/")[0]
671728
return custom_logger
672729

673-
# Then check for any registered CustomPromptManagement loggers
730+
# If prompt_id is provided, try to auto-detect which system has this prompt
731+
if prompt_id and dynamic_callback_params is not None:
732+
auto_detected_logger = self._auto_detect_prompt_management_logger(
733+
prompt_id=prompt_id,
734+
dynamic_callback_params=dynamic_callback_params,
735+
)
736+
if auto_detected_logger is not None:
737+
return auto_detected_logger
738+
739+
# Then check for any registered CustomPromptManagement loggers (fallback)
674740
prompt_management_loggers = (
675741
litellm.logging_callback_manager.get_custom_loggers_for_type(
676742
callback_type=CustomPromptManagement

tests/test_litellm/integrations/dotprompt/test_prompt_manager.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -521,3 +521,79 @@ def test_prompt_main():
521521
"""
522522
# TODO: Implement once PromptManager is integrated with litellm completion
523523
pass
524+
525+
526+
@pytest.mark.asyncio
527+
async def test_dotprompt_auto_detection_with_model_only():
528+
"""
529+
Test that dotprompt prompts can be auto-detected when passing model="gpt-4" and prompt_id,
530+
without needing to specify model="dotprompt/gpt-4".
531+
"""
532+
from litellm.integrations.dotprompt import DotpromptManager
533+
534+
prompt_dir = Path(__file__).parent
535+
dotprompt_manager = DotpromptManager(prompt_directory=str(prompt_dir))
536+
537+
# Register the dotprompt manager in callbacks
538+
original_callbacks = litellm.callbacks.copy()
539+
litellm.callbacks = [dotprompt_manager]
540+
541+
try:
542+
# Mock the HTTP handler to avoid actual API calls
543+
with patch("litellm.llms.custom_httpx.llm_http_handler.AsyncHTTPHandler.post") as mock_post:
544+
mock_response_data = litellm.ModelResponse(
545+
choices=[
546+
litellm.Choices(
547+
message=litellm.Message(content="Hello!"),
548+
index=0,
549+
finish_reason="stop",
550+
)
551+
]
552+
).model_dump()
553+
554+
# Create a proper mock response
555+
mock_response = MagicMock()
556+
mock_response.status_code = 200
557+
mock_response.text = json.dumps(mock_response_data)
558+
mock_response.headers = {"Content-Type": "application/json"}
559+
mock_response.json.return_value = mock_response_data
560+
561+
mock_post.return_value = mock_response
562+
563+
# Call with model="gpt-4" (no "dotprompt/" prefix) and prompt_id
564+
await litellm.acompletion(
565+
model="gpt-4",
566+
prompt_id="chat_prompt",
567+
prompt_variables={"user_message": "Hello world"},
568+
messages=[{"role": "user", "content": "This will be ignored"}],
569+
)
570+
571+
mock_post.assert_called_once()
572+
573+
# Get request body from the call (it's passed as 'data' parameter as JSON string)
574+
data_str = mock_post.call_args.kwargs.get("data", "{}")
575+
request_body = json.loads(data_str)
576+
577+
print(f"Request body: {json.dumps(request_body, indent=2)}")
578+
579+
# Verify the prompt was auto-detected and used
580+
# The chat_prompt.prompt has metadata: model: gpt-4, temperature: 0.7, max_tokens: 150
581+
assert request_body["model"] == "gpt-4"
582+
583+
# Note: OpenAI API might strip out temperature/max_tokens if they're not in the request
584+
# The key test is that the messages were transformed
585+
586+
# Verify the messages were transformed using the prompt template
587+
# chat_prompt template: "User: {{user_message}}"
588+
messages = request_body["messages"]
589+
assert len(messages) >= 1
590+
591+
# The first message should be from the prompt template with the variable substituted
592+
# Template is: "User: {{user_message}}" with user_message="Hello world"
593+
first_message_content = messages[0]["content"]
594+
print(f"First message content: {first_message_content}")
595+
assert "Hello world" in first_message_content
596+
597+
finally:
598+
# Restore original callbacks
599+
litellm.callbacks = original_callbacks

0 commit comments

Comments
 (0)