forked from OpenBMB/MiniCPM
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathfunction_calling.py
More file actions
120 lines (114 loc) · 3.78 KB
/
function_calling.py
File metadata and controls
120 lines (114 loc) · 3.78 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
#!/usr/bin/env python
# encoding: utf-8
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
import json
model_path = "openbmb/MiniCPM3-4B"
tools = [
{
"type": "function",
"function": {
"name": "get_delivery_date",
"description": "Get the delivery date for a customer's order. Call this whenever you need to know the delivery date, for example when a customer asks 'Where is my package'",
"parameters": {
"type": "object",
"properties": {
"order_id": {
"type": "string",
"description": "The customer's order ID.",
},
},
"required": ["order_id"],
"additionalProperties": False,
},
},
}
]
messages = [
{
"role": "system",
"content": "You are a helpful customer support assistant. Use the supplied tools to assist the user.",
},
{
"role": "user",
"content": "Hi, can you tell me the delivery date for my order? The order id is 1234 and 4321.",
},
# {
# "content": "",
# "tool_calls": [
# {
# "type": "function",
# "function": {
# "name": "get_delivery_date",
# "arguments": {"order_id": "1234"},
# },
# "id": "call_b4ab0b4ec4b5442e86f017fe0385e22e",
# },
# {
# "type": "function",
# "function": {
# "name": "get_delivery_date",
# "arguments": {"order_id": "4321"},
# },
# "id": "call_628965479dd84794bbb72ab9bdda0c39",
# },
# ],
# "role": "assistant",
# },
# {
# "role": "tool",
# "content": '{"delivery_date": "2024-09-05", "order_id": "1234"}',
# "tool_call_id": "call_b4ab0b4ec4b5442e86f017fe0385e22e",
# },
# {
# "role": "tool",
# "content": '{"delivery_date": "2024-09-05", "order_id": "4321"}',
# "tool_call_id": "call_628965479dd84794bbb72ab9bdda0c39",
# },
# {
# "content": "Both your orders will be delivered on 2024-09-05.",
# "role": "assistant",
# "thought": "\nI have the information you need, both orders will be delivered on the same date, 2024-09-05.\n",
# },
]
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
prompt = tokenizer.apply_chat_template(
messages, tools=tools, tokenize=False, add_generation_prompt=True
)
llm = LLM(model_path, trust_remote_code=True)
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=1000)
def fake_tool_execute(toolcall):
data = {
"delivery_date": "2024-09-05",
"order_id": toolcall.get("function", {})
.get("arguments", {})
.get("order_id", "order_id"),
}
return json.dumps(data)
while True:
prompt = tokenizer.apply_chat_template(
messages, tools=tools, tokenize=False, add_generation_prompt=True
)
outputs = llm.generate([prompt], sampling_params)
response = outputs[0].outputs[0].text
msg = tokenizer.decode_function_call(response)
if (
"tool_calls" in msg
and msg["tool_calls"] is not None
and len(msg["tool_calls"]) > 0
):
messages.append(msg)
print(msg)
for toolcall in msg["tool_calls"]:
tool_response = fake_tool_execute(toolcall)
tool_msg = {
"role": "tool",
"content": tool_response,
"tool_call_id": toolcall["id"],
}
messages.append(tool_msg)
print(tool_msg)
else:
messages.append(msg)
print(msg)
break