Skip to content

Commit 3ebe489

Browse files
authored
[Feat] Prompt Management - Add support for versioning prompts (#16836)
* test_dotprompt_auto_detection_with_model_only * fix _auto_detect_prompt_management_logger * test_dotprompt_with_prompt_version * add v1, v2 tests * add _compile_prompt_helper * fix _compile_prompt_helper * test_dotprompt_with_prompt_version * test_dotprompt_with_prompt_version, test_get_prompt_with_version
1 parent 1f8fe00 commit 3ebe489

File tree

5 files changed

+209
-12
lines changed

5 files changed

+209
-12
lines changed

litellm/integrations/dotprompt/dotprompt_manager.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -108,21 +108,30 @@ def _compile_prompt_helper(
108108
Compile a .prompt file into a PromptManagementClient structure.
109109
110110
This method:
111-
1. Loads the prompt template from the .prompt file
111+
1. Loads the prompt template from the .prompt file (with optional version)
112112
2. Renders it with the provided variables
113113
3. Converts the rendered text into chat messages
114114
4. Extracts model and optional parameters from metadata
115115
"""
116116

117117
try:
118118

119-
# Get the prompt template
120-
template = self.prompt_manager.get_prompt(prompt_id)
119+
# Get the prompt template (versioned or base)
120+
template = self.prompt_manager.get_prompt(
121+
prompt_id=prompt_id, version=prompt_version
122+
)
121123
if template is None:
122-
raise ValueError(f"Prompt '{prompt_id}' not found in prompt directory")
124+
version_str = f" (version {prompt_version})" if prompt_version else ""
125+
raise ValueError(
126+
f"Prompt '{prompt_id}'{version_str} not found in prompt directory"
127+
)
123128

124-
# Render the template with variables
125-
rendered_content = self.prompt_manager.render(prompt_id, prompt_variables)
129+
# Render the template with variables (pass version for proper lookup)
130+
rendered_content = self.prompt_manager.render(
131+
prompt_id=prompt_id,
132+
prompt_variables=prompt_variables,
133+
version=prompt_version,
134+
)
126135

127136
# Convert rendered content to chat messages
128137
messages = self._convert_to_messages(rendered_content)

litellm/integrations/dotprompt/prompt_manager.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -183,14 +183,18 @@ def _parse_frontmatter(self, content: str) -> Tuple[Dict[str, Any], str]:
183183
return frontmatter, template_content
184184

185185
def render(
186-
self, prompt_id: str, prompt_variables: Optional[Dict[str, Any]] = None
186+
self,
187+
prompt_id: str,
188+
prompt_variables: Optional[Dict[str, Any]] = None,
189+
version: Optional[int] = None,
187190
) -> str:
188191
"""
189192
Render a prompt template with the given variables.
190193
191194
Args:
192195
prompt_id: The ID of the prompt template to render
193196
prompt_variables: Variables to substitute in the template
197+
version: Optional version number. If provided, looks for {prompt_id}.v{version}
194198
195199
Returns:
196200
The rendered prompt string
@@ -199,13 +203,16 @@ def render(
199203
KeyError: If prompt_id is not found
200204
ValueError: If template rendering fails
201205
"""
202-
if prompt_id not in self.prompts:
206+
# Get the template (versioned or base)
207+
template = self.get_prompt(prompt_id=prompt_id, version=version)
208+
209+
if template is None:
203210
available_prompts = list(self.prompts.keys())
211+
version_str = f" (version {version})" if version else ""
204212
raise KeyError(
205-
f"Prompt '{prompt_id}' not found. Available prompts: {available_prompts}"
213+
f"Prompt '{prompt_id}'{version_str} not found. Available prompts: {available_prompts}"
206214
)
207215

208-
template = self.prompts[prompt_id]
209216
variables = prompt_variables or {}
210217

211218
# Validate input variables against schema if defined
@@ -254,8 +261,26 @@ def _get_python_type(self, schema_type: str) -> Union[type, tuple]:
254261

255262
return type_mapping.get(schema_type.lower(), str) # type: ignore
256263

257-
def get_prompt(self, prompt_id: str) -> Optional[PromptTemplate]:
258-
"""Get a prompt template by ID."""
264+
def get_prompt(
265+
self, prompt_id: str, version: Optional[int] = None
266+
) -> Optional[PromptTemplate]:
267+
"""
268+
Get a prompt template by ID and optional version.
269+
270+
Args:
271+
prompt_id: The base prompt ID
272+
version: Optional version number. If provided, looks for {prompt_id}.v{version}
273+
274+
Returns:
275+
The prompt template if found, None otherwise
276+
"""
277+
if version is not None:
278+
# Try versioned prompt first: prompt_id.v{version}
279+
versioned_id = f"{prompt_id}.v{version}"
280+
if versioned_id in self.prompts:
281+
return self.prompts[versioned_id]
282+
283+
# Fall back to base prompt_id
259284
return self.prompts.get(prompt_id)
260285

261286
def list_prompts(self) -> List[str]:
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
---
2+
model: gpt-3.5-turbo
3+
temperature: 0.5
4+
max_tokens: 100
5+
input:
6+
schema:
7+
user_message: string
8+
---
9+
10+
Version 1: {{user_message}}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
---
2+
model: gpt-4
3+
temperature: 0.9
4+
max_tokens: 200
5+
input:
6+
schema:
7+
user_message: string
8+
---
9+
10+
Version 2: {{user_message}}

tests/test_litellm/integrations/dotprompt/test_prompt_manager.py

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,29 @@ def test_get_prompt_metadata():
195195
assert "output" in metadata
196196

197197

198+
def test_get_prompt_with_version():
199+
"""Test that get_prompt correctly retrieves versioned prompts."""
200+
prompt_dir = Path(__file__).parent
201+
manager = PromptManager(prompt_directory=str(prompt_dir))
202+
203+
# Get base prompt (no version)
204+
base_prompt = manager.get_prompt(prompt_id="chat_prompt")
205+
assert base_prompt is not None
206+
assert "User: {{user_message}}" in base_prompt.content
207+
208+
# Get version 1
209+
v1_prompt = manager.get_prompt(prompt_id="chat_prompt", version=1)
210+
assert v1_prompt is not None
211+
assert "Version 1:" in v1_prompt.content
212+
assert v1_prompt.model == "gpt-3.5-turbo"
213+
214+
# Get version 2
215+
v2_prompt = manager.get_prompt(prompt_id="chat_prompt", version=2)
216+
assert v2_prompt is not None
217+
assert "Version 2:" in v2_prompt.content
218+
assert v2_prompt.model == "gpt-4"
219+
220+
198221
def test_add_prompt_programmatically():
199222
"""Test adding prompts programmatically."""
200223
prompt_dir = Path(
@@ -597,3 +620,123 @@ async def test_dotprompt_auto_detection_with_model_only():
597620
finally:
598621
# Restore original callbacks
599622
litellm.callbacks = original_callbacks
623+
624+
625+
@pytest.mark.asyncio
626+
async def test_dotprompt_with_prompt_version():
627+
"""
628+
Test that dotprompt can load and use specific prompt versions.
629+
Versions are stored as separate files with .v{version}.prompt naming convention.
630+
"""
631+
from litellm.integrations.dotprompt import DotpromptManager
632+
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
633+
634+
prompt_dir = Path(__file__).parent
635+
dotprompt_manager = DotpromptManager(prompt_directory=str(prompt_dir))
636+
637+
# Register the dotprompt manager in callbacks
638+
original_callbacks = litellm.callbacks.copy()
639+
litellm.callbacks = [dotprompt_manager]
640+
641+
try:
642+
# Mock the HTTP handler to avoid actual API calls
643+
with patch("litellm.llms.custom_httpx.llm_http_handler.AsyncHTTPHandler.post") as mock_post:
644+
mock_response_data = litellm.ModelResponse(
645+
choices=[
646+
litellm.Choices(
647+
message=litellm.Message(content="Hello!"),
648+
index=0,
649+
finish_reason="stop",
650+
)
651+
]
652+
).model_dump()
653+
654+
# Create a proper mock response
655+
mock_response = MagicMock()
656+
mock_response.status_code = 200
657+
mock_response.text = json.dumps(mock_response_data)
658+
mock_response.headers = {"Content-Type": "application/json"}
659+
mock_response.json.return_value = mock_response_data
660+
661+
mock_post.return_value = mock_response
662+
663+
# Test version 1
664+
await litellm.acompletion(
665+
model="gpt-3.5-turbo",
666+
prompt_id="chat_prompt",
667+
prompt_version=1,
668+
prompt_variables={"user_message": "Test v1"},
669+
messages=[],
670+
)
671+
672+
assert mock_post.call_count >= 1
673+
data_str = mock_post.call_args.kwargs.get("data", "{}")
674+
request_body = json.loads(data_str)
675+
676+
print(f"Version 1 request body: {json.dumps(request_body, indent=2)}")
677+
678+
# Verify version 1 prompt was used
679+
# chat_prompt.v1.prompt has: model: gpt-3.5-turbo, temperature: 0.5, max_tokens: 100
680+
assert request_body["model"] == "gpt-3.5-turbo"
681+
682+
# Verify the message contains "Version 1:" prefix from v1 template
683+
messages = request_body["messages"]
684+
assert len(messages) >= 1
685+
first_message_content = messages[0]["content"]
686+
print(f"Version 1 message: {first_message_content}")
687+
assert "Version 1:" in first_message_content
688+
assert "Test v1" in first_message_content
689+
690+
# Reset mock for version 2 test
691+
mock_post.reset_mock()
692+
693+
# Test version 2
694+
with patch("litellm.llms.custom_httpx.llm_http_handler.AsyncHTTPHandler.post") as mock_post:
695+
mock_response_data = litellm.ModelResponse(
696+
choices=[
697+
litellm.Choices(
698+
message=litellm.Message(content="Hello!"),
699+
index=0,
700+
finish_reason="stop",
701+
)
702+
]
703+
).model_dump()
704+
705+
# Create a proper mock response
706+
mock_response = MagicMock()
707+
mock_response.status_code = 200
708+
mock_response.text = json.dumps(mock_response_data)
709+
mock_response.headers = {"Content-Type": "application/json"}
710+
mock_response.json.return_value = mock_response_data
711+
712+
mock_post.return_value = mock_response
713+
714+
await litellm.acompletion(
715+
model="gpt-4",
716+
prompt_id="chat_prompt",
717+
prompt_version=2,
718+
prompt_variables={"user_message": "Test v2"},
719+
messages=[],
720+
)
721+
722+
mock_post.assert_called_once()
723+
data_str = mock_post.call_args.kwargs.get("data", "{}")
724+
request_body = json.loads(data_str)
725+
726+
print(f"Version 2 request body: {json.dumps(request_body, indent=2)}")
727+
728+
# Verify version 2 prompt was used
729+
# chat_prompt.v2.prompt has: model: gpt-4, temperature: 0.9, max_tokens: 200
730+
assert request_body["model"] == "gpt-4"
731+
732+
# Verify the message contains "Version 2:" prefix from v2 template
733+
messages = request_body["messages"]
734+
assert len(messages) >= 1
735+
first_message_content = messages[0]["content"]
736+
print(f"Version 2 message: {first_message_content}")
737+
assert "Version 2:" in first_message_content
738+
assert "Test v2" in first_message_content
739+
740+
finally:
741+
# Restore original callbacks
742+
litellm.callbacks = original_callbacks

0 commit comments

Comments
 (0)