Skip to content

Commit a1c2cec

Browse files
author
deepanshu
committed
add support for subfolders in git
1 parent f5305a7 commit a1c2cec

File tree

4 files changed

+278
-21
lines changed

4 files changed

+278
-21
lines changed

litellm/integrations/gitlab/gitlab_prompt_manager.py

Lines changed: 47 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,25 @@
1212
)
1313
from litellm.types.llms.openai import AllMessageValues
1414
from litellm.types.utils import StandardCallbackDynamicParams
15-
15+
from litellm._logging import verbose_proxy_logger
1616
from litellm.integrations.gitlab.gitlab_client import GitLabClient
1717

1818

19+
GITLAB_PREFIX = "gitlab::"
20+
21+
def encode_prompt_id(raw_id: str) -> str:
22+
"""Convert GitLab path IDs like 'invoice/extract' → 'gitlab::invoice::extract'"""
23+
if raw_id.startswith(GITLAB_PREFIX):
24+
return raw_id # already encoded
25+
return f"{GITLAB_PREFIX}{raw_id.replace('/', '::')}"
26+
27+
def decode_prompt_id(encoded_id: str) -> str:
28+
"""Convert 'gitlab::invoice::extract' → 'invoice/extract'"""
29+
if not encoded_id.startswith(GITLAB_PREFIX):
30+
return encoded_id
31+
return encoded_id[len(GITLAB_PREFIX):].replace("::", "/")
32+
33+
1934
class GitLabPromptTemplate:
2035
def __init__(
2136
self,
@@ -87,6 +102,7 @@ def __init__(
87102

88103
def _id_to_repo_path(self, prompt_id: str) -> str:
89104
"""Map a prompt_id to a repo path (respects prompts_path and adds .prompt)."""
105+
prompt_id = decode_prompt_id(prompt_id)
90106
if self.prompts_path:
91107
return f"{self.prompts_path}/{prompt_id}.prompt"
92108
return f"{prompt_id}.prompt"
@@ -101,26 +117,27 @@ def _repo_path_to_id(self, repo_path: str) -> str:
101117
path = path[len(self.prompts_path.strip("/")) + 1 :]
102118
if path.endswith(".prompt"):
103119
path = path[: -len(".prompt")]
104-
return path
120+
return encode_prompt_id(path)
105121

106122
# ---------- loading ----------
107123

108124
def _load_prompt_from_gitlab(self, prompt_id: str, *, ref: Optional[str] = None) -> None:
109125
"""Load a specific .prompt file from GitLab (scoped under prompts_path if set)."""
110126
try:
127+
# prompt_id = decode_prompt_id(prompt_id)
111128
file_path = self._id_to_repo_path(prompt_id)
112129
prompt_content = self.gitlab_client.get_file_content(file_path, ref=ref)
113130
if prompt_content:
114131
template = self._parse_prompt_file(prompt_content, prompt_id)
115132
self.prompts[prompt_id] = template
116133
except Exception as e:
117-
raise Exception(f"Failed to load prompt '{prompt_id}' from GitLab: {e}")
134+
raise Exception(f"Failed to load prompt '{encode_prompt_id(prompt_id)}' from GitLab: {e}")
118135

119136
def load_all_prompts(self, *, recursive: bool = True) -> List[str]:
120137
"""
121138
Eagerly load all .prompt files from prompts_path. Returns loaded IDs.
122139
"""
123-
files = self.list_templates(recursive=recursive) # reuse logic
140+
files = self.list_templates(recursive=recursive)
124141
loaded: List[str] = []
125142
for pid in files:
126143
if pid not in self.prompts:
@@ -195,9 +212,6 @@ def get_template(self, template_id: str) -> Optional[GitLabPromptTemplate]:
195212
return self.prompts.get(template_id)
196213

197214
def list_templates(self, *, recursive: bool = True) -> List[str]:
198-
"""
199-
List available prompt IDs discovered under prompts_path (no extension, relative to prompts_path).
200-
"""
201215
"""
202216
List available prompt IDs under prompts_path (no extension).
203217
Compatible with both list_files signatures:
@@ -438,13 +452,20 @@ def _compile_prompt_helper(
438452
prompt_version: Optional[int] = None,
439453
) -> PromptManagementClient:
440454
try:
441-
if prompt_id not in self.prompt_manager.prompts:
455+
verbose_proxy_logger.debug(f"GitLabPromptManager._compile_prompt_helper called with "
456+
f"prompt_id={prompt_id}, prompt_variables={prompt_variables}, ")
457+
decoded_id = decode_prompt_id(prompt_id)
458+
verbose_proxy_logger.debug(f"Decoded prompt_id: {decoded_id}")
459+
if decoded_id not in self.prompt_manager.prompts:
442460
git_ref = getattr(dynamic_callback_params, "extra", {}).get("git_ref") if hasattr(dynamic_callback_params, "extra") else None
443-
self.prompt_manager._load_prompt_from_gitlab(prompt_id, ref=git_ref)
461+
self.prompt_manager._load_prompt_from_gitlab(decoded_id, ref=git_ref)
462+
444463

445464
rendered_prompt, prompt_metadata = self.get_prompt_template(
446465
prompt_id, prompt_variables
447466
)
467+
verbose_proxy_logger.debug(f"Rendered prompt: {rendered_prompt}")
468+
verbose_proxy_logger.debug(f"Prompt metadata: {prompt_metadata}")
448469

449470
messages = self._parse_prompt_to_messages(rendered_prompt)
450471
template_model = prompt_metadata.get("model")
@@ -475,6 +496,11 @@ def get_chat_completion_prompt(
475496
prompt_label: Optional[str] = None,
476497
prompt_version: Optional[int] = None,
477498
) -> Tuple[str, List[AllMessageValues], dict]:
499+
verbose_proxy_logger.debug(f"GitLabPromptManager.get_chat_completion_prompt "
500+
f"called with prompt_id={prompt_id},"
501+
f" prompt_variables={prompt_variables}, "
502+
f"dynamic_callback_params={dynamic_callback_params},"
503+
f" prompt_label={prompt_label}, prompt_version={prompt_version}")
478504
return PromptManagementBase.get_chat_completion_prompt(
479505
self,
480506
model,
@@ -568,7 +594,10 @@ def load_all(self, *, recursive: bool = True) -> Dict[str, Dict[str, Any]]:
568594
entry = self._template_to_json(pid, tmpl)
569595

570596
self._by_file[file_path] = entry
571-
self._by_id[pid] = entry
597+
# prefixed_id = pid if pid.startswith("gitlab::") else f"gitlab::{pid}"
598+
encoded_id = encode_prompt_id(pid)
599+
self._by_id[encoded_id] = entry
600+
# self._by_id[pid] = entry
572601

573602
return self._by_id
574603

@@ -592,7 +621,14 @@ def get_by_file(self, file_path: str) -> Optional[Dict[str, Any]]:
592621

593622
def get_by_id(self, prompt_id: str) -> Optional[Dict[str, Any]]:
594623
"""Get a cached prompt JSON by prompt ID (relative to prompts_path)."""
595-
return self._by_id.get(prompt_id)
624+
if prompt_id in self._by_id:
625+
return self._by_id[prompt_id]
626+
627+
# Try normalized forms
628+
decoded = decode_prompt_id(prompt_id)
629+
encoded = encode_prompt_id(decoded)
630+
631+
return self._by_id.get(encoded) or self._by_id.get(decoded)
596632

597633
# -------------------------
598634
# Internals

litellm/proxy/prompts/prompt_endpoints.py

Lines changed: 184 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
from pathlib import Path
77
from typing import Any, Dict, List, Optional, cast
88

9-
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile
10-
from pydantic import BaseModel
9+
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, Body
10+
from pydantic import BaseModel, Field
1111

1212
from litellm._logging import verbose_proxy_logger
1313
from litellm.proxy._types import CommonProxyErrors, LitellmUserRoles, UserAPIKeyAuth
@@ -679,3 +679,185 @@ async def convert_prompt_file_to_json(
679679
except OSError:
680680
pass # Directory not empty or other error
681681

682+
class PromptCompletionRequest(BaseModel):
683+
prompt_id: str = Field(..., description="Unique ID of the prompt registered in PromptHub.")
684+
prompt_version: Optional[str] = Field(None, description="Optional version identifier.")
685+
prompt_variables: Dict[str, Any] = Field(default_factory=dict, description="Key-value mapping for template variables.")
686+
687+
688+
class PromptCompletionResponse(BaseModel):
689+
prompt_id: str
690+
prompt_version: Optional[str]
691+
model: str
692+
metadata: Dict[str, Any]
693+
variables: Dict[str, Any]
694+
completion_text: str
695+
raw_response: Dict[str, Any]
696+
697+
698+
@router.post(
699+
"/prompts/completions",
700+
tags=["Prompt Completions"],
701+
dependencies=[Depends(user_api_key_auth)],
702+
)
703+
async def generate_completion_from_prompt_id(
704+
request: PromptCompletionRequest = Body(...),
705+
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
706+
):
707+
"""
708+
Generate a model completion using a managed prompt.
709+
710+
Parameter merge priority:
711+
1. Prompt metadata/config (base defaults)
712+
2. Prompt-level litellm_params overrides
713+
3. User-supplied request.extra_params (highest precedence)
714+
"""
715+
716+
import litellm
717+
from litellm.proxy.prompts.prompt_registry import PROMPT_HUB
718+
from litellm.integrations.custom_prompt_management import CustomPromptManagement
719+
from litellm.integrations.gitlab import GitLabPromptManager
720+
from litellm.integrations.dotprompt import DotpromptManager
721+
from litellm.proxy._types import LitellmUserRoles
722+
723+
prompt_id = request.prompt_id
724+
variables = request.prompt_variables or {}
725+
726+
# ------------------------------------------------------------
727+
# Step 1: Access validation
728+
# ------------------------------------------------------------
729+
prompts: Optional[List[str]] = None
730+
if user_api_key_dict.metadata is not None:
731+
prompts = cast(Optional[List[str]], user_api_key_dict.metadata.get("prompts", None))
732+
if prompts is not None and prompt_id not in prompts:
733+
raise HTTPException(status_code=400, detail=f"Prompt {prompt_id} not found")
734+
735+
if user_api_key_dict.user_role not in (
736+
LitellmUserRoles.PROXY_ADMIN,
737+
LitellmUserRoles.PROXY_ADMIN.value,
738+
):
739+
raise HTTPException(
740+
status_code=403,
741+
detail=f"You are not authorized to access this prompt. Your role - {user_api_key_dict.user_role}, Your key's prompts - {prompts}",
742+
)
743+
744+
# ------------------------------------------------------------
745+
# Step 2: Load prompt and callback
746+
# ------------------------------------------------------------
747+
prompt_spec = PROMPT_HUB.get_prompt_by_id(prompt_id)
748+
if prompt_spec is None:
749+
raise HTTPException(status_code=404, detail=f"Prompt {prompt_id} not found")
750+
751+
prompt_callback: Optional[CustomPromptManagement] = PROMPT_HUB.get_prompt_callback_by_id(prompt_id)
752+
if prompt_callback is None:
753+
raise HTTPException(status_code=404, detail=f"No callback found for prompt {prompt_id}")
754+
755+
prompt_template: Optional[PromptTemplateBase] = None
756+
757+
if isinstance(prompt_callback, DotpromptManager):
758+
template = prompt_callback.prompt_manager.get_all_prompts_as_json()
759+
if template and len(template) == 1:
760+
tid = list(template.keys())[0]
761+
prompt_template = PromptTemplateBase(
762+
litellm_prompt_id=tid,
763+
content=template[tid]["content"],
764+
metadata=template[tid]["metadata"],
765+
)
766+
767+
elif isinstance(prompt_callback, GitLabPromptManager):
768+
prompt_json = prompt_spec.model_dump()
769+
prompt_template = PromptTemplateBase(
770+
litellm_prompt_id=prompt_json.get("prompt_id", ""),
771+
content=prompt_json.get("litellm_params", {}).get("model_config", {}).get("content", ""),
772+
metadata=prompt_json.get("litellm_params", {}).get("model_config", {}).get("metadata", {}),
773+
)
774+
775+
if not prompt_template:
776+
raise HTTPException(status_code=400, detail=f"Could not load prompt template for {prompt_id}")
777+
778+
# ------------------------------------------------------------
779+
# Step 3: Fill in template variables
780+
# ------------------------------------------------------------
781+
try:
782+
filled_prompt = prompt_template.content.format(**variables)
783+
except KeyError as e:
784+
raise HTTPException(status_code=400, detail=f"Missing variable: {str(e)}")
785+
786+
metadata = prompt_template.metadata or {}
787+
model = metadata.get("model")
788+
if not model:
789+
raise HTTPException(status_code=400, detail=f"Model not specified in metadata for {prompt_id}")
790+
791+
# ------------------------------------------------------------
792+
# Step 4: Build messages using prompt callback
793+
# ------------------------------------------------------------
794+
system_prompt = metadata.get("config", {}).get("system_prompt", "You are a helpful assistant.")
795+
796+
completion_prompt = prompt_callback.get_chat_completion_prompt(
797+
model=model,
798+
messages=[{"role": "system", "content": system_prompt}],
799+
non_default_params=metadata,
800+
prompt_id=prompt_id,
801+
prompt_variables=variables,
802+
dynamic_callback_params={},
803+
prompt_label=None,
804+
prompt_version=request.prompt_version,
805+
)
806+
807+
# ------------------------------------------------------------
808+
# Step 5: Merge parameters from multiple sources
809+
# ------------------------------------------------------------
810+
base_params = metadata.get("config", {}) or {}
811+
prompt_params = (
812+
prompt_spec.litellm_params.get("config", {})
813+
if hasattr(prompt_spec, "litellm_params") and isinstance(prompt_spec.litellm_params, dict)
814+
else {}
815+
)
816+
user_overrides = getattr(request, "extra_body", {}) or {}
817+
818+
# Flatten nested "config" keys that sometimes leak through metadata
819+
def flatten_config(d: dict) -> dict:
820+
if "config" in d and isinstance(d["config"], dict):
821+
flattened = {**d, **d["config"]}
822+
flattened.pop("config", None)
823+
return flattened
824+
return d
825+
826+
base_params = flatten_config(base_params)
827+
prompt_params = flatten_config(prompt_params)
828+
user_overrides = flatten_config(user_overrides)
829+
830+
# Merge priority: base < prompt-level < user overrides
831+
merged_params = {**base_params, **prompt_params, **user_overrides}
832+
merged_params.setdefault("stream", False)
833+
merged_params["user"] = user_api_key_dict.user_id
834+
merged_params.pop("model", None)
835+
merged_params.pop("messages", None)
836+
# ------------------------------------------------------------
837+
# Step 6: Invoke model
838+
# ------------------------------------------------------------
839+
try:
840+
response = await litellm.acompletion(
841+
model=completion_prompt[0],
842+
messages=completion_prompt[1],
843+
**merged_params,
844+
)
845+
except Exception as e:
846+
raise HTTPException(status_code=500, detail=f"Error invoking model: {str(e)}")
847+
848+
# ------------------------------------------------------------
849+
# Step 7: Extract text & return structured response
850+
# ------------------------------------------------------------
851+
completion_text = (
852+
response.get("choices", [{}])[0].get("message", {}).get("content", "")
853+
)
854+
855+
return PromptCompletionResponse(
856+
prompt_id=prompt_id,
857+
prompt_version=request.prompt_version,
858+
model=model,
859+
metadata=metadata,
860+
variables=variables,
861+
completion_text=completion_text,
862+
raw_response=response.model_dump() if hasattr(response, "model_dump") else response,
863+
)

0 commit comments

Comments
 (0)