From 4dc4a434316c0e35a6bb2226d1fc4560706ace24 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Thu, 6 Feb 2025 11:11:27 +0100 Subject: [PATCH 1/8] Add tools to HuggingFaceLocalChatGenerator --- .../generators/chat/hugging_face_local.py | 100 +++++++++++++++--- 1 file changed, 85 insertions(+), 15 deletions(-) diff --git a/haystack/components/generators/chat/hugging_face_local.py b/haystack/components/generators/chat/hugging_face_local.py index d5d05ae487..a42bcd39af 100644 --- a/haystack/components/generators/chat/hugging_face_local.py +++ b/haystack/components/generators/chat/hugging_face_local.py @@ -2,12 +2,15 @@ # # SPDX-License-Identifier: Apache-2.0 +import json +import re import sys from typing import Any, Callable, Dict, List, Literal, Optional, Union from haystack import component, default_from_dict, default_to_dict, logging -from haystack.dataclasses import ChatMessage, StreamingChunk +from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall from haystack.lazy_imports import LazyImport +from haystack.tools import Tool, _check_duplicate_tool_names from haystack.utils import ( ComponentDevice, Secret, @@ -33,6 +36,12 @@ PIPELINE_SUPPORTED_TASKS = ["text-generation", "text2text-generation"] +DEFAULT_TOOL_PATTERN = ( + r"(?:)?" + r'(?:\s*\{.*?"name"\s*:\s*"([^"]+)".*?"arguments"\s*:\s*(\{[^}]+\}).*?\}' + r'|\{.*?"function"\s*:\s*\{.*?"name"\s*:\s*"([^"]+)".*?"arguments"\s*:\s*(\{[^}]+\}).*?\})' +) + @component class HuggingFaceLocalChatGenerator: @@ -83,6 +92,8 @@ def __init__( # pylint: disable=too-many-positional-arguments huggingface_pipeline_kwargs: Optional[Dict[str, Any]] = None, stop_words: Optional[List[str]] = None, streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + tools: Optional[List[Tool]] = None, + tool_pattern: Optional[Union[str, Callable[[str], Optional[List[ToolCall]]]]] = None, ): """ Initializes the HuggingFaceLocalChatGenerator component. @@ -121,9 +132,19 @@ def __init__( # pylint: disable=too-many-positional-arguments For some chat models, the output includes both the new text and the original prompt. In these cases, make sure your prompt has no stop words. :param streaming_callback: An optional callable for handling streaming responses. + :param tools: A list of tools for which the model can prepare calls. + :param tool_pattern: + A pattern or callable to parse tool calls from model output. + If a string, it will be used as a regex pattern to extract ToolCall object. + If a callable, it should take a string and return a ToolCall object or None. + If None, a default pattern will be used. """ torch_and_transformers_import.check() + if tools and streaming_callback is not None: + raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.") + _check_duplicate_tool_names(tools) + huggingface_pipeline_kwargs = huggingface_pipeline_kwargs or {} generation_kwargs = generation_kwargs or {} @@ -167,11 +188,13 @@ def __init__( # pylint: disable=too-many-positional-arguments generation_kwargs["stop_sequences"] = generation_kwargs.get("stop_sequences", []) generation_kwargs["stop_sequences"].extend(stop_words or []) + self.tool_pattern = tool_pattern or DEFAULT_TOOL_PATTERN self.huggingface_pipeline_kwargs = huggingface_pipeline_kwargs self.generation_kwargs = generation_kwargs self.chat_template = chat_template self.streaming_callback = streaming_callback self.pipeline = None + self.tools = tools def _get_telemetry_data(self) -> Dict[str, Any]: """ @@ -238,6 +261,7 @@ def run( messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, Any]] = None, streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + tools: Optional[List[Tool]] = None, ): """ Invoke text generation inference based on the provided messages and generation parameters. @@ -245,12 +269,20 @@ def run( :param messages: A list of ChatMessage objects representing the input messages. :param generation_kwargs: Additional keyword arguments for text generation. :param streaming_callback: An optional callable for handling streaming responses. + :param tools: + A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter + provided during initialization. :returns: A list containing the generated responses as ChatMessage instances. """ if self.pipeline is None: raise RuntimeError("The generation model has not been loaded. Please call warm_up() before running.") + tools = tools or self.tools + if tools and streaming_callback is not None: + raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.") + _check_duplicate_tool_names(tools) + tokenizer = self.pipeline.tokenizer # Check and update generation parameters @@ -279,11 +311,14 @@ def run( # streamer parameter hooks into HF streaming, HFTokenStreamingHandler is an adapter to our streaming generation_kwargs["streamer"] = HFTokenStreamingHandler(tokenizer, streaming_callback, stop_words) + # convert messages to HF format hf_messages = [convert_message_to_hf_format(message) for message in messages] - - # Prepare the prompt for the model prepared_prompt = tokenizer.apply_chat_template( - hf_messages, tokenize=False, chat_template=self.chat_template, add_generation_prompt=True + hf_messages, + tokenize=False, + chat_template=self.chat_template, + add_generation_prompt=True, + tools=[tc.tool_spec for tc in tools] if tools else None, ) # Avoid some unnecessary warnings in the generation pipeline call @@ -299,11 +334,13 @@ def run( for stop_word in stop_words: replies = [reply.replace(stop_word, "").rstrip() for reply in replies] - # Create ChatMessage instances for each reply chat_messages = [ - self.create_message(reply, r_index, tokenizer, prepared_prompt, generation_kwargs) + self.create_message( + reply, r_index, tokenizer, prepared_prompt, generation_kwargs, parse_tool_calls=bool(tools) + ) for r_index, reply in enumerate(replies) ] + return {"replies": chat_messages} def create_message( # pylint: disable=too-many-positional-arguments @@ -313,6 +350,7 @@ def create_message( # pylint: disable=too-many-positional-arguments tokenizer: Union["PreTrainedTokenizer", "PreTrainedTokenizerFast"], prompt: str, generation_kwargs: Dict[str, Any], + parse_tool_calls: bool = False, ) -> ChatMessage: """ Create a ChatMessage instance from the provided text, populated with metadata. @@ -322,17 +360,23 @@ def create_message( # pylint: disable=too-many-positional-arguments :param tokenizer: The tokenizer used for generation. :param prompt: The prompt used for generation. :param generation_kwargs: The generation parameters. + :param parse_tool_calls: Whether to attempt parsing tool calls from the text. :returns: A ChatMessage instance. """ + completion_tokens = len(tokenizer.encode(text, add_special_tokens=False)) prompt_token_count = len(tokenizer.encode(prompt, add_special_tokens=False)) total_tokens = prompt_token_count + completion_tokens - # not the most sophisticated finish_reason detection, improve later to match - # https://platform.openai.com/docs/guides/text-generation/chat-completions-response-format - finish_reason = ( - "length" if completion_tokens >= generation_kwargs.get("max_new_tokens", sys.maxsize) else "stop" - ) + tool_calls = self._parse_tool_call(text) if parse_tool_calls else None + + # Determine finish reason based on context + if completion_tokens >= generation_kwargs.get("max_new_tokens", sys.maxsize): + finish_reason = "length" + elif tool_calls: + finish_reason = "tool_calls" + else: + finish_reason = "stop" meta = { "finish_reason": finish_reason, @@ -345,7 +389,7 @@ def create_message( # pylint: disable=too-many-positional-arguments }, } - return ChatMessage.from_assistant(text, meta=meta) + return ChatMessage.from_assistant(tool_calls=tool_calls, text=text, meta=meta) def _validate_stop_words(self, stop_words: Optional[List[str]]) -> Optional[List[str]]: """ @@ -362,6 +406,32 @@ def _validate_stop_words(self, stop_words: Optional[List[str]]) -> Optional[List ) return None - # deduplicate stop words - stop_words = list(set(stop_words or [])) - return stop_words + return list(set(stop_words or [])) + + def _parse_tool_call(self, text: str) -> Optional[List[ToolCall]]: + """ + Parse a tool call from model output text. + + :param text: The text to parse for tool calls. + :returns: A ToolCall object if a valid tool call is found, None otherwise. + """ + # if the tool pattern is a callable, call it with the text and return the result + if callable(self.tool_pattern): + return self.tool_pattern(text) + + # if the tool pattern is a regex pattern, search for it in the text + match = re.search(self.tool_pattern, text, re.DOTALL) + if not match: + return None + + # seem like most models are not producing tool ids, so we omit them + # and just use the tool name and arguments + name = match.group(1) or match.group(3) + args_str = match.group(2) or match.group(4) + + try: + arguments = json.loads(args_str) + return [ToolCall(tool_name=name, arguments=arguments)] + except json.JSONDecodeError: + logger.warning("Failed to parse tool call arguments: {args_str}", args_str=args_str) + return None From 0641206234a4268fe848e16891605762e88b86c6 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Thu, 6 Feb 2025 16:25:10 +0100 Subject: [PATCH 2/8] Add reno --- .../hf-local-chatgenerator-add-tooling-d8676dff2bdf0323.yaml | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 releasenotes/notes/hf-local-chatgenerator-add-tooling-d8676dff2bdf0323.yaml diff --git a/releasenotes/notes/hf-local-chatgenerator-add-tooling-d8676dff2bdf0323.yaml b/releasenotes/notes/hf-local-chatgenerator-add-tooling-d8676dff2bdf0323.yaml new file mode 100644 index 0000000000..b95e8532eb --- /dev/null +++ b/releasenotes/notes/hf-local-chatgenerator-add-tooling-d8676dff2bdf0323.yaml @@ -0,0 +1,4 @@ +--- +features: + - | + Adds tooling support to HuggingFaceLocalChatGenerator From 842c6f20f7560f15d5747bc6aa415e349429975d Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Thu, 6 Feb 2025 17:30:57 +0100 Subject: [PATCH 3/8] Fix types --- haystack/components/generators/chat/hugging_face_local.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/haystack/components/generators/chat/hugging_face_local.py b/haystack/components/generators/chat/hugging_face_local.py index a42bcd39af..5a8e67f6a1 100644 --- a/haystack/components/generators/chat/hugging_face_local.py +++ b/haystack/components/generators/chat/hugging_face_local.py @@ -164,9 +164,11 @@ def __init__( # pylint: disable=too-many-positional-arguments if "task" in huggingface_pipeline_kwargs: task = huggingface_pipeline_kwargs["task"] elif isinstance(huggingface_pipeline_kwargs["model"], str): - task = model_info( + pipeline_tag = model_info( huggingface_pipeline_kwargs["model"], token=huggingface_pipeline_kwargs["token"] ).pipeline_tag + # Ensure pipeline_tag is one of our supported tasks + task = pipeline_tag if pipeline_tag in PIPELINE_SUPPORTED_TASKS else None if task not in PIPELINE_SUPPORTED_TASKS: raise ValueError( From dcc2f34d5c78c95e52428649046649551a90c0d8 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Fri, 7 Feb 2025 09:32:38 +0100 Subject: [PATCH 4/8] Small post merge fix --- haystack/components/generators/chat/hugging_face_local.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/haystack/components/generators/chat/hugging_face_local.py b/haystack/components/generators/chat/hugging_face_local.py index 13194c3b02..7850eeb2bf 100644 --- a/haystack/components/generators/chat/hugging_face_local.py +++ b/haystack/components/generators/chat/hugging_face_local.py @@ -164,7 +164,7 @@ def __init__( # pylint: disable=too-many-positional-arguments if "task" in huggingface_pipeline_kwargs: task = huggingface_pipeline_kwargs["task"] elif isinstance(huggingface_pipeline_kwargs["model"], str): - pipeline_tag = model_info( + task = model_info( huggingface_pipeline_kwargs["model"], token=huggingface_pipeline_kwargs["token"] ).pipeline_tag # type: ignore[assignment] # we'll check below if task is in supported tasks From 7cbcbd677df5cbc528cd6340359f9002464cbc19 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Fri, 7 Feb 2025 09:52:34 +0100 Subject: [PATCH 5/8] Add unit tests --- .../generators/chat/hugging_face_local.py | 7 +- .../chat/test_hugging_face_local.py | 124 +++++++++++++++++- 2 files changed, 129 insertions(+), 2 deletions(-) diff --git a/haystack/components/generators/chat/hugging_face_local.py b/haystack/components/generators/chat/hugging_face_local.py index 7850eeb2bf..8ba6c9a982 100644 --- a/haystack/components/generators/chat/hugging_face_local.py +++ b/haystack/components/generators/chat/hugging_face_local.py @@ -420,7 +420,12 @@ def _parse_tool_call(self, text: str) -> Optional[List[ToolCall]]: return self.tool_pattern(text) # if the tool pattern is a regex pattern, search for it in the text - match = re.search(self.tool_pattern, text, re.DOTALL) + try: + match = re.search(self.tool_pattern, text, re.DOTALL) + except re.error: + logger.warning("Invalid regex pattern for tool parsing: {pattern}", pattern=self.tool_pattern) + return None + if not match: return None diff --git a/test/components/generators/chat/test_hugging_face_local.py b/test/components/generators/chat/test_hugging_face_local.py index 1f6b478370..0285cb3667 100644 --- a/test/components/generators/chat/test_hugging_face_local.py +++ b/test/components/generators/chat/test_hugging_face_local.py @@ -8,9 +8,10 @@ from transformers import PreTrainedTokenizer from haystack.components.generators.chat import HuggingFaceLocalChatGenerator -from haystack.dataclasses import ChatMessage, ChatRole +from haystack.dataclasses import ChatMessage, ChatRole, ToolCall from haystack.utils import ComponentDevice from haystack.utils.auth import Secret +from haystack.tools import Tool # used to test serialization of streaming_callback @@ -49,6 +50,18 @@ def mock_pipeline_tokenizer(): return mock_pipeline +@pytest.fixture +def tools(): + tool_parameters = {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]} + tool = Tool( + name="weather", + description="useful to determine the weather in a given location", + parameters=tool_parameters, + function=lambda x: x, + ) + return [tool] + + class TestHuggingFaceLocalChatGenerator: def test_initialize_with_valid_model_and_generation_parameters(self, model_info_mock): model = "HuggingFaceH4/zephyr-7b-alpha" @@ -307,3 +320,112 @@ def test_live_run(self, monkeypatch): assert "replies" in result assert isinstance(result["replies"][0], ChatMessage) assert "climate change" in result["replies"][0].text.lower() + + def test_init_fail_with_duplicate_tool_names(self, model_info_mock, tools): + duplicate_tools = [tools[0], tools[0]] + with pytest.raises(ValueError, match="Duplicate tool names found"): + HuggingFaceLocalChatGenerator(model="irrelevant", tools=duplicate_tools) + + def test_init_fail_with_tools_and_streaming(self, model_info_mock, tools): + with pytest.raises(ValueError, match="Using tools and streaming at the same time is not supported"): + HuggingFaceLocalChatGenerator( + model="irrelevant", tools=tools, streaming_callback=streaming_callback_handler + ) + + def test_run_with_tools(self, model_info_mock, tools): + generator = HuggingFaceLocalChatGenerator(model="meta-llama/Llama-2-13b-chat-hf", tools=tools) + + # Mock pipeline and tokenizer + mock_pipeline = Mock(return_value=[{"generated_text": '{"name": "weather", "arguments": {"city": "Paris"}}'}]) + mock_tokenizer = Mock(spec=PreTrainedTokenizer) + mock_tokenizer.encode.return_value = ["some", "tokens"] + mock_tokenizer.pad_token_id = 100 + mock_tokenizer.apply_chat_template.return_value = "test prompt" + mock_pipeline.tokenizer = mock_tokenizer + generator.pipeline = mock_pipeline + + messages = [ChatMessage.from_user("What's the weather in Paris?")] + results = generator.run(messages=messages) + + assert len(results["replies"]) == 1 + message = results["replies"][0] + assert message.tool_calls + tool_call = message.tool_calls[0] + assert isinstance(tool_call, ToolCall) + assert tool_call.tool_name == "weather" + assert tool_call.arguments == {"city": "Paris"} + assert message.meta["finish_reason"] == "tool_calls" + + def test_run_with_tools_in_run_method(self, model_info_mock, tools): + generator = HuggingFaceLocalChatGenerator(model="meta-llama/Llama-2-13b-chat-hf") + + # Mock pipeline and tokenizer + mock_pipeline = Mock(return_value=[{"generated_text": '{"name": "weather", "arguments": {"city": "Paris"}}'}]) + mock_tokenizer = Mock(spec=PreTrainedTokenizer) + mock_tokenizer.encode.return_value = ["some", "tokens"] + mock_tokenizer.pad_token_id = 100 + mock_tokenizer.apply_chat_template.return_value = "test prompt" + mock_pipeline.tokenizer = mock_tokenizer + generator.pipeline = mock_pipeline + + messages = [ChatMessage.from_user("What's the weather in Paris?")] + results = generator.run(messages=messages, tools=tools) + + assert len(results["replies"]) == 1 + message = results["replies"][0] + assert message.tool_calls + tool_call = message.tool_calls[0] + assert isinstance(tool_call, ToolCall) + assert tool_call.tool_name == "weather" + assert tool_call.arguments == {"city": "Paris"} + assert message.meta["finish_reason"] == "tool_calls" + + def test_run_with_tools_and_tool_response(self, model_info_mock, tools): + generator = HuggingFaceLocalChatGenerator(model="meta-llama/Llama-2-13b-chat-hf") + + # Mock pipeline and tokenizer + mock_pipeline = Mock(return_value=[{"generated_text": "The weather in Paris is 22°C"}]) + mock_tokenizer = Mock(spec=PreTrainedTokenizer) + mock_tokenizer.encode.return_value = ["some", "tokens"] + mock_tokenizer.pad_token_id = 100 + mock_tokenizer.apply_chat_template.return_value = "test prompt" + mock_pipeline.tokenizer = mock_tokenizer + generator.pipeline = mock_pipeline + + tool_call = ToolCall(tool_name="weather", arguments={"city": "Paris"}) + messages = [ + ChatMessage.from_user("What's the weather in Paris?"), + ChatMessage.from_assistant(tool_calls=[tool_call]), + ChatMessage.from_tool(tool_result="22°C", origin=tool_call), + ] + results = generator.run(messages=messages) + + assert len(results["replies"]) == 1 + message = results["replies"][0] + assert not message.tool_calls # No tool calls in the final response + assert "22°C" in message.text + assert message.meta["finish_reason"] == "stop" + + def test_run_with_invalid_tool_pattern(self, model_info_mock, tools): + generator = HuggingFaceLocalChatGenerator( + model="meta-llama/Llama-2-13b-chat-hf", + tools=tools, + tool_pattern=r"invalid[pattern", # Invalid regex pattern + ) + + # Mock pipeline and tokenizer + mock_pipeline = Mock(return_value=[{"generated_text": '{"name": "weather", "arguments": {"city": "Paris"}}'}]) + mock_tokenizer = Mock(spec=PreTrainedTokenizer) + mock_tokenizer.encode.return_value = ["some", "tokens"] + mock_tokenizer.pad_token_id = 100 + mock_tokenizer.apply_chat_template.return_value = "test prompt" + mock_pipeline.tokenizer = mock_tokenizer + generator.pipeline = mock_pipeline + + messages = [ChatMessage.from_user("What's the weather in Paris?")] + results = generator.run(messages=messages) + + assert len(results["replies"]) == 1 + message = results["replies"][0] + assert not message.tool_calls # No tool calls due to invalid pattern + assert message.meta["finish_reason"] == "stop" From 89d8b765bff1b02d5b35b3cc5df7e7b2464bf6a9 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Fri, 7 Feb 2025 10:51:20 +0100 Subject: [PATCH 6/8] Add tools serde and tests --- .../generators/chat/hugging_face_local.py | 5 ++- .../chat/test_hugging_face_local.py | 41 +++++++++++++++---- 2 files changed, 38 insertions(+), 8 deletions(-) diff --git a/haystack/components/generators/chat/hugging_face_local.py b/haystack/components/generators/chat/hugging_face_local.py index 8ba6c9a982..de48d2241b 100644 --- a/haystack/components/generators/chat/hugging_face_local.py +++ b/haystack/components/generators/chat/hugging_face_local.py @@ -10,7 +10,7 @@ from haystack import component, default_from_dict, default_to_dict, logging from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall from haystack.lazy_imports import LazyImport -from haystack.tools import Tool, _check_duplicate_tool_names +from haystack.tools import Tool, _check_duplicate_tool_names, deserialize_tools_inplace from haystack.utils import ( ComponentDevice, Secret, @@ -219,6 +219,7 @@ def to_dict(self) -> Dict[str, Any]: Dictionary with serialized data. """ callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None + serialized_tools = [tool.to_dict() for tool in self.tools] if self.tools else None serialization_dict = default_to_dict( self, huggingface_pipeline_kwargs=self.huggingface_pipeline_kwargs, @@ -226,6 +227,7 @@ def to_dict(self) -> Dict[str, Any]: streaming_callback=callback_name, token=self.token.to_dict() if self.token else None, chat_template=self.chat_template, + tools=serialized_tools, ) huggingface_pipeline_kwargs = serialization_dict["init_parameters"]["huggingface_pipeline_kwargs"] @@ -246,6 +248,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceLocalChatGenerator": """ torch_and_transformers_import.check() # leave this, cls method deserialize_secrets_inplace(data["init_parameters"], keys=["token"]) + deserialize_tools_inplace(data["init_parameters"], key="tools") init_params = data.get("init_parameters", {}) serialized_callback_handler = init_params.get("streaming_callback") if serialized_callback_handler: diff --git a/test/components/generators/chat/test_hugging_face_local.py b/test/components/generators/chat/test_hugging_face_local.py index 0285cb3667..358cfe95bc 100644 --- a/test/components/generators/chat/test_hugging_face_local.py +++ b/test/components/generators/chat/test_hugging_face_local.py @@ -19,6 +19,11 @@ def streaming_callback_handler(x): return x +def get_weather(city: str) -> str: + """Get the weather for a given city.""" + return f"Weather data for {city}" + + @pytest.fixture def chat_messages(): return [ @@ -57,8 +62,9 @@ def tools(): name="weather", description="useful to determine the weather in a given location", parameters=tool_parameters, - function=lambda x: x, + function=get_weather, ) + return [tool] @@ -151,14 +157,15 @@ def test_init_invalid_task(self): with pytest.raises(ValueError, match="is not supported."): HuggingFaceLocalChatGenerator(task="text-classification") - def test_to_dict(self, model_info_mock): + def test_to_dict(self, model_info_mock, tools): generator = HuggingFaceLocalChatGenerator( model="NousResearch/Llama-2-7b-chat-hf", token=Secret.from_env_var("ENV_VAR", strict=False), generation_kwargs={"n": 5}, stop_words=["stop", "words"], - streaming_callback=streaming_callback_handler, + streaming_callback=None, chat_template="irrelevant", + tools=tools, ) # Call the to_dict method @@ -170,16 +177,28 @@ def test_to_dict(self, model_info_mock): assert init_params["huggingface_pipeline_kwargs"]["model"] == "NousResearch/Llama-2-7b-chat-hf" assert "token" not in init_params["huggingface_pipeline_kwargs"] assert init_params["generation_kwargs"] == {"max_new_tokens": 512, "n": 5, "stop_sequences": ["stop", "words"]} - assert init_params["streaming_callback"] == "chat.test_hugging_face_local.streaming_callback_handler" + assert init_params["streaming_callback"] is None assert init_params["chat_template"] == "irrelevant" + assert init_params["tools"] == [ + { + "type": "haystack.tools.tool.Tool", + "data": { + "name": "weather", + "description": "useful to determine the weather in a given location", + "parameters": {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}, + "function": "chat.test_hugging_face_local.get_weather", + }, + } + ] - def test_from_dict(self, model_info_mock): + def test_from_dict(self, model_info_mock, tools): generator = HuggingFaceLocalChatGenerator( model="NousResearch/Llama-2-7b-chat-hf", generation_kwargs={"n": 5}, stop_words=["stop", "words"], - streaming_callback=streaming_callback_handler, + streaming_callback=None, chat_template="irrelevant", + tools=tools, ) # Call the to_dict method result = generator.to_dict() @@ -188,8 +207,16 @@ def test_from_dict(self, model_info_mock): assert generator_2.token == Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False) assert generator_2.generation_kwargs == {"max_new_tokens": 512, "n": 5, "stop_sequences": ["stop", "words"]} - assert generator_2.streaming_callback is streaming_callback_handler + assert generator_2.streaming_callback is None assert generator_2.chat_template == "irrelevant" + assert len(generator_2.tools) == 1 + assert generator_2.tools[0].name == "weather" + assert generator_2.tools[0].description == "useful to determine the weather in a given location" + assert generator_2.tools[0].parameters == { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + } @patch("haystack.components.generators.chat.hugging_face_local.pipeline") def test_warm_up(self, pipeline_mock, monkeypatch): From de50d2fafe3582136c6b79525db0d467719f62dd Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Fri, 7 Feb 2025 13:52:28 +0100 Subject: [PATCH 7/8] PR feedback --- .../generators/chat/hugging_face_local.py | 84 +++++++++---------- .../chat/test_hugging_face_local.py | 50 +++++++---- 2 files changed, 76 insertions(+), 58 deletions(-) diff --git a/haystack/components/generators/chat/hugging_face_local.py b/haystack/components/generators/chat/hugging_face_local.py index de48d2241b..f54c50181c 100644 --- a/haystack/components/generators/chat/hugging_face_local.py +++ b/haystack/components/generators/chat/hugging_face_local.py @@ -43,6 +43,35 @@ ) +def default_tool_parser(text: str) -> Optional[List[ToolCall]]: + """ + Default implementation for parsing tool calls from model output text. + + Uses DEFAULT_TOOL_PATTERN to extract tool calls. + + :param text: The text to parse for tool calls. + :returns: A list containing a single ToolCall if a valid tool call is found, None otherwise. + """ + try: + match = re.search(DEFAULT_TOOL_PATTERN, text, re.DOTALL) + except re.error: + logger.warning("Invalid regex pattern for tool parsing: {pattern}", pattern=DEFAULT_TOOL_PATTERN) + return None + + if not match: + return None + + name = match.group(1) or match.group(3) + args_str = match.group(2) or match.group(4) + + try: + arguments = json.loads(args_str) + return [ToolCall(tool_name=name, arguments=arguments)] + except json.JSONDecodeError: + logger.warning("Failed to parse tool call arguments: {args_str}", args_str=args_str) + return None + + @component class HuggingFaceLocalChatGenerator: """ @@ -93,7 +122,7 @@ def __init__( # pylint: disable=too-many-positional-arguments stop_words: Optional[List[str]] = None, streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, tools: Optional[List[Tool]] = None, - tool_pattern: Optional[Union[str, Callable[[str], Optional[List[ToolCall]]]]] = None, + tool_parsing_function: Optional[Callable[[str], Optional[List[ToolCall]]]] = None, ): """ Initializes the HuggingFaceLocalChatGenerator component. @@ -133,11 +162,9 @@ def __init__( # pylint: disable=too-many-positional-arguments In these cases, make sure your prompt has no stop words. :param streaming_callback: An optional callable for handling streaming responses. :param tools: A list of tools for which the model can prepare calls. - :param tool_pattern: - A pattern or callable to parse tool calls from model output. - If a string, it will be used as a regex pattern to extract ToolCall object. - If a callable, it should take a string and return a ToolCall object or None. - If None, a default pattern will be used. + :param tool_parsing_function: + A callable that takes a string and returns a list of ToolCall objects or None. + If None, the default_tool_parser will be used which extracts tool calls using a predefined pattern. """ torch_and_transformers_import.check() @@ -188,7 +215,7 @@ def __init__( # pylint: disable=too-many-positional-arguments generation_kwargs["stop_sequences"] = generation_kwargs.get("stop_sequences", []) generation_kwargs["stop_sequences"].extend(stop_words or []) - self.tool_pattern = tool_pattern or DEFAULT_TOOL_PATTERN + self.tool_parsing_function = tool_parsing_function or default_tool_parser self.huggingface_pipeline_kwargs = huggingface_pipeline_kwargs self.generation_kwargs = generation_kwargs self.chat_template = chat_template @@ -228,6 +255,7 @@ def to_dict(self) -> Dict[str, Any]: token=self.token.to_dict() if self.token else None, chat_template=self.chat_template, tools=serialized_tools, + tool_parsing_function=serialize_callable(self.tool_parsing_function), ) huggingface_pipeline_kwargs = serialization_dict["init_parameters"]["huggingface_pipeline_kwargs"] @@ -254,6 +282,10 @@ def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceLocalChatGenerator": if serialized_callback_handler: data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler) + tool_parsing_function = init_params.get("tool_parsing_function") + if tool_parsing_function: + init_params["tool_parsing_function"] = deserialize_callable(tool_parsing_function) + huggingface_pipeline_kwargs = init_params.get("huggingface_pipeline_kwargs", {}) deserialize_hf_model_kwargs(huggingface_pipeline_kwargs) return default_from_dict(cls, data) @@ -371,7 +403,7 @@ def create_message( # pylint: disable=too-many-positional-arguments prompt_token_count = len(tokenizer.encode(prompt, add_special_tokens=False)) total_tokens = prompt_token_count + completion_tokens - tool_calls = self._parse_tool_call(text) if parse_tool_calls else None + tool_calls = self.tool_parsing_function(text) if parse_tool_calls else None # Determine finish reason based on context if completion_tokens >= generation_kwargs.get("max_new_tokens", sys.maxsize): @@ -392,7 +424,8 @@ def create_message( # pylint: disable=too-many-positional-arguments }, } - return ChatMessage.from_assistant(tool_calls=tool_calls, text=text, meta=meta) + # If tool calls are detected, don't include the text content since it contains the raw tool call format + return ChatMessage.from_assistant(tool_calls=tool_calls, text=None if tool_calls else text, meta=meta) def _validate_stop_words(self, stop_words: Optional[List[str]]) -> Optional[List[str]]: """ @@ -410,36 +443,3 @@ def _validate_stop_words(self, stop_words: Optional[List[str]]) -> Optional[List return None return list(set(stop_words or [])) - - def _parse_tool_call(self, text: str) -> Optional[List[ToolCall]]: - """ - Parse a tool call from model output text. - - :param text: The text to parse for tool calls. - :returns: A ToolCall object if a valid tool call is found, None otherwise. - """ - # if the tool pattern is a callable, call it with the text and return the result - if callable(self.tool_pattern): - return self.tool_pattern(text) - - # if the tool pattern is a regex pattern, search for it in the text - try: - match = re.search(self.tool_pattern, text, re.DOTALL) - except re.error: - logger.warning("Invalid regex pattern for tool parsing: {pattern}", pattern=self.tool_pattern) - return None - - if not match: - return None - - # seem like most models are not producing tool ids, so we omit them - # and just use the tool name and arguments - name = match.group(1) or match.group(3) - args_str = match.group(2) or match.group(4) - - try: - arguments = json.loads(args_str) - return [ToolCall(tool_name=name, arguments=arguments)] - except json.JSONDecodeError: - logger.warning("Failed to parse tool call arguments: {args_str}", args_str=args_str) - return None diff --git a/test/components/generators/chat/test_hugging_face_local.py b/test/components/generators/chat/test_hugging_face_local.py index 358cfe95bc..226afca37c 100644 --- a/test/components/generators/chat/test_hugging_face_local.py +++ b/test/components/generators/chat/test_hugging_face_local.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 from unittest.mock import Mock, patch +from typing import Optional, List from haystack.dataclasses.streaming_chunk import StreamingChunk import pytest @@ -68,6 +69,11 @@ def tools(): return [tool] +def custom_tool_parser(text: str) -> Optional[List[ToolCall]]: + """Test implementation of a custom tool parser.""" + return [ToolCall(tool_name="weather", arguments={"city": "Berlin"})] + + class TestHuggingFaceLocalChatGenerator: def test_initialize_with_valid_model_and_generation_parameters(self, model_info_mock): model = "HuggingFaceH4/zephyr-7b-alpha" @@ -433,26 +439,38 @@ def test_run_with_tools_and_tool_response(self, model_info_mock, tools): assert "22°C" in message.text assert message.meta["finish_reason"] == "stop" - def test_run_with_invalid_tool_pattern(self, model_info_mock, tools): + def test_run_with_custom_tool_parser(self, model_info_mock, tools): + """Test that a custom tool parsing function works correctly.""" generator = HuggingFaceLocalChatGenerator( - model="meta-llama/Llama-2-13b-chat-hf", - tools=tools, - tool_pattern=r"invalid[pattern", # Invalid regex pattern + model="meta-llama/Llama-2-13b-chat-hf", tools=tools, tool_parsing_function=custom_tool_parser ) + generator.pipeline = Mock(return_value=[{"generated_text": "Let me check the weather for you"}]) + generator.pipeline.tokenizer = Mock() + generator.pipeline.tokenizer.encode.return_value = [1, 2, 3] + generator.pipeline.tokenizer.pad_token_id = 1 - # Mock pipeline and tokenizer - mock_pipeline = Mock(return_value=[{"generated_text": '{"name": "weather", "arguments": {"city": "Paris"}}'}]) - mock_tokenizer = Mock(spec=PreTrainedTokenizer) - mock_tokenizer.encode.return_value = ["some", "tokens"] - mock_tokenizer.pad_token_id = 100 - mock_tokenizer.apply_chat_template.return_value = "test prompt" - mock_pipeline.tokenizer = mock_tokenizer - generator.pipeline = mock_pipeline + messages = [ChatMessage.from_user("What's the weather like in Berlin?")] + results = generator.run(messages=messages) - messages = [ChatMessage.from_user("What's the weather in Paris?")] + assert len(results["replies"]) == 1 + assert len(results["replies"][0].tool_calls) == 1 + assert results["replies"][0].tool_calls[0].tool_name == "weather" + assert results["replies"][0].tool_calls[0].arguments == {"city": "Berlin"} + + def test_default_tool_parser(self, model_info_mock, tools): + """Test that the default tool parser works correctly with valid tool call format.""" + generator = HuggingFaceLocalChatGenerator(model="meta-llama/Llama-2-13b-chat-hf", tools=tools) + generator.pipeline = Mock( + return_value=[{"generated_text": '{"name": "weather", "arguments": {"city": "Berlin"}}'}] + ) + generator.pipeline.tokenizer = Mock() + generator.pipeline.tokenizer.encode.return_value = [1, 2, 3] + generator.pipeline.tokenizer.pad_token_id = 1 + + messages = [ChatMessage.from_user("What's the weather like in Berlin?")] results = generator.run(messages=messages) assert len(results["replies"]) == 1 - message = results["replies"][0] - assert not message.tool_calls # No tool calls due to invalid pattern - assert message.meta["finish_reason"] == "stop" + assert len(results["replies"][0].tool_calls) == 1 + assert results["replies"][0].tool_calls[0].tool_name == "weather" + assert results["replies"][0].tool_calls[0].arguments == {"city": "Berlin"} From be712f83295944a4529de549e87345072fd38d2d Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Mon, 10 Feb 2025 09:27:18 +0100 Subject: [PATCH 8/8] PR feedback --- test/components/generators/chat/test_hugging_face_local.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/components/generators/chat/test_hugging_face_local.py b/test/components/generators/chat/test_hugging_face_local.py index 226afca37c..38d25ec91a 100644 --- a/test/components/generators/chat/test_hugging_face_local.py +++ b/test/components/generators/chat/test_hugging_face_local.py @@ -444,7 +444,7 @@ def test_run_with_custom_tool_parser(self, model_info_mock, tools): generator = HuggingFaceLocalChatGenerator( model="meta-llama/Llama-2-13b-chat-hf", tools=tools, tool_parsing_function=custom_tool_parser ) - generator.pipeline = Mock(return_value=[{"generated_text": "Let me check the weather for you"}]) + generator.pipeline = Mock(return_value=[{"mocked_response": "Mocked response, we don't use it"}]) generator.pipeline.tokenizer = Mock() generator.pipeline.tokenizer.encode.return_value = [1, 2, 3] generator.pipeline.tokenizer.pad_token_id = 1