1
+ import json
1
2
import logging
2
3
from huggingface_hub import AsyncInferenceClient , InferenceClient , model_info
3
4
from huggingface_hub .hf_api import ModelInfo
4
5
from huggingface_hub .inference ._generated .types import (
5
6
ChatCompletionOutput ,
6
7
ChatCompletionStreamOutput ,
7
8
ChatCompletionOutputToolCall ,
9
+ ChatCompletionOutputFunctionDefinition ,
8
10
)
9
11
from typing import Any , Callable , Dict , List , Optional , Sequence , Union
10
12
@@ -136,20 +138,12 @@ def class_name(cls) -> str:
136
138
gt = 0.0 ,
137
139
)
138
140
is_chat_model : bool = Field (
139
- default = False ,
140
- description = (
141
- LLMMetadata .model_fields ["is_chat_model" ].description
142
- + " Unless chat templating is intentionally applied, Hugging Face models"
143
- " are not chat models."
144
- ),
141
+ default = True ,
142
+ description = "Controls whether the chat or text generation methods are used." ,
145
143
)
146
144
is_function_calling_model : bool = Field (
147
145
default = False ,
148
- description = (
149
- LLMMetadata .model_fields ["is_function_calling_model" ].description
150
- + " As of 10/17/2023, Hugging Face doesn't support function calling"
151
- " messages."
152
- ),
146
+ description = "Controls whether the function calling methods are used." ,
153
147
)
154
148
155
149
def __init__ (self , ** kwargs : Any ) -> None :
@@ -169,10 +163,21 @@ def __init__(self, **kwargs: Any) -> None:
169
163
else :
170
164
task = kwargs ["task" ].lower ()
171
165
166
+ if kwargs .get ("is_function_calling_model" , False ):
167
+ print (
168
+ "Function calling is currently not supported for Hugging Face Inference API, setting is_function_calling_model to False"
169
+ )
170
+ kwargs ["is_function_calling_model" ] = False
171
+
172
172
super ().__init__ (** kwargs ) # Populate pydantic Fields
173
173
self ._sync_client = InferenceClient (** self ._get_inference_client_kwargs ())
174
174
self ._async_client = AsyncInferenceClient (** self ._get_inference_client_kwargs ())
175
175
176
+ # set context window if not provided
177
+ info = self ._sync_client .get_endpoint_info ()
178
+ if "max_input_tokens" in info and kwargs .get ("context_window" ) is None :
179
+ self .context_window = info ["max_input_tokens" ]
180
+
176
181
def _get_inference_client_kwargs (self ) -> Dict [str , Any ]:
177
182
"""Extract the Hugging Face InferenceClient construction parameters."""
178
183
return {
@@ -194,7 +199,53 @@ def _get_model_kwargs(self, **kwargs: Any) -> Dict[str, Any]:
194
199
def _to_huggingface_messages (
195
200
self , messages : Sequence [ChatMessage ]
196
201
) -> List [Dict [str , Any ]]:
197
- return [{"role" : m .role .value , "content" : m .content } for m in messages ]
202
+ hf_dicts = []
203
+ for m in messages :
204
+ hf_dicts .append (
205
+ {"role" : m .role .value , "content" : m .content if m .content else "" }
206
+ )
207
+ if m .additional_kwargs .get ("tool_calls" , []):
208
+ tool_call_dicts = []
209
+ for tool_call in m .additional_kwargs ["tool_calls" ]:
210
+ function_dict = {
211
+ "name" : tool_call .id ,
212
+ "arguments" : tool_call .function .arguments ,
213
+ }
214
+ tool_call_dicts .append (
215
+ {"type" : "function" , "function" : function_dict }
216
+ )
217
+
218
+ hf_dicts [- 1 ]["tool_calls" ] = tool_call_dicts
219
+
220
+ if m .role == MessageRole .TOOL :
221
+ hf_dicts [- 1 ]["name" ] = m .additional_kwargs .get ("tool_call_id" )
222
+
223
+ return hf_dicts
224
+
225
+ def _parse_streaming_tool_calls (
226
+ self , tool_call_strs : List [str ]
227
+ ) -> List [ToolSelection | str ]:
228
+ tool_calls = []
229
+ # Try to parse into complete objects, otherwise keep as strings
230
+ for tool_call_str in tool_call_strs :
231
+ try :
232
+ tool_call_dict = json .loads (tool_call_str )
233
+ args = tool_call_dict ["function" ]
234
+ name = args .pop ("_name" )
235
+ tool_calls .append (
236
+ ChatCompletionOutputToolCall (
237
+ id = name ,
238
+ type = "function" ,
239
+ function = ChatCompletionOutputFunctionDefinition (
240
+ arguments = args ,
241
+ name = name ,
242
+ ),
243
+ )
244
+ )
245
+ except Exception as e :
246
+ tool_calls .append (tool_call_str )
247
+
248
+ return tool_calls
198
249
199
250
def get_model_info (self , ** kwargs : Any ) -> "ModelInfo" :
200
251
"""Get metadata on the current model from Hugging Face."""
@@ -260,6 +311,8 @@ def stream_chat(
260
311
261
312
def gen () -> ChatResponseGen :
262
313
response = ""
314
+ tool_call_strs = []
315
+ cur_index = - 1
263
316
for chunk in self ._sync_client .chat_completion (
264
317
messages = self ._to_huggingface_messages (messages ),
265
318
stream = True ,
@@ -269,9 +322,18 @@ def gen() -> ChatResponseGen:
269
322
270
323
delta = chunk .choices [0 ].delta .content or ""
271
324
response += delta
272
- tool_calls = chunk .choices [0 ].delta .tool_calls or []
325
+ tool_call_delta = chunk .choices [0 ].delta .tool_calls
326
+ if tool_call_delta :
327
+ if tool_call_delta .index != cur_index :
328
+ cur_index = tool_call_delta .index
329
+ tool_call_strs .append (tool_call_delta .function .arguments )
330
+ else :
331
+ tool_call_strs [
332
+ cur_index
333
+ ] += tool_call_delta .function .arguments
334
+
335
+ tool_calls = self ._parse_streaming_tool_calls (tool_call_strs )
273
336
additional_kwargs = {"tool_calls" : tool_calls } if tool_calls else {}
274
-
275
337
yield ChatResponse (
276
338
message = ChatMessage (
277
339
role = MessageRole .ASSISTANT ,
@@ -359,16 +421,32 @@ async def astream_chat(
359
421
360
422
async def gen () -> ChatResponseAsyncGen :
361
423
response = ""
424
+ tool_call_strs = []
425
+ cur_index = - 1
362
426
async for chunk in await self ._async_client .chat_completion (
363
427
messages = self ._to_huggingface_messages (messages ),
364
428
stream = True ,
365
429
** model_kwargs ,
366
430
):
431
+ if chunk .choices [0 ].finish_reason is not None :
432
+ break
433
+
367
434
chunk : ChatCompletionStreamOutput = chunk
368
435
369
436
delta = chunk .choices [0 ].delta .content or ""
370
437
response += delta
371
- tool_calls = chunk .choices [0 ].delta .tool_calls or []
438
+ tool_call_delta = chunk .choices [0 ].delta .tool_calls
439
+ if tool_call_delta :
440
+ if tool_call_delta .index != cur_index :
441
+ cur_index = tool_call_delta .index
442
+ tool_call_strs .append (tool_call_delta .function .arguments )
443
+ else :
444
+ tool_call_strs [
445
+ cur_index
446
+ ] += tool_call_delta .function .arguments
447
+
448
+ tool_calls = self ._parse_streaming_tool_calls (tool_call_strs )
449
+
372
450
additional_kwargs = {"tool_calls" : tool_calls } if tool_calls else {}
373
451
374
452
yield ChatResponse (
@@ -474,6 +552,10 @@ def get_tool_calls_from_response(
474
552
475
553
tool_selections = []
476
554
for tool_call in tool_calls :
555
+ # while streaming, tool_call is a string
556
+ if isinstance (tool_call , str ):
557
+ continue
558
+
477
559
tool_selections .append (
478
560
ToolSelection (
479
561
tool_id = tool_call .id ,
0 commit comments