diff --git a/include/jllm_engine.h b/include/jllm_engine.h index f2cb105..1702516 100644 --- a/include/jllm_engine.h +++ b/include/jllm_engine.h @@ -74,7 +74,9 @@ using TokenCallback = std::function; struct GenStats { int prompt_tokens = 0; + int prefill_tokens = 0; int completion_tokens = 0; + int kv_cache_reused_tokens = 0; float prompt_ms = 0; float decode_ms = 0; // Time-to-first-token: end-to-end wall clock from the start of diff --git a/src/engine/decode.cu b/src/engine/decode.cu index c9cd0dd..2fc5360 100644 --- a/src/engine/decode.cu +++ b/src/engine/decode.cu @@ -1667,6 +1667,7 @@ GenStats Engine::generate(const std::string& prompt, const GenParams& params, // range. 0 means cold prefill (no cache file or no match). const int prefill_start = try_hydrate_kv(prompt_tokens, params.conversation_id); + stats.kv_cache_reused_tokens = prefill_start; if (debug_kernels_enabled()) { fprintf(stderr, "[tokenizer] prompt tokens:"); for (int i = 0; i < (int)prompt_tokens.size() && i < 16; i++) { @@ -1707,6 +1708,7 @@ GenStats Engine::generate(const std::string& prompt, const GenParams& params, // the cache already covered the full prompt — handled as a special // case below (no Path A residual, fall through to fallback decode). const int M = N - prefill_start; + stats.prefill_tokens = M; if (M > 0 && batched_prefill_enabled() && batched_fits) { // Path B (issue #12): layer-major prefill. Allocate one @@ -1773,10 +1775,13 @@ GenStats Engine::generate(const std::string& prompt, const GenParams& params, cudaStreamSynchronize(stream_); auto t1 = Clock::now(); stats.prompt_ms = Ms(t1 - t0).count(); - if (stats.prompt_tokens > 0) - stats.prompt_tok_per_sec = stats.prompt_tokens / (stats.prompt_ms / 1000.0f); - fprintf(stderr, "[engine] Prefill: %d tokens in %.0f ms (%.1f tok/s)\n", - stats.prompt_tokens, stats.prompt_ms, stats.prompt_tok_per_sec); + if (stats.prefill_tokens > 0 && stats.prompt_ms > 0) + stats.prompt_tok_per_sec = stats.prefill_tokens / (stats.prompt_ms / 1000.0f); + fprintf(stderr, + "[engine] Prefill: %d new / %d prompt tokens in %.0f ms " + "(%.1f tok/s, kv_reused=%d)\n", + stats.prefill_tokens, stats.prompt_tokens, stats.prompt_ms, + stats.prompt_tok_per_sec, stats.kv_cache_reused_tokens); // Decode timer starts here so the Path A first-token sampling below // and the subsequent decode-loop iterations are both attributed to diff --git a/src/server/http_server.cpp b/src/server/http_server.cpp index 1428350..41ece9c 100644 --- a/src/server/http_server.cpp +++ b/src/server/http_server.cpp @@ -18,9 +18,13 @@ #include #include +#include +#include #include +#include #include #include +#include #include #include #include @@ -34,6 +38,16 @@ namespace jllm { // stomp on the KV cache pool. static std::mutex g_engine_mutex; +struct AgentRequestHints { + bool present = false; + std::string session_id; + int priority = 0; + int osl = 0; + bool speculative_prefill = false; + bool cache_ephemeral = false; + int cache_ttl_seconds = 0; +}; + // ── /health response ───────────────────────────────────────────────────── static json build_health(const Engine& engine) { @@ -150,6 +164,115 @@ struct ThinkSplit { } }; +static bool json_get_int(const json& obj, const char* key, int* out) { + if (!out || !obj.is_object() || !obj.contains(key)) return false; + const auto& v = obj.at(key); + if (!v.is_number_integer() && !v.is_number_unsigned()) return false; + long long raw = v.get(); + raw = std::max(std::numeric_limits::min(), + std::min(std::numeric_limits::max(), raw)); + *out = (int)raw; + return true; +} + +static std::string sanitize_session_id(const std::string& raw) { + std::string out; + out.reserve(std::min(raw.size(), 64)); + bool last_was_us = false; + for (unsigned char c : raw) { + if (std::isalnum(c) || c == '_' || c == '-') { + out.push_back((char)c); + last_was_us = false; + } else if (!last_was_us) { + out.push_back('_'); + last_was_us = true; + } + if (out.size() == 64) break; + } + while (!out.empty() && out.front() == '_') out.erase(out.begin()); + while (!out.empty() && out.back() == '_') out.pop_back(); + return out; +} + +static int parse_ttl_seconds(const json& ttl) { + if (ttl.is_number_integer() || ttl.is_number_unsigned()) { + long long n = ttl.get(); + if (n < 0) return 0; + return (int)std::min(n, std::numeric_limits::max()); + } + if (!ttl.is_string()) return 0; + std::string s = ttl.get(); + if (s.empty()) return 0; + char unit = s.back(); + int multiplier = 1; + if (unit == 's' || unit == 'S') { + multiplier = 1; + s.pop_back(); + } else if (unit == 'm' || unit == 'M') { + multiplier = 60; + s.pop_back(); + } else if (unit == 'h' || unit == 'H') { + multiplier = 3600; + s.pop_back(); + } + char* end = nullptr; + long long n = std::strtoll(s.c_str(), &end, 10); + if (end == s.c_str() || n < 0) return 0; + long long seconds = n * multiplier; + return (int)std::min(seconds, std::numeric_limits::max()); +} + +static AgentRequestHints parse_agent_request_hints(const json& body) { + AgentRequestHints out; + if (!body.is_object()) return out; + + if (body.contains("conversation_id") && body["conversation_id"].is_string()) { + out.session_id = sanitize_session_id(body["conversation_id"].get()); + out.present = out.present || !out.session_id.empty(); + } + + const json* nvext = nullptr; + if (body.contains("nvext") && body["nvext"].is_object()) { + nvext = &body["nvext"]; + } + const json* agent_hints = nullptr; + if (nvext && nvext->contains("agent_hints") && + (*nvext)["agent_hints"].is_object()) { + agent_hints = &(*nvext)["agent_hints"]; + } + if (agent_hints) { + out.present = true; + if (agent_hints->contains("session_id") && + (*agent_hints)["session_id"].is_string()) { + std::string sid = + sanitize_session_id((*agent_hints)["session_id"].get()); + if (!sid.empty()) out.session_id = sid; + } + (void)json_get_int(*agent_hints, "priority", &out.priority); + if (!json_get_int(*agent_hints, "osl", &out.osl)) { + (void)json_get_int(*agent_hints, "output_sequence_length", &out.osl); + } + if (agent_hints->contains("speculative_prefill") && + (*agent_hints)["speculative_prefill"].is_boolean()) { + out.speculative_prefill = + (*agent_hints)["speculative_prefill"].get(); + } + } + + if (nvext && nvext->contains("cache_control") && + (*nvext)["cache_control"].is_object()) { + const json& cc = (*nvext)["cache_control"]; + out.present = true; + out.cache_ephemeral = + cc.value("type", std::string()) == "ephemeral"; + if (cc.contains("ttl")) { + out.cache_ttl_seconds = parse_ttl_seconds(cc["ttl"]); + } + } + + return out; +} + // ── Chat completion (non-streaming) ────────────────────────────────────── static json build_completion(Engine& engine, const std::string& prompt, @@ -172,6 +295,19 @@ static json build_completion(Engine& engine, const std::string& prompt, json message = {{"role", "assistant"}, {"content", content}}; if (!reasoning.empty()) message["reasoning_content"] = reasoning; + json cache = { + {"prompt_tokens", stats.prompt_tokens}, + {"prefill_tokens", stats.prefill_tokens}, + {"kv_reused_tokens", stats.kv_cache_reused_tokens}, + {"kv_reuse_ratio", + stats.prompt_tokens > 0 + ? (double)stats.kv_cache_reused_tokens / (double)stats.prompt_tokens + : 0.0}, + }; + if (!params.conversation_id.empty()) { + cache["conversation_id"] = params.conversation_id; + } + return json{ {"id", "jllm-" + std::to_string(std::time(nullptr))}, {"object", "chat.completion"}, @@ -195,10 +331,30 @@ static json build_completion(Engine& engine, const std::string& prompt, {"ttft_ms", stats.ttft_ms}, {"peak_mem_mb", stats.peak_memory_mb}, {"peak_temp_c", stats.peak_thermal_c}, + {"cache", cache}, }}, }; } +static void attach_agent_hint_report(json& completion, + const AgentRequestHints& hints) { + if (!hints.present) return; + json report = json::object(); + if (!hints.session_id.empty()) report["session_id"] = hints.session_id; + if (hints.priority != 0) report["priority"] = hints.priority; + if (hints.osl > 0) report["osl"] = hints.osl; + if (hints.speculative_prefill) { + report["speculative_prefill"] = true; + } + if (hints.cache_ephemeral || hints.cache_ttl_seconds > 0) { + report["cache_control"] = { + {"type", hints.cache_ephemeral ? "ephemeral" : "default"}, + {"ttl_seconds", hints.cache_ttl_seconds}, + }; + } + completion["jetson"]["agent_hints"] = report; +} + // Reply with `{ "error": { "message": ..., "type": ... } }` (OpenAI shape). static void send_error(httplib::Response& res, int code, const std::string& msg, const std::string& type) { @@ -381,7 +537,11 @@ void run_server(Engine& engine, int port, bool default_kv_int8) { // the wrong format and the model emits garbage tokens. params.kv_int8 = body.value("kv_int8", default_kv_int8); + AgentRequestHints request_hints = parse_agent_request_hints(body); std::string conv_id = body.value("conversation_id", ""); + if (conv_id.empty() && !request_hints.session_id.empty()) { + conv_id = request_hints.session_id; + } if (!conv_id.empty()) { if (validate_conversation_id(conv_id)) { params.conversation_id = conv_id; @@ -396,10 +556,15 @@ void run_server(Engine& engine, int port, bool default_kv_int8) { const std::string prompt = format_qwen_chat(messages, think); fprintf(stderr, "[http] chat request body_bytes=%zu messages=%zu " - "prompt_bytes=%zu stream=%d max_tokens=%d think=%d\n", + "prompt_bytes=%zu stream=%d max_tokens=%d think=%d " + "session=%s priority=%d osl=%d cache_ttl_s=%d\n", req.body.size(), messages.size(), prompt.size(), body.value("stream", false) ? 1 : 0, - params.max_tokens, think ? 1 : 0); + params.max_tokens, think ? 1 : 0, + params.conversation_id.empty() ? "-" : params.conversation_id.c_str(), + request_hints.priority, + request_hints.osl, + request_hints.cache_ttl_seconds); if (body.value("stream", false)) { // SSE path — mutex is taken inside the chunked-content @@ -412,8 +577,9 @@ void run_server(Engine& engine, int port, bool default_kv_int8) { try { std::lock_guard lk(g_engine_mutex); - res.set_content(build_completion(engine, prompt, params).dump(), - "application/json"); + json completion = build_completion(engine, prompt, params); + attach_agent_hint_report(completion, request_hints); + res.set_content(completion.dump(), "application/json"); } catch (const std::length_error& e) { fprintf(stderr, "[http] chat rejected: %s\n", e.what()); send_error(res, 400, e.what(), "invalid_request_error");