Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: HuggingFaceLocalChatGenerator unified support for tools #8827

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 93 additions & 15 deletions haystack/components/generators/chat/hugging_face_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@
#
# SPDX-License-Identifier: Apache-2.0

import json
import re
import sys
from typing import Any, Callable, Dict, List, Literal, Optional, Union

from haystack import component, default_from_dict, default_to_dict, logging
from haystack.dataclasses import ChatMessage, StreamingChunk
from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall
from haystack.lazy_imports import LazyImport
from haystack.tools import Tool, _check_duplicate_tool_names, deserialize_tools_inplace
from haystack.utils import (
ComponentDevice,
Secret,
Expand All @@ -33,6 +36,41 @@

PIPELINE_SUPPORTED_TASKS = ["text-generation", "text2text-generation"]

DEFAULT_TOOL_PATTERN = (
r"(?:<tool_call>)?"
r'(?:\s*\{.*?"name"\s*:\s*"([^"]+)".*?"arguments"\s*:\s*(\{[^}]+\}).*?\}'
r'|\{.*?"function"\s*:\s*\{.*?"name"\s*:\s*"([^"]+)".*?"arguments"\s*:\s*(\{[^}]+\}).*?\})'
)


def default_tool_parser(text: str) -> Optional[List[ToolCall]]:
"""
Default implementation for parsing tool calls from model output text.

Uses DEFAULT_TOOL_PATTERN to extract tool calls.

:param text: The text to parse for tool calls.
:returns: A list containing a single ToolCall if a valid tool call is found, None otherwise.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we are only extracting a single tool call, but this is clearly explained.
We can improve this in future if there are requests from the community.

"""
try:
match = re.search(DEFAULT_TOOL_PATTERN, text, re.DOTALL)
except re.error:
logger.warning("Invalid regex pattern for tool parsing: {pattern}", pattern=DEFAULT_TOOL_PATTERN)
return None

if not match:
return None

name = match.group(1) or match.group(3)
args_str = match.group(2) or match.group(4)

try:
arguments = json.loads(args_str)
return [ToolCall(tool_name=name, arguments=arguments)]
except json.JSONDecodeError:
logger.warning("Failed to parse tool call arguments: {args_str}", args_str=args_str)
return None


@component
class HuggingFaceLocalChatGenerator:
Expand Down Expand Up @@ -83,6 +121,8 @@ def __init__( # pylint: disable=too-many-positional-arguments
huggingface_pipeline_kwargs: Optional[Dict[str, Any]] = None,
stop_words: Optional[List[str]] = None,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
tools: Optional[List[Tool]] = None,
tool_parsing_function: Optional[Callable[[str], Optional[List[ToolCall]]]] = None,
):
"""
Initializes the HuggingFaceLocalChatGenerator component.
Expand Down Expand Up @@ -121,9 +161,17 @@ def __init__( # pylint: disable=too-many-positional-arguments
For some chat models, the output includes both the new text and the original prompt.
In these cases, make sure your prompt has no stop words.
:param streaming_callback: An optional callable for handling streaming responses.
:param tools: A list of tools for which the model can prepare calls.
:param tool_parsing_function:
A callable that takes a string and returns a list of ToolCall objects or None.
If None, the default_tool_parser will be used which extracts tool calls using a predefined pattern.
"""
torch_and_transformers_import.check()

if tools and streaming_callback is not None:
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
_check_duplicate_tool_names(tools)

huggingface_pipeline_kwargs = huggingface_pipeline_kwargs or {}
generation_kwargs = generation_kwargs or {}

Expand Down Expand Up @@ -167,11 +215,13 @@ def __init__( # pylint: disable=too-many-positional-arguments
generation_kwargs["stop_sequences"] = generation_kwargs.get("stop_sequences", [])
generation_kwargs["stop_sequences"].extend(stop_words or [])

self.tool_parsing_function = tool_parsing_function or default_tool_parser
self.huggingface_pipeline_kwargs = huggingface_pipeline_kwargs
self.generation_kwargs = generation_kwargs
self.chat_template = chat_template
self.streaming_callback = streaming_callback
self.pipeline = None
self.tools = tools

def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Expand All @@ -196,13 +246,16 @@ def to_dict(self) -> Dict[str, Any]:
Dictionary with serialized data.
"""
callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
serialized_tools = [tool.to_dict() for tool in self.tools] if self.tools else None
serialization_dict = default_to_dict(
self,
huggingface_pipeline_kwargs=self.huggingface_pipeline_kwargs,
generation_kwargs=self.generation_kwargs,
streaming_callback=callback_name,
token=self.token.to_dict() if self.token else None,
chat_template=self.chat_template,
tools=serialized_tools,
anakin87 marked this conversation as resolved.
Show resolved Hide resolved
tool_parsing_function=serialize_callable(self.tool_parsing_function),
)

huggingface_pipeline_kwargs = serialization_dict["init_parameters"]["huggingface_pipeline_kwargs"]
Expand All @@ -223,11 +276,16 @@ def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceLocalChatGenerator":
"""
torch_and_transformers_import.check() # leave this, cls method
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
deserialize_tools_inplace(data["init_parameters"], key="tools")
init_params = data.get("init_parameters", {})
serialized_callback_handler = init_params.get("streaming_callback")
if serialized_callback_handler:
data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler)

tool_parsing_function = init_params.get("tool_parsing_function")
if tool_parsing_function:
init_params["tool_parsing_function"] = deserialize_callable(tool_parsing_function)

huggingface_pipeline_kwargs = init_params.get("huggingface_pipeline_kwargs", {})
deserialize_hf_model_kwargs(huggingface_pipeline_kwargs)
return default_from_dict(cls, data)
Expand All @@ -238,19 +296,28 @@ def run(
messages: List[ChatMessage],
generation_kwargs: Optional[Dict[str, Any]] = None,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
tools: Optional[List[Tool]] = None,
):
"""
Invoke text generation inference based on the provided messages and generation parameters.

:param messages: A list of ChatMessage objects representing the input messages.
:param generation_kwargs: Additional keyword arguments for text generation.
:param streaming_callback: An optional callable for handling streaming responses.
:param tools:
A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter
provided during initialization.
:returns:
A list containing the generated responses as ChatMessage instances.
"""
if self.pipeline is None:
raise RuntimeError("The generation model has not been loaded. Please call warm_up() before running.")

tools = tools or self.tools
if tools and streaming_callback is not None:
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
_check_duplicate_tool_names(tools)

tokenizer = self.pipeline.tokenizer

# Check and update generation parameters
Expand Down Expand Up @@ -279,11 +346,14 @@ def run(
# streamer parameter hooks into HF streaming, HFTokenStreamingHandler is an adapter to our streaming
generation_kwargs["streamer"] = HFTokenStreamingHandler(tokenizer, streaming_callback, stop_words)

# convert messages to HF format
hf_messages = [convert_message_to_hf_format(message) for message in messages]

# Prepare the prompt for the model
prepared_prompt = tokenizer.apply_chat_template(
hf_messages, tokenize=False, chat_template=self.chat_template, add_generation_prompt=True
hf_messages,
tokenize=False,
chat_template=self.chat_template,
add_generation_prompt=True,
tools=[tc.tool_spec for tc in tools] if tools else None,
)

# Avoid some unnecessary warnings in the generation pipeline call
Expand All @@ -299,11 +369,13 @@ def run(
for stop_word in stop_words:
replies = [reply.replace(stop_word, "").rstrip() for reply in replies]

# Create ChatMessage instances for each reply
chat_messages = [
self.create_message(reply, r_index, tokenizer, prepared_prompt, generation_kwargs)
self.create_message(
reply, r_index, tokenizer, prepared_prompt, generation_kwargs, parse_tool_calls=bool(tools)
)
for r_index, reply in enumerate(replies)
]

return {"replies": chat_messages}

def create_message( # pylint: disable=too-many-positional-arguments
Expand All @@ -313,6 +385,7 @@ def create_message( # pylint: disable=too-many-positional-arguments
tokenizer: Union["PreTrainedTokenizer", "PreTrainedTokenizerFast"],
prompt: str,
generation_kwargs: Dict[str, Any],
parse_tool_calls: bool = False,
) -> ChatMessage:
"""
Create a ChatMessage instance from the provided text, populated with metadata.
Expand All @@ -322,17 +395,23 @@ def create_message( # pylint: disable=too-many-positional-arguments
:param tokenizer: The tokenizer used for generation.
:param prompt: The prompt used for generation.
:param generation_kwargs: The generation parameters.
:param parse_tool_calls: Whether to attempt parsing tool calls from the text.
:returns: A ChatMessage instance.
"""

completion_tokens = len(tokenizer.encode(text, add_special_tokens=False))
prompt_token_count = len(tokenizer.encode(prompt, add_special_tokens=False))
total_tokens = prompt_token_count + completion_tokens

# not the most sophisticated finish_reason detection, improve later to match
# https://platform.openai.com/docs/guides/text-generation/chat-completions-response-format
finish_reason = (
"length" if completion_tokens >= generation_kwargs.get("max_new_tokens", sys.maxsize) else "stop"
)
tool_calls = self.tool_parsing_function(text) if parse_tool_calls else None

# Determine finish reason based on context
if completion_tokens >= generation_kwargs.get("max_new_tokens", sys.maxsize):
finish_reason = "length"
elif tool_calls:
finish_reason = "tool_calls"
else:
finish_reason = "stop"

meta = {
"finish_reason": finish_reason,
Expand All @@ -345,7 +424,8 @@ def create_message( # pylint: disable=too-many-positional-arguments
},
}

return ChatMessage.from_assistant(text, meta=meta)
# If tool calls are detected, don't include the text content since it contains the raw tool call format
return ChatMessage.from_assistant(tool_calls=tool_calls, text=None if tool_calls else text, meta=meta)

def _validate_stop_words(self, stop_words: Optional[List[str]]) -> Optional[List[str]]:
"""
Expand All @@ -362,6 +442,4 @@ def _validate_stop_words(self, stop_words: Optional[List[str]]) -> Optional[List
)
return None

# deduplicate stop words
stop_words = list(set(stop_words or []))
return stop_words
return list(set(stop_words or []))
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
features:
- |
Adds tooling support to HuggingFaceLocalChatGenerator
Loading