From 4c1528e67c441264ee4c78b68892c8ce7cfd0481 Mon Sep 17 00:00:00 2001 From: zhuyue Date: Mon, 18 Aug 2025 17:03:42 +0800 Subject: [PATCH 1/3] add completions endpoint --- scripts/launch_server.py | 169 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 169 insertions(+) diff --git a/scripts/launch_server.py b/scripts/launch_server.py index a315b4e6..5399da4a 100644 --- a/scripts/launch_server.py +++ b/scripts/launch_server.py @@ -3,6 +3,7 @@ from libinfinicore_infer import DeviceType from infer_task import InferTask from kvcache_pool import KVCachePool +import torch import argparse import queue @@ -294,6 +295,174 @@ async def chat_completions(request: Request): return JSONResponse(content=response) +def build_completion_task(id_, request_data, request: Request): + prompt = request_data.get("prompt", "") + tokens = request.app.state.model.tokenizer.encode(prompt) + return AsyncInferTask( + id_, + tokens, + request_data.get("max_tokens", 0), + request_data.get("temperature", 1.0), + request_data.get("top_k", 1), + request_data.get("top_p", 1.0), + request.app.state.model.eos_token_id, + ) + + +async def completion(id_, request_data, request: Request): + try: + infer_task = build_completion_task(id_, request_data, request) + await request.app.state.kv_cache_pool.acquire(infer_task) + + output = [] + logprobs = [] + + # If echo is True, we need to return the input tokens as well + echo = request_data.get("echo", False) + if echo: + # Add input tokens to output + input_tokens = infer_task.tokens + for token in input_tokens: + content = ( + request.app.state.model.tokenizer._tokenizer.id_to_token(token) + .replace("▁", " ") + .replace("<0x0A>", "\n") + ) + output.append(content) + # For input tokens, we don't have logprobs, so add None + logprobs.append(None) + + # Handle different max_tokens scenarios + max_tokens = request_data.get("max_tokens", 0) + if max_tokens == 0: + # Only calculate logprobs if echo=true + if echo: + # Only get logprobs of input tokens, no generation + from jiuge import JiugeBatchedTask + batch_inputs = JiugeBatchedTask([infer_task]) + logits = torch.zeros( + (batch_inputs.ntok, request.app.state.model.meta.dvoc), + dtype=request.app.state.model.meta.torch_dtype_logits + ) + from libinfinicore_infer import forward_batch + forward_batch( + request.app.state.model.model_instance, + batch_inputs.tokens, + batch_inputs.ntok, + batch_inputs.req_lens, + batch_inputs.nreq, + batch_inputs.req_pos, + batch_inputs.kv_caches, + logits.data_ptr(), + ) + + # Calculate logprobs for input tokens + logits = logits.float() + log_probs = torch.nn.functional.log_softmax(logits, dim=-1) + + # Calculate correct logprobs for input tokens + token_logprobs = [] + for i in range(len(infer_task.tokens) - 1): # Only up to second-to-last token + next_token = infer_task.tokens[i+1] # Next token to predict + logprob = log_probs[i, next_token].item() # Use position i logits to predict position i+1 token + token_logprobs.append(logprob) + + # First token has no context, so logprob is None + logprobs = [None] + token_logprobs + else: + # echo=false: don't calculate logprobs since user can't see input text + logprobs = [] + + # For max_tokens=0, we need to manually release the KV cache since we don't go through worker + await request.app.state.kv_cache_pool.release(infer_task) + print(f"[DEBUG] {id_} Released KV cache for max_tokens=0") + else: + # Generate new tokens with logprobs + request.app.state.request_queue.sync_q.put(infer_task) + + while True: + if ( + infer_task.finish_reason is not None + and infer_task.output_queue.async_q.empty() + ): + break + + token = await infer_task.output_queue.async_q.get() + content = ( + request.app.state.model.tokenizer._tokenizer.id_to_token(token) + .replace("▁", " ") + .replace("<0x0A>", "\n") + ) + output.append(content) + + # For generated tokens, we need to get logprobs + # This is a simplified implementation - in practice you'd need to get logprobs during generation + logprobs.append(-1.0) + + output_text = "".join(output).strip() + + # Prepare tokens list for logprobs + tokens_list = [] + text_offset_list = [] + current_offset = 0 + + # Build tokens list and text offsets + for i, content in enumerate(output): + tokens_list.append(content) + text_offset_list.append(current_offset) + current_offset += len(content) + + # Build response according to DeepSeek API completion format + response = { + "id": id_, + "object": "text_completion", + "created": int(time.time()), + "model": "jiuge", + "choices": [ + { + "text": output_text, + "index": 0, + "logprobs": { + "token_logprobs": logprobs, + "tokens": tokens_list, + "text_offset": text_offset_list, + "top_logprobs": [] + }, + "finish_reason": infer_task.finish_reason or "stop" + } + ], + "usage": { + "prompt_tokens": len(infer_task.tokens), + "prompt_cache_hit_tokens": 0, + "prompt_cache_miss_tokens": len(infer_task.tokens), + "completion_tokens": len(output) - len(infer_task.tokens) if echo else len(output), + "total_tokens": len(infer_task.tokens) + (len(output) - len(infer_task.tokens) if echo else len(output)), + "completion_tokens_details": { + "reasoning_tokens": 0 + } + } + } + return response + + except Exception as e: + print(f"[Error] ID: {id_} Exception: {e}") + return JSONResponse(content={"error": str(e)}, status_code=500) + finally: + if infer_task.finish_reason is None: + infer_task.finish_reason = "cancel" + + +@App.post("/completions") +async def completions(request: Request): + data = await request.json() + + if not data.get("prompt"): + return JSONResponse(content={"error": "No prompt provided"}, status_code=400) + + id_ = f"cmpl-{uuid.uuid4().hex}" + response = await completion(id_, data, request) + return JSONResponse(content=response) + if __name__ == "__main__": uvicorn.run(App, host="0.0.0.0", port=8000) From b9f48f8f7320f2c5ac31fbbbe85d3c08d38d9837 Mon Sep 17 00:00:00 2001 From: zhuyue Date: Wed, 20 Aug 2025 16:26:16 +0800 Subject: [PATCH 2/3] Add completions endpoint, only support max_tokens=0 --- scripts/launch_server.py | 172 ++++++++++++++++++--------------------- 1 file changed, 78 insertions(+), 94 deletions(-) diff --git a/scripts/launch_server.py b/scripts/launch_server.py index 5399da4a..4f157332 100644 --- a/scripts/launch_server.py +++ b/scripts/launch_server.py @@ -177,17 +177,27 @@ def worker_loop(app): def build_task(id_, request_data, request: Request): - messages = request_data.get("messages", []) - input_content = request.app.state.model.tokenizer.apply_chat_template( - conversation=messages, - add_generation_prompt=True, - tokenize=False, - ) - tokens = request.app.state.model.tokenizer.encode(input_content) + # Handle both chat and completion formats + if "messages" in request_data: + # Chat format + messages = request_data.get("messages", []) + input_content = request.app.state.model.tokenizer.apply_chat_template( + conversation=messages, + add_generation_prompt=True, + tokenize=False, + ) + tokens = request.app.state.model.tokenizer.encode(input_content) + max_tokens = request_data.get("max_tokens", request.app.state.model.max_context_len()) + else: + # Completion format + prompt = request_data.get("prompt", "") + tokens = request.app.state.model.tokenizer.encode(prompt) + max_tokens = request_data.get("max_tokens", 0) + return AsyncInferTask( id_, tokens, - request_data.get("max_tokens", request.app.state.model.max_context_len()), + max_tokens, request_data.get("temperature", 1.0), request_data.get("top_k", 1), request_data.get("top_p", 1.0), @@ -295,29 +305,27 @@ async def chat_completions(request: Request): return JSONResponse(content=response) -def build_completion_task(id_, request_data, request: Request): - prompt = request_data.get("prompt", "") - tokens = request.app.state.model.tokenizer.encode(prompt) - return AsyncInferTask( - id_, - tokens, - request_data.get("max_tokens", 0), - request_data.get("temperature", 1.0), - request_data.get("top_k", 1), - request_data.get("top_p", 1.0), - request.app.state.model.eos_token_id, - ) + async def completion(id_, request_data, request: Request): + infer_task = None # Initialize to None to avoid UnboundLocalError try: - infer_task = build_completion_task(id_, request_data, request) + # Check if max_tokens > 0 is requested + max_tokens = request_data.get("max_tokens", 0) + if max_tokens > 0: + return JSONResponse( + content={"error": "max_tokens > 0 is not supported yet. Please use max_tokens=0 for logprobs calculation."}, + status_code=400 + ) + + infer_task = build_task(id_, request_data, request) await request.app.state.kv_cache_pool.acquire(infer_task) output = [] logprobs = [] - # If echo is True, we need to return the input tokens as well + # Handle echo and logprobs calculation echo = request_data.get("echo", False) if echo: # Add input tokens to output @@ -329,75 +337,46 @@ async def completion(id_, request_data, request: Request): .replace("<0x0A>", "\n") ) output.append(content) - # For input tokens, we don't have logprobs, so add None - logprobs.append(None) - - # Handle different max_tokens scenarios - max_tokens = request_data.get("max_tokens", 0) - if max_tokens == 0: - # Only calculate logprobs if echo=true - if echo: - # Only get logprobs of input tokens, no generation - from jiuge import JiugeBatchedTask - batch_inputs = JiugeBatchedTask([infer_task]) - logits = torch.zeros( - (batch_inputs.ntok, request.app.state.model.meta.dvoc), - dtype=request.app.state.model.meta.torch_dtype_logits - ) - from libinfinicore_infer import forward_batch - forward_batch( - request.app.state.model.model_instance, - batch_inputs.tokens, - batch_inputs.ntok, - batch_inputs.req_lens, - batch_inputs.nreq, - batch_inputs.req_pos, - batch_inputs.kv_caches, - logits.data_ptr(), - ) - - # Calculate logprobs for input tokens - logits = logits.float() - log_probs = torch.nn.functional.log_softmax(logits, dim=-1) - - # Calculate correct logprobs for input tokens - token_logprobs = [] - for i in range(len(infer_task.tokens) - 1): # Only up to second-to-last token - next_token = infer_task.tokens[i+1] # Next token to predict - logprob = log_probs[i, next_token].item() # Use position i logits to predict position i+1 token - token_logprobs.append(logprob) - - # First token has no context, so logprob is None - logprobs = [None] + token_logprobs - else: - # echo=false: don't calculate logprobs since user can't see input text - logprobs = [] - # For max_tokens=0, we need to manually release the KV cache since we don't go through worker - await request.app.state.kv_cache_pool.release(infer_task) - print(f"[DEBUG] {id_} Released KV cache for max_tokens=0") - else: - # Generate new tokens with logprobs - request.app.state.request_queue.sync_q.put(infer_task) + # Calculate logprobs for input tokens + from jiuge import JiugeBatchedTask + batch_inputs = JiugeBatchedTask([infer_task]) + logits = torch.zeros( + (batch_inputs.ntok, request.app.state.model.meta.dvoc), + dtype=request.app.state.model.meta.torch_dtype_logits + ) + from libinfinicore_infer import forward_batch + forward_batch( + request.app.state.model.model_instance, + batch_inputs.tokens, + batch_inputs.ntok, + batch_inputs.req_lens, + batch_inputs.nreq, + batch_inputs.req_pos, + batch_inputs.kv_caches, + logits.data_ptr(), + ) - while True: - if ( - infer_task.finish_reason is not None - and infer_task.output_queue.async_q.empty() - ): - break - - token = await infer_task.output_queue.async_q.get() - content = ( - request.app.state.model.tokenizer._tokenizer.id_to_token(token) - .replace("▁", " ") - .replace("<0x0A>", "\n") - ) - output.append(content) - - # For generated tokens, we need to get logprobs - # This is a simplified implementation - in practice you'd need to get logprobs during generation - logprobs.append(-1.0) + # Calculate logprobs for input tokens + logits = logits.float() + log_probs = torch.nn.functional.log_softmax(logits, dim=-1) + + # Calculate correct logprobs for input tokens + token_logprobs = [] + for i in range(len(infer_task.tokens) - 1): # Only up to second-to-last token + next_token = infer_task.tokens[i+1] # Next token to predict + logprob = log_probs[i, next_token].item() # Use position i logits to predict position i+1 token + token_logprobs.append(logprob) + + # First token has no context, so logprob is None + logprobs = [None] + token_logprobs + else: + # echo=false: don't calculate logprobs since user can't see input text + logprobs = [] + + # For max_tokens=0, we need to manually release the KV cache since we don't go through worker + await request.app.state.kv_cache_pool.release(infer_task) + print(f"[DEBUG] {id_} Released KV cache for max_tokens=0") output_text = "".join(output).strip() @@ -428,15 +407,15 @@ async def completion(id_, request_data, request: Request): "text_offset": text_offset_list, "top_logprobs": [] }, - "finish_reason": infer_task.finish_reason or "stop" + "finish_reason": "stop" } ], "usage": { "prompt_tokens": len(infer_task.tokens), "prompt_cache_hit_tokens": 0, "prompt_cache_miss_tokens": len(infer_task.tokens), - "completion_tokens": len(output) - len(infer_task.tokens) if echo else len(output), - "total_tokens": len(infer_task.tokens) + (len(output) - len(infer_task.tokens) if echo else len(output)), + "completion_tokens": 0, + "total_tokens": len(infer_task.tokens), "completion_tokens_details": { "reasoning_tokens": 0 } @@ -448,7 +427,7 @@ async def completion(id_, request_data, request: Request): print(f"[Error] ID: {id_} Exception: {e}") return JSONResponse(content={"error": str(e)}, status_code=500) finally: - if infer_task.finish_reason is None: + if infer_task and infer_task.finish_reason is None: infer_task.finish_reason = "cancel" @@ -461,7 +440,12 @@ async def completions(request: Request): id_ = f"cmpl-{uuid.uuid4().hex}" response = await completion(id_, data, request) - return JSONResponse(content=response) + + # Check if response is already a JSONResponse (error case) + if isinstance(response, JSONResponse): + return response + else: + return JSONResponse(content=response) if __name__ == "__main__": uvicorn.run(App, host="0.0.0.0", port=8000) From 189d8d526d4757184c2b41cb0e8b345782858bab Mon Sep 17 00:00:00 2001 From: zhuyue Date: Fri, 19 Sep 2025 15:03:11 +0800 Subject: [PATCH 3/3] Replace torch's log_softmax with InfiniCore's logSoftmax operator. consistent with jiuge_ppl's chunk method. --- scripts/jiuge.py | 8 ++++---- scripts/launch_server.py | 11 ++++------- scripts/test_ppl.py | 5 +++-- src/cache_manager/opcache_manager.hpp | 2 ++ src/models/inference_context.cpp | 20 ++++++++++++++++++++ src/models/inference_context.hpp | 6 ++++++ src/models/jiuge/jiuge.cpp | 6 +++++- 7 files changed, 44 insertions(+), 14 deletions(-) diff --git a/scripts/jiuge.py b/scripts/jiuge.py index 523820c9..e23f032f 100644 --- a/scripts/jiuge.py +++ b/scripts/jiuge.py @@ -616,7 +616,7 @@ def perplexity(self, test_sequences: List[Sequence[int]], batch_size=10): batch_id += 1 batch_inputs = JiugeBatchedTask(tasks[:batch_id]) - logits = torch.zeros( + log_probs = torch.zeros( (batch_inputs.ntok, self.meta.dvoc), dtype=self.meta.torch_dtype_logits ) self.jiuge_model.forward_batch( @@ -627,12 +627,12 @@ def perplexity(self, test_sequences: List[Sequence[int]], batch_size=10): batch_inputs.nreq, batch_inputs.req_pos, batch_inputs.kv_caches, - logits.data_ptr(), + log_probs.data_ptr(), ) - logits = logits.float() + # forward_batch now returns log_softmax results, no need for additional calculation + log_probs = log_probs.float() token_ids = torch.tensor(true_tokens, dtype=torch.int64) # [ntok,] - log_probs = torch.nn.functional.log_softmax(logits, dim=-1) # (ntok, vocab) token_logprobs = log_probs[ torch.arange(batch_inputs.ntok), token_ids ] # (ntok,) diff --git a/scripts/launch_server.py b/scripts/launch_server.py index 4f157332..903a5396 100644 --- a/scripts/launch_server.py +++ b/scripts/launch_server.py @@ -341,12 +341,11 @@ async def completion(id_, request_data, request: Request): # Calculate logprobs for input tokens from jiuge import JiugeBatchedTask batch_inputs = JiugeBatchedTask([infer_task]) - logits = torch.zeros( + log_probs = torch.zeros( (batch_inputs.ntok, request.app.state.model.meta.dvoc), dtype=request.app.state.model.meta.torch_dtype_logits ) - from libinfinicore_infer import forward_batch - forward_batch( + request.app.state.model.jiuge_model.forward_batch( request.app.state.model.model_instance, batch_inputs.tokens, batch_inputs.ntok, @@ -354,12 +353,10 @@ async def completion(id_, request_data, request: Request): batch_inputs.nreq, batch_inputs.req_pos, batch_inputs.kv_caches, - logits.data_ptr(), + log_probs.data_ptr(), ) - # Calculate logprobs for input tokens - logits = logits.float() - log_probs = torch.nn.functional.log_softmax(logits, dim=-1) + log_probs = log_probs.float() # Calculate correct logprobs for input tokens token_logprobs = [] diff --git a/scripts/test_ppl.py b/scripts/test_ppl.py index 268a9f7d..5627dc57 100644 --- a/scripts/test_ppl.py +++ b/scripts/test_ppl.py @@ -33,8 +33,9 @@ # endcode, chunk and decode tokens = tokenizer.encode(text, add_special_tokens=False) - for i in range(0, len(tokens), CHUNK_SIZE): - chunk_tokens = tokens[i : min(i + CHUNK_SIZE, len(tokens))] + # 使用与jiuge_ppl.py相同的分割逻辑,只处理完整的chunk + for i in range(0, len(tokens) - CHUNK_SIZE + 1, CHUNK_SIZE): + chunk_tokens = tokens[i : i + CHUNK_SIZE] chunk_text = tokenizer.decode(chunk_tokens) resp = requests.post( diff --git a/src/cache_manager/opcache_manager.hpp b/src/cache_manager/opcache_manager.hpp index 333583e8..83ef5aed 100644 --- a/src/cache_manager/opcache_manager.hpp +++ b/src/cache_manager/opcache_manager.hpp @@ -158,6 +158,7 @@ class CacheManager { DECLARE_OP_CACHE(RoPE) DECLARE_OP_CACHE(Rearrange) DECLARE_OP_CACHE(CausalSoftmax) + DECLARE_OP_CACHE(LogSoftmax) DECLARE_OP_CACHE(Topkrouter) DECLARE_OP_CACHE(SwiGLU) DECLARE_OP_CACHE(RandomSample) @@ -170,6 +171,7 @@ class CacheManager { RoPE_cache(capacity, DESTROY_FUNC(RoPE)), Rearrange_cache(capacity, DESTROY_FUNC(Rearrange)), CausalSoftmax_cache(capacity, DESTROY_FUNC(CausalSoftmax)), + LogSoftmax_cache(capacity, DESTROY_FUNC(LogSoftmax)), Topkrouter_cache(capacity, DESTROY_FUNC(Topkrouter)), SwiGLU_cache(capacity, DESTROY_FUNC(SwiGLU)), RandomSample_cache(capacity, DESTROY_FUNC(RandomSample)), diff --git a/src/models/inference_context.cpp b/src/models/inference_context.cpp index e41e4bb3..15517538 100644 --- a/src/models/inference_context.cpp +++ b/src/models/inference_context.cpp @@ -143,6 +143,26 @@ void InferenceContext::causalSoftmax(std::shared_ptr y, y->data(), x->data(), stream)); } +void InferenceContext::logSoftmax(std::shared_ptr y, + std::shared_ptr x) { + size_t key = CacheManager::createDescriptorKey(y, x); + + infiniopLogSoftmaxDescriptor_t desc; + if (!cache_manager->getLogSoftmaxDescriptor(key, desc)) { + RUN_INFINI(infiniopCreateLogSoftmaxDescriptor( + op_handle, &desc, y->desc(), x->desc())); + cache_manager->putLogSoftmaxDescriptor(key, desc); + } + + size_t workspace_size = 0; + RUN_INFINI(infiniopGetLogSoftmaxWorkspaceSize(desc, &workspace_size)); + ensure_workspace(workspace_size); + void *workspace = workspace_storage->memory(); + + RUN_INFINI(infiniopLogSoftmax(desc, workspace, workspace_size, + y->data(), x->data(), stream)); +} + void InferenceContext::topkrouter(std::shared_ptr values, // F32 std::shared_ptr indices, // I32 std::shared_ptr x, diff --git a/src/models/inference_context.hpp b/src/models/inference_context.hpp index 0cf93f6f..d8597b5c 100644 --- a/src/models/inference_context.hpp +++ b/src/models/inference_context.hpp @@ -37,6 +37,8 @@ struct InferenceContext { infiniopRoPEAlgo_t algo); void causalSoftmax(std::shared_ptr y, std::shared_ptr x); + void logSoftmax(std::shared_ptr y, + std::shared_ptr x); void topkrouter(std::shared_ptr values, // F32 std::shared_ptr indices, // I32 @@ -111,6 +113,10 @@ inline void causalSoftmax(std::shared_ptr y, std::shared_ptr x) getInferenceContext().causalSoftmax(y, x); } +inline void logSoftmax(std::shared_ptr y, std::shared_ptr x) { + getInferenceContext().logSoftmax(y, x); +} + inline void topkrouter(std::shared_ptr values, // F32 std::shared_ptr indices, // I32 std::shared_ptr x, diff --git a/src/models/jiuge/jiuge.cpp b/src/models/jiuge/jiuge.cpp index 059842cc..f33b3e1e 100644 --- a/src/models/jiuge/jiuge.cpp +++ b/src/models/jiuge/jiuge.cpp @@ -262,8 +262,12 @@ void inferDeviceBatch(const JiugeMeta &meta, JiugeDeviceResource &rsrc, rmsnorm(logits_out, logits_in, rsrc.w_out_norm, meta.epsilon); auto last_logits_buf = Tensor::buffer(dt_logits, {ntok, dvoc}, rsrc.memory_pool); linear(last_logits_buf, logits_out, rsrc.w_out_embd, 1.0, 0.0, nullptr, nullptr); + + auto log_logits_buf = Tensor::buffer(dt_logits, {ntok, dvoc}, rsrc.memory_pool); + logSoftmax(log_logits_buf, last_logits_buf); + RUN_INFINI(infinirtStreamSynchronize(stream)); - RUN_INFINI(infinirtMemcpy(last_logits, last_logits_buf->data(), dsize(dt_logits) * ntok * dvoc, INFINIRT_MEMCPY_D2H)); + RUN_INFINI(infinirtMemcpy(last_logits, log_logits_buf->data(), dsize(dt_logits) * ntok * dvoc, INFINIRT_MEMCPY_D2H)); } if (output != nullptr) { size_t token_offset = 0;