Skip to content

[WIP] support for OpenAI Responses API and image gen #8331

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 38 additions & 7 deletions dspy/adapters/base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import logging
from typing import TYPE_CHECKING, Any, Optional, Type

from dspy.adapters.types import History
from dspy.adapters.types import BaseType, History
from dspy.adapters.types.base_type import split_message_content_for_custom_types
from dspy.signatures.signature import Signature
from dspy.utils.callback import BaseCallback, with_callbacks

logger = logging.getLogger(__name__)

if TYPE_CHECKING:
from dspy.clients.lm import LM

Expand All @@ -22,20 +25,41 @@ def __init_subclass__(cls, **kwargs) -> None:

def _call_post_process(self, outputs: list[dict[str, Any]], signature: Type[Signature]) -> list[dict[str, Any]]:
values = []
output_field_name, field_info = next(iter(signature.output_fields.items()))
annotation = field_info.annotation

for output in outputs:
output_logprobs = None

if isinstance(output, dict):
output, output_logprobs = output["text"], output["logprobs"]

value = self.parse(signature, output)
if "text" in output:
output, output_logprobs = output["text"], output.get("logprobs")
elif "content" in output:
output_content = output["content"]
if isinstance(output_content, list) and output_content and "text" in output_content[0]:
output = output_content[0]["text"]
else:
output = ""
output_logprobs = None
else:
output = str(output)
output_logprobs = None
try:
if issubclass(annotation, BaseType):
try:
parsed = annotation.parse(output)
value = {output_field_name: parsed}
except Exception as e:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is related to the implicit warning that certain custom fields must be the only OutputField (aka can't have 2 Image OutputFields or 1 Image, 1 ToolCall OutputField) but maybe we can express this better? @chenmoneygithub @TomeHirata

logger.warning(f"Output is not of expected annotation field '{output_field_name}': {e}")
continue
else:
value = self.parse(signature, output)
except TypeError:
value = self.parse(signature, output)

if output_logprobs is not None:
value["logprobs"] = output_logprobs

values.append(value)

return values

def __call__(
Expand All @@ -48,6 +72,13 @@ def __call__(
) -> list[dict[str, Any]]:
inputs = self.format(signature, demos, inputs)

if getattr(lm, "model_type", None) == "responses":
for msg in inputs:
if msg["role"] == "user" and isinstance(msg["content"], list):
for block in msg["content"]:
if block.get("type") == "text":
block["type"] = "input_text"

outputs = lm(messages=inputs, **lm_kwargs)
return self._call_post_process(outputs, signature)

Expand Down
4 changes: 4 additions & 0 deletions dspy/adapters/types/base_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ def format(self) -> list[dict[str, Any]]:

def format(self) -> list[dict[str, Any]]:
raise NotImplementedError

@classmethod
def parse(cls, raw: Any) -> "BaseType":
return cls(**raw)

@pydantic.model_serializer()
def serialize_model(self):
Expand Down
23 changes: 22 additions & 1 deletion dspy/adapters/types/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,29 @@ def format(self) -> Union[list[dict[str, Any]], str]:
image_url = encode_image(self.url)
except Exception as e:
raise ValueError(f"Failed to format image for DSPy: {e}")
return [{"type": "image_url", "image_url": {"url": image_url}}]
if isinstance(image_url, str):
return [{"type": "input_image", "image_url": image_url}]
else:
return [{"type": "input_image", "image_url": {"url": image_url}}]

@classmethod
def parse(cls, raw: Any) -> "Image":
if isinstance(raw, dict):
if "result" in raw:
b64 = raw["result"]
else:
raise TypeError("Input for parsing is missing 'result'")
elif hasattr(raw, "result"):
b64 = getattr(raw, "result")
elif hasattr(raw, "data"):
b64 = raw.data[0].b64_json
elif isinstance(raw, str):
b64 = raw
else:
raise TypeError(f"Unsupported type {type(raw)} for Image.parse")
uri = b64 if b64.startswith("data:") else f"data:image/png;base64,{b64}"
return cls(url=uri)

@pydantic.model_validator(mode="before")
@classmethod
def validate_input(cls, values):
Expand Down
17 changes: 15 additions & 2 deletions dspy/clients/base_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,20 @@ def __init__(self, model, model_type="chat", temperature=0.0, max_tokens=1000, c

def _process_lm_response(self, response, prompt, messages, **kwargs):
merged_kwargs = {**self.kwargs, **kwargs}
if merged_kwargs.get("logprobs"):
if hasattr(response, "output") or (isinstance(response, dict) and "output" in response):
output_items = response.output if hasattr(response, "output") else response["output"]
output_items = list(output_items)
if any(not (hasattr(item, "content") or (isinstance(item, dict) and "content" in item)) for item in output_items):
outputs = output_items
else:
outputs = []
for item in output_items:
content_list = item.get("content", []) if isinstance(item, dict) else getattr(item, "content", [])
for c in content_list:
if (isinstance(c, dict) and c.get("type") == "output_text") or (hasattr(c, "type") and getattr(c, "type", None) == "output_text"):
text_val = c.get("text") if isinstance(c, dict) else getattr(c, "text", "")
outputs.append(text_val)
elif merged_kwargs.get("logprobs"):
outputs = [
{
"text": c.message.content if hasattr(c, "message") else c["text"],
Expand Down Expand Up @@ -133,7 +146,7 @@ def copy(self, **kwargs):
return new_instance

def inspect_history(self, n: int = 1):
return inspect_history(self.history, n)
return pretty_print_history(self.history, n)

def update_global_history(self, entry):
if settings.disable_history:
Expand Down
110 changes: 90 additions & 20 deletions dspy/clients/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class LM(BaseLM):
def __init__(
self,
model: str,
model_type: Literal["chat", "text"] = "chat",
model_type: Literal["chat", "text", "responses"] = "chat",
temperature: float = 0.0,
max_tokens: int = 4000,
cache: bool = True,
Expand Down Expand Up @@ -120,24 +120,41 @@ def forward(self, prompt=None, messages=None, **kwargs):

messages = messages or [{"role": "user", "content": prompt}]
kwargs = {**self.kwargs, **kwargs}

completion = litellm_completion if self.model_type == "chat" else litellm_text_completion


if self.model_type == "chat":
completion = litellm_completion
elif self.model_type == "text":
completion = litellm_text_completion
elif self.model_type == "responses":
completion = litellm_responses_completion

completion, litellm_cache_args = self._get_cached_completion_fn(completion, cache, enable_memory_cache)

results = completion(
request=dict(model=self.model, messages=messages, **kwargs),
num_retries=self.num_retries,
cache=litellm_cache_args,
)
if self.model_type != "responses":
if any(c.finish_reason == "length" for c in results["choices"]):
logger.warning(
f"LM response was truncated due to exceeding max_tokens={self.kwargs['max_tokens']}. "
"You can inspect the latest LM interactions with `dspy.inspect_history()`. "
"To avoid truncation, consider passing a larger max_tokens when setting up dspy.LM. "
f"You may also consider increasing the temperature (currently {self.kwargs['temperature']}) "
" if the reason for truncation is repetition."
)
else:
if results.get("truncation") == "enabled":
logger.warning(
f"LM response was truncated due to exceeding max_tokens={self.kwargs['max_tokens']}. "
"You can inspect the latest LM interactions with `dspy.inspect_history()`. "
"To avoid truncation, consider passing a larger max_tokens when setting up dspy.LM. "
f"You may also consider increasing the temperature (currently {self.kwargs['temperature']}) "
" if the reason for truncation is repetition."
)

if any(c.finish_reason == "length" for c in results["choices"]):
logger.warning(
f"LM response was truncated due to exceeding max_tokens={self.kwargs['max_tokens']}. "
"You can inspect the latest LM interactions with `dspy.inspect_history()`. "
"To avoid truncation, consider passing a larger max_tokens when setting up dspy.LM. "
f"You may also consider increasing the temperature (currently {self.kwargs['temperature']}) "
" if the reason for truncation is repetition."
)

if not getattr(results, "cache_hit", False) and dspy.settings.usage_tracker and hasattr(results, "usage"):
settings.usage_tracker.add_usage(self.model, dict(results.usage))
Expand All @@ -151,7 +168,12 @@ async def aforward(self, prompt=None, messages=None, **kwargs):
messages = messages or [{"role": "user", "content": prompt}]
kwargs = {**self.kwargs, **kwargs}

completion = alitellm_completion if self.model_type == "chat" else alitellm_text_completion
if self.model_type == "chat":
completion = alitellm_completion
elif self.model_type == "text":
completion = alitellm_text_completion
elif self.model_type == "responses":
completion = alitellm_responses_completion
completion, litellm_cache_args = self._get_cached_completion_fn(completion, cache, enable_memory_cache)

results = await completion(
Expand All @@ -160,14 +182,24 @@ async def aforward(self, prompt=None, messages=None, **kwargs):
cache=litellm_cache_args,
)

if any(c.finish_reason == "length" for c in results["choices"]):
logger.warning(
f"LM response was truncated due to exceeding max_tokens={self.kwargs['max_tokens']}. "
"You can inspect the latest LM interactions with `dspy.inspect_history()`. "
"To avoid truncation, consider passing a larger max_tokens when setting up dspy.LM. "
f"You may also consider increasing the temperature (currently {self.kwargs['temperature']}) "
" if the reason for truncation is repetition."
)
if self.model_type != "responses":
if any(c.finish_reason == "length" for c in results["choices"]):
logger.warning(
f"LM response was truncated due to exceeding max_tokens={self.kwargs['max_tokens']}. "
"You can inspect the latest LM interactions with `dspy.inspect_history()`. "
"To avoid truncation, consider passing a larger max_tokens when setting up dspy.LM. "
f"You may also consider increasing the temperature (currently {self.kwargs['temperature']}) "
" if the reason for truncation is repetition."
)
else:
if results.get("truncation") == "enabled":
logger.warning(
f"LM response was truncated due to exceeding max_tokens={self.kwargs['max_tokens']}. "
"You can inspect the latest LM interactions with `dspy.inspect_history()`. "
"To avoid truncation, consider passing a larger max_tokens when setting up dspy.LM. "
f"You may also consider increasing the temperature (currently {self.kwargs['temperature']}) "
" if the reason for truncation is repetition."
)

if not getattr(results, "cache_hit", False) and dspy.settings.usage_tracker and hasattr(results, "usage"):
settings.usage_tracker.add_usage(self.model, dict(results.usage))
Expand Down Expand Up @@ -373,3 +405,41 @@ async def alitellm_text_completion(request: Dict[str, Any], num_retries: int, ca
retry_strategy="exponential_backoff_retry",
**request,
)

def litellm_responses_completion(request: Dict[str, Any], num_retries: int, cache: Optional[Dict[str, Any]] = None):
cache = cache or {"no-cache": True, "no-store": True}
if "messages" in request:
content_blocks = []
for msg in request.pop("messages"):
c = msg.get("content")
if isinstance(c, str):
content_blocks.append({"type": "input_text", "text": c})
elif isinstance(c, list):
content_blocks.extend(c)
request["input"] = [{"role": msg.get("role", "user"), "content": content_blocks}]

return litellm.responses(
cache=cache,
num_retries=num_retries,
retry_strategy="exponential_backoff_retry",
**request,
)


async def alitellm_responses_completion(request: Dict[str, Any], num_retries: int, cache: Optional[Dict[str, Any]] = None):
cache = cache or {"no-cache": True, "no-store": True}
if "messages" in request:
content_blocks = []
for msg in request.pop("messages"):
c = msg.get("content")
if isinstance(c, str):
content_blocks.append({"type": "input_text", "text": c})
elif isinstance(c, list):
content_blocks.extend(c)
request["input"] = [{"role": msg.get("role", "user"), "content": content_blocks}]
return await litellm.aresponses(
cache=cache,
num_retries=num_retries,
retry_strategy="exponential_backoff_retry",
**request,
)
18 changes: 17 additions & 1 deletion dspy/utils/inspect_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,26 @@ def pretty_print_history(history, n: int = 1):
else:
image_str = f"<image_url: {c['image_url']['url']}>"
print(_blue(image_str.strip()))
elif c["type"] == "input_audio":
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this got lost in an earlier PR, adding back in

audio_format = c["input_audio"]["format"]
len_audio = len(c["input_audio"]["data"])
audio_str = f"<audio format='{audio_format}' base64-encoded, length={len_audio}>"
print(_blue(audio_str.strip()))
print("\n")

print(_red("Response:"))
print(_green(outputs[0].strip()))
out = outputs[0]
if isinstance(out, str):
print(_green(out.strip()))
elif hasattr(out, "result") and isinstance(out.result, str):
b64 = out.result
if b64.startswith("data:"):
head, b64data = b64.split(",", 1)
print(_green(f"<Image output: {head}base64,<IMAGE_BASE_64_ENCODED({len(b64data)})>>"))
else:
print(_green(f"<Image output: base64,<IMAGE_BASE_64_ENCODED({len(b64)})>>"))
else:
print(_green(f"<Non-string output: {repr(out)}>"))

if len(outputs) > 1:
choices_text = f" \t (and {len(outputs) - 1} other completions)"
Expand Down
27 changes: 16 additions & 11 deletions tests/adapters/test_chat_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ class MySignature(dspy.Signature):
assert user_message_content[2]["type"] == "text"

# Assert that the image is formatted correctly
expected_image_content = {"type": "image_url", "image_url": {"url": "https://example.com/image.jpg"}}
expected_image_content = {"type": "input_image", "image_url": "https://example.com/image.jpg"}
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all test related changes are for the refactoring of dspy.Image to handle the Responses Image format (the API strictly requires input_image)

assert expected_image_content in user_message_content


Expand All @@ -245,9 +245,10 @@ class MySignature(dspy.Signature):
# 1 system message, 2 few shot examples (1 user and assistant message for each example), 1 user message
assert len(messages) == 6

assert {"type": "image_url", "image_url": {"url": "https://example.com/image1.jpg"}} in messages[1]["content"]
assert {"type": "image_url", "image_url": {"url": "https://example.com/image2.jpg"}} in messages[3]["content"]
assert {"type": "image_url", "image_url": {"url": "https://example.com/image3.jpg"}} in messages[5]["content"]
assert {"type": "input_image", "image_url": "https://example.com/image1.jpg"} in messages[1]["content"]
assert {"type": "input_image", "image_url": "https://example.com/image2.jpg"} in messages[3]["content"]
assert {"type": "input_image", "image_url": "https://example.com/image3.jpg"} in messages[5]["content"]



def test_chat_adapter_formats_image_with_nested_images():
Expand All @@ -268,9 +269,11 @@ class MySignature(dspy.Signature):
adapter = dspy.ChatAdapter()
messages = adapter.format(MySignature, [], {"image": image_wrapper})

expected_image1_content = {"type": "image_url", "image_url": {"url": "https://example.com/image1.jpg"}}
expected_image2_content = {"type": "image_url", "image_url": {"url": "https://example.com/image2.jpg"}}
expected_image3_content = {"type": "image_url", "image_url": {"url": "https://example.com/image3.jpg"}}
expected_image1_content = {"type": "input_image", "image_url": "https://example.com/image1.jpg"}
expected_image2_content = {"type": "input_image", "image_url": "https://example.com/image2.jpg"}
expected_image3_content = {"type": "input_image", "image_url": "https://example.com/image3.jpg"}



assert expected_image1_content in messages[1]["content"]
assert expected_image2_content in messages[1]["content"]
Expand Down Expand Up @@ -305,12 +308,14 @@ class MySignature(dspy.Signature):
assert len(messages) == 4

# Image information in the few-shot example's user message
expected_image1_content = {"type": "image_url", "image_url": {"url": "https://example.com/image1.jpg"}}
expected_image2_content = {"type": "image_url", "image_url": {"url": "https://example.com/image2.jpg"}}
expected_image3_content = {"type": "image_url", "image_url": {"url": "https://example.com/image3.jpg"}}
expected_image1_content = {"type": "input_image", "image_url": "https://example.com/image1.jpg"}
expected_image2_content = {"type": "input_image", "image_url": "https://example.com/image2.jpg"}
expected_image3_content = {"type": "input_image", "image_url": "https://example.com/image3.jpg"}


assert expected_image1_content in messages[1]["content"]
assert expected_image2_content in messages[1]["content"]
assert expected_image3_content in messages[1]["content"]

# The query image is formatted in the last user message
assert {"type": "image_url", "image_url": {"url": "https://example.com/image4.jpg"}} in messages[-1]["content"]
assert {"type": "input_image", "image_url": "https://example.com/image4.jpg"} in messages[-1]["content"]
Loading