11# Copyright (c) Meta Platforms, Inc. and affiliates
22
3+ import json
34import os
45from datetime import datetime
56from unittest .mock import patch
1516from openai import OpenAIError
1617from openai .types .chat import ChatCompletion , ChatCompletionMessage
1718from openai .types .chat .chat_completion import Choice
19+ from pydantic import BaseModel
1820
1921from haystack_integrations .components .generators .meta_llama .chat .chat_generator import (
2022 MetaLlamaChatGenerator ,
@@ -158,12 +160,44 @@ def test_to_dict_default(self, monkeypatch):
158160
159161 def test_to_dict_with_parameters (self , monkeypatch ):
160162 monkeypatch .setenv ("ENV_VAR" , "test-api-key" )
163+
164+ class NobelPrizeInfo (BaseModel ):
165+ recipient_name : str
166+ award_year : int
167+
168+ schema = {
169+ "json_schema" : {
170+ "name" : "NobelPrizeInfo" ,
171+ "schema" : {
172+ "additionalProperties" : False ,
173+ "properties" : {
174+ "award_year" : {
175+ "title" : "Award Year" ,
176+ "type" : "integer" ,
177+ },
178+ "recipient_name" : {
179+ "title" : "Recipient Name" ,
180+ "type" : "string" ,
181+ },
182+ },
183+ "required" : [
184+ "recipient_name" ,
185+ "award_year" ,
186+ ],
187+ "title" : "NobelPrizeInfo" ,
188+ "type" : "object" ,
189+ },
190+ "strict" : True ,
191+ },
192+ "type" : "json_schema" ,
193+ }
194+
161195 component = MetaLlamaChatGenerator (
162196 api_key = Secret .from_env_var ("ENV_VAR" ),
163197 model = "Llama-4-Scout-17B-16E-Instruct-FP8" ,
164198 streaming_callback = print_streaming_chunk ,
165199 api_base_url = "test-base-url" ,
166- generation_kwargs = {"max_tokens" : 10 , "some_test_param" : "test-params" },
200+ generation_kwargs = {"max_tokens" : 10 , "some_test_param" : "test-params" , "response_format" : NobelPrizeInfo },
167201 )
168202 data = component .to_dict ()
169203
@@ -177,7 +211,7 @@ def test_to_dict_with_parameters(self, monkeypatch):
177211 "model" : "Llama-4-Scout-17B-16E-Instruct-FP8" ,
178212 "api_base_url" : "test-base-url" ,
179213 "streaming_callback" : "haystack.components.generators.utils.print_streaming_chunk" ,
180- "generation_kwargs" : {"max_tokens" : 10 , "some_test_param" : "test-params" },
214+ "generation_kwargs" : {"max_tokens" : 10 , "some_test_param" : "test-params" , "response_format" : schema },
181215 }
182216
183217 for key , value in expected_params .items ():
@@ -317,6 +351,77 @@ def __call__(self, chunk: StreamingChunk) -> None:
317351 assert callback .counter > 1
318352 assert "Paris" in callback .responses
319353
354+ @pytest .mark .skipif (
355+ not os .environ .get ("LLAMA_API_KEY" , None ),
356+ reason = "Export an env var called LLAMA_API_KEY containing the Llama API key to run this test." ,
357+ )
358+ @pytest .mark .integration
359+ def test_live_run_response_format (self ):
360+ class NobelPrizeInfo (BaseModel ):
361+ recipient_name : str
362+ award_year : int
363+ category : str
364+ achievement_description : str
365+ nationality : str
366+
367+ chat_messages = [
368+ ChatMessage .from_user (
369+ "In 2021, American scientist David Julius received the Nobel Prize in"
370+ " Physiology or Medicine for his groundbreaking discoveries on how the human body"
371+ " senses temperature and touch."
372+ )
373+ ]
374+ component = MetaLlamaChatGenerator (generation_kwargs = {"response_format" : NobelPrizeInfo })
375+ results = component .run (chat_messages )
376+ assert isinstance (results , dict )
377+ assert "replies" in results
378+ assert isinstance (results ["replies" ], list )
379+ assert len (results ["replies" ]) == 1
380+ assert isinstance (results ["replies" ][0 ], ChatMessage )
381+ message = results ["replies" ][0 ]
382+ assert isinstance (message .text , str )
383+ msg = json .loads (message .text )
384+ assert msg ["recipient_name" ] == "David Julius"
385+ assert msg ["award_year" ] == 2021
386+ assert "category" in msg
387+ assert "achievement_description" in msg
388+ assert msg ["nationality" ] == "American"
389+
390+ @pytest .mark .skipif (
391+ not os .environ .get ("LLAMA_API_KEY" , None ),
392+ reason = "Export an env var called LLAMA_API_KEY containing the Llama API key to run this test." ,
393+ )
394+ @pytest .mark .integration
395+ def test_live_run_with_response_format_json_schema (self ):
396+ response_schema = {
397+ "type" : "json_schema" ,
398+ "json_schema" : {
399+ "name" : "CapitalCity" ,
400+ "strict" : True ,
401+ "schema" : {
402+ "title" : "CapitalCity" ,
403+ "type" : "object" ,
404+ "properties" : {
405+ "city" : {"title" : "City" , "type" : "string" },
406+ "country" : {"title" : "Country" , "type" : "string" },
407+ },
408+ "required" : ["city" , "country" ],
409+ "additionalProperties" : False ,
410+ },
411+ },
412+ }
413+
414+ chat_messages = [ChatMessage .from_user ("What's the capital of France?" )]
415+ comp = MetaLlamaChatGenerator (generation_kwargs = {"response_format" : response_schema })
416+ results = comp .run (chat_messages )
417+ assert len (results ["replies" ]) == 1
418+ message : ChatMessage = results ["replies" ][0 ]
419+ msg = json .loads (message .text )
420+ assert "Paris" in msg ["city" ]
421+ assert isinstance (msg ["country" ], str )
422+ assert "France" in msg ["country" ]
423+ assert message .meta ["finish_reason" ] == "stop"
424+
320425 @pytest .mark .skipif (
321426 not os .environ .get ("LLAMA_API_KEY" , None ),
322427 reason = "Export an env var called LLAMA_API_KEY containing the OpenAI API key to run this test." ,
0 commit comments