Skip to content
This repository was archived by the owner on Sep 23, 2025. It is now read-only.

Commit 6d33b49

Browse files
xuechendijiafuzhacarsonwang
authored
Agent tool support (#134)
* add test files for openai_tools_agent Signed-off-by: Xue, Chendi <[email protected]> * complete to add for tool Signed-off-by: Xue, Chendi <[email protected]> * Delete my_app directory to bring ci back Signed-off-by: jiafu zhang <[email protected]> * Add http based test for agent tool Signed-off-by: Xue, Chendi <[email protected]> * Update llm_on_ray/inference/api_openai_backend/router_app.py Co-authored-by: Carson Wang <[email protected]> Signed-off-by: Chendi.Xue <[email protected]> * remove ref app Signed-off-by: Xue, Chendi <[email protected]> * update UT Signed-off-by: Xue, Chendi <[email protected]> --------- Signed-off-by: Xue, Chendi <[email protected]> Signed-off-by: jiafu zhang <[email protected]> Signed-off-by: Chendi.Xue <[email protected]> Co-authored-by: jiafu zhang <[email protected]> Co-authored-by: Carson Wang <[email protected]>
1 parent aa2d08e commit 6d33b49

File tree

14 files changed

+795
-41
lines changed

14 files changed

+795
-41
lines changed

.github/workflows/workflow_inference.yml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,14 @@ jobs:
189189
docker exec "${TARGET}" bash -c "python examples/inference/api_server_openai/query_http_requests.py --model_name ${{ matrix.model }}"
190190
fi
191191
192+
- name: Run Agent tool Inference Test with REST API
193+
run: |
194+
TARGET=${{steps.target.outputs.target}}
195+
if [[ ${{ matrix.model }} == "llama-2-7b-chat-hf" ]]; then
196+
docker exec "${TARGET}" bash -c "llm_on_ray-serve --models ${{ matrix.model }}"
197+
docker exec "${TARGET}" bash -c "python examples/inference/api_server_openai/query_http_requests_tool.py --model_name ${{ matrix.model }}"
198+
fi
199+
192200
- name: Stop Ray
193201
run: |
194202
TARGET=${{steps.target.outputs.target}}

MANIFEST.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
# with [tools.setuptools] in pyproject.toml, the configs below work in both baremetal and container
22
include inference/**/*.yaml
3+
include inference/**/*.jinja
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
#
2+
# Copyright 2023 The LLM-on-Ray Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
17+
import argparse
18+
import os
19+
20+
from langchain_openai import ChatOpenAI
21+
from langchain.callbacks import StreamingStdOutCallbackHandler, StdOutCallbackHandler
22+
from langchain.agents import AgentExecutor, create_openai_tools_agent
23+
from langchain import hub
24+
25+
parser = argparse.ArgumentParser(
26+
description="Example script of enable langchain agent", add_help=True
27+
)
28+
parser.add_argument(
29+
"--model_name",
30+
default="mistral-7b-instruct-v0.2",
31+
type=str,
32+
help="The name of model to request",
33+
)
34+
parser.add_argument(
35+
"--streaming_response",
36+
default=False,
37+
action="store_true",
38+
help="Whether to enable streaming response",
39+
)
40+
parser.add_argument(
41+
"--prompt_template",
42+
default="hwchase17/openai-tools-agent",
43+
type=str,
44+
help="prompt template for openai tools agent",
45+
)
46+
parser.add_argument(
47+
"--max_tokens",
48+
default="512",
49+
type=int,
50+
help="max number of tokens used in this example",
51+
)
52+
53+
args = parser.parse_args()
54+
55+
if "OPENAI_API_KEY" in os.environ:
56+
openai_api_key = os.environ["OPENAI_API_KEY"]
57+
else:
58+
openai_api_key = "not_needed"
59+
60+
if "OPENAI_BASE_URL" in os.environ:
61+
openai_base_url = os.environ["OPENAI_BASE_URL"]
62+
elif openai_api_key == "not_needed":
63+
openai_base_url = "http://localhost:8000/v1"
64+
else:
65+
openai_base_url = "https://api.openai.com/v1"
66+
67+
# ================================================ #
68+
# Lets define a function/tool for getting the weather. In this demo it we mockthe output
69+
# In real life, you'd end up calling a library/API such as PWOWM (open weather map) library:
70+
# Depending on your app's functionality, you may also, call vendor/external or internal custom APIs
71+
72+
from pydantic import BaseModel, Field
73+
from typing import Optional, Type
74+
from langchain.tools import BaseTool
75+
76+
77+
def get_current_weather(location, unit):
78+
# Call an external API to get relevant information (like serpapi, etc)
79+
# Here for the demo we will send a mock response
80+
weather_info = {
81+
"location": location,
82+
"temperature": "78",
83+
"unit": unit,
84+
"forecast": ["sunny", "with a chance of rain"],
85+
}
86+
return weather_info
87+
88+
89+
class GetCurrentWeatherCheckInput(BaseModel):
90+
# Check the input for Weather
91+
location: str = Field(
92+
..., description="The name of the location name for which we need to find the weather"
93+
)
94+
unit: str = Field(..., description="The unit for the temperature value")
95+
96+
97+
class GetCurrentWeatherTool(BaseTool):
98+
name = "get_current_weather"
99+
description = "Used to find the weather for a given location in said unit"
100+
101+
def _run(self, location: str, unit: str):
102+
# print("I am running!")
103+
weather_response = get_current_weather(location, unit)
104+
return weather_response
105+
106+
def _arun(self, location: str, unit: str):
107+
raise NotImplementedError("This tool does not support async")
108+
109+
args_schema: Optional[Type[BaseModel]] = GetCurrentWeatherCheckInput
110+
111+
112+
# ================================================ #
113+
114+
tools = [GetCurrentWeatherTool()]
115+
prompt = hub.pull(args.prompt_template)
116+
llm = ChatOpenAI(
117+
openai_api_base=openai_base_url,
118+
model_name=args.model_name,
119+
openai_api_key=openai_api_key,
120+
max_tokens=args.max_tokens,
121+
callbacks=[
122+
StreamingStdOutCallbackHandler() if args.streaming_response else StdOutCallbackHandler()
123+
],
124+
streaming=args.streaming_response,
125+
)
126+
agent = create_openai_tools_agent(tools=tools, llm=llm, prompt=prompt)
127+
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
128+
agent_executor.invoke({"input": "what is the weather today in Boston?"})
129+
agent_executor.invoke({"input": "tell me a short joke?"})

examples/inference/api_server_langchain/query_langchain_sdk.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
action="store_true",
3030
help="Whether to enable streaming response",
3131
)
32+
parser.add_argument("--max_tokens", default=256, help="The maximum numbers of tokens to generate")
33+
3234

3335
args = parser.parse_args()
3436

@@ -52,6 +54,7 @@
5254
model_name=args.model_name,
5355
openai_api_key=openai_api_key,
5456
streaming=args.streaming_response,
57+
max_tokens=args.max_tokens,
5558
)
5659

5760
prompt = PromptTemplate(template="list 3 {things}", input_variables=["things"])
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
#
2+
# Copyright 2023 The LLM-on-Ray Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
17+
import argparse
18+
from openai import OpenAI
19+
import os
20+
21+
parser = argparse.ArgumentParser(
22+
description="Example script to query with openai sdk", add_help=True
23+
)
24+
parser.add_argument(
25+
"--model_name",
26+
default="mistral-7b-instruct-v0.2",
27+
type=str,
28+
help="The name of model to request",
29+
)
30+
parser.add_argument(
31+
"--streaming_response",
32+
default=False,
33+
action="store_true",
34+
help="Whether to enable streaming response",
35+
)
36+
parser.add_argument(
37+
"--max_new_tokens", default=512, help="The maximum numbers of tokens to generate"
38+
)
39+
args = parser.parse_args()
40+
41+
if "OPENAI_API_KEY" in os.environ:
42+
openai_api_key = os.environ["OPENAI_API_KEY"]
43+
else:
44+
openai_api_key = "not_needed"
45+
46+
if "OPENAI_BASE_URL" in os.environ:
47+
openai_base_url = os.environ["OPENAI_BASE_URL"]
48+
elif openai_api_key == "not_needed":
49+
openai_base_url = "http://localhost:8000/v1"
50+
else:
51+
openai_base_url = "https://api.openai.com/v1"
52+
53+
54+
client = OpenAI(base_url=openai_base_url, api_key=openai_api_key)
55+
56+
tools = [
57+
{
58+
"type": "function",
59+
"function": {
60+
"name": "get_current_weather",
61+
"description": "Get the current weather in a given location",
62+
"parameters": {
63+
"type": "object",
64+
"properties": {
65+
"location": {
66+
"type": "string",
67+
"description": "The city and state, e.g. San Francisco, CA",
68+
},
69+
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
70+
},
71+
"required": ["location"],
72+
},
73+
},
74+
}
75+
]
76+
messages = [
77+
[
78+
{"role": "user", "content": "You are a helpful assistant"},
79+
{"role": "user", "content": "What's the weather like in Boston today?"},
80+
],
81+
[
82+
{"role": "user", "content": "You are a helpful assistant"},
83+
{"role": "user", "content": "Tell me a short joke?"},
84+
],
85+
]
86+
for message in messages:
87+
print(f"User: {message[1]['content']}")
88+
print("Assistant:", end=" ", flush=True)
89+
chat_completion = client.chat.completions.create(
90+
model=args.model_name,
91+
messages=message,
92+
max_tokens=args.max_new_tokens,
93+
tools=tools,
94+
tool_choice="auto",
95+
stream=args.streaming_response,
96+
)
97+
98+
if args.streaming_response:
99+
for chunk in chat_completion:
100+
content = chunk.choices[0].delta.content
101+
if content is not None:
102+
print(content, end="", flush=True)
103+
tool_calls = chunk.choices[0].delta.tool_calls
104+
if tool_calls is not None:
105+
print(tool_calls, end="", flush=True)
106+
print("")
107+
else:
108+
print(repr(chat_completion.choices[0].message.model_dump()))

0 commit comments

Comments
 (0)