Skip to content

Commit 3c01c98

Browse files
revamp
1 parent bf7c359 commit 3c01c98

File tree

1 file changed

+97
-15
lines changed
  • llama-index-integrations/llms/llama-index-llms-huggingface-api/llama_index/llms/huggingface_api

1 file changed

+97
-15
lines changed

llama-index-integrations/llms/llama-index-llms-huggingface-api/llama_index/llms/huggingface_api/base.py

+97-15
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
import json
12
import logging
23
from huggingface_hub import AsyncInferenceClient, InferenceClient, model_info
34
from huggingface_hub.hf_api import ModelInfo
45
from huggingface_hub.inference._generated.types import (
56
ChatCompletionOutput,
67
ChatCompletionStreamOutput,
78
ChatCompletionOutputToolCall,
9+
ChatCompletionOutputFunctionDefinition,
810
)
911
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
1012

@@ -136,20 +138,12 @@ def class_name(cls) -> str:
136138
gt=0.0,
137139
)
138140
is_chat_model: bool = Field(
139-
default=False,
140-
description=(
141-
LLMMetadata.model_fields["is_chat_model"].description
142-
+ " Unless chat templating is intentionally applied, Hugging Face models"
143-
" are not chat models."
144-
),
141+
default=True,
142+
description="Controls whether the chat or text generation methods are used.",
145143
)
146144
is_function_calling_model: bool = Field(
147145
default=False,
148-
description=(
149-
LLMMetadata.model_fields["is_function_calling_model"].description
150-
+ " As of 10/17/2023, Hugging Face doesn't support function calling"
151-
" messages."
152-
),
146+
description="Controls whether the function calling methods are used.",
153147
)
154148

155149
def __init__(self, **kwargs: Any) -> None:
@@ -169,10 +163,21 @@ def __init__(self, **kwargs: Any) -> None:
169163
else:
170164
task = kwargs["task"].lower()
171165

166+
if kwargs.get("is_function_calling_model", False):
167+
print(
168+
"Function calling is currently not supported for Hugging Face Inference API, setting is_function_calling_model to False"
169+
)
170+
kwargs["is_function_calling_model"] = False
171+
172172
super().__init__(**kwargs) # Populate pydantic Fields
173173
self._sync_client = InferenceClient(**self._get_inference_client_kwargs())
174174
self._async_client = AsyncInferenceClient(**self._get_inference_client_kwargs())
175175

176+
# set context window if not provided
177+
info = self._sync_client.get_endpoint_info()
178+
if "max_input_tokens" in info and kwargs.get("context_window") is None:
179+
self.context_window = info["max_input_tokens"]
180+
176181
def _get_inference_client_kwargs(self) -> Dict[str, Any]:
177182
"""Extract the Hugging Face InferenceClient construction parameters."""
178183
return {
@@ -194,7 +199,53 @@ def _get_model_kwargs(self, **kwargs: Any) -> Dict[str, Any]:
194199
def _to_huggingface_messages(
195200
self, messages: Sequence[ChatMessage]
196201
) -> List[Dict[str, Any]]:
197-
return [{"role": m.role.value, "content": m.content} for m in messages]
202+
hf_dicts = []
203+
for m in messages:
204+
hf_dicts.append(
205+
{"role": m.role.value, "content": m.content if m.content else ""}
206+
)
207+
if m.additional_kwargs.get("tool_calls", []):
208+
tool_call_dicts = []
209+
for tool_call in m.additional_kwargs["tool_calls"]:
210+
function_dict = {
211+
"name": tool_call.id,
212+
"arguments": tool_call.function.arguments,
213+
}
214+
tool_call_dicts.append(
215+
{"type": "function", "function": function_dict}
216+
)
217+
218+
hf_dicts[-1]["tool_calls"] = tool_call_dicts
219+
220+
if m.role == MessageRole.TOOL:
221+
hf_dicts[-1]["name"] = m.additional_kwargs.get("tool_call_id")
222+
223+
return hf_dicts
224+
225+
def _parse_streaming_tool_calls(
226+
self, tool_call_strs: List[str]
227+
) -> List[ToolSelection | str]:
228+
tool_calls = []
229+
# Try to parse into complete objects, otherwise keep as strings
230+
for tool_call_str in tool_call_strs:
231+
try:
232+
tool_call_dict = json.loads(tool_call_str)
233+
args = tool_call_dict["function"]
234+
name = args.pop("_name")
235+
tool_calls.append(
236+
ChatCompletionOutputToolCall(
237+
id=name,
238+
type="function",
239+
function=ChatCompletionOutputFunctionDefinition(
240+
arguments=args,
241+
name=name,
242+
),
243+
)
244+
)
245+
except Exception as e:
246+
tool_calls.append(tool_call_str)
247+
248+
return tool_calls
198249

199250
def get_model_info(self, **kwargs: Any) -> "ModelInfo":
200251
"""Get metadata on the current model from Hugging Face."""
@@ -260,6 +311,8 @@ def stream_chat(
260311

261312
def gen() -> ChatResponseGen:
262313
response = ""
314+
tool_call_strs = []
315+
cur_index = -1
263316
for chunk in self._sync_client.chat_completion(
264317
messages=self._to_huggingface_messages(messages),
265318
stream=True,
@@ -269,9 +322,18 @@ def gen() -> ChatResponseGen:
269322

270323
delta = chunk.choices[0].delta.content or ""
271324
response += delta
272-
tool_calls = chunk.choices[0].delta.tool_calls or []
325+
tool_call_delta = chunk.choices[0].delta.tool_calls
326+
if tool_call_delta:
327+
if tool_call_delta.index != cur_index:
328+
cur_index = tool_call_delta.index
329+
tool_call_strs.append(tool_call_delta.function.arguments)
330+
else:
331+
tool_call_strs[
332+
cur_index
333+
] += tool_call_delta.function.arguments
334+
335+
tool_calls = self._parse_streaming_tool_calls(tool_call_strs)
273336
additional_kwargs = {"tool_calls": tool_calls} if tool_calls else {}
274-
275337
yield ChatResponse(
276338
message=ChatMessage(
277339
role=MessageRole.ASSISTANT,
@@ -359,16 +421,32 @@ async def astream_chat(
359421

360422
async def gen() -> ChatResponseAsyncGen:
361423
response = ""
424+
tool_call_strs = []
425+
cur_index = -1
362426
async for chunk in await self._async_client.chat_completion(
363427
messages=self._to_huggingface_messages(messages),
364428
stream=True,
365429
**model_kwargs,
366430
):
431+
if chunk.choices[0].finish_reason is not None:
432+
break
433+
367434
chunk: ChatCompletionStreamOutput = chunk
368435

369436
delta = chunk.choices[0].delta.content or ""
370437
response += delta
371-
tool_calls = chunk.choices[0].delta.tool_calls or []
438+
tool_call_delta = chunk.choices[0].delta.tool_calls
439+
if tool_call_delta:
440+
if tool_call_delta.index != cur_index:
441+
cur_index = tool_call_delta.index
442+
tool_call_strs.append(tool_call_delta.function.arguments)
443+
else:
444+
tool_call_strs[
445+
cur_index
446+
] += tool_call_delta.function.arguments
447+
448+
tool_calls = self._parse_streaming_tool_calls(tool_call_strs)
449+
372450
additional_kwargs = {"tool_calls": tool_calls} if tool_calls else {}
373451

374452
yield ChatResponse(
@@ -474,6 +552,10 @@ def get_tool_calls_from_response(
474552

475553
tool_selections = []
476554
for tool_call in tool_calls:
555+
# while streaming, tool_call is a string
556+
if isinstance(tool_call, str):
557+
continue
558+
477559
tool_selections.append(
478560
ToolSelection(
479561
tool_id=tool_call.id,

0 commit comments

Comments
 (0)