-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathservice.py
96 lines (83 loc) · 3.65 KB
/
service.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from __future__ import annotations
import base64, io, logging, traceback, typing, argparse, asyncio, os
import bentoml, fastapi, PIL.Image, typing_extensions, annotated_types
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
ENGINE_CONFIG = {
'model': 'deepseek-ai/DeepSeek-R1-Distill-Llama-8B',
'max_model_len': 4096,
'enable_prefix_caching': True,
}
MAX_TOKENS = 4096
openai_api_app = fastapi.FastAPI()
@bentoml.asgi_app(openai_api_app, path='/v1')
@bentoml.service(
name='bentovllm-r1-llama3.1-8b-tool-calling-service',
traffic={'timeout': 300},
resources={'gpu': 1, 'gpu_type': 'nvidia-l4'},
envs=[{'name': 'HF_TOKEN'}],
labels={'owner': 'bentoml-team', 'type': 'prebuilt'},
image=bentoml.images.PythonImage(python_version='3.11').requirements_file('requirements.txt'),
)
class VLLM:
model_id = ENGINE_CONFIG['model']
model = bentoml.models.HuggingFaceModel(model_id, exclude=['*.pth', '*.pt'])
def __init__(self) -> None:
from vllm import AsyncEngineArgs, AsyncLLMEngine
import vllm.entrypoints.openai.api_server as vllm_api_server
OPENAI_ENDPOINTS = [
['/chat/completions', vllm_api_server.create_chat_completion, ['POST']],
['/models', vllm_api_server.show_available_models, ['GET']],
]
for route, endpoint, methods in OPENAI_ENDPOINTS:
openai_api_app.add_api_route(path=route, endpoint=endpoint, methods=methods, include_in_schema=True)
ENGINE_ARGS = AsyncEngineArgs(**dict(ENGINE_CONFIG, model=self.model))
engine = AsyncLLMEngine.from_engine_args(ENGINE_ARGS)
model_config = engine.engine.get_model_config()
args = argparse.Namespace()
args.model = self.model
args.disable_log_requests = True
args.max_log_len = 1000
args.response_role = 'assistant'
args.served_model_name = [self.model_id]
args.chat_template = None
args.chat_template_content_format = 'auto'
args.lora_modules = None
args.prompt_adapters = None
args.request_logger = None
args.disable_log_stats = True
args.return_tokens_as_token_ids = False
args.enable_tool_call_parser = False
args.enable_auto_tool_choice = False
args.tool_call_parser = None
args.enable_prompt_tokens_details = False
args.enable_reasoning = False
args.reasoning_parser = None
args.enable_auto_tool_choice = True
args.enable_tool_call_parser = True
args.tool_call_parser = 'llama3_json'
args.reasoning_parser = 'deepseek_r1'
asyncio.create_task(vllm_api_server.init_app_state(engine, model_config, openai_api_app.state, args))
@bentoml.api
async def generate(
self,
prompt: str = 'Who are you? Please respond in pirate speak!',
max_tokens: typing_extensions.Annotated[
int, annotated_types.Ge(128), annotated_types.Le(MAX_TOKENS)
] = MAX_TOKENS,
) -> typing.AsyncGenerator[str, None]:
from openai import AsyncOpenAI
client = AsyncOpenAI(base_url='http://127.0.0.1:3000/v1', api_key='dummy')
try:
completion = await client.chat.completions.create(
model=self.model_id,
messages=[dict(role='user', content=[dict(type='text', text=prompt)])],
stream=True,
max_tokens=max_tokens,
)
async for chunk in completion:
yield chunk.choices[0].delta.content or ''
except Exception:
logger.error(traceback.format_exc())
yield 'Internal error found. Check server logs for more information'
return