Skip to content

Commit e1d2550

Browse files
authored
feat: support structured outputs in MetaLlamaChatGenerator (#2410)
* Structured outputs support for MetaLLama * Add example * Updates * Update chat_generator.py
1 parent 069df42 commit e1d2550

File tree

3 files changed

+169
-5
lines changed

3 files changed

+169
-5
lines changed
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# SPDX-FileCopyrightText: 2024-present deepset GmbH <[email protected]>
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
6+
# This example demonstrates how to use the MetaLlamaChatGenerator component
7+
# with structured outputs.
8+
# To run this example, you will need to
9+
# set `LLAMA_API_KEY` environment variable
10+
11+
from haystack.dataclasses import ChatMessage
12+
from pydantic import BaseModel
13+
14+
from haystack_integrations.components.generators.meta_llama import MetaLlamaChatGenerator
15+
16+
17+
class NobelPrizeInfo(BaseModel):
18+
recipient_name: str
19+
award_year: int
20+
category: str
21+
achievement_description: str
22+
nationality: str
23+
24+
25+
chat_messages = [
26+
ChatMessage.from_user(
27+
"In 2021, American scientist David Julius received the Nobel Prize in"
28+
" Physiology or Medicine for his groundbreaking discoveries on how the human body"
29+
" senses temperature and touch."
30+
)
31+
]
32+
component = MetaLlamaChatGenerator(generation_kwargs={"response_format": NobelPrizeInfo})
33+
results = component.run(chat_messages)
34+
35+
# print(results)

integrations/meta_llama/src/haystack_integrations/components/generators/meta_llama/chat/chat_generator.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,11 @@
88
from haystack import component, default_to_dict, logging
99
from haystack.components.generators.chat import OpenAIChatGenerator
1010
from haystack.dataclasses import ChatMessage, StreamingCallbackT
11-
from haystack.tools import ToolsType
11+
from haystack.tools import ToolsType, serialize_tools_or_toolset
1212
from haystack.utils import serialize_callable
1313
from haystack.utils.auth import Secret
14+
from openai.lib._pydantic import to_strict_json_schema
15+
from pydantic import BaseModel
1416

1517
logger = logging.getLogger(__name__)
1618

@@ -91,6 +93,12 @@ def __init__(
9193
events as they become available, with the stream terminated by a data: [DONE] message.
9294
- `safe_prompt`: Whether to inject a safety prompt before all conversations.
9395
- `random_seed`: The seed to use for random sampling.
96+
- `response_format`: A JSON schema or a Pydantic model that enforces the structure of the model's response.
97+
If provided, the output will always be validated against this
98+
format (unless the model returns a tool call).
99+
For details, see the [OpenAI Structured Outputs documentation](https://platform.openai.com/docs/guides/structured-outputs).
100+
For structured outputs with streaming, the `response_format` must be a JSON
101+
schema and not a Pydantic model.
94102
:param tools:
95103
A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls.
96104
Each tool should have a unique name.
@@ -134,13 +142,29 @@ def to_dict(self) -> Dict[str, Any]:
134142
The serialized component as a dictionary.
135143
"""
136144
callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
145+
generation_kwargs = self.generation_kwargs.copy()
146+
response_format = generation_kwargs.get("response_format")
147+
148+
# If the response format is a Pydantic model, it's converted to openai's json schema format
149+
# If it's already a json schema, it's left as is
150+
if response_format and isinstance(response_format, type) and issubclass(response_format, BaseModel):
151+
json_schema = {
152+
"type": "json_schema",
153+
"json_schema": {
154+
"name": response_format.__name__,
155+
"strict": True,
156+
"schema": to_strict_json_schema(response_format),
157+
},
158+
}
159+
160+
generation_kwargs["response_format"] = json_schema
137161

138162
return default_to_dict(
139163
self,
140164
model=self.model,
141165
streaming_callback=callback_name,
142166
api_base_url=self.api_base_url,
143-
generation_kwargs=self.generation_kwargs,
167+
generation_kwargs=generation_kwargs,
144168
api_key=self.api_key.to_dict(),
145-
tools=[tool.to_dict() for tool in self.tools] if self.tools else None,
169+
tools=serialize_tools_or_toolset(self.tools),
146170
)

integrations/meta_llama/tests/test_llama_chat_generator.py

Lines changed: 107 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates
22

3+
import json
34
import os
45
from datetime import datetime
56
from unittest.mock import patch
@@ -15,6 +16,7 @@
1516
from openai import OpenAIError
1617
from openai.types.chat import ChatCompletion, ChatCompletionMessage
1718
from openai.types.chat.chat_completion import Choice
19+
from pydantic import BaseModel
1820

1921
from haystack_integrations.components.generators.meta_llama.chat.chat_generator import (
2022
MetaLlamaChatGenerator,
@@ -158,12 +160,44 @@ def test_to_dict_default(self, monkeypatch):
158160

159161
def test_to_dict_with_parameters(self, monkeypatch):
160162
monkeypatch.setenv("ENV_VAR", "test-api-key")
163+
164+
class NobelPrizeInfo(BaseModel):
165+
recipient_name: str
166+
award_year: int
167+
168+
schema = {
169+
"json_schema": {
170+
"name": "NobelPrizeInfo",
171+
"schema": {
172+
"additionalProperties": False,
173+
"properties": {
174+
"award_year": {
175+
"title": "Award Year",
176+
"type": "integer",
177+
},
178+
"recipient_name": {
179+
"title": "Recipient Name",
180+
"type": "string",
181+
},
182+
},
183+
"required": [
184+
"recipient_name",
185+
"award_year",
186+
],
187+
"title": "NobelPrizeInfo",
188+
"type": "object",
189+
},
190+
"strict": True,
191+
},
192+
"type": "json_schema",
193+
}
194+
161195
component = MetaLlamaChatGenerator(
162196
api_key=Secret.from_env_var("ENV_VAR"),
163197
model="Llama-4-Scout-17B-16E-Instruct-FP8",
164198
streaming_callback=print_streaming_chunk,
165199
api_base_url="test-base-url",
166-
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
200+
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params", "response_format": NobelPrizeInfo},
167201
)
168202
data = component.to_dict()
169203

@@ -177,7 +211,7 @@ def test_to_dict_with_parameters(self, monkeypatch):
177211
"model": "Llama-4-Scout-17B-16E-Instruct-FP8",
178212
"api_base_url": "test-base-url",
179213
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
180-
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
214+
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params", "response_format": schema},
181215
}
182216

183217
for key, value in expected_params.items():
@@ -317,6 +351,77 @@ def __call__(self, chunk: StreamingChunk) -> None:
317351
assert callback.counter > 1
318352
assert "Paris" in callback.responses
319353

354+
@pytest.mark.skipif(
355+
not os.environ.get("LLAMA_API_KEY", None),
356+
reason="Export an env var called LLAMA_API_KEY containing the Llama API key to run this test.",
357+
)
358+
@pytest.mark.integration
359+
def test_live_run_response_format(self):
360+
class NobelPrizeInfo(BaseModel):
361+
recipient_name: str
362+
award_year: int
363+
category: str
364+
achievement_description: str
365+
nationality: str
366+
367+
chat_messages = [
368+
ChatMessage.from_user(
369+
"In 2021, American scientist David Julius received the Nobel Prize in"
370+
" Physiology or Medicine for his groundbreaking discoveries on how the human body"
371+
" senses temperature and touch."
372+
)
373+
]
374+
component = MetaLlamaChatGenerator(generation_kwargs={"response_format": NobelPrizeInfo})
375+
results = component.run(chat_messages)
376+
assert isinstance(results, dict)
377+
assert "replies" in results
378+
assert isinstance(results["replies"], list)
379+
assert len(results["replies"]) == 1
380+
assert isinstance(results["replies"][0], ChatMessage)
381+
message = results["replies"][0]
382+
assert isinstance(message.text, str)
383+
msg = json.loads(message.text)
384+
assert msg["recipient_name"] == "David Julius"
385+
assert msg["award_year"] == 2021
386+
assert "category" in msg
387+
assert "achievement_description" in msg
388+
assert msg["nationality"] == "American"
389+
390+
@pytest.mark.skipif(
391+
not os.environ.get("LLAMA_API_KEY", None),
392+
reason="Export an env var called LLAMA_API_KEY containing the Llama API key to run this test.",
393+
)
394+
@pytest.mark.integration
395+
def test_live_run_with_response_format_json_schema(self):
396+
response_schema = {
397+
"type": "json_schema",
398+
"json_schema": {
399+
"name": "CapitalCity",
400+
"strict": True,
401+
"schema": {
402+
"title": "CapitalCity",
403+
"type": "object",
404+
"properties": {
405+
"city": {"title": "City", "type": "string"},
406+
"country": {"title": "Country", "type": "string"},
407+
},
408+
"required": ["city", "country"],
409+
"additionalProperties": False,
410+
},
411+
},
412+
}
413+
414+
chat_messages = [ChatMessage.from_user("What's the capital of France?")]
415+
comp = MetaLlamaChatGenerator(generation_kwargs={"response_format": response_schema})
416+
results = comp.run(chat_messages)
417+
assert len(results["replies"]) == 1
418+
message: ChatMessage = results["replies"][0]
419+
msg = json.loads(message.text)
420+
assert "Paris" in msg["city"]
421+
assert isinstance(msg["country"], str)
422+
assert "France" in msg["country"]
423+
assert message.meta["finish_reason"] == "stop"
424+
320425
@pytest.mark.skipif(
321426
not os.environ.get("LLAMA_API_KEY", None),
322427
reason="Export an env var called LLAMA_API_KEY containing the OpenAI API key to run this test.",

0 commit comments

Comments
 (0)