diff --git a/dspy/adapters/base.py b/dspy/adapters/base.py index 9ad6d12565..1237685b90 100644 --- a/dspy/adapters/base.py +++ b/dspy/adapters/base.py @@ -1,10 +1,13 @@ +import logging from typing import TYPE_CHECKING, Any, Optional, Type -from dspy.adapters.types import History +from dspy.adapters.types import BaseType, History from dspy.adapters.types.base_type import split_message_content_for_custom_types from dspy.signatures.signature import Signature from dspy.utils.callback import BaseCallback, with_callbacks +logger = logging.getLogger(__name__) + if TYPE_CHECKING: from dspy.clients.lm import LM @@ -22,20 +25,41 @@ def __init_subclass__(cls, **kwargs) -> None: def _call_post_process(self, outputs: list[dict[str, Any]], signature: Type[Signature]) -> list[dict[str, Any]]: values = [] + output_field_name, field_info = next(iter(signature.output_fields.items())) + annotation = field_info.annotation for output in outputs: output_logprobs = None - if isinstance(output, dict): - output, output_logprobs = output["text"], output["logprobs"] - - value = self.parse(signature, output) + if "text" in output: + output, output_logprobs = output["text"], output.get("logprobs") + elif "content" in output: + output_content = output["content"] + if isinstance(output_content, list) and output_content and "text" in output_content[0]: + output = output_content[0]["text"] + else: + output = "" + output_logprobs = None + else: + output = str(output) + output_logprobs = None + try: + if issubclass(annotation, BaseType): + try: + parsed = annotation.parse(output) + value = {output_field_name: parsed} + except Exception as e: + logger.warning(f"Output is not of expected annotation field '{output_field_name}': {e}") + continue + else: + value = self.parse(signature, output) + except TypeError: + value = self.parse(signature, output) if output_logprobs is not None: value["logprobs"] = output_logprobs - values.append(value) - + return values def __call__( @@ -48,6 +72,13 @@ def __call__( ) -> list[dict[str, Any]]: inputs = self.format(signature, demos, inputs) + if getattr(lm, "model_type", None) == "responses": + for msg in inputs: + if msg["role"] == "user" and isinstance(msg["content"], list): + for block in msg["content"]: + if block.get("type") == "text": + block["type"] = "input_text" + outputs = lm(messages=inputs, **lm_kwargs) return self._call_post_process(outputs, signature) diff --git a/dspy/adapters/types/base_type.py b/dspy/adapters/types/base_type.py index f2983ef463..9443371e25 100644 --- a/dspy/adapters/types/base_type.py +++ b/dspy/adapters/types/base_type.py @@ -28,6 +28,10 @@ def format(self) -> list[dict[str, Any]]: def format(self) -> list[dict[str, Any]]: raise NotImplementedError + + @classmethod + def parse(cls, raw: Any) -> "BaseType": + return cls(**raw) @pydantic.model_serializer() def serialize_model(self): diff --git a/dspy/adapters/types/image.py b/dspy/adapters/types/image.py index 693b407fb6..bc3cf12c46 100644 --- a/dspy/adapters/types/image.py +++ b/dspy/adapters/types/image.py @@ -33,8 +33,29 @@ def format(self) -> Union[list[dict[str, Any]], str]: image_url = encode_image(self.url) except Exception as e: raise ValueError(f"Failed to format image for DSPy: {e}") - return [{"type": "image_url", "image_url": {"url": image_url}}] + if isinstance(image_url, str): + return [{"type": "input_image", "image_url": image_url}] + else: + return [{"type": "input_image", "image_url": {"url": image_url}}] + @classmethod + def parse(cls, raw: Any) -> "Image": + if isinstance(raw, dict): + if "result" in raw: + b64 = raw["result"] + else: + raise TypeError("Input for parsing is missing 'result'") + elif hasattr(raw, "result"): + b64 = getattr(raw, "result") + elif hasattr(raw, "data"): + b64 = raw.data[0].b64_json + elif isinstance(raw, str): + b64 = raw + else: + raise TypeError(f"Unsupported type {type(raw)} for Image.parse") + uri = b64 if b64.startswith("data:") else f"data:image/png;base64,{b64}" + return cls(url=uri) + @pydantic.model_validator(mode="before") @classmethod def validate_input(cls, values): diff --git a/dspy/clients/base_lm.py b/dspy/clients/base_lm.py index 0c3ae2ee76..e42b9ade84 100644 --- a/dspy/clients/base_lm.py +++ b/dspy/clients/base_lm.py @@ -50,7 +50,20 @@ def __init__(self, model, model_type="chat", temperature=0.0, max_tokens=1000, c def _process_lm_response(self, response, prompt, messages, **kwargs): merged_kwargs = {**self.kwargs, **kwargs} - if merged_kwargs.get("logprobs"): + if hasattr(response, "output") or (isinstance(response, dict) and "output" in response): + output_items = response.output if hasattr(response, "output") else response["output"] + output_items = list(output_items) + if any(not (hasattr(item, "content") or (isinstance(item, dict) and "content" in item)) for item in output_items): + outputs = output_items + else: + outputs = [] + for item in output_items: + content_list = item.get("content", []) if isinstance(item, dict) else getattr(item, "content", []) + for c in content_list: + if (isinstance(c, dict) and c.get("type") == "output_text") or (hasattr(c, "type") and getattr(c, "type", None) == "output_text"): + text_val = c.get("text") if isinstance(c, dict) else getattr(c, "text", "") + outputs.append(text_val) + elif merged_kwargs.get("logprobs"): outputs = [ { "text": c.message.content if hasattr(c, "message") else c["text"], @@ -133,7 +146,7 @@ def copy(self, **kwargs): return new_instance def inspect_history(self, n: int = 1): - return inspect_history(self.history, n) + return pretty_print_history(self.history, n) def update_global_history(self, entry): if settings.disable_history: diff --git a/dspy/clients/lm.py b/dspy/clients/lm.py index 384960a114..ed985b62cd 100644 --- a/dspy/clients/lm.py +++ b/dspy/clients/lm.py @@ -29,7 +29,7 @@ class LM(BaseLM): def __init__( self, model: str, - model_type: Literal["chat", "text"] = "chat", + model_type: Literal["chat", "text", "responses"] = "chat", temperature: float = 0.0, max_tokens: int = 4000, cache: bool = True, @@ -120,8 +120,15 @@ def forward(self, prompt=None, messages=None, **kwargs): messages = messages or [{"role": "user", "content": prompt}] kwargs = {**self.kwargs, **kwargs} - - completion = litellm_completion if self.model_type == "chat" else litellm_text_completion + + + if self.model_type == "chat": + completion = litellm_completion + elif self.model_type == "text": + completion = litellm_text_completion + elif self.model_type == "responses": + completion = litellm_responses_completion + completion, litellm_cache_args = self._get_cached_completion_fn(completion, cache, enable_memory_cache) results = completion( @@ -129,15 +136,25 @@ def forward(self, prompt=None, messages=None, **kwargs): num_retries=self.num_retries, cache=litellm_cache_args, ) + if self.model_type != "responses": + if any(c.finish_reason == "length" for c in results["choices"]): + logger.warning( + f"LM response was truncated due to exceeding max_tokens={self.kwargs['max_tokens']}. " + "You can inspect the latest LM interactions with `dspy.inspect_history()`. " + "To avoid truncation, consider passing a larger max_tokens when setting up dspy.LM. " + f"You may also consider increasing the temperature (currently {self.kwargs['temperature']}) " + " if the reason for truncation is repetition." + ) + else: + if results.get("truncation") == "enabled": + logger.warning( + f"LM response was truncated due to exceeding max_tokens={self.kwargs['max_tokens']}. " + "You can inspect the latest LM interactions with `dspy.inspect_history()`. " + "To avoid truncation, consider passing a larger max_tokens when setting up dspy.LM. " + f"You may also consider increasing the temperature (currently {self.kwargs['temperature']}) " + " if the reason for truncation is repetition." + ) - if any(c.finish_reason == "length" for c in results["choices"]): - logger.warning( - f"LM response was truncated due to exceeding max_tokens={self.kwargs['max_tokens']}. " - "You can inspect the latest LM interactions with `dspy.inspect_history()`. " - "To avoid truncation, consider passing a larger max_tokens when setting up dspy.LM. " - f"You may also consider increasing the temperature (currently {self.kwargs['temperature']}) " - " if the reason for truncation is repetition." - ) if not getattr(results, "cache_hit", False) and dspy.settings.usage_tracker and hasattr(results, "usage"): settings.usage_tracker.add_usage(self.model, dict(results.usage)) @@ -151,7 +168,12 @@ async def aforward(self, prompt=None, messages=None, **kwargs): messages = messages or [{"role": "user", "content": prompt}] kwargs = {**self.kwargs, **kwargs} - completion = alitellm_completion if self.model_type == "chat" else alitellm_text_completion + if self.model_type == "chat": + completion = alitellm_completion + elif self.model_type == "text": + completion = alitellm_text_completion + elif self.model_type == "responses": + completion = alitellm_responses_completion completion, litellm_cache_args = self._get_cached_completion_fn(completion, cache, enable_memory_cache) results = await completion( @@ -160,14 +182,24 @@ async def aforward(self, prompt=None, messages=None, **kwargs): cache=litellm_cache_args, ) - if any(c.finish_reason == "length" for c in results["choices"]): - logger.warning( - f"LM response was truncated due to exceeding max_tokens={self.kwargs['max_tokens']}. " - "You can inspect the latest LM interactions with `dspy.inspect_history()`. " - "To avoid truncation, consider passing a larger max_tokens when setting up dspy.LM. " - f"You may also consider increasing the temperature (currently {self.kwargs['temperature']}) " - " if the reason for truncation is repetition." - ) + if self.model_type != "responses": + if any(c.finish_reason == "length" for c in results["choices"]): + logger.warning( + f"LM response was truncated due to exceeding max_tokens={self.kwargs['max_tokens']}. " + "You can inspect the latest LM interactions with `dspy.inspect_history()`. " + "To avoid truncation, consider passing a larger max_tokens when setting up dspy.LM. " + f"You may also consider increasing the temperature (currently {self.kwargs['temperature']}) " + " if the reason for truncation is repetition." + ) + else: + if results.get("truncation") == "enabled": + logger.warning( + f"LM response was truncated due to exceeding max_tokens={self.kwargs['max_tokens']}. " + "You can inspect the latest LM interactions with `dspy.inspect_history()`. " + "To avoid truncation, consider passing a larger max_tokens when setting up dspy.LM. " + f"You may also consider increasing the temperature (currently {self.kwargs['temperature']}) " + " if the reason for truncation is repetition." + ) if not getattr(results, "cache_hit", False) and dspy.settings.usage_tracker and hasattr(results, "usage"): settings.usage_tracker.add_usage(self.model, dict(results.usage)) @@ -373,3 +405,41 @@ async def alitellm_text_completion(request: Dict[str, Any], num_retries: int, ca retry_strategy="exponential_backoff_retry", **request, ) + +def litellm_responses_completion(request: Dict[str, Any], num_retries: int, cache: Optional[Dict[str, Any]] = None): + cache = cache or {"no-cache": True, "no-store": True} + if "messages" in request: + content_blocks = [] + for msg in request.pop("messages"): + c = msg.get("content") + if isinstance(c, str): + content_blocks.append({"type": "input_text", "text": c}) + elif isinstance(c, list): + content_blocks.extend(c) + request["input"] = [{"role": msg.get("role", "user"), "content": content_blocks}] + + return litellm.responses( + cache=cache, + num_retries=num_retries, + retry_strategy="exponential_backoff_retry", + **request, + ) + + +async def alitellm_responses_completion(request: Dict[str, Any], num_retries: int, cache: Optional[Dict[str, Any]] = None): + cache = cache or {"no-cache": True, "no-store": True} + if "messages" in request: + content_blocks = [] + for msg in request.pop("messages"): + c = msg.get("content") + if isinstance(c, str): + content_blocks.append({"type": "input_text", "text": c}) + elif isinstance(c, list): + content_blocks.extend(c) + request["input"] = [{"role": msg.get("role", "user"), "content": content_blocks}] + return await litellm.aresponses( + cache=cache, + num_retries=num_retries, + retry_strategy="exponential_backoff_retry", + **request, + ) \ No newline at end of file diff --git a/dspy/utils/inspect_history.py b/dspy/utils/inspect_history.py index 16baefd303..032c01f890 100644 --- a/dspy/utils/inspect_history.py +++ b/dspy/utils/inspect_history.py @@ -41,10 +41,26 @@ def pretty_print_history(history, n: int = 1): else: image_str = f"" print(_blue(image_str.strip())) + elif c["type"] == "input_audio": + audio_format = c["input_audio"]["format"] + len_audio = len(c["input_audio"]["data"]) + audio_str = f"