Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 162 additions & 8 deletions scripts/launch_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from libinfinicore_infer import DeviceType
from infer_task import InferTask
from kvcache_pool import KVCachePool
import torch

import argparse
import queue
Expand Down Expand Up @@ -160,17 +161,27 @@ def worker_loop(app):


def build_task(id_, request_data, request: Request):
messages = request_data.get("messages", [])
input_content = request.app.state.model.tokenizer.apply_chat_template(
conversation=messages,
add_generation_prompt=True,
tokenize=False,
)
tokens = request.app.state.model.tokenizer.encode(input_content)
# Handle both chat and completion formats
if "messages" in request_data:
# Chat format
messages = request_data.get("messages", [])
input_content = request.app.state.model.tokenizer.apply_chat_template(
conversation=messages,
add_generation_prompt=True,
tokenize=False,
)
tokens = request.app.state.model.tokenizer.encode(input_content)
max_tokens = request_data.get("max_tokens", request.app.state.model.max_context_len())
else:
# Completion format
prompt = request_data.get("prompt", "")
tokens = request.app.state.model.tokenizer.encode(prompt)
max_tokens = request_data.get("max_tokens", 0)

return AsyncInferTask(
id_,
tokens,
request_data.get("max_tokens", request.app.state.model.max_context_len()),
max_tokens,
request_data.get("temperature", 1.0),
request_data.get("top_k", 1),
request_data.get("top_p", 1.0),
Expand Down Expand Up @@ -277,6 +288,149 @@ async def chat_completions(request: Request):
response = await chat(id_, data, request)
return JSONResponse(content=response)





async def completion(id_, request_data, request: Request):
infer_task = None # Initialize to None to avoid UnboundLocalError
try:
# Check if max_tokens > 0 is requested
max_tokens = request_data.get("max_tokens", 0)
if max_tokens > 0:
return JSONResponse(
content={"error": "max_tokens > 0 is not supported yet. Please use max_tokens=0 for logprobs calculation."},
status_code=400
)

infer_task = build_task(id_, request_data, request)
await request.app.state.kv_cache_pool.acquire(infer_task)

output = []
logprobs = []

# Handle echo and logprobs calculation
echo = request_data.get("echo", False)
if echo:
# Add input tokens to output
input_tokens = infer_task.tokens
for token in input_tokens:
content = (
request.app.state.model.tokenizer._tokenizer.id_to_token(token)
.replace("▁", " ")
.replace("<0x0A>", "\n")
)
output.append(content)

# Calculate logprobs for input tokens
from jiuge import JiugeBatchedTask
batch_inputs = JiugeBatchedTask([infer_task])
logits = torch.zeros(
(batch_inputs.ntok, request.app.state.model.meta.dvoc),
dtype=request.app.state.model.meta.torch_dtype_logits
)
from libinfinicore_infer import forward_batch
forward_batch(
request.app.state.model.model_instance,
batch_inputs.tokens,
batch_inputs.ntok,
batch_inputs.req_lens,
batch_inputs.nreq,
batch_inputs.req_pos,
batch_inputs.kv_caches,
logits.data_ptr(),
)

# Calculate logprobs for input tokens
logits = logits.float()
log_probs = torch.nn.functional.log_softmax(logits, dim=-1)

# Calculate correct logprobs for input tokens
token_logprobs = []
for i in range(len(infer_task.tokens) - 1): # Only up to second-to-last token
next_token = infer_task.tokens[i+1] # Next token to predict
logprob = log_probs[i, next_token].item() # Use position i logits to predict position i+1 token
token_logprobs.append(logprob)

# First token has no context, so logprob is None
logprobs = [None] + token_logprobs
else:
# echo=false: don't calculate logprobs since user can't see input text
logprobs = []

# For max_tokens=0, we need to manually release the KV cache since we don't go through worker
await request.app.state.kv_cache_pool.release(infer_task)
print(f"[DEBUG] {id_} Released KV cache for max_tokens=0")

output_text = "".join(output).strip()

# Prepare tokens list for logprobs
tokens_list = []
text_offset_list = []
current_offset = 0

# Build tokens list and text offsets
for i, content in enumerate(output):
tokens_list.append(content)
text_offset_list.append(current_offset)
current_offset += len(content)

# Build response according to DeepSeek API completion format
response = {
"id": id_,
"object": "text_completion",
"created": int(time.time()),
"model": "jiuge",
"choices": [
{
"text": output_text,
"index": 0,
"logprobs": {
"token_logprobs": logprobs,
"tokens": tokens_list,
"text_offset": text_offset_list,
"top_logprobs": []
},
"finish_reason": "stop"
}
],
"usage": {
"prompt_tokens": len(infer_task.tokens),
"prompt_cache_hit_tokens": 0,
"prompt_cache_miss_tokens": len(infer_task.tokens),
"completion_tokens": 0,
"total_tokens": len(infer_task.tokens),
"completion_tokens_details": {
"reasoning_tokens": 0
}
}
}
return response

except Exception as e:
print(f"[Error] ID: {id_} Exception: {e}")
return JSONResponse(content={"error": str(e)}, status_code=500)
finally:
if infer_task and infer_task.finish_reason is None:
infer_task.finish_reason = "cancel"


@App.post("/completions")
async def completions(request: Request):
data = await request.json()

if not data.get("prompt"):
return JSONResponse(content={"error": "No prompt provided"}, status_code=400)

id_ = f"cmpl-{uuid.uuid4().hex}"
response = await completion(id_, data, request)

# Check if response is already a JSONResponse (error case)
if isinstance(response, JSONResponse):
return response
else:
return JSONResponse(content=response)

if __name__ == "__main__":
uvicorn.run(App, host="0.0.0.0", port=8000)

Expand Down