diff --git a/rtp_llm/config/engine_config.py b/rtp_llm/config/engine_config.py index 7aedfccd94..8d8debc880 100644 --- a/rtp_llm/config/engine_config.py +++ b/rtp_llm/config/engine_config.py @@ -331,8 +331,17 @@ def update_worker_addrs( return worker_addrs = [] worker_grpc_addrs = [] + all_worker_grpc_addrs = [""] * parallelism_config.world_size local_rank = parallelism_config.local_rank for member in world_info.members: + member_grpc_addr = f"{member.ip}:{member.rpc_server_port}" + if 0 <= member.world_rank < len(all_worker_grpc_addrs): + all_worker_grpc_addrs[member.world_rank] = member_grpc_addr + else: + logging.warning( + f"world_info member world_rank {member.world_rank} out of range " + f"for world_size {parallelism_config.world_size}" + ) if ( int( (member.world_rank / parallelism_config.tp_size) @@ -343,7 +352,7 @@ def update_worker_addrs( worker_addrs.append( f"{member.ip}:{member.cache_store_listen_port}:{member.cache_store_rdma_listen_port}" ) - worker_grpc_addrs.append(f"{member.ip}:{member.rpc_server_port}") + worker_grpc_addrs.append(member_grpc_addr) logging.info( f"append member for pd sep " f"{member.ip}:{member.rpc_server_port}, {member.cache_store_listen_port}, " @@ -351,6 +360,7 @@ def update_worker_addrs( ) runtime_config.worker_grpc_addrs = worker_grpc_addrs runtime_config.worker_addrs = worker_addrs + runtime_config.all_worker_grpc_addrs = all_worker_grpc_addrs def setup_pd_sep_config( diff --git a/rtp_llm/config/exceptions.py b/rtp_llm/config/exceptions.py index d07b95659b..f8212f7e68 100644 --- a/rtp_llm/config/exceptions.py +++ b/rtp_llm/config/exceptions.py @@ -76,8 +76,16 @@ class ExceptionType(IntEnum): # master error MASTER_NO_AVAILABLE_WORKER = 8400 + MASTER_NO_PREFILL_WORKER = 8402 + MASTER_NO_DECODE_WORKER = 8403 + MASTER_NO_PDFUSION_WORKER = 8404 + MASTER_NO_VIT_WORKER = 8405 + MASTER_INVALID_REQUEST = 8406 # route error + ROUTER_QUEUE_FULL = 8502 + ROUTER_QUEUE_TIMEOUT = 8503 + ROUTER_REQUEST_CANCELLED = 8504 ROUTE_ERROR = 8500 # multimodal error diff --git a/rtp_llm/config/generate_config.py b/rtp_llm/config/generate_config.py index 31ffa4f20d..7dba6593fb 100644 --- a/rtp_llm/config/generate_config.py +++ b/rtp_llm/config/generate_config.py @@ -51,14 +51,16 @@ class RoleAddr(BaseModel): @field_validator("role", mode="before") @classmethod def validate_role(cls, v): - """Convert string to RoleType enum for deserialization.""" - if isinstance(v, str): - return getattr(RoleType, v) + """Convert proto enum (int) to RoleType enum for deserialization.""" + if isinstance(v, int): + return RoleType(v) elif isinstance(v, RoleType): return v + elif isinstance(v, str): + return getattr(RoleType, v.upper()) else: raise ValueError( - f"RoleType must be a string or RoleType enum, got {type(v)}" + f"RoleType must be an int, str, or RoleType enum, got {type(v)}" ) @field_serializer("role") @@ -174,9 +176,7 @@ class GenerateConfig(BaseModel): enable_memory_cache: bool = True enable_remote_cache: bool = True - # 是否强制相同 request_id 的 stream 在一批中调度 - force_batch: bool = False - batch_group_timeout: Optional[int] = None # ms + group_timeout: Optional[int] = None # ms unique_key: str = "" diff --git a/rtp_llm/config/py_config_modules.py b/rtp_llm/config/py_config_modules.py index 0d66b31efe..b5f9ce1b66 100644 --- a/rtp_llm/config/py_config_modules.py +++ b/rtp_llm/config/py_config_modules.py @@ -391,15 +391,21 @@ def __init__(self): self.master_queue_reject_threshold: int = 100000 self.master_default_timeout_ms: int = 3600000 self.master_max_connect_pool_size: int = 100000 + self.master_connector_limit_per_host: int = 0 # Session total timeout in seconds. If < 0: auto (3600 when queue mode, 0.5 otherwise). self.master_session_timeout_s: float = -1 + # When True, disable domain fallback routing when master is unavailable or not configured. + # Requests will fail with ROUTE_ERROR instead of falling back to VipServer domain routing. + self.disable_domain_fallback: bool = False def to_string(self): return ( f"master_queue_reject_threshold: {self.master_queue_reject_threshold}\n" f"master_default_timeout_ms: {self.master_default_timeout_ms}\n" f"master_max_connect_pool_size: {self.master_max_connect_pool_size}\n" - f"master_session_timeout_s: {self.master_session_timeout_s}" + f"master_connector_limit_per_host: {self.master_connector_limit_per_host}\n" + f"master_session_timeout_s: {self.master_session_timeout_s}\n" + f"disable_domain_fallback: {self.disable_domain_fallback}" ) diff --git a/rtp_llm/config/server_config_setup.py b/rtp_llm/config/server_config_setup.py index 9f82753da1..6b698f399f 100644 --- a/rtp_llm/config/server_config_setup.py +++ b/rtp_llm/config/server_config_setup.py @@ -467,10 +467,23 @@ def fetch_model_files_to_local(py_env_configs: PyEnvConfigs): ) +def get_cuda_device_id_for_local_rank(local_rank: int) -> int: + """Map logical local rank to CUDA device id. + + RTP_LLM_LOCAL_DEVICE_OFFSET is only intended for local multi-part smoke tests + that simulate multiple nodes in separate server processes on one host. + """ + return local_rank + int(os.environ.get("RTP_LLM_LOCAL_DEVICE_OFFSET", "0")) + + def setup_cuda_device_and_accl_env(local_rank: int) -> None: """Apply CUDA device and ACCL env side effects (same as ParallelInfo.from_params).""" + cuda_device_id = get_cuda_device_id_for_local_rank(local_rank) if torch.cuda.is_available(): - torch.cuda.set_device(local_rank) + torch.cuda.set_device(cuda_device_id) + logging.info( + "local rank %s mapped to cuda device %s", local_rank, cuda_device_id + ) if os.environ.get("ACCL_SELECT_PATH") == "1": select_port = str(local_rank % 2) diff --git a/rtp_llm/cpp/api_server/InferenceService.cc b/rtp_llm/cpp/api_server/InferenceService.cc index ca9ba2e4fb..94b47f538a 100644 --- a/rtp_llm/cpp/api_server/InferenceService.cc +++ b/rtp_llm/cpp/api_server/InferenceService.cc @@ -186,7 +186,7 @@ void InferenceService::inferResponse(int64_t auto input = fillGenerateInput(request_id, req.input_texts[i], req.input_urls[i], req.generate_configs[i]); inputs.push_back(input); } - auto ori_streams = engine_->batchEnqueue(inputs); + auto ori_streams = engine_->enqueueMultiple(inputs); std::vector> streams; streams.reserve(ori_streams.size()); for (size_t idx = 0; idx < ori_streams.size(); ++idx) { diff --git a/rtp_llm/cpp/api_server/test/InferenceServiceTest.cc b/rtp_llm/cpp/api_server/test/InferenceServiceTest.cc index 79daadd75f..2df3d2426b 100644 --- a/rtp_llm/cpp/api_server/test/InferenceServiceTest.cc +++ b/rtp_llm/cpp/api_server/test/InferenceServiceTest.cc @@ -272,7 +272,7 @@ TEST_F(InferenceServiceTest, InferResponseSuccess) { auto mock_stream = CreateMockGenerateStream(); auto stream = std::dynamic_pointer_cast(mock_stream); std::vector streams({stream}); - EXPECT_CALL(*mock_engine_, batchEnqueue(Matcher>&>(_))) + EXPECT_CALL(*mock_engine_, enqueueMultiple(Matcher>&>(_))) .WillOnce(Return(streams)); // stream diff --git a/rtp_llm/cpp/api_server/test/mock/MockEngineBase.h b/rtp_llm/cpp/api_server/test/mock/MockEngineBase.h index 5e71165b62..a1729cebe6 100644 --- a/rtp_llm/cpp/api_server/test/mock/MockEngineBase.h +++ b/rtp_llm/cpp/api_server/test/mock/MockEngineBase.h @@ -15,7 +15,7 @@ class MockEngineBase: public EngineBase { public: MOCK_METHOD1(enqueue, std::shared_ptr(const std::shared_ptr&)); MOCK_METHOD1(enqueue, void(std::shared_ptr&)); - MOCK_METHOD1(batchEnqueue, + MOCK_METHOD1(enqueueMultiple, std::vector(const std::vector>& inputs)); MOCK_METHOD0(stop, absl::Status()); MOCK_METHOD2(preRun, absl::StatusOr(const std::shared_ptr&, preRunMode)); diff --git a/rtp_llm/cpp/config/ConfigModules.cc b/rtp_llm/cpp/config/ConfigModules.cc index 6717816904..c55fcdc3b3 100644 --- a/rtp_llm/cpp/config/ConfigModules.cc +++ b/rtp_llm/cpp/config/ConfigModules.cc @@ -416,6 +416,13 @@ std::string RuntimeConfig::to_string() const { if (i < worker_addrs.size() - 1) oss << ", "; } + oss << "]\n" + << "all_worker_grpc_addrs: ["; + for (size_t i = 0; i < all_worker_grpc_addrs.size(); ++i) { + oss << all_worker_grpc_addrs[i]; + if (i < all_worker_grpc_addrs.size() - 1) + oss << ", "; + } oss << "]\n" << "specify_gpu_arch: " << specify_gpu_arch; return oss.str(); @@ -588,7 +595,13 @@ std::string PDSepConfig::to_string() const { << "load_cache_timeout_ms: " << load_cache_timeout_ms << "\n" << "max_rpc_timeout_ms: " << max_rpc_timeout_ms << "\n" << "worker_port_offset: " << worker_port_offset << "\n" - << "decode_entrance: " << decode_entrance; + << "decode_entrance: " << decode_entrance << "\n" + << "batch_dispatch_timeout_ms: " << batch_dispatch_timeout_ms << "\n" + << "batch_prepare_timeout_ms: " << batch_prepare_timeout_ms << "\n" + << "batch_load_timeout_ms: " << batch_load_timeout_ms << "\n" + << "prefill_enqueue_pool_size: " << prefill_enqueue_pool_size << "\n" + << "prefill_worker_lambda_pool_size: " << prefill_worker_lambda_pool_size << "\n" + << "prefill_slot_pool_size: " << prefill_slot_pool_size; return oss.str(); } diff --git a/rtp_llm/cpp/config/ConfigModules.h b/rtp_llm/cpp/config/ConfigModules.h index 8e8b03235e..70cdbc14dd 100644 --- a/rtp_llm/cpp/config/ConfigModules.h +++ b/rtp_llm/cpp/config/ConfigModules.h @@ -175,17 +175,17 @@ struct KVCacheConfig { bool enable_device_cache = true; bool enable_memory_cache = false; // When true, memory-cache H2D/D2H may use split-KV SM scatter/gather (CUDA) when layout is eligible. - bool enable_memory_cache_sm_copy = false; - bool enable_remote_cache = false; - bool write_cache_sync = false; - bool enable_tiered_memory_cache = false; - bool enable_gpu_prefix_tree = true; - bool enable_prefix_tree_memory_cache = true; - bool enable_legacy_memory_connector_fallback = true; - int64_t prefix_tree_memory_state_swa_pool_ratio = 0; + bool enable_memory_cache_sm_copy = false; + bool enable_remote_cache = false; + bool write_cache_sync = false; + bool enable_tiered_memory_cache = false; + bool enable_gpu_prefix_tree = true; + bool enable_prefix_tree_memory_cache = true; + bool enable_legacy_memory_connector_fallback = true; + int64_t prefix_tree_memory_state_swa_pool_ratio = 0; bool enable_dsv4_state_block_independent_eviction = false; - int64_t device_cache_min_free_blocks = 0; - int load_cache_retry_times = 1; // Maximum retry attempts for load cache transfer failures + int64_t device_cache_min_free_blocks = 0; + int load_cache_retry_times = 1; // Maximum retry attempts for load cache transfer failures // DSV4 fixed-allocation pool block count. 0 means the fixed regions // (INDEXER_STATE / CSA_STATE / HCA_STATE / SWA_KV) use the normal @@ -371,9 +371,9 @@ struct BatchDecodeSchedulerConfig { }; struct FIFOSchedulerConfig { - int64_t max_context_batch_size = 1; - int64_t max_batch_tokens_size = 0; - bool cp_force_single_prefill = true; + int64_t max_context_batch_size = 1; + int64_t max_batch_tokens_size = 0; + bool cp_force_single_prefill = true; int64_t max_inited_kv_cache_streams = 0; std::string to_string() const; }; @@ -405,6 +405,7 @@ struct RuntimeConfig { std::string model_name = ""; std::vector worker_grpc_addrs; std::vector worker_addrs; + std::vector all_worker_grpc_addrs; // Fields merged from PyDeviceResourceConfig std::string specify_gpu_arch = ""; @@ -433,6 +434,17 @@ struct PDSepConfig { int64_t max_rpc_timeout_ms = 2 * 3600 * 1000; // 2h default int64_t worker_port_offset = 0; bool decode_entrance = false; + int64_t batch_dispatch_timeout_ms = 60000; // 60s, cross-DP dispatch + int64_t batch_prepare_timeout_ms = 10000; // 10s, prepareAllocateResource + int64_t batch_load_timeout_ms = 10000; // 10s, remoteLoadCacheStart + + // ========== Prefill Thread Pool Configuration ========== + // enqueue pool size (L1 DP dispatch, fast ms-level). 0 = use formula default. + int64_t prefill_enqueue_pool_size = 0; + // worker lambda pool size (heavy EnqueueGroup coordination, I/O-bound). 0 = use formula default. + int64_t prefill_worker_lambda_pool_size = 0; + // slot pool size (L2 Prepare + L3 Load + L4 Finish). 0 = use formula default. + int64_t prefill_slot_pool_size = 0; std::string to_string() const; }; diff --git a/rtp_llm/cpp/disaggregate/cache_store/LoadContext.cpp b/rtp_llm/cpp/disaggregate/cache_store/LoadContext.cpp index 21f2ecc881..126662d536 100644 --- a/rtp_llm/cpp/disaggregate/cache_store/LoadContext.cpp +++ b/rtp_llm/cpp/disaggregate/cache_store/LoadContext.cpp @@ -68,7 +68,8 @@ void SyncContext::updateResult(bool succes autil::TimeUtility::currentTimeInMilliSeconds() - start_time_ms_); } - if (++done_layer_cnt_ == expect_layer_cnt_) { + ++done_layer_cnt_; + if (done_layer_cnt_ == expect_layer_cnt_ || !success) { cond_.notify_all(); } } @@ -81,6 +82,14 @@ void SyncContext::waitDone() { return; } + if (error_info_.hasError()) { + RTP_LLM_LOG_INFO("load context wait done on early error: %s, done %d/%d layers", + error_info_.ToString().c_str(), + done_layer_cnt_.load(), + expect_layer_cnt_); + return; + } + if (autil::TimeUtility::currentTimeInMilliSeconds() >= deadline_ms_) { auto error_code = ErrorCode::CACHE_STORE_LOAD_BUFFER_TIMEOUT; error_info_ = ErrorInfo(error_code, ErrorCodeToString(error_code)); @@ -97,7 +106,7 @@ void SyncContext::waitDone() { // sync wait, safe to use this if (cond_.wait_for(lock, std::chrono::milliseconds(once_time_ms), [this] { - return done_layer_cnt_ == expect_layer_cnt_; + return done_layer_cnt_ == expect_layer_cnt_ || error_info_.hasError(); })) { return; } diff --git a/rtp_llm/cpp/distribute/BUILD b/rtp_llm/cpp/distribute/BUILD index 2e058ede4e..9420232822 100644 --- a/rtp_llm/cpp/distribute/BUILD +++ b/rtp_llm/cpp/distribute/BUILD @@ -17,6 +17,27 @@ cc_library( copts = copts(), ) +cc_library( + name = "rpc_cpu_tp_broadcaster_hdr", + hdrs = ["RpcCpuTpBroadcaster.h"], + deps = [ + "//rtp_llm/cpp/model_rpc:broadcast_manager", + "//rtp_llm/cpp/model_rpc/proto:model_rpc_service_cc_proto", + ], + copts = copts(), +) + +cc_library( + name = "rpc_cpu_tp_broadcaster", + srcs = ["RpcCpuTpBroadcaster.cc"], + hdrs = ["RpcCpuTpBroadcaster.h"], + deps = [ + ":rpc_cpu_tp_broadcaster_hdr", + "//rtp_llm/cpp/utils:core_utils", + ], + copts = copts(), +) + cc_test( name = "cpu_tp_broadcaster_test", srcs = ["test/CpuTpBroadcasterTest.cc"], @@ -28,3 +49,15 @@ cc_test( copts = copts(), timeout = "short", ) + +cc_test( + name = "rpc_cpu_tp_broadcaster_test", + srcs = ["test/RpcCpuTpBroadcasterTest.cc"], + deps = [ + ":rpc_cpu_tp_broadcaster", + "@com_google_googletest//:gtest", + "@com_google_googletest//:gtest_main", + ], + copts = copts(), + timeout = "short", +) diff --git a/rtp_llm/cpp/distribute/RpcCpuTpBroadcaster.cc b/rtp_llm/cpp/distribute/RpcCpuTpBroadcaster.cc new file mode 100644 index 0000000000..016b6acfaa --- /dev/null +++ b/rtp_llm/cpp/distribute/RpcCpuTpBroadcaster.cc @@ -0,0 +1,288 @@ +#include "rtp_llm/cpp/distribute/RpcCpuTpBroadcaster.h" + +#include "rtp_llm/cpp/utils/AssertUtils.h" +#include "rtp_llm/cpp/utils/Logger.h" + +#include +#include +#include + +namespace rtp_llm { + +namespace { + +constexpr int kDefaultTimeoutMs = 30000; + +int normalizeTimeoutMs(int timeout_ms) { + return timeout_ms > 0 ? timeout_ms : kDefaultTimeoutMs; +} + +} // namespace + +RpcCpuTpBroadcaster& RpcCpuTpBroadcaster::instance() { + static RpcCpuTpBroadcaster i; + return i; +} + +std::size_t RpcCpuTpBroadcaster::InboxKeyHash::operator()(const InboxKey& key) const { + std::size_t h = std::hash{}(key.group_key); + h ^= std::hash{}(key.seq) + 0x9e3779b97f4a7c15ULL + (h << 6) + (h >> 2); + h ^= std::hash{}(key.dst_tp_rank) + 0x9e3779b97f4a7c15ULL + (h << 6) + (h >> 2); + return h; +} + +std::string RpcCpuTpBroadcaster::makeGroupKey(int dp_rank, int tp_size, int world_size) const { + std::ostringstream oss; + oss << "tp_cpu_broadcast:dp=" << dp_rank << ":tp=" << tp_size << ":world=" << world_size; + return oss.str(); +} + +void RpcCpuTpBroadcaster::initialize(int tp_rank, + int tp_size, + int dp_rank, + int world_size, + const std::vector& worker_grpc_addrs, + int timeout_ms) { + std::lock_guard lock(mu_); + timeout_ms = normalizeTimeoutMs(timeout_ms); + + if (initialized_.load(std::memory_order_acquire)) { + const std::string new_group_key = makeGroupKey(dp_rank, tp_size, world_size); + RTP_LLM_CHECK_WITH_INFO(tp_rank_ == tp_rank && tp_size_ == tp_size && dp_rank_ == dp_rank + && world_size_ == world_size && group_key_ == new_group_key, + "RpcCpuTpBroadcaster re-init mismatch: was rank=%d size=%d dp=%d world=%d group=%s, " + "now rank=%d size=%d dp=%d world=%d group=%s", + tp_rank_, + tp_size_, + dp_rank_, + world_size_, + group_key_.c_str(), + tp_rank, + tp_size, + dp_rank, + world_size, + new_group_key.c_str()); + return; + } + + if (tp_size <= 1) { + tp_rank_ = tp_rank; + tp_size_ = tp_size; + dp_rank_ = dp_rank; + world_size_ = world_size; + timeout_ms_ = timeout_ms; + group_key_ = makeGroupKey(dp_rank, tp_size, world_size); + initialized_.store(true, std::memory_order_release); + return; + } + + RTP_LLM_CHECK_WITH_INFO(tp_rank >= 0 && tp_rank < tp_size, + "RpcCpuTpBroadcaster bad tp_rank=%d tp_size=%d", + tp_rank, + tp_size); + RTP_LLM_CHECK_WITH_INFO(static_cast(worker_grpc_addrs.size()) >= world_size, + "RpcCpuTpBroadcaster worker_grpc_addrs too small: addrs=%zu world_size=%d", + worker_grpc_addrs.size(), + world_size); + + tp_rank_ = tp_rank; + tp_size_ = tp_size; + dp_rank_ = dp_rank; + world_size_ = world_size; + timeout_ms_ = timeout_ms; + group_key_ = makeGroupKey(dp_rank, tp_size, world_size); + seq_.store(0, std::memory_order_release); + inbox_.clear(); + peer_addrs_.clear(); + peer_tp_ranks_.clear(); + broadcast_manager_.reset(); + + if (tp_rank_ == 0) { + peer_addrs_.reserve(tp_size - 1); + peer_tp_ranks_.reserve(tp_size - 1); + for (int peer_tp_rank = 1; peer_tp_rank < tp_size; ++peer_tp_rank) { + const int world_rank = dp_rank * tp_size + peer_tp_rank; + RTP_LLM_CHECK_WITH_INFO(world_rank >= 0 && world_rank < static_cast(worker_grpc_addrs.size()), + "RpcCpuTpBroadcaster bad peer world_rank=%d addrs=%zu", + world_rank, + worker_grpc_addrs.size()); + peer_addrs_.push_back(worker_grpc_addrs[world_rank]); + peer_tp_ranks_.push_back(peer_tp_rank); + } + broadcast_manager_ = std::make_shared(peer_addrs_); + RTP_LLM_CHECK_WITH_INFO(broadcast_manager_->init(), + "RpcCpuTpBroadcaster BroadcastManager init failed for %zu peer(s)", + peer_addrs_.size()); + } + + initialized_.store(true, std::memory_order_release); + cv_.notify_all(); + RTP_LLM_LOG_INFO("Initialized RpcCpuTpBroadcaster rank=%d tp_size=%d dp_rank=%d world_size=%d peers=%zu timeout_ms=%d", + tp_rank_, + tp_size_, + dp_rank_, + world_size_, + peer_addrs_.size(), + timeout_ms_); +} + +void RpcCpuTpBroadcaster::reset() { + { + std::lock_guard lock(mu_); + inbox_.clear(); + peer_addrs_.clear(); + peer_tp_ranks_.clear(); + broadcast_manager_.reset(); + tp_rank_ = 0; + tp_size_ = 1; + dp_rank_ = 0; + world_size_ = 1; + timeout_ms_ = kDefaultTimeoutMs; + group_key_.clear(); + seq_.store(0, std::memory_order_release); + initialized_.store(false, std::memory_order_release); + } + cv_.notify_all(); +} + +uint64_t RpcCpuTpBroadcaster::nextSeq() { + return seq_.fetch_add(1, std::memory_order_acq_rel); +} + +void RpcCpuTpBroadcaster::broadcast(void* buf, std::size_t nbytes, int root) { + RTP_LLM_CHECK_WITH_INFO(initialized_.load(std::memory_order_acquire), + "RpcCpuTpBroadcaster::broadcast called before initialize"); + if (tp_size_ <= 1 || nbytes == 0) { + return; + } + RTP_LLM_CHECK_WITH_INFO(root == 0, "RpcCpuTpBroadcaster supports only root=0; got %d", root); + + const uint64_t seq = nextSeq(); + if (tp_rank_ == 0) { + std::shared_ptr manager; + std::vector peer_tp_ranks; + std::string group_key; + int timeout_ms = kDefaultTimeoutMs; + { + std::lock_guard lock(mu_); + manager = broadcast_manager_; + peer_tp_ranks = peer_tp_ranks_; + group_key = group_key_; + timeout_ms = timeout_ms_; + } + RTP_LLM_CHECK_WITH_INFO(manager != nullptr, "RpcCpuTpBroadcaster root has no BroadcastManager"); + + std::vector requests; + requests.reserve(peer_tp_ranks.size()); + for (int peer_tp_rank : peer_tp_ranks) { + CpuTpBroadcastRequestPB request; + request.set_group_key(group_key); + request.set_seq(seq); + request.set_root(root); + request.set_src_tp_rank(tp_rank_); + request.set_dst_tp_rank(peer_tp_rank); + request.set_nbytes(static_cast(nbytes)); + request.set_payload(buf, nbytes); + requests.push_back(std::move(request)); + } + + auto rpc_call = [](std::shared_ptr& stub, + std::shared_ptr& ctx, + const CpuTpBroadcastRequestPB& request, + grpc::CompletionQueue* cq) { + return stub->AsyncCpuTpBroadcast(ctx.get(), request, cq); + }; + + auto result = manager->broadcast( + requests, timeout_ms, rpc_call); + RTP_LLM_CHECK_WITH_INFO(result != nullptr, + "RpcCpuTpBroadcaster broadcast setup failed seq=%lu nbytes=%zu", + seq, + nbytes); + RTP_LLM_CHECK_WITH_INFO(result->waitDone(timeout_ms), + "RpcCpuTpBroadcaster broadcast wait timeout seq=%lu timeout_ms=%d", + seq, + timeout_ms); + RTP_LLM_CHECK_WITH_INFO(result->success(), "RpcCpuTpBroadcaster broadcast RPC failed seq=%lu", seq); + for (const auto& response : result->responses()) { + RTP_LLM_CHECK_WITH_INFO(response.success(), + "RpcCpuTpBroadcaster peer rejected seq=%lu: %s", + seq, + response.error_message().c_str()); + } + return; + } + + InboxKey key; + std::string payload; + int timeout_ms = kDefaultTimeoutMs; + { + std::unique_lock lock(mu_); + key = InboxKey{group_key_, seq, tp_rank_}; + timeout_ms = timeout_ms_; + const bool ready = cv_.wait_for(lock, std::chrono::milliseconds(timeout_ms), [&] { + return !initialized_.load(std::memory_order_acquire) || inbox_.find(key) != inbox_.end(); + }); + RTP_LLM_CHECK_WITH_INFO(ready && initialized_.load(std::memory_order_acquire), + "RpcCpuTpBroadcaster receive timeout seq=%lu rank=%d timeout_ms=%d", + seq, + tp_rank_, + timeout_ms); + auto it = inbox_.find(key); + RTP_LLM_CHECK_WITH_INFO(it != inbox_.end(), "RpcCpuTpBroadcaster missing inbox payload seq=%lu", seq); + payload = std::move(it->second); + inbox_.erase(it); + } + + RTP_LLM_CHECK_WITH_INFO(payload.size() == nbytes, + "RpcCpuTpBroadcaster size mismatch seq=%lu rank=%d expected=%zu actual=%zu", + seq, + tp_rank_, + nbytes, + payload.size()); + std::memcpy(buf, payload.data(), nbytes); +} + +bool RpcCpuTpBroadcaster::handleBroadcastRequest(const CpuTpBroadcastRequestPB& request, + CpuTpBroadcastResponsePB* response) { + auto fail = [&](const std::string& message) { + response->set_success(false); + response->set_error_message(message); + RTP_LLM_LOG_WARNING("RpcCpuTpBroadcaster rejected request: %s", message.c_str()); + return false; + }; + + std::unique_lock lock(mu_); + if (!initialized_.load(std::memory_order_acquire)) { + cv_.wait_for(lock, std::chrono::milliseconds(kDefaultTimeoutMs), [&] { + return initialized_.load(std::memory_order_acquire); + }); + } + if (!initialized_.load(std::memory_order_acquire)) { + return fail("broadcaster is not initialized"); + } + if (request.group_key() != group_key_) { + return fail("group_key mismatch: got " + request.group_key() + ", expected " + group_key_); + } + if (request.root() != 0 || request.src_tp_rank() != 0) { + return fail("only root tp_rank 0 is supported"); + } + if (request.dst_tp_rank() != tp_rank_) { + return fail("dst_tp_rank mismatch"); + } + if (request.nbytes() != request.payload().size()) { + return fail("payload size mismatch"); + } + + InboxKey key{request.group_key(), request.seq(), request.dst_tp_rank()}; + if (inbox_.find(key) != inbox_.end()) { + return fail("duplicate payload"); + } + inbox_.emplace(std::move(key), request.payload()); + response->set_success(true); + response->clear_error_message(); + cv_.notify_all(); + return true; +} + +} // namespace rtp_llm diff --git a/rtp_llm/cpp/distribute/RpcCpuTpBroadcaster.h b/rtp_llm/cpp/distribute/RpcCpuTpBroadcaster.h new file mode 100644 index 0000000000..01537629d5 --- /dev/null +++ b/rtp_llm/cpp/distribute/RpcCpuTpBroadcaster.h @@ -0,0 +1,86 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "rtp_llm/cpp/model_rpc/BroadcastManager.h" +#include "rtp_llm/cpp/model_rpc/proto/model_rpc_service.pb.h" + +namespace rtp_llm { + +// Cross-node CPU TP broadcaster over RpcService. Root rank fanouts bytes to TP +// peers; non-root ranks wait on a local inbox filled by the gRPC server thread. +// The logical API intentionally matches CpuTpBroadcaster so execBroadcastCpu can +// choose this path without changing tpSyncModelInputs' packing/unpacking logic. +class RpcCpuTpBroadcaster { +public: + static RpcCpuTpBroadcaster& instance(); + + void initialize(int tp_rank, + int tp_size, + int dp_rank, + int world_size, + const std::vector& worker_grpc_addrs, + int timeout_ms); + + void reset(); + + bool isInitialized() const { + return initialized_.load(std::memory_order_acquire); + } + + void broadcast(void* buf, std::size_t nbytes, int root); + + bool handleBroadcastRequest(const CpuTpBroadcastRequestPB& request, CpuTpBroadcastResponsePB* response); + +private: + struct InboxKey { + std::string group_key; + uint64_t seq = 0; + int dst_tp_rank = 0; + + bool operator==(const InboxKey& other) const { + return group_key == other.group_key && seq == other.seq && dst_tp_rank == other.dst_tp_rank; + } + }; + + struct InboxKeyHash { + std::size_t operator()(const InboxKey& key) const; + }; + + RpcCpuTpBroadcaster() = default; + ~RpcCpuTpBroadcaster() = default; + RpcCpuTpBroadcaster(const RpcCpuTpBroadcaster&) = delete; + RpcCpuTpBroadcaster& operator=(const RpcCpuTpBroadcaster&) = delete; + + uint64_t nextSeq(); + std::string makeGroupKey(int dp_rank, int tp_size, int world_size) const; + +private: + mutable std::mutex mu_; + std::condition_variable cv_; + std::atomic initialized_{false}; + std::atomic seq_{0}; + + int tp_rank_ = 0; + int tp_size_ = 1; + int dp_rank_ = 0; + int world_size_ = 1; + int timeout_ms_ = 3000; + std::string group_key_; + + std::vector peer_addrs_; + std::vector peer_tp_ranks_; + std::shared_ptr broadcast_manager_; + + std::unordered_map inbox_; +}; + +} // namespace rtp_llm diff --git a/rtp_llm/cpp/distribute/test/RpcCpuTpBroadcasterTest.cc b/rtp_llm/cpp/distribute/test/RpcCpuTpBroadcasterTest.cc new file mode 100644 index 0000000000..b75cf7d9fb --- /dev/null +++ b/rtp_llm/cpp/distribute/test/RpcCpuTpBroadcasterTest.cc @@ -0,0 +1,62 @@ +#include "rtp_llm/cpp/distribute/RpcCpuTpBroadcaster.h" + +#include +#include + +#include "gtest/gtest.h" + +namespace rtp_llm { +namespace { + +CpuTpBroadcastRequestPB makeRequest(uint64_t seq, int dst_tp_rank, const std::string& payload) { + CpuTpBroadcastRequestPB request; + request.set_group_key("tp_cpu_broadcast:dp=0:tp=2:world=2"); + request.set_seq(seq); + request.set_root(0); + request.set_src_tp_rank(0); + request.set_dst_tp_rank(dst_tp_rank); + request.set_nbytes(payload.size()); + request.set_payload(payload); + return request; +} + +TEST(RpcCpuTpBroadcasterTest, NonRootWaitsUntilServerThreadPublishesPayload) { + auto& bcast = RpcCpuTpBroadcaster::instance(); + bcast.reset(); + bcast.initialize(/*tp_rank=*/1, + /*tp_size=*/2, + /*dp_rank=*/0, + /*world_size=*/2, + /*worker_grpc_addrs=*/{"unused-root", "unused-rank1"}, + /*timeout_ms=*/1000); + + std::vector recv(5, 0); + std::thread waiter([&] { bcast.broadcast(recv.data(), recv.size(), /*root=*/0); }); + + CpuTpBroadcastResponsePB response; + ASSERT_TRUE(bcast.handleBroadcastRequest(makeRequest(/*seq=*/0, /*dst_tp_rank=*/1, "hello"), &response)); + EXPECT_TRUE(response.success()); + + waiter.join(); + EXPECT_EQ(std::string(recv.begin(), recv.end()), "hello"); + bcast.reset(); +} + +TEST(RpcCpuTpBroadcasterTest, RejectsWrongDestinationRank) { + auto& bcast = RpcCpuTpBroadcaster::instance(); + bcast.reset(); + bcast.initialize(/*tp_rank=*/1, + /*tp_size=*/2, + /*dp_rank=*/0, + /*world_size=*/2, + /*worker_grpc_addrs=*/{"unused-root", "unused-rank1"}, + /*timeout_ms=*/1000); + + CpuTpBroadcastResponsePB response; + EXPECT_FALSE(bcast.handleBroadcastRequest(makeRequest(/*seq=*/0, /*dst_tp_rank=*/0, "bad"), &response)); + EXPECT_FALSE(response.success()); + bcast.reset(); +} + +} // namespace +} // namespace rtp_llm diff --git a/rtp_llm/cpp/engine_base/EngineBase.cc b/rtp_llm/cpp/engine_base/EngineBase.cc index f8cd35acf2..071ef9658b 100644 --- a/rtp_llm/cpp/engine_base/EngineBase.cc +++ b/rtp_llm/cpp/engine_base/EngineBase.cc @@ -14,7 +14,7 @@ EngineBase::EngineBase(const EngineInitParams& params) { EngineBase::~EngineBase() {} -std::vector EngineBase::batchEnqueue(const std::vector>& inputs) { +std::vector EngineBase::enqueueMultiple(const std::vector>& inputs) { throw std::runtime_error("not implemeted"); } @@ -28,7 +28,8 @@ void EngineBase::initRuntime(const EngineInitParams& params) { Logger::getEngineLogger().setRank(rank); Logger::getEngineLogger().flush(); size_t device_id = params.parallelism_config.world_rank % params.parallelism_config.local_world_size; - mla_ops_type_ = rtp_llm::initRuntime(device_id, + device_id += autil::EnvUtil::getEnv("RTP_LLM_LOCAL_DEVICE_OFFSET", 0); + mla_ops_type_ = rtp_llm::initRuntime(device_id, params.profiling_debug_logging_config.trace_memory, params.device_resource_config.enable_comm_overlap, params.model_config_.mla_ops_type); diff --git a/rtp_llm/cpp/engine_base/EngineBase.h b/rtp_llm/cpp/engine_base/EngineBase.h index 6030d21931..f282f7e98e 100644 --- a/rtp_llm/cpp/engine_base/EngineBase.h +++ b/rtp_llm/cpp/engine_base/EngineBase.h @@ -52,7 +52,7 @@ class EngineBase { virtual void enqueue(std::shared_ptr& stream) = 0; - virtual std::vector batchEnqueue(const std::vector>& inputs); + virtual std::vector enqueueMultiple(const std::vector>& inputs); virtual std::shared_ptr makeStream(const std::shared_ptr& input); diff --git a/rtp_llm/cpp/engine_base/WorkerStatusInfo.h b/rtp_llm/cpp/engine_base/WorkerStatusInfo.h index 8a828125c7..a191986adb 100644 --- a/rtp_llm/cpp/engine_base/WorkerStatusInfo.h +++ b/rtp_llm/cpp/engine_base/WorkerStatusInfo.h @@ -8,11 +8,12 @@ #include #include #include "rtp_llm/cpp/engine_base/schedulers/EngineScheduleInfo.h" +#include "rtp_llm/cpp/config/RoleTypes.h" namespace rtp_llm { struct WorkerStatusInfo { - std::string role; + RoleType role; EngineScheduleInfo engine_schedule_info; int64_t status_version; int64_t latest_finished_version; diff --git a/rtp_llm/cpp/engine_base/schedulers/BatchDecodeScheduler.h b/rtp_llm/cpp/engine_base/schedulers/BatchDecodeScheduler.h index 7e9f9d3b15..b2bce914c7 100644 --- a/rtp_llm/cpp/engine_base/schedulers/BatchDecodeScheduler.h +++ b/rtp_llm/cpp/engine_base/schedulers/BatchDecodeScheduler.h @@ -50,7 +50,7 @@ class BatchDecodeScheduler: public SchedulerBase { return absl::OkStatus(); } - std::vector batchEnqueue(const std::vector& streams) override { + std::vector enqueueGroup(const std::vector& streams) override { return {}; // Not implemented for BatchDecodeScheduler } @@ -66,41 +66,6 @@ class BatchDecodeScheduler: public SchedulerBase { RTP_LLM_LOG_INFO("BatchDecodeScheduler update batch size to %d, mode to %d", batch_size_, int(scheduler_type_)); } - // 根据状态机转移后的目标状态,将 stream 路由到对应的队列 - void addStreamToNewState(const GenerateStreamPtr& stream, StreamState new_state) { - switch (new_state) { - case StreamState::WAITING: - waiting_streams_.push_back(stream); - break; - case StreamState::LOADING_CACHE: - loading_cache_streams_.push_back(stream); - break; - case StreamState::RUNNING: - running_streams_.push_back(stream); - break; - case StreamState::FINISHED: - break; - default: - RTP_LLM_LOG_ERROR( - "Unknown state: %d for stream [%ld]", static_cast(new_state), stream->streamId()); - break; - } - } - - // 通过 GenerateStateMachine 驱动每个 stream 的状态转移,状态变化的 stream 移入对应队列 - void evaluateAndUpdateStreams(std::list& streams) { - for (auto it = streams.begin(); it != streams.end();) { - auto state = (*it)->getStatus(); - auto new_state = (*it)->moveToNext(); - if (new_state != state) { - addStreamToNewState(*it, new_state); - it = streams.erase(it); - } else { - it++; - } - } - } - void evaluateWaitingStreams() { // 清理 waiting_streams_ 中有错误的 stream waiting_streams_.remove_if([](const auto& s) { return s->hasError(); }); @@ -118,14 +83,16 @@ class BatchDecodeScheduler: public SchedulerBase { // 凑到batch_size_个stream再统一入队 if (new_streams.size() >= batch_size_) { for (auto& stream : new_streams) { - stream->reportEvent(StreamEvents::CanRun); - // 忙等stream load cache done, 和原有SyncLoadCache逻辑等效 - while (stream->getStatus() != StreamState::FINISHED && stream->moveToNext() != StreamState::RUNNING) { + stream->prepare(); + while (stream->alive() && !stream->isReady()) { std::this_thread::sleep_for(std::chrono::milliseconds(1)); } + if (stream->alive()) { + stream->activate(); + } } - // 过滤 FINISHED stream,仅将 RUNNING stream 加入 running_streams_ - new_streams.remove_if([](const auto& s) { return s->getStatus() == StreamState::FINISHED; }); + // Filter out streams that are no longer alive + new_streams.remove_if([](const auto& s) { return !s->alive(); }); running_streams_.insert(running_streams_.end(), new_streams.begin(), new_streams.end()); // 从waiting_streams_中移除已调度的stream for (auto& stream : new_streams) { @@ -147,7 +114,7 @@ class BatchDecodeScheduler: public SchedulerBase { if (scheduler_type_ == SchedulerType::kBatchDecode) { (*it)->setIsContextStream(false); // for linear attn, incrKVBlock to clear unused linear block - (*it)->moveToNext(); + (*it)->advance(); } } } @@ -155,14 +122,18 @@ class BatchDecodeScheduler: public SchedulerBase { absl::StatusOr> schedule() override { std::unique_lock lock(lock_); cond_.wait_for(lock, std::chrono::seconds(30), [this] { - return waiting_streams_.size() >= batch_size_ || running_streams_.size() > 0 - || !loading_cache_streams_.empty(); + return waiting_streams_.size() >= batch_size_ || !running_streams_.empty(); }); - // 统一通过状态机驱动各队列中 stream 的状态转移 - // LOADING_CACHE -> DONE/WAITING: error / load cache done - evaluateAndUpdateStreams(loading_cache_streams_); - evaluateAndUpdateStreams(running_streams_); + // running: advance + cleanup finished + for (auto it = running_streams_.begin(); it != running_streams_.end();) { + (*it)->advance(); + if (!(*it)->alive()) { + it = running_streams_.erase(it); + } else { + ++it; + } + } if (running_streams_.empty() && waiting_streams_.size() >= batch_size_) { evaluateWaitingStreams(); @@ -192,14 +163,13 @@ class BatchDecodeScheduler: public SchedulerBase { int64_t onflightStreams() override { std::lock_guard lock(lock_); - return waiting_streams_.size() + loading_cache_streams_.size() + running_streams_.size(); + return waiting_streams_.size() + running_streams_.size(); } private: std::mutex lock_; std::condition_variable cond_; std::list waiting_streams_; - std::list loading_cache_streams_; std::list running_streams_; uint32_t batch_size_; bool reorder_request_; diff --git a/rtp_llm/cpp/engine_base/schedulers/EngineScheduleInfo.h b/rtp_llm/cpp/engine_base/schedulers/EngineScheduleInfo.h index b0045988b0..72d6fe86b5 100644 --- a/rtp_llm/cpp/engine_base/schedulers/EngineScheduleInfo.h +++ b/rtp_llm/cpp/engine_base/schedulers/EngineScheduleInfo.h @@ -5,15 +5,25 @@ namespace rtp_llm { +enum class TaskPhase { + PENDING = 0, + RECEIVED = 1, + KV_ALLOCATED = 2, + RUNNING = 3, +}; + struct EngineScheduleInfo { struct TaskInfo { - int64_t request_id; - int64_t prefix_length; - int64_t input_length; - int64_t waiting_time_ms; - int64_t iterate_count = 0; - int64_t end_time_ms = -1; - bool is_waiting = true; + int64_t request_id; + int64_t prefix_length; + int64_t input_length; + int64_t waiting_time_ms; + int64_t iterate_count = 0; + int64_t end_time_ms = -1; + TaskPhase phase = TaskPhase::PENDING; + int64_t error_code = 0; + std::string error_message; + int64_t batch_id = -1; }; std::vector running_task_info_list; std::vector finished_task_info_list; diff --git a/rtp_llm/cpp/engine_base/schedulers/FIFOScheduler.cc b/rtp_llm/cpp/engine_base/schedulers/FIFOScheduler.cc index 232e7b54df..91496d59fc 100644 --- a/rtp_llm/cpp/engine_base/schedulers/FIFOScheduler.cc +++ b/rtp_llm/cpp/engine_base/schedulers/FIFOScheduler.cc @@ -28,18 +28,11 @@ FIFOScheduler::FIFOScheduler(const RuntimeConfig& runtime_conf max_seq_len_(model_config.max_seq_len), max_batch_tokens_size_(runtime_config.fifo_scheduler_config.max_batch_tokens_size), max_generate_batch_size_(runtime_config.max_generate_batch_size), - max_inited_kv_cache_streams_( - std::max(runtime_config.fifo_scheduler_config.max_inited_kv_cache_streams, 0)), need_fill_fake_stream_(parallelism_config.dp_size > 1 && parallelism_config.tp_rank == 0), - cp_force_single_prefill_(parallelism_config.prefill_cp_config.is_enabled() - && runtime_config.fifo_scheduler_config.cp_force_single_prefill), metrics_reporter_(metrics_reporter) { - RTP_LLM_LOG_INFO("max_generate_batch_size is [%zu], max_batch_tokens_size is [%zu], " - "cp_force_single_prefill is [%d], max_inited_kv_cache_streams is [%zu]", + RTP_LLM_LOG_INFO("max_generate_batch_size is [%zu], max_batch_tokens_size is [%zu]", max_generate_batch_size_, - max_batch_tokens_size_, - cp_force_single_prefill_, - max_inited_kv_cache_streams_); + max_batch_tokens_size_); } FIFOScheduler::~FIFOScheduler() { @@ -49,15 +42,17 @@ FIFOScheduler::~FIFOScheduler() { bool FIFOScheduler::empty() { lock_guard lock(lock_); - return waiting_streams_.empty() && loading_cache_streams_.empty() && running_streams_.empty(); + return waiting_.empty() && loading_.empty() && running_.empty(); } -void FIFOScheduler::cancelStreams(std::list& streams) { - for (auto& stream : streams) { - stream->reportError(ErrorCode::CANCELLED, "scheduler stopped"); - stream->moveToNext(); // Stream should be finished after moveToNext +void FIFOScheduler::cancelUnits(std::list& units) { + for (auto& unit : units) { + for (auto& stream : unit.streams) { + stream->reportError(ErrorCode::CANCELLED, "scheduler stopped"); + stream->finish(); + } } - streams.clear(); + units.clear(); } absl::Status FIFOScheduler::stop() { @@ -65,9 +60,9 @@ absl::Status FIFOScheduler::stop() { { lock_guard lock(lock_); stop_ = true; - cancelStreams(waiting_streams_); - cancelStreams(loading_cache_streams_); - cancelStreams(running_streams_); + cancelUnits(waiting_); + cancelUnits(loading_); + cancelUnits(running_); } cond_.notify_all(); return absl::OkStatus(); @@ -77,20 +72,18 @@ int64_t FIFOScheduler::lastScheduleTime() { return empty() ? autil::TimeUtility::currentTimeInMilliSeconds() : last_schedule_time_.load(); } -// 在入队前校验输入长度,避免无效请求进入等待队列 -// 检查输入长度、投机解码预留空间和 batch token 上限。 bool FIFOScheduler::checkInputLength(const GenerateStreamPtr& stream) { const auto input_length = static_cast(stream->inputLength()); const auto reserve_step = stream->reserveStep(); if (reserve_step > 0 && !(input_length <= max_seq_len_ && reserve_step <= max_seq_len_ - input_length)) { const auto allowed_input_length = reserve_step <= max_seq_len_ ? max_seq_len_ - reserve_step : 0; - auto error_info = autil::StringUtil::formatString( - "input len %zu with speculative reserve_step %zu exceeds max seq len %zu, " - "allowed max input len for speculative decoding is %zu", - input_length, - reserve_step, - max_seq_len_, - allowed_input_length); + auto error_info = + autil::StringUtil::formatString("input len %zu with speculative reserve_step %zu exceeds max seq len %zu, " + "allowed max input len for speculative decoding is %zu", + input_length, + reserve_step, + max_seq_len_, + allowed_input_length); stream->reportError(ErrorCode::LONG_PROMPT_ERROR, error_info); return false; } @@ -99,7 +92,7 @@ bool FIFOScheduler::checkInputLength(const GenerateStreamPtr& stream) { autil::StringUtil::formatString("input len " + std::to_string(stream->inputLength()) + " is greater than kv cache max available tokens num " + std::to_string(cache_manager_->maxAvailableTokensNum()))); - return false; // Input length exceeds max available tokens + return false; } else if ((size_t)stream->inputLength() * stream->currentBatchSize() > max_batch_tokens_size_) { auto error_info = autil::StringUtil::formatString("input len [%d] * batch size [%d] > max_batch_tokens_size [%d]", @@ -119,247 +112,166 @@ absl::Status FIFOScheduler::enqueue(const GenerateStreamPtr& stream) { } { std::lock_guard lock(lock_); - waiting_streams_.emplace_back(stream); + ScheduleUnit unit; + unit.group_id = -1; + unit.streams.push_back(stream); + waiting_.push_back(std::move(unit)); schedule_trigger_ = true; } cond_.notify_all(); return absl::OkStatus(); } -std::vector> FIFOScheduler::batchEnqueue(const vector& streams) { +std::vector> FIFOScheduler::enqueueGroup(const vector& streams) { RTP_LLM_PROFILE_FUNCTION(); - std::vector> stream_enqueued; - for (auto it = streams.begin(); it != streams.end(); ++it) { - if (checkInputLength((*it))) { - stream_enqueued.emplace_back((*it)); + std::vector> valid_streams; + valid_streams.reserve(streams.size()); + for (const auto& stream : streams) { + if (checkInputLength(stream)) { + valid_streams.emplace_back(stream); } } - { + if (!valid_streams.empty()) { std::lock_guard lock(lock_); - waiting_streams_.insert(waiting_streams_.end(), stream_enqueued.begin(), stream_enqueued.end()); + bool is_group = !valid_streams.empty() && valid_streams[0]->isGroup(); + if (is_group) { + ScheduleUnit unit; + unit.group_id = valid_streams[0]->groupId(); + unit.streams = valid_streams; + waiting_.push_back(std::move(unit)); + } else { + for (auto& stream : valid_streams) { + ScheduleUnit unit; + unit.group_id = -1; + unit.streams.push_back(stream); + waiting_.push_back(std::move(unit)); + } + } schedule_trigger_ = true; } cond_.notify_all(); - return stream_enqueued; + return valid_streams; } -bool FIFOScheduler::evaluateRunningBatch(const list& streams, - const GenerateStreamPtr& new_stream) const { - RTP_LLM_PROFILE_FUNCTION(); - if (pd_sep_config_.role_type == RoleType::DECODE) { - // Decode-only scheduling can top up an existing running decode batch. - // max_generate_batch_size_ is an inclusive cap; only requests above it - // should be rejected. - if (running_streams_.size() + streams.size() + 1 <= max_generate_batch_size_) { - return true; +void FIFOScheduler::accountBatchMetrics(const GenerateStreamPtr& new_stream) { + for (auto& unit : running_) { + for (auto& stream : unit.streams) { + stream->incBatchWithPrefillTimes(1); + stream->incBatchWithPrefillLen(new_stream->currentExecuteTokenSize()); } } - // prefill and decode not mixed together - if (!running_streams_.empty()) { - return false; - } - // Conservative CP prefill mode: cap at one stream per round unless runtime - // config explicitly allows CP prefill batching. - if (cp_force_single_prefill_ && !streams.empty()) { - return false; - } - if (running_streams_.size() + streams.size() + 1 > max_generate_batch_size_) { - return false; - } - - int max_token_size = new_stream->contextLength(); - if (streams.empty() && max_token_size + running_streams_.size() < int(max_seq_len_)) { - return true; - } - for (auto& stream : streams) { - max_token_size = std::max(max_token_size, stream->contextLength()); - } - // 这里的判断是要求当前调度轮所有请求参与计算的 token 数之和小于 max_batch_tokens_size_,loading_cache_streams - // 这一轮实际不参与计算,不需要计入。 - return max_token_size * (streams.size() + 1) + running_streams_.size() < int(max_batch_tokens_size_); } -size_t FIFOScheduler::countInitedKVCacheStreams() const { - auto count_inited = [](const list& streams) { - size_t count = 0; - for (const auto& stream : streams) { - if (stream && stream->curBlocksNum() > 0) { - ++count; - } - } - return count; - }; - return count_inited(waiting_streams_) + count_inited(loading_cache_streams_) + count_inited(running_streams_); +bool FIFOScheduler::waitPredicate() { + return stop_ || schedule_trigger_ || !waiting_.empty() || !loading_.empty() || !running_.empty(); } -void FIFOScheduler::accountBatchMetrics(const GenerateStreamPtr& new_stream) { - for (auto& stream : running_streams_) { - stream->incBatchWithPrefillTimes(1); - stream->incBatchWithPrefillLen(new_stream->currentExecuteTokenSize()); +size_t FIFOScheduler::countStreams(const std::list& queue) const { + size_t total = 0; + for (const auto& unit : queue) { + total += unit.size(); } + return total; } -bool FIFOScheduler::waitPredicate() { - // Check streams directly without calling empty() which acquires lock_ (already held by schedule()) - return stop_ || schedule_trigger_ || !waiting_streams_.empty() || !loading_cache_streams_.empty() - || !running_streams_.empty(); -} - -// 通过 GenerateStateMachine 驱动每个 stream 的状态转移,状态变化的 stream 移入对应队列 -void FIFOScheduler::evaluateAndUpdateStreams(list& streams) { - RTP_LLM_PROFILE_FUNCTION(); - for (auto it = streams.begin(); it != streams.end();) { - auto state = (*it)->getStatus(); - auto new_state = (*it)->moveToNext(); - if (new_state != state) { - addStreamToNewState(*it, new_state); - it = streams.erase(it); - } else { - it++; +std::list FIFOScheduler::flattenRunning() const { + std::list result; + for (const auto& unit : running_) { + for (const auto& stream : unit.streams) { + result.push_back(stream); } } + return result; } -void FIFOScheduler::evaluateWaitingStreams(list& waiting_streams) { - RTP_LLM_PROFILE_FUNCTION(); - list admitted_streams; - std::unordered_set admitted_stream_ptrs; - const size_t inited_kv_streams = - max_inited_kv_cache_streams_ > 0 ? countInitedKVCacheStreams() : 0; - size_t admitted_new_init_streams = 0; - - // Batch group scheduling support: - // 1. Group completeness: force_batch streams with same batch_group_id are scheduled together - // only when group size reaches batch_group_size - // 2. Timeout fallback: if batch_group_timeout expires, incomplete group is scheduled as normal - // 3. Batch isolation: each scheduling round handles only one type: - // - normal streams, OR - // - streams from a single force_batch group - - struct GroupInfo { - int64_t first_arrival_time = 0; - int count = 0; - }; - std::unordered_map request_group_info; +bool FIFOScheduler::canAdmitUnit(size_t admitted_count, + size_t admitted_total_tokens, + size_t running_count, + const ScheduleUnit& unit) const { + if (pd_sep_config_.role_type == RoleType::DECODE) { + return running_count + admitted_count + unit.size() <= max_generate_batch_size_; + } + if (running_count > 0) { + return false; + } + if (admitted_count + unit.size() > max_generate_batch_size_) { + return false; + } + if (unit.isGroup()) { + return true; + } + size_t unit_tokens = 0; + for (const auto& s : unit.streams) { + unit_tokens += s->contextLength(); + } + return admitted_total_tokens + unit_tokens < max_batch_tokens_size_; +} - int64_t now = autil::TimeUtility::currentTimeInMilliSeconds(); +void FIFOScheduler::admitWaitingUnits() { + size_t admitted_count = 0; + size_t admitted_total_tokens = 0; + size_t running_count = countStreams(running_); + int64_t admitted_group_id = -1; - // Build group info statistics for force_batch streams - for (const auto& stream : waiting_streams) { - if (stream->forceBatch() && stream->batchGroupId() != -1) { - auto& info = request_group_info[stream->batchGroupId()]; - if (info.count == 0) { - info.first_arrival_time = stream->enqueueTime() / 1000; - } - info.count++; + // Remove units with pre-existing errors to avoid zombie entries + for (auto it = waiting_.begin(); it != waiting_.end();) { + if (it->hasError()) { + it = waiting_.erase(it); + } else { + ++it; } } - int64_t force_batch_group_id = -1; - - for (auto it = waiting_streams.begin(); it != waiting_streams.end();) { - auto& stream = *it; - bool force_batch = stream->forceBatch(); - - // Check if this stream can be scheduled based on batch group rules - if (force_batch && stream->batchGroupId() != -1) { - auto& info = request_group_info[stream->batchGroupId()]; - // Check timeout: if expired, treat as normal stream - if (now - info.first_arrival_time > stream->batchGroupTimeout()) { - force_batch = false; - } else if (info.count < stream->batchGroupSize()) { - // Group incomplete, skip this stream - it++; - continue; + for (auto it = waiting_.begin(); it != waiting_.end();) { + auto& unit = *it; + if (admitted_count > 0) { + if (admitted_group_id != -1) { + break; } - } - - // Batch isolation: force_batch streams and normal streams cannot mix in the same round. - // The first stream that passes checks determines the batch type for this round. - if (!admitted_streams.empty()) { - if (force_batch_group_id != -1) { - // Already in force_batch mode, only accept same group - if (!force_batch || stream->batchGroupId() != force_batch_group_id) { - it++; - continue; - } - } else { - // Already in normal mode, skip force_batch streams - if (force_batch) { - it++; - continue; - } + if (unit.isGroup()) { + ++it; + continue; } } - - // Check for errors and memory constraints - // - // Some PD decode streams already carry CanRun before entering FIFO: DecodeRpcServer uses - // CanRun to drive the pre-enqueue KV allocation path. CanRun is a permanent event, so it - // cannot be used as proof that FIFO has admitted this stream in the current scheduling - // round. Always run FIFO capacity checks and only advance streams admitted here. - const bool already_inited_kv = stream->curBlocksNum() > 0; - if (max_inited_kv_cache_streams_ > 0 && !already_inited_kv - && inited_kv_streams + admitted_new_init_streams >= max_inited_kv_cache_streams_) { - it++; + if (!canAdmitUnit(admitted_count, admitted_total_tokens, running_count, unit)) { + ++it; continue; } - - if (!stream->hasError() && evaluateRunningBatch(admitted_streams, stream)) { - if (!stream->hasEvent(StreamEvents::CanRun)) { - stream->reportEvent(StreamEvents::CanRun); - } - admitted_streams.push_back(stream); - admitted_stream_ptrs.insert(stream.get()); - if (max_inited_kv_cache_streams_ > 0 && !already_inited_kv) { - ++admitted_new_init_streams; - } - - // Lock batch type based on first scheduled stream - if (admitted_streams.size() == 1 && force_batch && stream->batchGroupId() != -1) { - force_batch_group_id = stream->batchGroupId(); - } + if (admitted_count == 0 && unit.isGroup()) { + admitted_group_id = unit.group_id; } - it++; - } - - for (auto it = waiting_streams.begin(); it != waiting_streams.end();) { - auto& stream = *it; - if (!stream->hasError() && admitted_stream_ptrs.find(stream.get()) == admitted_stream_ptrs.end()) { - it++; + bool needs_loading = unit.prepare(); + if (!unit.alive()) { + it = waiting_.erase(it); continue; } - auto state = stream->getStatus(); - auto new_state = stream->moveToNext(); - if (new_state != state) { - addStreamToNewState(stream, new_state); - it = waiting_streams.erase(it); + admitted_count += unit.size(); + for (const auto& s : unit.streams) { + admitted_total_tokens += s->contextLength(); + } + if (needs_loading) { + loading_.splice(loading_.end(), waiting_, it++); } else { - it++; + unit.activate(); + if (unit.alive()) { + for (auto& s : unit.streams) { + if (s->isContextStream()) { + RTP_LLM_ACCESS_LOG_INFO("request_activated: %s role=prefill input_len=%d", + s->streamLogTag().c_str(), s->inputLength()); + } else { + RTP_LLM_ACCESS_LOG_INFO("request_activated: %s role=decode seq_len=%d", + s->streamLogTag().c_str(), s->seqLength()); + } + accountBatchMetrics(s); + } + running_.splice(running_.end(), waiting_, it++); + } else { + it = waiting_.erase(it); + } } } } -void FIFOScheduler::addStreamToNewState(const GenerateStreamPtr& stream, StreamState new_state) { - switch (new_state) { - case StreamState::WAITING: - waiting_streams_.push_back(stream); - break; - case StreamState::LOADING_CACHE: - loading_cache_streams_.push_back(stream); - break; - case StreamState::RUNNING: - accountBatchMetrics(stream); - new_streams_.push_back(stream); - break; - case StreamState::FINISHED: - break; - default: - RTP_LLM_LOG_ERROR("Unknown state: %d for stream [%ld]", static_cast(new_state), stream->streamId()); - break; - } -} - absl::StatusOr> FIFOScheduler::schedule() { unique_lock lock(lock_); if (need_fill_fake_stream_) { @@ -367,60 +279,102 @@ absl::StatusOr> FIFOScheduler::schedule() { } else { cond_.wait(lock, [this] { return waitPredicate(); }); } - schedule_trigger_ = false; - // LOADING_CACHE -> DONE/WAITING: error / load cache done - evaluateAndUpdateStreams(loading_cache_streams_); - // RUNNING -> DONE: error / finished - evaluateAndUpdateStreams(running_streams_); + // 1. running: advance + cleanup + for (auto it = running_.begin(); it != running_.end();) { + it->advance(); + if (!it->alive()) { + for (auto& s : it->streams) { + RTP_LLM_ACCESS_LOG_INFO("request_finished: %s output_len=%ld iter_count=%ld", + s->streamLogTag().c_str(), + s->outputTokenLen(), s->iterCount()); + } + it = running_.erase(it); + } else { + ++it; + } + } - // WAITING -> RUNNING: can run - // WAITING -> LOADING_CACHE: load cache ok - // - // WAITING streams are advanced only after FIFO admits them in this scheduling round. - // This matters for PD decode: DecodeRpcServer may pre-set CanRun before enqueue to - // allocate KV blocks, so a permanent CanRun bit alone must not bypass capacity checks. - size_t prev_waiting_size = waiting_streams_.size(); - evaluateWaitingStreams(waiting_streams_); - running_streams_.insert(running_streams_.end(), new_streams_.begin(), new_streams_.end()); - new_streams_.clear(); + // 2. loading: check ready -> activate -> move to running + // DECODE streams always enter loading_ (async KV cache loading from prefill). + // PREFILL streams may also enter loading_ when prefix caching is enabled + // (asyncLoadCache returns true based on cache connector configuration). + // Use activated_count to track streams moved to running in this cycle, + // so the batch size check accounts for previously activated units. + // Similarly, accumulate admitted_total_tokens for token budget checking + // (relevant for PREFILL non-group units with prefix caching). + size_t running_at_step2 = countStreams(running_); + size_t activated_count = 0; + size_t admitted_total_tokens = 0; + for (auto it = loading_.begin(); it != loading_.end();) { + if (it->isReady()) { + if (!canAdmitUnit(activated_count, admitted_total_tokens, running_at_step2, *it)) { + ++it; + continue; + } + it->activate(); + if (it->alive()) { + activated_count += it->size(); + for (auto& s : it->streams) { + if (s->isContextStream()) { + RTP_LLM_ACCESS_LOG_INFO("request_activated: %s role=prefill input_len=%d", + s->streamLogTag().c_str(), s->inputLength()); + } else { + RTP_LLM_ACCESS_LOG_INFO("request_activated: %s role=decode seq_len=%d", + s->streamLogTag().c_str(), s->seqLength()); + } + admitted_total_tokens += s->contextLength(); + accountBatchMetrics(s); + } + running_.splice(running_.end(), loading_, it++); + } else { + it = loading_.erase(it); + } + } else { + ++it; + } + } - // If streams were scheduled, trigger next scheduling round - if (waiting_streams_.size() < prev_waiting_size) { + // 3. waiting: admit -> prepare -> loading or running + size_t prev_waiting_size = countStreams(waiting_); + admitWaitingUnits(); + if (countStreams(waiting_) < prev_waiting_size) { schedule_trigger_ = true; } reportMetrics(); last_schedule_time_ = autil::TimeUtility::currentTimeInMilliSeconds(); - return running_streams_; + return flattenRunning(); } int64_t FIFOScheduler::waitingStreamsSize() { std::lock_guard lock(lock_); - return waiting_streams_.size(); + return countStreams(waiting_); } int64_t FIFOScheduler::runningStreamsSize() { std::lock_guard lock(lock_); - return running_streams_.size(); + return countStreams(running_); } int64_t FIFOScheduler::onflightStreams() { std::lock_guard lock(lock_); - return waiting_streams_.size() + loading_cache_streams_.size() + running_streams_.size(); + return countStreams(waiting_) + countStreams(loading_) + countStreams(running_); } std::vector FIFOScheduler::waitingTaskList() { std::lock_guard lock(lock_); waiting_task_list_.clear(); - waiting_task_list_.reserve(waiting_streams_.size()); - for (const auto& stream : waiting_streams_) { - EngineScheduleInfo::TaskInfo task_info; - task_info.request_id = stream->streamId(); - task_info.prefix_length = stream->prefixLength(); - task_info.input_length = stream->inputLength(); - waiting_task_list_.emplace_back(task_info); + for (const auto& unit : waiting_) { + for (const auto& stream : unit.streams) { + EngineScheduleInfo::TaskInfo task_info; + task_info.request_id = stream->streamId(); + task_info.prefix_length = stream->prefixLength(); + task_info.input_length = stream->inputLength(); + task_info.batch_id = unit.group_id; + waiting_task_list_.emplace_back(task_info); + } } return waiting_task_list_; } @@ -428,13 +382,15 @@ std::vector FIFOScheduler::waitingTaskList() { std::vector FIFOScheduler::runningTaskList() { std::lock_guard lock(lock_); running_task_list_.clear(); - running_task_list_.reserve(running_streams_.size()); - for (const auto& stream : running_streams_) { - EngineScheduleInfo::TaskInfo task_info; - task_info.request_id = stream->streamId(); - task_info.prefix_length = stream->prefixLength(); - task_info.input_length = stream->inputLength(); - running_task_list_.emplace_back(task_info); + for (const auto& unit : running_) { + for (const auto& stream : unit.streams) { + EngineScheduleInfo::TaskInfo task_info; + task_info.request_id = stream->streamId(); + task_info.prefix_length = stream->prefixLength(); + task_info.input_length = stream->inputLength(); + task_info.batch_id = unit.group_id; + running_task_list_.emplace_back(task_info); + } } return running_task_list_; } @@ -442,9 +398,9 @@ std::vector FIFOScheduler::runningTaskList() { void FIFOScheduler::reportMetrics() { if (metrics_reporter_) { RtpLLMSchedulerMetricsCollector collector; - collector.wait_stream_size = waiting_streams_.size(); - collector.running_stream_size = running_streams_.size(); - collector.loading_cache_stream_size = loading_cache_streams_.size(); + collector.wait_stream_size = countStreams(waiting_); + collector.running_stream_size = countStreams(running_); + collector.loading_cache_stream_size = countStreams(loading_); metrics_reporter_->report(nullptr, &collector); } return; diff --git a/rtp_llm/cpp/engine_base/schedulers/FIFOScheduler.h b/rtp_llm/cpp/engine_base/schedulers/FIFOScheduler.h index 3d8860a5ad..fb70e9503f 100644 --- a/rtp_llm/cpp/engine_base/schedulers/FIFOScheduler.h +++ b/rtp_llm/cpp/engine_base/schedulers/FIFOScheduler.h @@ -8,6 +8,7 @@ #include "rtp_llm/cpp/cache/KVCacheManager.h" #include "rtp_llm/cpp/engine_base/stream/GenerateTypes.h" #include "rtp_llm/cpp/engine_base/schedulers/SchedulerBase.h" +#include "rtp_llm/cpp/engine_base/schedulers/ScheduleUnit.h" #include "kmonitor/client/MetricsReporter.h" #include "rtp_llm/cpp/config/ConfigModules.h" #include "rtp_llm/cpp/engine_base/schedulers/EngineScheduleInfo.h" @@ -26,15 +27,8 @@ class FIFOScheduler: public SchedulerBase { ~FIFOScheduler() override; - // Enqueue a single stream. Returns OkStatus on success, InvalidArgumentError if checkInputLength fails. - // On failure, the stream's error is reported via reportError() but the stream is NOT queued. - // Caller must check the return status to know whether the stream was actually enqueued. - absl::Status enqueue(const GenerateStreamPtr& stream) override; - - // Enqueue multiple streams. Silently filters out streams that fail checkInputLength (their errors - // are reported via reportError()). Returns only the streams that were successfully enqueued. - // Caller should compare the returned vector size with the input size to detect dropped streams. - std::vector> batchEnqueue(const std::vector& streams) override; + absl::Status enqueue(const GenerateStreamPtr& stream) override; + std::vector> enqueueGroup(const std::vector& streams) override; absl::StatusOr> schedule() override; absl::Status stop() override; bool empty() override; @@ -50,37 +44,31 @@ class FIFOScheduler: public SchedulerBase { int64_t onflightStreams() override; private: - int64_t lastScheduleTime() override; - bool evaluateRunningBatch(const std::list& streams, const GenerateStreamPtr& new_stream) const; - size_t countInitedKVCacheStreams() const; - void accountBatchMetrics(const GenerateStreamPtr& new_stream); - bool waitPredicate(); - void addStreamToNewState(const GenerateStreamPtr& stream, StreamState new_state); - void evaluateWaitingStreams(std::list& streams); - void cancelStreams(std::list& streams); - bool checkInputLength(const GenerateStreamPtr& stream); - -protected: - void evaluateAndUpdateStreams(std::list& streams); + int64_t lastScheduleTime() override; + void accountBatchMetrics(const GenerateStreamPtr& new_stream); + bool waitPredicate(); + bool checkInputLength(const GenerateStreamPtr& stream); + void admitWaitingUnits(); + bool canAdmitUnit(size_t admitted_count, + size_t admitted_total_tokens, + size_t running_count, + const ScheduleUnit& unit) const; + std::list flattenRunning() const; + size_t countStreams(const std::list& queue) const; + void cancelUnits(std::list& units); protected: PDSepConfig pd_sep_config_; ModelSpecificConfig model_specific_config_; - std::list waiting_streams_; - std::list loading_cache_streams_; - std::list running_streams_; - std::list new_streams_; + std::list waiting_; + std::list loading_; + std::list running_; std::shared_ptr cache_manager_; std::atomic last_schedule_time_ = autil::TimeUtility::currentTimeInMilliSeconds(); - size_t max_seq_len_ = 0; - size_t max_batch_tokens_size_ = 0; - size_t max_generate_batch_size_ = 1; - size_t max_inited_kv_cache_streams_ = 0; - const bool need_fill_fake_stream_ = false; - // Optional guard for Context-Parallel prefill: when enabled, force prefill - // to one stream per round. This remains the conservative default while - // newer dsv4 CP paths can opt in to batched prefill through runtime config. - const bool cp_force_single_prefill_ = false; + size_t max_seq_len_ = 0; + size_t max_batch_tokens_size_ = 0; + size_t max_generate_batch_size_ = 1; + const bool need_fill_fake_stream_ = false; std::atomic stop_ = false; bool schedule_trigger_ = false; std::mutex lock_; @@ -89,8 +77,6 @@ class FIFOScheduler: public SchedulerBase { std::vector waiting_task_list_; std::vector running_task_list_; - - // TODO @wangyin support different beams run togather }; } // namespace rtp_llm diff --git a/rtp_llm/cpp/engine_base/schedulers/GatherBatchScheduler.h b/rtp_llm/cpp/engine_base/schedulers/GatherBatchScheduler.h index 608230733d..aeeedde7ed 100644 --- a/rtp_llm/cpp/engine_base/schedulers/GatherBatchScheduler.h +++ b/rtp_llm/cpp/engine_base/schedulers/GatherBatchScheduler.h @@ -1,4 +1,5 @@ #pragma once +#include #include "rtp_llm/cpp/engine_base/schedulers/FIFOScheduler.h" namespace rtp_llm { @@ -48,61 +49,116 @@ class GatherBatchScheduler: virtual public FIFOScheduler { absl::StatusOr> schedule() override { std::unique_lock lock(lock_); cond_.wait_for(lock, std::chrono::seconds(30), [this] { - return waiting_streams_.size() >= static_cast(gather_batch_size_) || running_streams_.size() > 0 - || !loading_cache_streams_.empty(); + return gatherCountStreams(waiting_) >= static_cast(gather_batch_size_) || !running_.empty() + || !loading_.empty(); }); - // LOADING_CACHE -> DONE/WAITING: error / load cache done - evaluateAndUpdateStreams(loading_cache_streams_); - // RUNNING -> DONE: error / finished - evaluateAndUpdateStreams(running_streams_); + // running: advance + cleanup + for (auto it = running_.begin(); it != running_.end();) { + it->advance(); + if (!it->alive()) { + it = running_.erase(it); + } else { + ++it; + } + } + + // loading: check ready -> activate -> move to running + for (auto it = loading_.begin(); it != loading_.end();) { + if (it->isReady()) { + it->activate(); + if (it->alive()) { + running_.splice(running_.end(), loading_, it++); + } else { + it = loading_.erase(it); + } + } else { + ++it; + } + } - // PyWrappedModel currently does not support a mixed prefill+decode batch (see - // PyWrappedModel::buildPyAttentionInputs cu_seqlens slicing). Defer the gather - // until running streams drain so the next batch is pure prefill. - // NOTE: the `load_python_model` flag was removed upstream in commit 901d077f1; - // we now always assume a python model and gate solely on running-stream count. - const bool python_model_busy = !running_streams_.empty(); - if (waiting_streams_.size() >= static_cast(gather_batch_size_) && !python_model_busy) { - // Gather exactly gather_batch_size_ streams + const bool python_model_busy = !running_.empty(); + if (gatherCountStreams(waiting_) >= static_cast(gather_batch_size_) && !python_model_busy) { std::list new_streams; - for (auto it = waiting_streams_.begin(); it != waiting_streams_.end(); it++) { - if (!(*it)->hasError()) { - new_streams.push_back(*it); + for (auto& unit : waiting_) { + for (auto& s : unit.streams) { + if (!s->hasError()) { + new_streams.push_back(s); + } + if (new_streams.size() >= static_cast(gather_batch_size_)) { + break; + } } if (new_streams.size() >= static_cast(gather_batch_size_)) { break; } } - // Only schedule when we have enough streams if (new_streams.size() >= static_cast(gather_batch_size_)) { + std::unordered_set scheduled_ids; for (auto& stream : new_streams) { - stream->reportEvent(StreamEvents::CanRun); - // busy wait for loading cache done, equivalent to to original logic. - while (stream->getStatus() != StreamState::FINISHED - && stream->moveToNext() != StreamState::RUNNING) { + stream->prepare(); + while (stream->alive() && !stream->isReady()) { std::this_thread::sleep_for(std::chrono::milliseconds(1)); } + if (stream->alive()) { + stream->activate(); + } } - // 过滤 FINISHED stream,仅将 RUNNING stream 加入 running_streams_ - new_streams.remove_if([](const auto& s) { return s->getStatus() == StreamState::FINISHED; }); - // 按 streamId 排序以保证 CI 确定性结果 + new_streams.remove_if([](const auto& s) { return !s->alive(); }); new_streams.sort([](const GenerateStreamPtr& a, const GenerateStreamPtr& b) { return a->streamId() < b->streamId(); }); - running_streams_.insert(running_streams_.end(), new_streams.begin(), new_streams.end()); - // Remove scheduled streams from waiting_streams_ - for (auto& stream : new_streams) { - waiting_streams_.remove(stream); + for (auto& s : new_streams) { + scheduled_ids.insert(s->streamId()); + } + // Move scheduled streams into a running unit + ScheduleUnit run_unit; + run_unit.group_id = -1; + for (auto& s : new_streams) { + run_unit.streams.push_back(s); + } + running_.push_back(std::move(run_unit)); + // Remove scheduled streams from waiting_ + for (auto it = waiting_.begin(); it != waiting_.end();) { + for (auto sit = it->streams.begin(); sit != it->streams.end();) { + if (scheduled_ids.count((*sit)->streamId())) { + sit = it->streams.erase(sit); + } else { + ++sit; + } + } + if (it->streams.empty()) { + it = waiting_.erase(it); + } else { + ++it; + } } } gather_batch_size_ = 1; } - return running_streams_; + return gatherFlattenRunning(); } protected: + size_t gatherCountStreams(const std::list& queue) const { + size_t total = 0; + for (const auto& unit : queue) { + total += unit.size(); + } + return total; + } + + std::list gatherFlattenRunning() const { + std::list result; + for (const auto& unit : running_) { + for (const auto& stream : unit.streams) { + result.push_back(stream); + } + } + return result; + } + int gather_batch_size_; }; diff --git a/rtp_llm/cpp/engine_base/schedulers/ScheduleUnit.h b/rtp_llm/cpp/engine_base/schedulers/ScheduleUnit.h new file mode 100644 index 0000000000..8dff4d241d --- /dev/null +++ b/rtp_llm/cpp/engine_base/schedulers/ScheduleUnit.h @@ -0,0 +1,80 @@ +#pragma once + +#include +#include +#include "rtp_llm/cpp/engine_base/stream/GenerateTypes.h" + +namespace rtp_llm { + +struct ScheduleUnit { + int64_t group_id = -1; + std::vector streams; + + bool prepare() { + bool any_loading = false; + for (auto it = streams.begin(); it != streams.end();) { + bool needs = (*it)->prepare(); + if (!(*it)->alive()) { + it = streams.erase(it); + continue; + } + if (needs) + any_loading = true; + ++it; + } + return any_loading; + } + + bool isReady() { + for (auto& s : streams) { + if (!s->isReady()) + return false; + } + return true; + } + + void activate() { + for (auto it = streams.begin(); it != streams.end();) { + (*it)->activate(); + if (!(*it)->alive()) { + it = streams.erase(it); + continue; + } + ++it; + } + } + + void advance() { + for (auto it = streams.begin(); it != streams.end();) { + (*it)->advance(); + if (!(*it)->alive()) { + it = streams.erase(it); + continue; + } + ++it; + } + } + + bool alive() const { + return !streams.empty(); + } + + bool hasError() const { + for (const auto& s : streams) { + if (s->hasError()) { + return true; + } + } + return false; + } + + bool isGroup() const { + return group_id != -1; + } + + size_t size() const { + return streams.size(); + } +}; + +} // namespace rtp_llm diff --git a/rtp_llm/cpp/engine_base/schedulers/SchedulerBase.h b/rtp_llm/cpp/engine_base/schedulers/SchedulerBase.h index 660807fe6d..f92dbe246b 100644 --- a/rtp_llm/cpp/engine_base/schedulers/SchedulerBase.h +++ b/rtp_llm/cpp/engine_base/schedulers/SchedulerBase.h @@ -16,7 +16,7 @@ class SchedulerBase { public: virtual ~SchedulerBase() {} virtual absl::Status enqueue(const GenerateStreamPtr& stream) = 0; - virtual std::vector batchEnqueue(const std::vector& streams) = 0; + virtual std::vector enqueueGroup(const std::vector& streams) = 0; virtual absl::StatusOr> schedule() = 0; // Conservative-KV scheduling variant for async execution. The async path diff --git a/rtp_llm/cpp/engine_base/schedulers/test/BUILD b/rtp_llm/cpp/engine_base/schedulers/test/BUILD index 32662f7419..170846341b 100644 --- a/rtp_llm/cpp/engine_base/schedulers/test/BUILD +++ b/rtp_llm/cpp/engine_base/schedulers/test/BUILD @@ -3,6 +3,18 @@ load("//bazel:arch_select.bzl", "torch_deps") cc_test = cc_test_wrapper +cc_import( + name = "cuda13_torch_nvshmem", + shared_library = "@pip_gpu_cuda13_torch_torch//:site-packages/torch/lib/libtorch_nvshmem.so", +) + +cuda13_torch_link_deps = select({ + "@//:using_cuda13_x86": [ + ":cuda13_torch_nvshmem", + ], + "//conditions:default": [], +}) + test_deps = [ "//rtp_llm/cpp/testing:device_test_utils", "//rtp_llm/models_py/bindings/cuda/ops:cuda_impl", @@ -14,7 +26,7 @@ test_deps = [ "@com_google_googletest//:gtest_main", "@local_config_cuda//cuda:cuda_headers", "@local_config_cuda//cuda:cudart", -] + torch_deps() +] + torch_deps() + cuda13_torch_link_deps async_cache_test_deps = test_deps + [ "//rtp_llm/cpp/cache/test:cache_config_test_utils", diff --git a/rtp_llm/cpp/engine_base/schedulers/test/FIFOSchedulerAsyncCacheTest.cc b/rtp_llm/cpp/engine_base/schedulers/test/FIFOSchedulerAsyncCacheTest.cc index ccabba49e7..502bf93e33 100644 --- a/rtp_llm/cpp/engine_base/schedulers/test/FIFOSchedulerAsyncCacheTest.cc +++ b/rtp_llm/cpp/engine_base/schedulers/test/FIFOSchedulerAsyncCacheTest.cc @@ -8,6 +8,7 @@ #define private public #define protected public #include "rtp_llm/cpp/engine_base/schedulers/FIFOScheduler.h" +#include "rtp_llm/cpp/engine_base/schedulers/ScheduleUnit.h" #include "rtp_llm/cpp/engine_base/stream/GenerateStream.h" #include "rtp_llm/cpp/engine_base/stream/StreamCacheResource.h" #include "rtp_llm/cpp/normal_engine/NormalGenerateStream.h" @@ -122,16 +123,16 @@ TEST_F(FIFOSchedulerAsyncCacheTest, testScheduleNew_NoReuseCache_DirectlyRunning auto result = scheduler->schedule(); ASSERT_TRUE(result.ok()); ASSERT_EQ(result.value().size(), 1); - ASSERT_EQ(scheduler->loading_cache_streams_.size(), 0); + ASSERT_EQ(scheduler->countStreams(scheduler->loading_), 0u); ASSERT_EQ(scheduler->waitingStreamsSize(), 0); ASSERT_EQ(scheduler->runningStreamsSize(), 1); } // ============================================================================ -// 2. scheduleNew: stream with reuse_cache and connector enters LOADING_CACHE +// 2. scheduleNew: stream with reuse_cache and connector enters loading_ queue // ============================================================================ -TEST_F(FIFOSchedulerAsyncCacheTest, testScheduleNew_WithReuseCache_EntersLoadingCache) { +TEST_F(FIFOSchedulerAsyncCacheTest, testScheduleNew_WithReuseCache_EntersLoadingQueue) { setupMockCoordinator(); auto pending_ctx = createPendingAsyncContext(); EXPECT_CALL(*mock_coord_, asyncRead(_)).WillOnce(Return(std::static_pointer_cast(pending_ctx))); @@ -142,19 +143,18 @@ TEST_F(FIFOSchedulerAsyncCacheTest, testScheduleNew_WithReuseCache_EntersLoading ASSERT_TRUE(scheduler->enqueue(stream).ok()); auto result = scheduler->schedule(); ASSERT_TRUE(result.ok()); - // Stream is in LOADING_CACHE, not in running + // Stream is in loading_ queue, not in running ASSERT_EQ(result.value().size(), 0); - ASSERT_TRUE(stream->getStatus() == StreamState::LOADING_CACHE); - ASSERT_EQ(scheduler->loading_cache_streams_.size(), 1); + ASSERT_EQ(scheduler->countStreams(scheduler->loading_), 1u); ASSERT_EQ(scheduler->waitingStreamsSize(), 0); ASSERT_EQ(scheduler->runningStreamsSize(), 0); } // ============================================================================ -// 3. evaluateLoadingCacheStreams: stream load done -> moves to waiting -> then running +// 3. loading check: stream load done -> moves to waiting -> then running // ============================================================================ -TEST_F(FIFOSchedulerAsyncCacheTest, testEvaluateLoadingCache_LoadDone_MovesToRunning) { +TEST_F(FIFOSchedulerAsyncCacheTest, testLoadingCheck_LoadDone_MovesToRunning) { setupMockCoordinator(); // Mock context: done() returns true when checked (load completes immediately) @@ -170,30 +170,29 @@ TEST_F(FIFOSchedulerAsyncCacheTest, testEvaluateLoadingCache_LoadDone_MovesToRun ASSERT_TRUE(scheduler->enqueue(stream).ok()); - // First schedule: stream enters LOADING_CACHE - // (evaluateLoadingCacheStreams runs before scheduleNew, so loading_cache_streams_ is empty at that point) + // First schedule: stream enters loading_ queue + // (loading check runs before scheduleNew, so loading_ is empty at that point) auto result1 = scheduler->schedule(); ASSERT_TRUE(result1.ok()); ASSERT_EQ(result1.value().size(), 0); - ASSERT_EQ(scheduler->loading_cache_streams_.size(), 1); - ASSERT_TRUE(stream->getStatus() == StreamState::LOADING_CACHE); + ASSERT_EQ(scheduler->countStreams(scheduler->loading_), 1u); - // Second schedule: evaluateLoadingCacheStreams -> loadCacheDone()=true -> WAITING -> scheduleNew -> RUNNING + // Second schedule: loading check -> loadCacheDone()=true -> WAITING -> scheduleNew -> RUNNING auto result2 = scheduler->schedule(); ASSERT_TRUE(result2.ok()); ASSERT_EQ(result2.value().size(), 1); - ASSERT_EQ(scheduler->loading_cache_streams_.size(), 0); + ASSERT_EQ(scheduler->countStreams(scheduler->loading_), 0u); ASSERT_EQ(scheduler->runningStreamsSize(), 1); } // ============================================================================ -// 4. evaluateLoadingCacheStreams: stream with error during loading -> evicted +// 4. loading check: stream with error during loading -> evicted // ============================================================================ -TEST_F(FIFOSchedulerAsyncCacheTest, testEvaluateLoadingCache_ErrorDuringLoading_Evicted) { +TEST_F(FIFOSchedulerAsyncCacheTest, testLoadingCheck_ErrorDuringLoading_Evicted) { setupMockCoordinator(); - // Mock context: done() returns true so evaluateLoadingCacheStreams proceeds to error check + // Mock context: done() returns true so loading check proceeds to error check auto mock_ctx = std::make_shared>(); ON_CALL(*mock_ctx, done()).WillByDefault(Return(true)); ON_CALL(*mock_ctx, success()).WillByDefault(Return(true)); @@ -206,10 +205,10 @@ TEST_F(FIFOSchedulerAsyncCacheTest, testEvaluateLoadingCache_ErrorDuringLoading_ ASSERT_TRUE(scheduler->enqueue(stream).ok()); - // First schedule: enters LOADING_CACHE + // First schedule: enters loading_ queue auto result1 = scheduler->schedule(); ASSERT_TRUE(result1.ok()); - ASSERT_EQ(scheduler->loading_cache_streams_.size(), 1); + ASSERT_EQ(scheduler->countStreams(scheduler->loading_), 1u); // Simulate external error (e.g., cancel from gRPC) stream->reportError(ErrorCode::CANCELLED, "cancelled by client"); @@ -217,16 +216,16 @@ TEST_F(FIFOSchedulerAsyncCacheTest, testEvaluateLoadingCache_ErrorDuringLoading_ // Second schedule: loadCacheDone()=true, hasError()=true -> stream evicted and finished auto result2 = scheduler->schedule(); ASSERT_TRUE(result2.ok()); - ASSERT_EQ(scheduler->loading_cache_streams_.size(), 0); + ASSERT_EQ(scheduler->countStreams(scheduler->loading_), 0u); ASSERT_EQ(result2.value().size(), 0); ASSERT_TRUE(stream->isFinished()); } // ============================================================================ -// 5. loading_cache_streams_ counted in evaluateRunningBatch (batch size limit) +// 5. admitted_count limits per-round batch size // ============================================================================ -TEST_F(FIFOSchedulerAsyncCacheTest, testLoadingCacheStreams_CountedInBatchLimit) { +TEST_F(FIFOSchedulerAsyncCacheTest, testLoadingQueue_CountedInBatchLimit) { setupMockCoordinator(); // Set max batch size to 2 @@ -255,26 +254,26 @@ TEST_F(FIFOSchedulerAsyncCacheTest, testLoadingCacheStreams_CountedInBatchLimit) auto result = scheduler->schedule(); ASSERT_TRUE(result.ok()); - // loading_cache_streams_ should count toward max_generate_batch_size - // With max=2, only 2 streams should be scheduled (into LOADING_CACHE) + // loading_ queue should count toward max_generate_batch_size + // With max=2, only 2 streams should be scheduled (into loading_ queue) // The 3rd stream should remain in waiting ASSERT_LE(result.value().size(), 2); } // ============================================================================ -// 6. scheduleNew: stream returning from LOADING_CACHE (already has blocks) skips asyncLoadCache +// 6. scheduleNew: stream returning from loading_ queue (already has blocks) skips asyncLoadCache // ============================================================================ -TEST_F(FIFOSchedulerAsyncCacheTest, testScheduleNew_ReturningFromLoadingCache_SkipsAsyncLoad) { +TEST_F(FIFOSchedulerAsyncCacheTest, testScheduleNew_ReturningFromLoadingQueue_SkipsAsyncLoad) { setupMockCoordinator(); - // Mock context: done() returns true when checked in evaluateLoadingCacheStreams + // Mock context: done() returns true when checked in loading check auto mock_ctx = std::make_shared>(); ON_CALL(*mock_ctx, done()).WillByDefault(Return(true)); ON_CALL(*mock_ctx, success()).WillByDefault(Return(true)); ON_CALL(*mock_ctx, waitDone()).WillByDefault(Return()); - // asyncRead should only be called ONCE (for the first time entering LOADING_CACHE) + // asyncRead should only be called ONCE (for the first time entering loading_ queue) EXPECT_CALL(*mock_coord_, asyncRead(_)).Times(1).WillOnce(Return(std::static_pointer_cast(mock_ctx))); auto scheduler = createScheduler(); @@ -282,10 +281,10 @@ TEST_F(FIFOSchedulerAsyncCacheTest, testScheduleNew_ReturningFromLoadingCache_Sk ASSERT_TRUE(scheduler->enqueue(stream).ok()); - // First schedule: stream -> LOADING_CACHE + // First schedule: stream -> loading_ queue auto result1 = scheduler->schedule(); ASSERT_TRUE(result1.ok()); - ASSERT_EQ(scheduler->loading_cache_streams_.size(), 1); + ASSERT_EQ(scheduler->countStreams(scheduler->loading_), 1u); // Second schedule: load done -> back to WAITING -> scheduleNew -> RUNNING (skips asyncLoadCache) auto result2 = scheduler->schedule(); @@ -295,10 +294,10 @@ TEST_F(FIFOSchedulerAsyncCacheTest, testScheduleNew_ReturningFromLoadingCache_Sk } // ============================================================================ -// 7. loading_cache_streams_ included in empty() and onflightStreams() +// 7. loading_ queue included in empty() and onflightStreams() // ============================================================================ -TEST_F(FIFOSchedulerAsyncCacheTest, testLoadingCacheStreams_IncludedInEmptyAndOnflight) { +TEST_F(FIFOSchedulerAsyncCacheTest, testLoadingQueue_IncludedInEmptyAndOnflight) { setupMockCoordinator(); auto pending_ctx = createPendingAsyncContext(); @@ -310,26 +309,29 @@ TEST_F(FIFOSchedulerAsyncCacheTest, testLoadingCacheStreams_IncludedInEmptyAndOn ASSERT_TRUE(scheduler->enqueue(stream).ok()); auto result = scheduler->schedule(); ASSERT_TRUE(result.ok()); - ASSERT_EQ(scheduler->loading_cache_streams_.size(), 1); + ASSERT_EQ(scheduler->countStreams(scheduler->loading_), 1u); - // Scheduler should NOT be empty when there are loading_cache_streams_ + // Scheduler should NOT be empty when there are streams in loading_ queue ASSERT_FALSE(scheduler->empty()); - // onflightStreams should include loading_cache_streams_ + // onflightStreams should include loading_ queue ASSERT_EQ(scheduler->onflightStreams(), 1); } // ============================================================================ -// 8. loading_cache_streams_ included in waitPredicate() +// 8. loading_ queue included in waitPredicate() // ============================================================================ -TEST_F(FIFOSchedulerAsyncCacheTest, testWaitPredicate_IncludesLoadingCacheStreams) { +TEST_F(FIFOSchedulerAsyncCacheTest, testWaitPredicate_IncludesLoadingQueue) { auto scheduler = createScheduler(); // Empty scheduler -> waitPredicate should be false ASSERT_FALSE(scheduler->waitPredicate()); - // Add a fake stream to loading_cache_streams_ - auto stream = createStream({1, 2, 3}); - scheduler->loading_cache_streams_.emplace_back(stream); + // Add a fake stream to loading_ queue via ScheduleUnit + auto stream = createStream({1, 2, 3}); + ScheduleUnit unit; + unit.group_id = -1; + unit.streams.push_back(stream); + scheduler->loading_.push_back(std::move(unit)); ASSERT_TRUE(scheduler->waitPredicate()); } @@ -381,21 +383,20 @@ TEST_F(FIFOSchedulerAsyncCacheTest, testMixedAsyncAndDirectStreams) { ASSERT_TRUE(scheduler->enqueue(stream1).ok()); ASSERT_TRUE(scheduler->enqueue(stream2).ok()); - // Single schedule: stream1 -> LOADING_CACHE (async load), stream2 -> RUNNING (directly) + // Single schedule: stream1 -> loading_ queue (async load), stream2 -> RUNNING (directly) auto result = scheduler->schedule(); ASSERT_TRUE(result.ok()); ASSERT_EQ(result.value().size(), 1); // Only stream2 is running - ASSERT_TRUE(stream1->getStatus() == StreamState::LOADING_CACHE); - ASSERT_EQ(scheduler->loading_cache_streams_.size(), 1); + ASSERT_EQ(scheduler->countStreams(scheduler->loading_), 1u); ASSERT_EQ(scheduler->waitingStreamsSize(), 0); ASSERT_EQ(scheduler->runningStreamsSize(), 1); } // ============================================================================ -// 11. evaluateLoadingCacheStreams: stream still loading -> stays in queue +// 11. loading check: stream still loading -> stays in queue // ============================================================================ -TEST_F(FIFOSchedulerAsyncCacheTest, testEvaluateLoadingCache_StillLoading_StaysInQueue) { +TEST_F(FIFOSchedulerAsyncCacheTest, testLoadingCheck_StillLoading_StaysInQueue) { setupMockCoordinator(); auto pending_ctx = createPendingAsyncContext(); @@ -406,27 +407,26 @@ TEST_F(FIFOSchedulerAsyncCacheTest, testEvaluateLoadingCache_StillLoading_StaysI ASSERT_TRUE(scheduler->enqueue(stream).ok()); - // First schedule: enters LOADING_CACHE + // First schedule: enters loading_ queue auto result1 = scheduler->schedule(); ASSERT_TRUE(result1.ok()); - ASSERT_EQ(scheduler->loading_cache_streams_.size(), 1); + ASSERT_EQ(scheduler->countStreams(scheduler->loading_), 1u); // Second schedule: still pending (done() returns false) auto result2 = scheduler->schedule(); ASSERT_TRUE(result2.ok()); - ASSERT_EQ(scheduler->loading_cache_streams_.size(), 1); - ASSERT_TRUE(stream->getStatus() == StreamState::LOADING_CACHE); + ASSERT_EQ(scheduler->countStreams(scheduler->loading_), 1u); ASSERT_EQ(result2.value().size(), 0); } // ============================================================================ -// 12. schedule() ordering: load_done_streams inserted at head of waiting_streams_ +// 12. schedule() ordering: load_done_streams inserted at head of waiting_ // ============================================================================ TEST_F(FIFOSchedulerAsyncCacheTest, testScheduleOrdering_LoadDoneStreamsAtWaitingHead) { setupMockCoordinator(); - // Mock context: done() returns true when checked in evaluateLoadingCacheStreams + // Mock context: done() returns true when checked in loading check auto mock_ctx = std::make_shared>(); ON_CALL(*mock_ctx, done()).WillByDefault(Return(true)); ON_CALL(*mock_ctx, success()).WillByDefault(Return(true)); @@ -436,18 +436,18 @@ TEST_F(FIFOSchedulerAsyncCacheTest, testScheduleOrdering_LoadDoneStreamsAtWaitin auto scheduler = createScheduler(); - // Stream1: will enter LOADING_CACHE first + // Stream1: will enter loading_ queue first auto stream1 = createStream({1, 2}, /*reuse_cache=*/true, /*enable_memory_cache=*/true); ASSERT_TRUE(scheduler->enqueue(stream1).ok()); auto result1 = scheduler->schedule(); ASSERT_TRUE(result1.ok()); - ASSERT_EQ(scheduler->loading_cache_streams_.size(), 1); + ASSERT_EQ(scheduler->countStreams(scheduler->loading_), 1u); // Stream2: enqueued later while stream1 is loading auto stream2 = createStream({3, 4}, /*reuse_cache=*/false); ASSERT_TRUE(scheduler->enqueue(stream2).ok()); - // Second schedule: stream1 load done -> moves to WAITING head -> should be scheduled before stream2 + // Second schedule: stream1 load done -> moves to waiting_ head -> should be scheduled before stream2 auto result2 = scheduler->schedule(); ASSERT_TRUE(result2.ok()); // Both streams should be running now diff --git a/rtp_llm/cpp/engine_base/schedulers/test/FIFOSchedulerCancelTest.cc b/rtp_llm/cpp/engine_base/schedulers/test/FIFOSchedulerCancelTest.cc index c567921a51..a65dfd8caa 100644 --- a/rtp_llm/cpp/engine_base/schedulers/test/FIFOSchedulerCancelTest.cc +++ b/rtp_llm/cpp/engine_base/schedulers/test/FIFOSchedulerCancelTest.cc @@ -68,20 +68,20 @@ class FIFOSchedulerCancelTest: public DeviceTestBase { return cache_manager_->freeBlocksNum(); } - // Helper function to schedule a stream through WAITING->LOADING_CACHE->WAITING->RUNNING - // Returns the result of the final schedule() call when stream is RUNNING + // Helper function to schedule a stream through to RUNNING state + // Returns the result of the schedule() call that transitions stream(s) to RUNNING absl::StatusOr> scheduleToRunning(std::shared_ptr& scheduler) { - // First schedule: WAITING -> LOADING_CACHE + // First schedule: stream should transition to RUNNING auto result1 = scheduler->schedule(); if (!result1.ok() || result1.value().size() > 0) { - return result1; // Unexpected: already RUNNING or error + return result1; } - // Second schedule: LOADING_CACHE -> WAITING (with CanRun event set) + // If not yet RUNNING, try again (e.g., loading cache) auto result2 = scheduler->schedule(); if (!result2.ok() || result2.value().size() > 0) { - return result2; // Unexpected: error or already done + return result2; } - // Third schedule: WAITING -> RUNNING + // Third attempt return scheduler->schedule(); } @@ -153,7 +153,7 @@ TEST_F(FIFOSchedulerCancelTest, CancelWhileRunning) { // ============================================================================ // 3. Cancel during resource allocation (stream transitions through WAITING -// where initKVBlock happens inside moveToNext) +// where initKVBlock happens inside prepare()) // ============================================================================ TEST_F(FIFOSchedulerCancelTest, CancelDuringResourceAllocation) { auto scheduler = createScheduler(); @@ -163,7 +163,7 @@ TEST_F(FIFOSchedulerCancelTest, CancelDuringResourceAllocation) { ASSERT_TRUE(scheduler->enqueue(stream).ok()); // Report error before the first schedule() — the stream is WAITING - // and moveToNext() will see the Error event before attempting initKVBlock + // and prepare() will see the Error event via alive() check stream->reportError(ErrorCode::CANCELLED, "cancelled during init"); auto result = scheduler->schedule(); @@ -244,9 +244,9 @@ TEST_F(FIFOSchedulerCancelTest, VerifyStateAndErrorAfterCancel) { ASSERT_FALSE(stream->getStatus() == StreamState::RUNNING); ASSERT_FALSE(stream->getStatus() == StreamState::WAITING); ASSERT_EQ(stream->stopReason(), "user requested cancel"); - // moveToNext on FINISHED stream should not crash (idempotent) - auto state = stream->moveToNext(); - ASSERT_EQ(state, StreamState::FINISHED); + // finish() on FINISHED stream should not crash (idempotent) + stream->finish(); + ASSERT_TRUE(stream->isFinished()); } // ============================================================================ @@ -315,8 +315,8 @@ TEST_F(FIFOSchedulerCancelTest, ConcurrentCancelDuringSchedule) { // Run schedule() while cancel thread is active. // Only call schedule() when there are running streams to avoid blocking // on an empty scheduler (waitPredicate would return false and cv blocks forever). - // reportEvent() only sets event flags; the streams remain in running_streams_ - // until schedule() -> evaluateAndUpdateStreams() calls moveToNext(). + // reportEvent() only sets event flags; the streams remain in running_ + // until schedule() -> advance() detects the error and calls finish(). while (scheduler->runningStreamsSize() > 0) { auto result = scheduler->schedule(); ASSERT_TRUE(result.ok()); diff --git a/rtp_llm/cpp/engine_base/schedulers/test/FIFOSchedulerTest.cc b/rtp_llm/cpp/engine_base/schedulers/test/FIFOSchedulerTest.cc index 5394332f25..1a649dbf75 100644 --- a/rtp_llm/cpp/engine_base/schedulers/test/FIFOSchedulerTest.cc +++ b/rtp_llm/cpp/engine_base/schedulers/test/FIFOSchedulerTest.cc @@ -180,7 +180,7 @@ TEST_F(FIFOSchedulerTest, testRejectInputWithoutSpeculativeReserveSpace) { auto valid_stream = make_stream(16); auto invalid_stream2 = make_stream(17); - auto enqueued = scheduler.batchEnqueue({invalid_stream2, valid_stream}); + auto enqueued = scheduler.enqueueGroup({invalid_stream2, valid_stream}); ASSERT_EQ(enqueued.size(), 1); ASSERT_EQ(enqueued[0], valid_stream); ASSERT_TRUE(invalid_stream2->hasError()); @@ -495,7 +495,7 @@ TEST_F(FIFOSchedulerTest, testMaxContextBatchSize) { } } -TEST_F(FIFOSchedulerTest, testBatchEnqueue) { +TEST_F(FIFOSchedulerTest, testEnqueueGroup) { CacheConfig cache_config = makeMhaCacheConfig(1, 4, 1, 4, 8, rtp_llm::DataType::TYPE_FP16); std::shared_ptr cache_manager = std::make_shared(cache_config); ASSERT_TRUE(cache_manager->init()); @@ -530,7 +530,7 @@ TEST_F(FIFOSchedulerTest, testBatchEnqueue) { make_shared(query, model_config, runtime_config, resource_context, nullptr); streams.push_back(stream); } - auto enqueued = scheduler.batchEnqueue(streams); + auto enqueued = scheduler.enqueueGroup(streams); ASSERT_EQ(enqueued.size(), streams.size()); // Single schedule: both streams transition to RUNNING (no cache loading needed) @@ -543,6 +543,327 @@ TEST_F(FIFOSchedulerTest, testBatchEnqueue) { ASSERT_EQ(scheduler.runningStreamsSize(), 2); } +namespace { + +std::shared_ptr makeGroupedStream(int64_t group_id, + int group_size, + const ModelConfig& model_config, + const RuntimeConfig& runtime_config, + const ResourceContext& resource_context, + std::vector tokens = {1, 2, 3}) { + auto query = std::make_shared(); + query->input_ids = torch::tensor(tokens, torch::kInt32); + query->generate_config = std::make_shared(); + query->group_id = group_id; + query->group_size = group_size; + return std::make_shared(query, model_config, runtime_config, resource_context, nullptr); +} + +std::shared_ptr makeSingleStream(const ModelConfig& model_config, + const RuntimeConfig& runtime_config, + const ResourceContext& resource_context, + std::vector tokens = {1, 2, 3}) { + auto query = std::make_shared(); + query->input_ids = torch::tensor(tokens, torch::kInt32); + query->generate_config = std::make_shared(); + return std::make_shared(query, model_config, runtime_config, resource_context, nullptr); +} + +} // namespace + +TEST_F(FIFOSchedulerTest, groupIsolation_size2) { + CacheConfig cache_config = makeMhaCacheConfig(1, 4, 1, 4, 8, rtp_llm::DataType::TYPE_FP16); + auto cache_manager = std::make_shared(cache_config); + ASSERT_TRUE(cache_manager->init()); + ResourceContext resource_context; + resource_context.cache_manager = cache_manager; + + ModelConfig model_config; + model_config.max_seq_len = 8192; + RuntimeConfig runtime_config; + runtime_config.max_generate_batch_size = 100; + runtime_config.fifo_scheduler_config.max_batch_tokens_size = 8192; + PDSepConfig pd_sep_config; + ParallelismConfig parallelism_config; + ModelSpecificConfig model_specific_config; + FIFOScheduler scheduler( + runtime_config, model_config, pd_sep_config, parallelism_config, model_specific_config, cache_manager); + + vector streams = { + makeGroupedStream(100, 2, model_config, runtime_config, resource_context), + makeGroupedStream(100, 2, model_config, runtime_config, resource_context), + }; + scheduler.enqueueGroup(streams); + + auto result = scheduler.schedule(); + ASSERT_TRUE(result.ok()); + ASSERT_EQ(scheduler.runningStreamsSize(), 2); + ASSERT_EQ(scheduler.waitingStreamsSize(), 0); + for (const auto& t : scheduler.runningTaskList()) { + ASSERT_EQ(t.batch_id, 100); + } +} + +TEST_F(FIFOSchedulerTest, groupIsolation_size3) { + CacheConfig cache_config = makeMhaCacheConfig(1, 4, 1, 4, 8, rtp_llm::DataType::TYPE_FP16); + auto cache_manager = std::make_shared(cache_config); + ASSERT_TRUE(cache_manager->init()); + ResourceContext resource_context; + resource_context.cache_manager = cache_manager; + + ModelConfig model_config; + model_config.max_seq_len = 8192; + RuntimeConfig runtime_config; + runtime_config.max_generate_batch_size = 100; + runtime_config.fifo_scheduler_config.max_batch_tokens_size = 8192; + PDSepConfig pd_sep_config; + ParallelismConfig parallelism_config; + ModelSpecificConfig model_specific_config; + FIFOScheduler scheduler( + runtime_config, model_config, pd_sep_config, parallelism_config, model_specific_config, cache_manager); + + vector streams; + for (int i = 0; i < 3; ++i) { + streams.push_back(makeGroupedStream(100, 3, model_config, runtime_config, resource_context)); + } + scheduler.enqueueGroup(streams); + + auto result = scheduler.schedule(); + ASSERT_TRUE(result.ok()); + ASSERT_EQ(scheduler.runningStreamsSize(), 3); + ASSERT_EQ(scheduler.waitingStreamsSize(), 0); + for (const auto& t : scheduler.runningTaskList()) { + ASSERT_EQ(t.batch_id, 100); + } +} + +TEST_F(FIFOSchedulerTest, groupIsolation_size4) { + CacheConfig cache_config = makeMhaCacheConfig(1, 6, 1, 4, 8, rtp_llm::DataType::TYPE_FP16); + auto cache_manager = std::make_shared(cache_config); + ASSERT_TRUE(cache_manager->init()); + ResourceContext resource_context; + resource_context.cache_manager = cache_manager; + + ModelConfig model_config; + model_config.max_seq_len = 8192; + RuntimeConfig runtime_config; + runtime_config.max_generate_batch_size = 100; + runtime_config.fifo_scheduler_config.max_batch_tokens_size = 8192; + PDSepConfig pd_sep_config; + ParallelismConfig parallelism_config; + ModelSpecificConfig model_specific_config; + FIFOScheduler scheduler( + runtime_config, model_config, pd_sep_config, parallelism_config, model_specific_config, cache_manager); + + vector streams; + for (int i = 0; i < 4; ++i) { + streams.push_back(makeGroupedStream(100, 4, model_config, runtime_config, resource_context)); + } + scheduler.enqueueGroup(streams); + + auto result = scheduler.schedule(); + ASSERT_TRUE(result.ok()); + ASSERT_EQ(scheduler.runningStreamsSize(), 4); + ASSERT_EQ(scheduler.waitingStreamsSize(), 0); + for (const auto& t : scheduler.runningTaskList()) { + ASSERT_EQ(t.batch_id, 100); + } +} + +TEST_F(FIFOSchedulerTest, groupIsolation_groupNotMixedWithSingles_groupFirst) { + CacheConfig cache_config = makeMhaCacheConfig(1, 4, 1, 4, 8, rtp_llm::DataType::TYPE_FP16); + auto cache_manager = std::make_shared(cache_config); + ASSERT_TRUE(cache_manager->init()); + ResourceContext resource_context; + resource_context.cache_manager = cache_manager; + + ModelConfig model_config; + model_config.max_seq_len = 8192; + RuntimeConfig runtime_config; + runtime_config.max_generate_batch_size = 100; + runtime_config.fifo_scheduler_config.max_batch_tokens_size = 8192; + PDSepConfig pd_sep_config; + ParallelismConfig parallelism_config; + ModelSpecificConfig model_specific_config; + FIFOScheduler scheduler( + runtime_config, model_config, pd_sep_config, parallelism_config, model_specific_config, cache_manager); + + vector group_streams = { + makeGroupedStream(100, 2, model_config, runtime_config, resource_context), + makeGroupedStream(100, 2, model_config, runtime_config, resource_context), + }; + scheduler.enqueueGroup(group_streams); + scheduler.enqueue(makeSingleStream(model_config, runtime_config, resource_context)); + + auto r1 = scheduler.schedule(); + ASSERT_TRUE(r1.ok()); + ASSERT_EQ(scheduler.runningStreamsSize(), 2); + ASSERT_EQ(scheduler.waitingStreamsSize(), 1); + + auto running = scheduler.runningTaskList(); + for (const auto& t : running) { + ASSERT_EQ(t.batch_id, 100); + } +} + +TEST_F(FIFOSchedulerTest, groupIsolation_groupNotMixedWithSingles_singlesFirst) { + CacheConfig cache_config = makeMhaCacheConfig(1, 4, 1, 4, 8, rtp_llm::DataType::TYPE_FP16); + auto cache_manager = std::make_shared(cache_config); + ASSERT_TRUE(cache_manager->init()); + ResourceContext resource_context; + resource_context.cache_manager = cache_manager; + + ModelConfig model_config; + model_config.max_seq_len = 8192; + RuntimeConfig runtime_config; + runtime_config.max_generate_batch_size = 100; + runtime_config.fifo_scheduler_config.max_batch_tokens_size = 8192; + PDSepConfig pd_sep_config; + ParallelismConfig parallelism_config; + ModelSpecificConfig model_specific_config; + FIFOScheduler scheduler( + runtime_config, model_config, pd_sep_config, parallelism_config, model_specific_config, cache_manager); + + scheduler.enqueue(makeSingleStream(model_config, runtime_config, resource_context)); + scheduler.enqueue(makeSingleStream(model_config, runtime_config, resource_context)); + vector group_streams = { + makeGroupedStream(100, 2, model_config, runtime_config, resource_context), + makeGroupedStream(100, 2, model_config, runtime_config, resource_context), + }; + scheduler.enqueueGroup(group_streams); + + auto r1 = scheduler.schedule(); + ASSERT_TRUE(r1.ok()); + // Singles admitted first, group waits (cannot mix with already-admitted singles) + ASSERT_EQ(scheduler.runningStreamsSize(), 2); + ASSERT_EQ(scheduler.waitingStreamsSize(), 2); + + auto running = scheduler.runningTaskList(); + for (const auto& t : running) { + ASSERT_EQ(t.batch_id, -1); + } + + // Second schedule: singles still running, group still cannot be admitted + auto r2 = scheduler.schedule(); + ASSERT_TRUE(r2.ok()); + ASSERT_EQ(scheduler.runningStreamsSize(), 2); + ASSERT_EQ(scheduler.waitingStreamsSize(), 2); +} + +TEST_F(FIFOSchedulerTest, groupIsolation_twoGroupsNotMixed) { + CacheConfig cache_config = makeMhaCacheConfig(1, 4, 1, 4, 8, rtp_llm::DataType::TYPE_FP16); + auto cache_manager = std::make_shared(cache_config); + ASSERT_TRUE(cache_manager->init()); + ResourceContext resource_context; + resource_context.cache_manager = cache_manager; + + ModelConfig model_config; + model_config.max_seq_len = 8192; + RuntimeConfig runtime_config; + runtime_config.max_generate_batch_size = 100; + runtime_config.fifo_scheduler_config.max_batch_tokens_size = 8192; + PDSepConfig pd_sep_config; + ParallelismConfig parallelism_config; + ModelSpecificConfig model_specific_config; + FIFOScheduler scheduler( + runtime_config, model_config, pd_sep_config, parallelism_config, model_specific_config, cache_manager); + + vector group_a = { + makeGroupedStream(100, 2, model_config, runtime_config, resource_context), + makeGroupedStream(100, 2, model_config, runtime_config, resource_context), + }; + vector group_b = { + makeGroupedStream(200, 3, model_config, runtime_config, resource_context), + makeGroupedStream(200, 3, model_config, runtime_config, resource_context), + makeGroupedStream(200, 3, model_config, runtime_config, resource_context), + }; + scheduler.enqueueGroup(group_a); + scheduler.enqueueGroup(group_b); + + auto r1 = scheduler.schedule(); + ASSERT_TRUE(r1.ok()); + ASSERT_EQ(scheduler.runningStreamsSize(), 2); + ASSERT_EQ(scheduler.waitingStreamsSize(), 3); + + auto running = scheduler.runningTaskList(); + for (const auto& t : running) { + ASSERT_EQ(t.batch_id, 100); + } +} + +TEST_F(FIFOSchedulerTest, groupIsolation_singlesCanMix) { + CacheConfig cache_config = makeMhaCacheConfig(1, 4, 1, 4, 8, rtp_llm::DataType::TYPE_FP16); + auto cache_manager = std::make_shared(cache_config); + ASSERT_TRUE(cache_manager->init()); + ResourceContext resource_context; + resource_context.cache_manager = cache_manager; + + ModelConfig model_config; + model_config.max_seq_len = 8192; + RuntimeConfig runtime_config; + runtime_config.max_generate_batch_size = 100; + runtime_config.fifo_scheduler_config.max_batch_tokens_size = 8192; + PDSepConfig pd_sep_config; + ParallelismConfig parallelism_config; + ModelSpecificConfig model_specific_config; + FIFOScheduler scheduler( + runtime_config, model_config, pd_sep_config, parallelism_config, model_specific_config, cache_manager); + + scheduler.enqueue(makeSingleStream(model_config, runtime_config, resource_context)); + scheduler.enqueue(makeSingleStream(model_config, runtime_config, resource_context)); + scheduler.enqueue(makeSingleStream(model_config, runtime_config, resource_context)); + + auto result = scheduler.schedule(); + ASSERT_TRUE(result.ok()); + ASSERT_EQ(scheduler.runningStreamsSize(), 3); + ASSERT_EQ(scheduler.waitingStreamsSize(), 0); + + auto running = scheduler.runningTaskList(); + for (const auto& t : running) { + ASSERT_EQ(t.batch_id, -1); + } +} + +TEST_F(FIFOSchedulerTest, groupIsolation_interleavedSinglesAndGroup) { + CacheConfig cache_config = makeMhaCacheConfig(1, 4, 1, 4, 8, rtp_llm::DataType::TYPE_FP16); + auto cache_manager = std::make_shared(cache_config); + ASSERT_TRUE(cache_manager->init()); + ResourceContext resource_context; + resource_context.cache_manager = cache_manager; + + ModelConfig model_config; + model_config.max_seq_len = 8192; + RuntimeConfig runtime_config; + runtime_config.max_generate_batch_size = 100; + runtime_config.fifo_scheduler_config.max_batch_tokens_size = 8192; + PDSepConfig pd_sep_config; + ParallelismConfig parallelism_config; + ModelSpecificConfig model_specific_config; + FIFOScheduler scheduler( + runtime_config, model_config, pd_sep_config, parallelism_config, model_specific_config, cache_manager); + + // Enqueue order: single_A, group(100, 2), single_B + scheduler.enqueue(makeSingleStream(model_config, runtime_config, resource_context)); + + vector group_streams = { + makeGroupedStream(100, 2, model_config, runtime_config, resource_context), + makeGroupedStream(100, 2, model_config, runtime_config, resource_context), + }; + scheduler.enqueueGroup(group_streams); + scheduler.enqueue(makeSingleStream(model_config, runtime_config, resource_context)); + + auto r1 = scheduler.schedule(); + ASSERT_TRUE(r1.ok()); + // Two singles admitted together, group skipped (cannot join already-admitted singles) + ASSERT_EQ(scheduler.runningStreamsSize(), 2); + ASSERT_EQ(scheduler.waitingStreamsSize(), 2); + + auto running = scheduler.runningTaskList(); + for (const auto& t : running) { + ASSERT_EQ(t.batch_id, -1); + } +} + TEST_F(FIFOSchedulerTest, testPdDecodePreCanRunStillRespectsMaxGenerateBatchSize) { CacheConfig cache_config = makeMhaCacheConfig(1, 10, 1, 4, 8, rtp_llm::DataType::TYPE_FP16); std::shared_ptr cache_manager = std::make_shared(cache_config); @@ -572,7 +893,8 @@ TEST_F(FIFOSchedulerTest, testPdDecodePreCanRunStillRespectsMaxGenerateBatchSize // DecodeRpcServer pre-sets CanRun to drive pre-enqueue KV allocation. stream->reportEvent(StreamEvents::CanRun); - EXPECT_EQ(stream->moveToNext(), StreamState::WAITING); + stream->prepare(); + EXPECT_EQ(stream->getStatus(), StreamState::WAITING); EXPECT_TRUE(stream->hasEvent(StreamEvents::CanRun)); EXPECT_TRUE(stream->hasEvent(StreamEvents::LoadInitiated)); stream->setIsContextStream(false); @@ -627,7 +949,8 @@ TEST_F(FIFOSchedulerTest, testPdDecodePreCanRunCanTopUpToMaxGenerateBatchSize) { // DecodeRpcServer pre-sets CanRun to drive pre-enqueue KV allocation. stream->reportEvent(StreamEvents::CanRun); - EXPECT_EQ(stream->moveToNext(), StreamState::WAITING); + stream->prepare(); + EXPECT_EQ(stream->getStatus(), StreamState::WAITING); EXPECT_TRUE(stream->hasEvent(StreamEvents::CanRun)); EXPECT_TRUE(stream->hasEvent(StreamEvents::LoadInitiated)); stream->setIsContextStream(false); @@ -691,7 +1014,8 @@ TEST_F(FIFOSchedulerTest, testMaxInitedKVCacheStreamsAllowsAlreadyInitedStreams) auto stream = make_shared(query, model_config, runtime_config, resource_context, nullptr); stream->reportEvent(StreamEvents::CanRun); - EXPECT_EQ(stream->moveToNext(), StreamState::WAITING); + stream->prepare(); + EXPECT_EQ(stream->getStatus(), StreamState::WAITING); EXPECT_GT(stream->curBlocksNum(), 0); stream->setIsContextStream(false); return stream; @@ -738,7 +1062,8 @@ TEST_F(FIFOSchedulerTest, testPdDecodePreCanRunWithPendingAsyncStillCountsRunnin // DecodeRpcServer pre-sets CanRun to drive pre-enqueue KV allocation. stream->reportEvent(StreamEvents::CanRun); - EXPECT_EQ(stream->moveToNext(), StreamState::WAITING); + stream->prepare(); + EXPECT_EQ(stream->getStatus(), StreamState::WAITING); EXPECT_TRUE(stream->hasEvent(StreamEvents::CanRun)); EXPECT_TRUE(stream->hasEvent(StreamEvents::LoadInitiated)); stream->setIsContextStream(false); @@ -792,9 +1117,9 @@ TEST_F(FIFOSchedulerTest, testCpForceSinglePrefillConfig) { ModelConfig model_config; model_config.max_seq_len = 8192; RuntimeConfig runtime_config; - runtime_config.max_generate_batch_size = 100; - runtime_config.fifo_scheduler_config.max_batch_tokens_size = 8192; - runtime_config.fifo_scheduler_config.cp_force_single_prefill = cp_force_single_prefill; + runtime_config.max_generate_batch_size = 100; + runtime_config.fifo_scheduler_config.max_batch_tokens_size = 8192; + runtime_config.fifo_scheduler_config.cp_force_single_prefill = cp_force_single_prefill; PDSepConfig pd_sep_config; ParallelismConfig parallelism_config; ModelSpecificConfig model_specific_config; @@ -810,7 +1135,7 @@ TEST_F(FIFOSchedulerTest, testCpForceSinglePrefillConfig) { streams.push_back( make_shared(query, model_config, runtime_config, resource_context, nullptr)); } - scheduler.batchEnqueue(streams); + scheduler.enqueueGroup(streams); auto streams_status = scheduler.schedule(); EXPECT_TRUE(streams_status.ok()); return streams_status.value().size(); @@ -843,27 +1168,25 @@ TEST_F(FIFOSchedulerTest, testForceBatchGroupComplete) { // Enqueue only 2 of 3 — group incomplete, should not be scheduled { - std::shared_ptr query = make_shared(); - query->input_ids = torch::tensor({1}, torch::kInt32); - query->generate_config = make_shared(); - query->generate_config->force_batch = true; - query->generate_config->batch_group_timeout = 10; - query->batch_group_id = group_id; - query->batch_group_size = group_size; - query->begin_time_us = autil::TimeUtility::currentTimeInMicroSeconds(); + std::shared_ptr query = make_shared(); + query->input_ids = torch::tensor({1}, torch::kInt32); + query->generate_config = make_shared(); + query->generate_config->group_timeout = 10; + query->group_id = group_id; + query->group_size = group_size; + query->begin_time_us = autil::TimeUtility::currentTimeInMicroSeconds(); shared_ptr stream = make_shared(query, model_config, runtime_config, resource_context, nullptr); ASSERT_TRUE(scheduler.enqueue(stream).ok()); } { - std::shared_ptr query = make_shared(); - query->input_ids = torch::tensor({1}, torch::kInt32); - query->generate_config = make_shared(); - query->generate_config->force_batch = true; - query->generate_config->batch_group_timeout = 10; - query->batch_group_id = group_id; - query->batch_group_size = group_size; - query->begin_time_us = autil::TimeUtility::currentTimeInMicroSeconds(); + std::shared_ptr query = make_shared(); + query->input_ids = torch::tensor({1}, torch::kInt32); + query->generate_config = make_shared(); + query->generate_config->group_timeout = 10; + query->group_id = group_id; + query->group_size = group_size; + query->begin_time_us = autil::TimeUtility::currentTimeInMicroSeconds(); shared_ptr stream = make_shared(query, model_config, runtime_config, resource_context, nullptr); ASSERT_TRUE(scheduler.enqueue(stream).ok()); @@ -877,14 +1200,13 @@ TEST_F(FIFOSchedulerTest, testForceBatchGroupComplete) { // Enqueue the 3rd — group complete, all 3 should be scheduled together { - std::shared_ptr query = make_shared(); - query->input_ids = torch::tensor({1}, torch::kInt32); - query->generate_config = make_shared(); - query->generate_config->force_batch = true; - query->generate_config->batch_group_timeout = 10; - query->batch_group_id = group_id; - query->batch_group_size = group_size; - query->begin_time_us = autil::TimeUtility::currentTimeInMicroSeconds(); + std::shared_ptr query = make_shared(); + query->input_ids = torch::tensor({1}, torch::kInt32); + query->generate_config = make_shared(); + query->generate_config->group_timeout = 10; + query->group_id = group_id; + query->group_size = group_size; + query->begin_time_us = autil::TimeUtility::currentTimeInMicroSeconds(); shared_ptr stream = make_shared(query, model_config, runtime_config, resource_context, nullptr); ASSERT_TRUE(scheduler.enqueue(stream).ok()); @@ -898,6 +1220,49 @@ TEST_F(FIFOSchedulerTest, testForceBatchGroupComplete) { ASSERT_EQ(scheduler.runningStreamsSize(), 3); } +TEST_F(FIFOSchedulerTest, testForceBatchCompleteGroupSkipsTokenCapAfterTimeout) { + CacheConfig cache_config = makeMhaCacheConfig(1, 11, 1, 4, 8, rtp_llm::DataType::TYPE_FP16); + std::shared_ptr cache_manager = std::make_shared(cache_config); + ASSERT_TRUE(cache_manager->init()); + ResourceContext resource_context; + resource_context.cache_manager = cache_manager; + + ModelConfig model_config; + model_config.max_seq_len = 8192; + RuntimeConfig runtime_config; + runtime_config.max_generate_batch_size = 100; + runtime_config.fifo_scheduler_config.max_batch_tokens_size = 2; + PDSepConfig pd_sep_config; + ParallelismConfig parallelism_config; + ModelSpecificConfig model_specific_config; + FIFOScheduler scheduler( + runtime_config, model_config, pd_sep_config, parallelism_config, model_specific_config, cache_manager); + + int64_t group_id = 101; + int group_size = 3; + int timeout_ms = 10; + int64_t past_time = autil::TimeUtility::currentTimeInMicroSeconds() - (timeout_ms + 100) * 1000; + + for (int i = 0; i < group_size; ++i) { + std::shared_ptr query = make_shared(); + query->input_ids = torch::tensor({1}, torch::kInt32); + query->generate_config = make_shared(); + query->generate_config->group_timeout = timeout_ms; + query->group_id = group_id; + query->group_size = group_size; + query->begin_time_us = past_time; + shared_ptr stream = + make_shared(query, model_config, runtime_config, resource_context, nullptr); + ASSERT_TRUE(scheduler.enqueue(stream).ok()); + } + + auto result = scheduler.schedule(); + ASSERT_TRUE(result.ok()); + ASSERT_EQ(result.value().size(), 3); + ASSERT_EQ(scheduler.waitingStreamsSize(), 0); + ASSERT_EQ(scheduler.runningStreamsSize(), 3); +} + TEST_F(FIFOSchedulerTest, testForceBatchTimeout) { CacheConfig cache_config = makeMhaCacheConfig(1, 11, 1, 4, 8, rtp_llm::DataType::TYPE_FP16); std::shared_ptr cache_manager = std::make_shared(cache_config); @@ -923,27 +1288,25 @@ TEST_F(FIFOSchedulerTest, testForceBatchTimeout) { // Enqueue only 2 of 3 with begin_time far in the past so timeout has expired { - std::shared_ptr query = make_shared(); - query->input_ids = torch::tensor({1}, torch::kInt32); - query->generate_config = make_shared(); - query->generate_config->force_batch = true; - query->generate_config->batch_group_timeout = timeout_ms; - query->batch_group_id = group_id; - query->batch_group_size = group_size; - query->begin_time_us = past_time; + std::shared_ptr query = make_shared(); + query->input_ids = torch::tensor({1}, torch::kInt32); + query->generate_config = make_shared(); + query->generate_config->group_timeout = timeout_ms; + query->group_id = group_id; + query->group_size = group_size; + query->begin_time_us = past_time; shared_ptr stream = make_shared(query, model_config, runtime_config, resource_context, nullptr); ASSERT_TRUE(scheduler.enqueue(stream).ok()); } { - std::shared_ptr query = make_shared(); - query->input_ids = torch::tensor({1}, torch::kInt32); - query->generate_config = make_shared(); - query->generate_config->force_batch = true; - query->generate_config->batch_group_timeout = timeout_ms; - query->batch_group_id = group_id; - query->batch_group_size = group_size; - query->begin_time_us = past_time; + std::shared_ptr query = make_shared(); + query->input_ids = torch::tensor({1}, torch::kInt32); + query->generate_config = make_shared(); + query->generate_config->group_timeout = timeout_ms; + query->group_id = group_id; + query->group_size = group_size; + query->begin_time_us = past_time; shared_ptr stream = make_shared(query, model_config, runtime_config, resource_context, nullptr); ASSERT_TRUE(scheduler.enqueue(stream).ok()); @@ -956,6 +1319,49 @@ TEST_F(FIFOSchedulerTest, testForceBatchTimeout) { ASSERT_EQ(scheduler.waitingStreamsSize(), 0); } +TEST_F(FIFOSchedulerTest, testIncompleteForceBatchTimeoutUsesNormalTokenCap) { + CacheConfig cache_config = makeMhaCacheConfig(1, 11, 1, 4, 8, rtp_llm::DataType::TYPE_FP16); + std::shared_ptr cache_manager = std::make_shared(cache_config); + ASSERT_TRUE(cache_manager->init()); + ResourceContext resource_context; + resource_context.cache_manager = cache_manager; + + ModelConfig model_config; + model_config.max_seq_len = 8192; + RuntimeConfig runtime_config; + runtime_config.max_generate_batch_size = 100; + runtime_config.fifo_scheduler_config.max_batch_tokens_size = 2; + PDSepConfig pd_sep_config; + ParallelismConfig parallelism_config; + ModelSpecificConfig model_specific_config; + FIFOScheduler scheduler( + runtime_config, model_config, pd_sep_config, parallelism_config, model_specific_config, cache_manager); + + int64_t group_id = 201; + int group_size = 3; + int timeout_ms = 10; + int64_t past_time = autil::TimeUtility::currentTimeInMicroSeconds() - (timeout_ms + 100) * 1000; + + for (int i = 0; i < 2; ++i) { + std::shared_ptr query = make_shared(); + query->input_ids = torch::tensor({1}, torch::kInt32); + query->generate_config = make_shared(); + query->generate_config->group_timeout = timeout_ms; + query->group_id = group_id; + query->group_size = group_size; + query->begin_time_us = past_time; + shared_ptr stream = + make_shared(query, model_config, runtime_config, resource_context, nullptr); + ASSERT_TRUE(scheduler.enqueue(stream).ok()); + } + + auto result = scheduler.schedule(); + ASSERT_TRUE(result.ok()); + ASSERT_EQ(result.value().size(), 1); + ASSERT_EQ(scheduler.waitingStreamsSize(), 1); + ASSERT_EQ(scheduler.runningStreamsSize(), 1); +} + TEST_F(FIFOSchedulerTest, testForceBatchIsolation) { CacheConfig cache_config = makeMhaCacheConfig(1, 11, 1, 4, 8, rtp_llm::DataType::TYPE_FP16); std::shared_ptr cache_manager = std::make_shared(cache_config); @@ -989,27 +1395,25 @@ TEST_F(FIFOSchedulerTest, testForceBatchIsolation) { ASSERT_TRUE(scheduler.enqueue(normal_stream).ok()); } { - std::shared_ptr query = make_shared(); - query->input_ids = torch::tensor({1}, torch::kInt32); - query->generate_config = make_shared(); - query->generate_config->force_batch = true; - query->generate_config->batch_group_timeout = 10; - query->batch_group_id = group_id; - query->batch_group_size = group_size; - query->begin_time_us = autil::TimeUtility::currentTimeInMicroSeconds(); + std::shared_ptr query = make_shared(); + query->input_ids = torch::tensor({1}, torch::kInt32); + query->generate_config = make_shared(); + query->generate_config->group_timeout = 10; + query->group_id = group_id; + query->group_size = group_size; + query->begin_time_us = autil::TimeUtility::currentTimeInMicroSeconds(); shared_ptr stream = make_shared(query, model_config, runtime_config, resource_context, nullptr); ASSERT_TRUE(scheduler.enqueue(stream).ok()); } { - std::shared_ptr query = make_shared(); - query->input_ids = torch::tensor({1}, torch::kInt32); - query->generate_config = make_shared(); - query->generate_config->force_batch = true; - query->generate_config->batch_group_timeout = 10; - query->batch_group_id = group_id; - query->batch_group_size = group_size; - query->begin_time_us = autil::TimeUtility::currentTimeInMicroSeconds(); + std::shared_ptr query = make_shared(); + query->input_ids = torch::tensor({1}, torch::kInt32); + query->generate_config = make_shared(); + query->generate_config->group_timeout = 10; + query->group_id = group_id; + query->group_size = group_size; + query->begin_time_us = autil::TimeUtility::currentTimeInMicroSeconds(); shared_ptr stream = make_shared(query, model_config, runtime_config, resource_context, nullptr); ASSERT_TRUE(scheduler.enqueue(stream).ok()); @@ -1059,27 +1463,25 @@ TEST_F(FIFOSchedulerTest, testTwoForceBatchGroupsIsolation) { // Enqueue group A (2 streams), then group B (2 streams), both complete vector> group_a_streams; for (int i = 0; i < group_size; i++) { - std::shared_ptr query = make_shared(); - query->input_ids = torch::tensor({1}, torch::kInt32); - query->generate_config = make_shared(); - query->generate_config->force_batch = true; - query->generate_config->batch_group_timeout = 10; - query->batch_group_id = group_id_a; - query->batch_group_size = group_size; - query->begin_time_us = autil::TimeUtility::currentTimeInMicroSeconds(); + std::shared_ptr query = make_shared(); + query->input_ids = torch::tensor({1}, torch::kInt32); + query->generate_config = make_shared(); + query->generate_config->group_timeout = 10; + query->group_id = group_id_a; + query->group_size = group_size; + query->begin_time_us = autil::TimeUtility::currentTimeInMicroSeconds(); auto stream = make_shared(query, model_config, runtime_config, resource_context, nullptr); group_a_streams.push_back(stream); ASSERT_TRUE(scheduler.enqueue(stream).ok()); } for (int i = 0; i < group_size; i++) { - std::shared_ptr query = make_shared(); - query->input_ids = torch::tensor({1}, torch::kInt32); - query->generate_config = make_shared(); - query->generate_config->force_batch = true; - query->generate_config->batch_group_timeout = 10; - query->batch_group_id = group_id_b; - query->batch_group_size = group_size; - query->begin_time_us = autil::TimeUtility::currentTimeInMicroSeconds(); + std::shared_ptr query = make_shared(); + query->input_ids = torch::tensor({1}, torch::kInt32); + query->generate_config = make_shared(); + query->generate_config->group_timeout = 10; + query->group_id = group_id_b; + query->group_size = group_size; + query->begin_time_us = autil::TimeUtility::currentTimeInMicroSeconds(); auto stream = make_shared(query, model_config, runtime_config, resource_context, nullptr); ASSERT_TRUE(scheduler.enqueue(stream).ok()); } diff --git a/rtp_llm/cpp/engine_base/stream/GenerateConfig.h b/rtp_llm/cpp/engine_base/stream/GenerateConfig.h index 6710462eb1..97421d401b 100644 --- a/rtp_llm/cpp/engine_base/stream/GenerateConfig.h +++ b/rtp_llm/cpp/engine_base/stream/GenerateConfig.h @@ -93,8 +93,7 @@ class GenerateConfig: public autil::legacy::Jsonizable { bool enable_memory_cache = true; bool enable_remote_cache = true; std::string trace_id; - bool force_batch = false; // If true, streams with same batch_group_id must be scheduled together - std::optional batch_group_timeout; + std::optional group_timeout; std::string unique_key; bool top1() { @@ -156,8 +155,7 @@ class GenerateConfig: public autil::legacy::Jsonizable { << ", gen_timeline: " << gen_timeline << ", profile_step: " << profile_step << ", reuse_cache: " << reuse_cache << ", enable_device_cache: " << enable_device_cache << ", enable_memory_cache: " << enable_memory_cache - << ", enable_remote_cache: " << enable_remote_cache << ", force_batch: " << force_batch - << ", unique_key: " << unique_key << "}"; + << ", enable_remote_cache: " << enable_remote_cache << ", unique_key: " << unique_key << "}"; return debug_string.str(); } @@ -243,9 +241,8 @@ class GenerateConfig: public autil::legacy::Jsonizable { JSONIZE(enable_device_cache); JSONIZE(enable_memory_cache); JSONIZE(enable_remote_cache); - JSONIZE(force_batch); JSONIZE(aux_info); - JSONIZE_OPTIONAL(batch_group_timeout); + JSONIZE_OPTIONAL(group_timeout); JSONIZE(unique_key); #undef JSONIZE #undef JSONIZE_OPTIONAL diff --git a/rtp_llm/cpp/engine_base/stream/GenerateStateMachine.cc b/rtp_llm/cpp/engine_base/stream/GenerateStateMachine.cc index 917191525d..f6abb4c8d7 100644 --- a/rtp_llm/cpp/engine_base/stream/GenerateStateMachine.cc +++ b/rtp_llm/cpp/engine_base/stream/GenerateStateMachine.cc @@ -1,175 +1,3 @@ #include "rtp_llm/cpp/engine_base/stream/GenerateStateMachine.h" -#include "rtp_llm/cpp/engine_base/stream/GenerateStream.h" -#include "rtp_llm/cpp/engine_base/stream/StreamCacheResource.h" -#include "rtp_llm/cpp/config/RoleTypes.h" -#include -#include -using namespace std; - -namespace rtp_llm { -namespace { - -bool asyncDebugEnabled() { - const char* env = std::getenv("RTP_LLM_ASYNC_DEBUG"); - return env != nullptr && std::string(env) == "1"; -} - -} // namespace -// ============================================================================ -// GenerateStateMachine method implementations -// ============================================================================ - -StreamState GenerateStateMachine::moveToNext() { - // Error 最高优先级,任何状态下直接终止 - if (events_.has(StreamEvents::Error)) { - status.store(StreamState::FINISHED, std::memory_order_release); - releaseResource(); - return StreamState::FINISHED; - } - - switch (status.load(std::memory_order_acquire)) { - case StreamState::WAITING: - handleWaiting(); - break; - case StreamState::LOADING_CACHE: - handleLoading(); - break; - case StreamState::RUNNING: - handleRunning(); - break; - case StreamState::FINISHED: - break; - default: - RTP_LLM_LOG_ERROR("Error: Unrecognized Generate State"); - if (error_info.ok()) { - error_info = ErrorInfo(ErrorCode::UNKNOWN_ERROR, "Error: Unrecognized Generate State"); - } - status.store(StreamState::FINISHED, std::memory_order_release); - releaseResource(); - break; - } - return status.load(std::memory_order_acquire); -} - -void GenerateStateMachine::handleWaiting() { - if (!events_.has(StreamEvents::CanRun)) { - return; - } - // LoadInitiated 未设置时,必须先执行 initKVBlock 和 asyncLoadCache - if (!events_.has(StreamEvents::LoadInitiated)) { - auto result = stream_cache_resource_->initKVBlock(reserve_step_); - if (!result.ok()) { - error_info = ErrorInfo(ErrorCode::MALLOC_FAILED, "LACK MEM"); - status.store(StreamState::FINISHED, std::memory_order_release); - releaseResource(); - return; - } - bool ret = stream_cache_resource_->asyncLoadCache(); - // 设置 LoadInitiated 标志,表示已尝试asyncLoadCache. 当前实现即便asyncLoadCache失败也不再重试 - reportEvent(StreamEvents::LoadInitiated); - if (ret) { - status.store(StreamState::LOADING_CACHE, std::memory_order_release); - } else if (stream_cache_resource_->resourceContext().role_type != RoleType::DECODE) { - // Loading cache 失败或不需要loading,直接触发重计算 - // 当前decodeRpcServer会调用moveToNext,判断role type避免decodeRpcServer在enqueue前提早走到running状态 - status.store(StreamState::RUNNING, std::memory_order_release); - } - return; - } - - // A PREFILL role normally runs only the context pass, so it must not call - // incrKVBlock while the stream is still a context stream. With PD fallback, - // the same role can continue into decode after GenerateStream::update() - // flips isContextStream() to false; from that point block tables must grow - // exactly like a decode stream. - if (stream_cache_resource_->resourceContext().role_type == RoleType::PREFILL - && stream_cache_resource_->isContextStream()) { - status.store(StreamState::RUNNING, std::memory_order_release); - return; - } - - // Decode streams, including PREFILL-role streams after fallback, must keep - // cache block tables aligned with the growing sequence length. - auto result = stream_cache_resource_->incrKVBlock(reserve_step_); - if (!result.ok()) { - error_info = ErrorInfo(ErrorCode::MALLOC_FAILED, "LACK MEM"); - status.store(StreamState::FINISHED, std::memory_order_release); - releaseResource(); - return; - } - status.store(StreamState::RUNNING, std::memory_order_release); - return; -} - -void GenerateStateMachine::handleLoading() { - if (stream_cache_resource_->loadCacheDone()) { - status.store(StreamState::WAITING, std::memory_order_release); - } -} - -void GenerateStateMachine::handleRunning() { - // in pd sep case,kvcache could be released after remote load done. - if (events_.has(StreamEvents::GenerateDone)) { - status.store(StreamState::FINISHED, std::memory_order_release); - releaseResource(); - return; - } - if (stream_cache_resource_->resourceContext().role_type == RoleType::PREFILL - && stream_cache_resource_->isContextStream()) { - return; - } - // Use the publish-time seqLength so incrKVBlock doesn't race the async - // worker's update() — a stale read skips the block-boundary allocation. - // Prefer Normal state; fall back to MTP state if the stream is MTP. - int seq_len_override = -1; - GenerateStream* stream = stream_cache_resource_->stream(); - if (stream != nullptr) { - const int normal_override = stream->getNormalAsyncDeviceState().next_real_seq_len; - if (normal_override > 0) { - seq_len_override = normal_override; - } else { - const auto& mtp_state = stream->getMtpAsyncDeviceState(); - const int mtp_override = mtp_state.next_real_seq_len; - if (mtp_override > 0) { - seq_len_override = mtp_override; - } - } - if (asyncDebugEnabled() && stream->hasPendingAsyncBookkeeping()) { - RTP_LLM_LOG_WARNING("[async-debug] handleRunning while async bookkeeping pending: stream=%ld pd_sep=%d " - "status=%s seq_len=%d normal_last_real=%d normal_next_real=%d " - "mtp_next_real=%d override=%d", - stream->streamId(), - stream->queryPdSep(), - StreamStateToString(status.load(std::memory_order_acquire)).c_str(), - stream->seqLength(), - stream->getNormalAsyncDeviceState().last_real_seq_len, - stream->getNormalAsyncDeviceState().next_real_seq_len, - stream->getMtpAsyncDeviceState().next_real_seq_len, - seq_len_override); - } - } - auto result = stream_cache_resource_->incrKVBlock(reserve_step_, seq_len_override); - if (!result.ok()) { - // Report Error event so moveToNext() won't be called again on this stream - reportEvent(StreamEvents::Error, ErrorCode::MALLOC_FAILED, "incrKVBlock failed: LACK MEM"); - status.store(StreamState::FINISHED, std::memory_order_release); - releaseResource(); - } -} - -void GenerateStateMachine::releaseResource() { - if (stream_cache_resource_->isResourceReleased()) { - return; - } - // releaseResource runs under GenerateStream::mutex_; do not wait here. - // If a worker still owns KV blocks, mark deferred and let its dec path - // perform the release after the pending count drains. - GenerateStream* stream = stream_cache_resource_->stream(); - if (stream != nullptr && stream->hasPendingAsyncBookkeeping()) { - stream->markDeferredRelease(); - return; - } - stream_cache_resource_->releaseResource(); -} -} // namespace rtp_llm +namespace rtp_llm {} // namespace rtp_llm diff --git a/rtp_llm/cpp/engine_base/stream/GenerateStateMachine.h b/rtp_llm/cpp/engine_base/stream/GenerateStateMachine.h index 1092b4de02..4dff8dfab0 100644 --- a/rtp_llm/cpp/engine_base/stream/GenerateStateMachine.h +++ b/rtp_llm/cpp/engine_base/stream/GenerateStateMachine.h @@ -12,12 +12,12 @@ namespace rtp_llm { class StreamCacheResource; // forward declaration -// Stream 生命周期状态机,将原先分散在 FIFOScheduler 中的状态转移逻辑集中管理。 -// 状态转移路径: WAITING -> LOADING_CACHE -> WAITING -> RUNNING -> FINISHED -// 每次调度轮调用 moveToNext() 驱动状态转移,由 FIFOScheduler::evaluateAndUpdateStreams 统一调用。 -// 外部通过 reportEvent() 投递事件(替代原先分散的 reportXX 接口),moveToNext() 消费累积事件后决策转移。 -// 线程安全说明:GenerateStateMachine 本身不提供同步机制,外部调用者需保证 reportEvent() 和 moveToNext() -// 的调用串行化(通常通过 GenerateStream::mutex_ 保护)。 +// Stream lifecycle state holder. Lifecycle transitions (WAITING -> RUNNING -> +// FINISHED, cache loading, error handling) are now driven directly by +// GenerateStream's lifecycle methods (prepare/isReady/activate/advance/finish). +// This struct retains event accumulation, status storage, and error reporting. +// Thread safety: callers must serialise reportEvent() and status writes +// (typically via GenerateStream::mutex_). struct GenerateStateMachine { public: GenerateStateMachine(std::shared_ptr stream_cache_resource): @@ -39,8 +39,6 @@ struct GenerateStateMachine { return events_.has(event); } - StreamState moveToNext(); - StreamState getStatus() const { return status.load(std::memory_order_acquire); } @@ -50,16 +48,11 @@ struct GenerateStateMachine { } // 公开的状态和错误信息,GenerateStream 等外部代码直接访问 - // status 使用 atomic 保证线程安全:moveToNext() 在 mutex_ 下写入,getStatus() 无锁读取 + // status 使用 atomic 保证线程安全:lifecycle methods 在 mutex_ 下写入,getStatus() 无锁读取 std::atomic status = StreamState::WAITING; ErrorInfo error_info; private: - void handleWaiting(); - void handleLoading(); - void handleRunning(); - void releaseResource(); - StreamEvents events_; std::shared_ptr stream_cache_resource_ = nullptr; diff --git a/rtp_llm/cpp/engine_base/stream/GenerateStream.cc b/rtp_llm/cpp/engine_base/stream/GenerateStream.cc index 8b2ca5d2be..926aa79735 100644 --- a/rtp_llm/cpp/engine_base/stream/GenerateStream.cc +++ b/rtp_llm/cpp/engine_base/stream/GenerateStream.cc @@ -227,24 +227,6 @@ int64_t GenerateStream::streamId() const { return generate_input_->request_id; } -std::string GenerateStream::streamLogTag() const { - const auto& request_info = generate_input_->request_info; - std::string tag = std::string("request_id=") + std::to_string(streamId()) + " trace_id=" + traceId(); - if (!request_info.request_id.empty()) { - tag += " source_request_id=" + request_info.request_id; - } - if (!request_info.frontend_ip.empty()) { - tag += " frontend_ip=" + request_info.frontend_ip; - } - if (!request_info.dash_ip.empty()) { - tag += " dash_ip=" + request_info.dash_ip; - } - if (!request_info.source_role.empty()) { - tag += " source_role=" + request_info.source_role; - } - return tag; -} - std::string GenerateStream::adapterName() const { return generate_input_->generate_config->adapter_name; } @@ -582,6 +564,17 @@ void GenerateStream::checkTimeout() { } } +void GenerateStream::checkTimeoutWithoutLock() { + auto running_time_ms = (autil::TimeUtility::currentTimeInMicroSeconds() - begin_time_us_) / 1000; + auto timeout_ms = getTimeoutMs(); + if (timeout_ms > 0 && timeout_ms < running_time_ms) { + generate_status_->reportEvent(StreamEvents::Error, + ErrorCode::GENERATE_TIMEOUT, + "query has been running " + std::to_string(running_time_ms) + " ms, " + + "timeout_ms = " + std::to_string(timeout_ms) + ", it's timeout"); + } +} + // 统一的事件上报接口,替代原先所有 reportXX 方法。 // 外部线程调用时自动加锁保护 error_info 和 events_ 的一致性。 void GenerateStream::reportEvent(StreamEvents::EventType event, ErrorCode error_code, const std::string& error_msg) { @@ -589,7 +582,7 @@ void GenerateStream::reportEvent(StreamEvents::EventType event, ErrorCode error_ generate_status_->reportEvent(event, error_code, error_msg); } -// 无锁版本,供已持有 mutex_ 的内部调用路径使用(如 update/specUpdate/moveToNext 链路)。 +// 无锁版本,供已持有 mutex_ 的内部调用路径使用(如 update/specUpdate 链路)。 void GenerateStream::reportEventWithoutLock(StreamEvents::EventType event, ErrorCode error_code, const std::string& error_msg) { @@ -623,16 +616,112 @@ void GenerateStream::setReserveStep(size_t reserve_step) { generate_status_->setReserveStep(reserve_step); } -StreamState GenerateStream::moveToNext() { - checkTimeout(); +bool GenerateStream::prepare() { std::lock_guard lock(*mutex_); - StreamState state = generate_status_->moveToNext(); + generate_status_->reportEvent(StreamEvents::CanRun); + auto result = streamCacheResource().initKVBlock(reserveStep()); + if (!result.ok()) { + generate_status_->reportEvent(StreamEvents::Error, + ErrorCode::MALLOC_FAILED, + std::string("initKVBlock failed: ") + + std::string(result.message().data(), result.message().size())); + finish_internal(); + return false; + } + needs_cache_loading_ = streamCacheResource().asyncLoadCache(); + generate_status_->reportEvent(StreamEvents::LoadInitiated); + return needs_cache_loading_; +} + +bool GenerateStream::isReady() { + if (!needs_cache_loading_) + return true; + return streamCacheResource().loadCacheDone(); +} + +void GenerateStream::activate() { + std::lock_guard lock(*mutex_); + if (generate_status_->hasEvent(StreamEvents::Error)) { + finish_internal(); + return; + } + if (streamCacheResource().resourceContext().role_type == RoleType::PREFILL && isContextStream()) { + generate_status_->status.store(StreamState::RUNNING, std::memory_order_release); + return; + } + auto result = streamCacheResource().incrKVBlock(reserveStep()); + if (!result.ok()) { + generate_status_->reportEvent(StreamEvents::Error, + ErrorCode::MALLOC_FAILED, + std::string("incrKVBlock(activate) failed: ") + + std::string(result.message().data(), result.message().size())); + finish_internal(); + return; + } + generate_status_->status.store(StreamState::RUNNING, std::memory_order_release); +} - // notify one thread waiting for stream completion - if (getStatus() == StreamState::FINISHED) { - cv_->notify_one(); +void GenerateStream::advance() { + std::lock_guard lock(*mutex_); + checkTimeoutWithoutLock(); + if (generate_status_->hasEvent(StreamEvents::Error)) { + finish_internal(); + return; } - return state; + if (generate_status_->hasEvent(StreamEvents::GenerateDone)) { + finish_internal(); + return; + } + if (streamCacheResource().resourceContext().role_type == RoleType::PREFILL && isContextStream()) { + return; + } + // Use the publish-time seqLength so incrKVBlock doesn't race the async + // worker's update() — a stale read skips the block-boundary allocation. + // Prefer Normal state; fall back to MTP state if the stream is MTP. + int seq_len_override = -1; + const int normal_override = getNormalAsyncDeviceState().next_real_seq_len; + if (normal_override > 0) { + seq_len_override = normal_override; + } else { + const auto& mtp_state = getMtpAsyncDeviceState(); + const int mtp_override = mtp_state.next_real_seq_len; + if (mtp_override > 0) { + seq_len_override = mtp_override; + } + } + auto result = streamCacheResource().incrKVBlock(reserveStep(), seq_len_override); + if (!result.ok()) { + generate_status_->reportEvent(StreamEvents::Error, + ErrorCode::MALLOC_FAILED, + std::string("incrKVBlock(advance) failed: ") + + std::string(result.message().data(), result.message().size())); + finish_internal(); + return; + } +} + +bool GenerateStream::alive() { + return generate_status_->getStatus() != StreamState::FINISHED; +} + +void GenerateStream::finish() { + std::lock_guard lock(*mutex_); + finish_internal(); +} + +void GenerateStream::finish_internal() { + if (generate_status_->getStatus() == StreamState::FINISHED) + return; + // releaseResource runs under mutex_; do not wait here. + // If a worker still owns KV blocks, mark deferred and let its dec path + // perform the release after the pending count drains. + if (hasPendingAsyncBookkeeping()) { + markDeferredRelease(); + } else if (!streamCacheResource().isResourceReleased()) { + streamCacheResource().releaseResource(); + } + generate_status_->status.store(StreamState::FINISHED, std::memory_order_release); + cv_->notify_one(); } bool GenerateStream::hasError() const { diff --git a/rtp_llm/cpp/engine_base/stream/GenerateStream.h b/rtp_llm/cpp/engine_base/stream/GenerateStream.h index 4c1e286ef5..b2f9581344 100644 --- a/rtp_llm/cpp/engine_base/stream/GenerateStream.h +++ b/rtp_llm/cpp/engine_base/stream/GenerateStream.h @@ -235,6 +235,7 @@ class GenerateStream: public std::enable_shared_from_this { int64_t getTimeoutMs() const; void checkTimeout(); + void checkTimeoutWithoutLock(); void reportEvent(StreamEvents::EventType event, ErrorCode error_code = ErrorCode::NONE_ERROR, @@ -249,11 +250,17 @@ class GenerateStream: public std::enable_shared_from_this { ErrorInfo statusInfo(); std::string stopReason(); - void setReserveStep(size_t reserve_step); - size_t reserveStep() const { + void setReserveStep(size_t reserve_step); + size_t reserveStep() const { return reserve_step_; } - StreamState moveToNext(); + // Lifecycle methods — replace moveToNext(). + bool prepare(); + bool isReady(); + void activate(); + void advance(); + bool alive(); + void finish(); virtual StreamState getStatus() const; bool isFinished() const; // Returns true if stream is active (no error and not finished) @@ -434,19 +441,20 @@ class GenerateStream: public std::enable_shared_from_this { return generate_input_->generate_config->trace_id; } - int batchGroupSize() const { - return generate_input_->batch_group_size; + int groupSize() const { + return generate_input_->group_size; } - int batchGroupTimeout() const { - return generate_input_->generate_config->batch_group_timeout.value_or(100); + int groupTimeout() const { + return generate_input_->generate_config->group_timeout.value_or(100); } - bool forceBatch() const { - return generate_input_->generate_config->force_batch; + bool isGroup() const { + return generate_input_->group_id != -1; } - int64_t batchGroupId() const { - return generate_input_->batch_group_id; + + int64_t groupId() const { + return generate_input_->group_id; } int64_t enqueueTime() const { @@ -454,7 +462,14 @@ class GenerateStream: public std::enable_shared_from_this { } /// Log-friendly stream id: numeric ``streamId()`` (``request_id`` / ``inter_request_id``) + ``trace_id`` string. - std::string streamLogTag() const; + std::string streamLogTag() const { + char buf[256]; + std::string tid = traceId(); + snprintf(buf, sizeof(buf), "trace_id=%s req_id=%ld", + tid.empty() ? "-" : tid.c_str(), + streamId()); + return std::string(buf); + } std::vector getAllLogitsProcessorPtr() const { return logits_processor_list_; @@ -684,6 +699,7 @@ class GenerateStream: public std::enable_shared_from_this { void reportStreamMetrics(); void reportCacheReuseMetrics() const; + void finish_internal(); protected: uint64_t stream_magic_ = STREAM_MAGIC; @@ -749,6 +765,7 @@ class GenerateStream: public std::enable_shared_from_this { size_t propose_step_ = 0; size_t score_len_ = 0; size_t reserve_step_ = 0; + bool needs_cache_loading_ = false; bool acceped_bouns_token_ = false; int sp_edit_search_index_ = 0; bool sp_edit_first_time_ = true; @@ -777,10 +794,10 @@ class GenerateStream: public std::enable_shared_from_this { // Stream-async device-resident state for the next decode step's prepare. // These structs stay default-constructed (epoch=0, undefined tensors) until // their corresponding async/sync publisher installs a usable state. - MtpAsyncDeviceState mtp_async_state_; - uint64_t mtp_async_epoch_counter_ = 0; - NormalAsyncDeviceState normal_async_state_; - uint64_t normal_async_epoch_counter_ = 0; + MtpAsyncDeviceState mtp_async_state_; + uint64_t mtp_async_epoch_counter_ = 0; + NormalAsyncDeviceState normal_async_state_; + uint64_t normal_async_epoch_counter_ = 0; std::shared_ptr> grpc_normal_device_state_pending_ = std::make_shared>(false); bool return_all_hidden_states_ = false; diff --git a/rtp_llm/cpp/engine_base/stream/GenerateTypes.h b/rtp_llm/cpp/engine_base/stream/GenerateTypes.h index 972c8e4e7a..73b6cb30a5 100644 --- a/rtp_llm/cpp/engine_base/stream/GenerateTypes.h +++ b/rtp_llm/cpp/engine_base/stream/GenerateTypes.h @@ -76,8 +76,12 @@ class GenerateInput { int64_t begin_time_us = 0; // Batch grouping params - int batch_group_size = 1; - int64_t batch_group_id = -1; // Batch group ID for force batch grouping, -1 means not set + int group_size = 1; + int64_t group_id = -1; + + bool isGroup() const { + return group_id != -1; + } }; struct AuxInfo { @@ -149,7 +153,7 @@ inline std::string StreamStateToString(StreamState state) { } } -// 事件集合:外部通过 reportEvent() 投递事件,状态机在 moveToNext() 中统一消费。 +// 事件集合:外部通过 reportEvent() 投递事件,生命周期方法中统一消费。 // 内部使用 bit flag 组合多个并发事件。 // 所有事件均为永久事件:一旦设置即保留,不会被自动清除。 class StreamEvents { diff --git a/rtp_llm/cpp/engine_base/stream/test/GenerateStreamStateTest.cc b/rtp_llm/cpp/engine_base/stream/test/GenerateStreamStateTest.cc index 151b8eb306..65e308dbfb 100644 --- a/rtp_llm/cpp/engine_base/stream/test/GenerateStreamStateTest.cc +++ b/rtp_llm/cpp/engine_base/stream/test/GenerateStreamStateTest.cc @@ -284,79 +284,52 @@ TEST_F(GenerateStreamStateTest, testStreamStateToString) { } // ============================================================================ -// 9. LoadInitiated event: Verify Decode mode cache load fix +// 9. Lifecycle method tests // ============================================================================ -TEST_F(GenerateStreamStateTest, testLoadInitiatedPreventsDuplicateInitKVBlock) { +TEST_F(GenerateStreamStateTest, testPrepareAllocatesKVBlocks) { auto stream = createStream(); ASSERT_EQ(stream->getStatus(), StreamState::WAITING); - - // Simulate DecodeRpcServer: call initKVBlock directly and set LoadInitiated - auto& resource = stream->streamCacheResource(); - ASSERT_TRUE(resource.initKVBlock().ok()); - stream->reportEvent(StreamEvents::LoadInitiated); - - // FIFOScheduler calls moveToNext, should skip initKVBlock and asyncLoadCache - auto new_state = stream->moveToNext(); - // Should stay in WAITING because CanRun is not set yet - ASSERT_EQ(new_state, StreamState::WAITING); - - // Now simulate FIFOScheduler setting CanRun - stream->reportEvent(StreamEvents::CanRun); - new_state = stream->moveToNext(); - ASSERT_EQ(new_state, StreamState::RUNNING); + bool needs_loading = stream->prepare(); + ASSERT_FALSE(needs_loading); + ASSERT_TRUE(stream->alive()); } -TEST_F(GenerateStreamStateTest, testLoadInitiatedSkipsAsyncLoadCache) { +TEST_F(GenerateStreamStateTest, testActivateTransitionsToRunning) { auto stream = createStream(); - ASSERT_EQ(stream->getStatus(), StreamState::WAITING); - - // Simulate DecodeRpcServer: only initKVBlock, no asyncLoadCache - auto& resource = stream->streamCacheResource(); - ASSERT_TRUE(resource.initKVBlock().ok()); - stream->reportEvent(StreamEvents::LoadInitiated); - - // Verify load_cache_context_ is null (no asyncLoadCache was called) - ASSERT_FALSE(resource.load_cache_context_); - - // moveToNext should not trigger asyncLoadCache because LoadInitiated is set - stream->reportEvent(StreamEvents::CanRun); - auto new_state = stream->moveToNext(); - ASSERT_EQ(new_state, StreamState::RUNNING); - - // Still no asyncLoadCache context - ASSERT_FALSE(resource.load_cache_context_); + stream->prepare(); + stream->activate(); + ASSERT_EQ(stream->getStatus(), StreamState::RUNNING); + ASSERT_TRUE(stream->alive()); } -TEST_F(GenerateStreamStateTest, testPrefillFallbackDecodeGrowsBlocksAfterContext) { - auto stream = createStream({1, 2}, /*reuse_cache=*/false, RoleType::PREFILL); - - stream->reportEvent(StreamEvents::CanRun); - ASSERT_EQ(stream->moveToNext(), StreamState::RUNNING); - ASSERT_TRUE(stream->isContextStream()); - ASSERT_EQ(stream->curBlocksNum(), 1u); - - // PD fallback can continue decoding in the PREFILL role. The next decode - // token is at absolute position 2, which crosses the 2-token test block - // boundary and therefore requires a second block-table column. - stream->setIsContextStream(false); - stream->setSeqLength(3); - - ASSERT_EQ(stream->moveToNext(), StreamState::RUNNING); - EXPECT_EQ(stream->curBlocksNum(), 2u); +TEST_F(GenerateStreamStateTest, testFinishReleasesResource) { + auto stream = createStream(); + stream->prepare(); + stream->activate(); + ASSERT_EQ(stream->getStatus(), StreamState::RUNNING); + stream->finish(); + ASSERT_EQ(stream->getStatus(), StreamState::FINISHED); + ASSERT_FALSE(stream->alive()); } -TEST_F(GenerateStreamStateTest, testNormalPathTriggersAsyncLoadCache) { - // Create stream with reuse_cache enabled to trigger asyncLoadCache - auto stream = createStream({1, 2, 3, 4, 5, 6}, /*reuse_cache=*/true); - ASSERT_EQ(stream->getStatus(), StreamState::WAITING); - - // Normal path: moveToNext should trigger initKVBlock + asyncLoadCache - auto new_state = stream->moveToNext(); +TEST_F(GenerateStreamStateTest, testFinishIsIdempotent) { + auto stream = createStream(); + stream->prepare(); + stream->activate(); + stream->finish(); + ASSERT_FALSE(stream->alive()); + stream->finish(); + ASSERT_FALSE(stream->alive()); +} - // Should transition to LOADING_CACHE if asyncLoadCache was initiated - // or stay in WAITING if no connectors are available - ASSERT_TRUE(new_state == StreamState::LOADING_CACHE || new_state == StreamState::WAITING); +TEST_F(GenerateStreamStateTest, testPrepareWithReuseCacheReturnsLoading) { + // Only test if reuse_cache streams exist in test helpers + // If createStream doesn't support reuse_cache, skip this test + auto stream = createStream(); + // Without reuse cache connector, prepare returns false (no loading needed) + bool needs_loading = stream->prepare(); + ASSERT_FALSE(needs_loading); } } // namespace rtp_llm diff --git a/rtp_llm/cpp/model_rpc/BUILD b/rtp_llm/cpp/model_rpc/BUILD index 1d6d35678d..c602d0e1f6 100644 --- a/rtp_llm/cpp/model_rpc/BUILD +++ b/rtp_llm/cpp/model_rpc/BUILD @@ -62,9 +62,11 @@ cc_library( ":rpc_error_code", ":tensor_pb_convert", "//rtp_llm/cpp/cache:cache_types", + "//rtp_llm/cpp/cache:recent_cache_key_window", "//rtp_llm/cpp/cache:batch_kv_cache_resource", "//rtp_llm/cpp/cache:kv_cache_transfer_planner", "//rtp_llm/cpp/disaggregate/cache_store", + "//rtp_llm/cpp/distribute:rpc_cpu_tp_broadcaster_hdr", "//rtp_llm/cpp/engine_base:profiler", "//rtp_llm/cpp/engine_base/stream:generate_types", "//rtp_llm/cpp/metrics", diff --git a/rtp_llm/cpp/model_rpc/DecodeRpcServer.cc b/rtp_llm/cpp/model_rpc/DecodeRpcServer.cc index 2b89b11093..0eacdd049b 100644 --- a/rtp_llm/cpp/model_rpc/DecodeRpcServer.cc +++ b/rtp_llm/cpp/model_rpc/DecodeRpcServer.cc @@ -140,13 +140,14 @@ void DecodeRpcServer::allocateResource(DecodeGenerateContext& decode_context) { generate_stream->reportEvent(StreamEvents::CanRun); decode_context.setStream(generate_stream); - // WAITING -> LOADING_CACHE -> WAITING, 直到load cache完成并移动到 WAITING 状态 - // NOTE: 此处的 busy-wait 是安全的,因为 stream 尚未 enqueue 到 scheduler, - // 不会与其他线程并发调用 moveToNext()。gRPC 线程独占驱动状态机直到 WAITING。 - while (!generate_stream->hasError() && generate_stream->moveToNext() == StreamState::LOADING_CACHE) { + // Prepare KV cache allocation, then wait until the stream is ready. + // This busy-wait is safe because the stream has not been enqueued to the + // scheduler yet -- the gRPC thread exclusively drives the state machine. + generate_stream->prepare(); + while (generate_stream->alive() && !generate_stream->isReady()) { this_thread::sleep_for(chrono::milliseconds(1)); } - if (generate_stream->hasError()) { + if (!generate_stream->alive() || generate_stream->hasError()) { auto stream_error = generate_stream->statusInfo(); string error_msg = stream_error.ToString(); if (error_msg.empty()) { @@ -765,6 +766,52 @@ ErrorInfo DecodeRpcServer::loadCache(const LoadKVCacheContext& load_context) { parts.size() == 1, "Dsv4 fixed/SWA opaque block expects one part when CP-sliced, got %zu", parts.size()); auto& block = parts[0]; RTP_LLM_CHECK_WITH_INFO(block.addr != nullptr, "null DSV4 fixed/SWA block addr while slicing"); + if (region_name == KVCacheRegionName::SWA_KV) { + RTP_LLM_CHECK_WITH_INFO(gid < cfg.cache_specs.size(), "group id out of range for cache_specs: %zu", gid); + const auto& spec = cfg.cache_specs[gid]; + RTP_LLM_CHECK_WITH_INFO(spec != nullptr, "null cache spec for group %zu", gid); + const auto* state_spec = dynamic_cast(spec.get()); + RTP_LLM_CHECK_WITH_INFO(state_spec != nullptr, + "CP-sliced SWA_KV group %zu expects DSV4StateSpec, got %s", + gid, + spec->debugString().c_str()); + const size_t cp_size = static_cast(load_context.prefill_cp_size); + RTP_LLM_CHECK_WITH_INFO(state_spec->state_dim == DSV4_FP8_KV_ENTRY_BYTES + && state_spec->store_dtype == DataType::TYPE_UINT8, + "SWA_KV expects uint8 %u-byte entries, got state_dim=%u dtype=%d", + DSV4_FP8_KV_ENTRY_BYTES, + state_spec->state_dim, + static_cast(state_spec->store_dtype)); + RTP_LLM_CHECK_WITH_INFO(state_spec->entries_per_block % cp_size == 0, + "CP-sliced SWA_KV entries %u not divisible by cp_size %zu", + state_spec->entries_per_block, + cp_size); + + constexpr size_t kSwaTokenDataBytes = DSV4_FP8_MLA_BLOCK_ALIGNMENT_BYTES; + constexpr size_t kSwaTokenScaleBytes = DSV4_FP8_KV_ENTRY_BYTES - kSwaTokenDataBytes; + const size_t local_entries = state_spec->entries_per_block / cp_size; + const size_t data_bytes = local_entries * kSwaTokenDataBytes; + const size_t scale_bytes = local_entries * kSwaTokenScaleBytes; + const size_t data_offset = data_bytes * static_cast(peer_idx); + const size_t scale_region_offset = static_cast(state_spec->entries_per_block) * kSwaTokenDataBytes; + const size_t scale_offset = scale_region_offset + scale_bytes * static_cast(peer_idx); + RTP_LLM_CHECK_WITH_INFO( + scale_offset + scale_bytes <= block.size_bytes, + "Dsv4 SWA_KV DATA/SCALE slice exceeds block bytes: data=[%zu,%zu) scale=[%zu,%zu) block=%zu gid=%zu", + data_offset, + data_offset + data_bytes, + scale_offset, + scale_offset + scale_bytes, + block.size_bytes, + gid); + BlockInfo data_block = block; + BlockInfo scale_block = block; + data_block.addr = static_cast(static_cast(block.addr) + data_offset); + data_block.size_bytes = data_bytes; + scale_block.addr = static_cast(static_cast(block.addr) + scale_offset); + scale_block.size_bytes = scale_bytes; + return std::vector{data_block, scale_block}; + } const size_t slice_bytes = cpFixedSliceBytes(cfg, gid); const size_t slice_offset = slice_bytes * static_cast(peer_idx); RTP_LLM_CHECK_WITH_INFO(slice_offset + slice_bytes <= block.size_bytes, @@ -787,7 +834,7 @@ ErrorInfo DecodeRpcServer::loadCache(const LoadKVCacheContext& load_context) { return group_tokens > 0 && group_tokens == cfg.seq_size_per_block * static_cast(load_context.prefill_cp_size); }; - auto blockPositionsForLoad = [&](size_t block_num, + auto blockPositionsForLoad = [&](size_t block_num, const CacheConfig& cfg, bool cfg_use_hybrid, CacheGroupType group_type, @@ -815,8 +862,8 @@ ErrorInfo DecodeRpcServer::loadCache(const LoadKVCacheContext& load_context) { const size_t cp_size = static_cast(load_context.prefill_cp_size); const size_t compact_blocks = (block_num + cp_size - 1) / cp_size; const size_t reuse_blocks = static_cast(std::max(load_context.reuse_block_size, 0)); - const size_t start = cfg_use_hybrid ? (compact_blocks > 2 ? compact_blocks - 2 : 0) : - std::min(reuse_blocks, compact_blocks); + const size_t start = + cfg_use_hybrid ? (compact_blocks > 2 ? compact_blocks - 2 : 0) : std::min(reuse_blocks, compact_blocks); block_pos_list.reserve(compact_blocks - start); for (size_t compact_pos = start; compact_pos < compact_blocks; ++compact_pos) { block_pos_list.push_back(std::min((compact_pos + 1) * cp_size - 1, block_num - 1)); @@ -824,11 +871,11 @@ ErrorInfo DecodeRpcServer::loadCache(const LoadKVCacheContext& load_context) { return block_pos_list; }; auto cacheKeyIndexForBlock = [&](const CacheConfig& cfg, - KVCacheRegionName region_name, - size_t gid, - size_t block_pos, - size_t cache_key_count, - size_t& cache_key_index) { + KVCacheRegionName region_name, + size_t gid, + size_t block_pos, + size_t cache_key_count, + size_t& cache_key_index) { if (cache_key_count == 0) { return false; } @@ -888,8 +935,12 @@ ErrorInfo DecodeRpcServer::loadCache(const LoadKVCacheContext& load_context) { continue; } size_t cache_key_index = 0; - if (!cacheKeyIndexForBlock( - cache_config, region_name, gid, block_pos, load_context.cache_keys.size(), cache_key_index)) { + if (!cacheKeyIndexForBlock(cache_config, + region_name, + gid, + block_pos, + load_context.cache_keys.size(), + cache_key_index)) { continue; } auto cache_key = makeCacheKey( @@ -1012,9 +1063,8 @@ ErrorInfo DecodeRpcServer::loadCache(const LoadKVCacheContext& load_context) { region_name = mtp_cache_cfg.group_region_names[gid]; } CacheGroupType group_type = groupType(mtp_cache_cfg, mtp_use_hybrid, gid); - auto block_pos_list = - blockPositionsForLoad( - block_num, mtp_cache_cfg, mtp_use_hybrid, group_type, region_name, gid); + auto block_pos_list = blockPositionsForLoad( + block_num, mtp_cache_cfg, mtp_use_hybrid, group_type, region_name, gid); if (!shouldLoadGroupFromPeer(group_type, region_name, i)) { continue; diff --git a/rtp_llm/cpp/model_rpc/GenerateContext.h b/rtp_llm/cpp/model_rpc/GenerateContext.h index b88a3d64ad..1c8be80275 100644 --- a/rtp_llm/cpp/model_rpc/GenerateContext.h +++ b/rtp_llm/cpp/model_rpc/GenerateContext.h @@ -76,19 +76,17 @@ class GenerateContext { ErrorCode::GENERATE_TIMEOUT, \ "request cost time is " + std::to_string(request_cost_time_ms) + " ms" + ", request timeout is " \ + std::to_string(generate_context.request_timeout_ms) + " ms"); \ - generate_context.error_status = serializeErrorMsg(generate_context.request_key, \ - generate_context.request_info, \ - generate_context.error_info); \ + generate_context.error_status = serializeErrorMsg( \ + generate_context.request_key, generate_context.request_info, generate_context.error_info); \ return generate_context.error_status; \ } \ } #define CHECK_REQUEST_CANCELLED(generate_context) \ - if (generate_context.server_context->IsCancelled()) { \ + if (generate_context.server_context && generate_context.server_context->IsCancelled()) { \ generate_context.error_info = ErrorInfo(ErrorCode::CANCELLED, "request is cancelled"); \ - generate_context.error_status = serializeErrorMsg(generate_context.request_key, \ - generate_context.request_info, \ - generate_context.error_info); \ + generate_context.error_status = serializeErrorMsg( \ + generate_context.request_key, generate_context.request_info, generate_context.error_info); \ return generate_context.error_status; \ } diff --git a/rtp_llm/cpp/model_rpc/LocalRpcServer.cc b/rtp_llm/cpp/model_rpc/LocalRpcServer.cc index 292fe02277..59e7d877b0 100644 --- a/rtp_llm/cpp/model_rpc/LocalRpcServer.cc +++ b/rtp_llm/cpp/model_rpc/LocalRpcServer.cc @@ -9,6 +9,7 @@ #include "rtp_llm/cpp/model_rpc/proto/model_rpc_service.pb.h" #include "rtp_llm/cpp/config/EplbConfig.h" #include "rtp_llm/cpp/cache/Types.h" +#include "rtp_llm/cpp/distribute/RpcCpuTpBroadcaster.h" using namespace std; @@ -84,10 +85,9 @@ grpc::Status LocalRpcServer::serializeErrorMsg(const string& request_key, ErrorI return serializeErrorMsg(request_key, RequestInfo(), error_info); } -grpc::Status LocalRpcServer::serializeErrorMsg(const string& request_key, - const RequestInfo& request_info, - ErrorInfo error_info) { - const auto& error_msg = error_info.ToString(); +grpc::Status +LocalRpcServer::serializeErrorMsg(const string& request_key, const RequestInfo& request_info, ErrorInfo error_info) { + const auto& error_msg = error_info.ToString(); const auto request_log_tag = formatRequestLogTag(request_key, request_info); RTP_LLM_LOG_WARNING("%s, error code [%s], error message [%s]", request_log_tag.c_str(), @@ -130,7 +130,7 @@ grpc::Status LocalRpcServer::pollStreamOutput(grpc::ServerContext* c stream->generateConfig()->aux_info, maga_init_params_.misc_config.aux_string, stream->specialTokens().eos_token_id); - if (context->IsCancelled()) { + if (context && context->IsCancelled()) { stream->reportError(ErrorCode::CANCELLED, "request cancelled by user"); RTP_LLM_LOG_WARNING("request [%s] cancelled by user", request_key.c_str()); return grpc::Status(grpc::StatusCode::CANCELLED, "request cancelled by user"); @@ -158,7 +158,7 @@ grpc::Status LocalRpcServer::GenerateStreamCall(grpc::ServerContext* RTP_LLM_LOG_DEBUG("receive request %ld", request_id); auto generate_context = GenerateContext(request_id, request->generate_config().timeout_ms(), context, metrics_reporter_, meta_); - auto input = QueryConverter::transQuery(request); + auto input = QueryConverter::transQuery(request); generate_context.request_info = input->request_info; if (applyTimelineGate(generate_context.request_key, input->generate_config->gen_timeline, @@ -244,7 +244,7 @@ grpc::Status LocalRpcServer::GetWorkerStatus(grpc::ServerContext* context, RTP_LLM_LOG_DEBUG("getWorkerStatusInfo took %ld us", request_after_ws_time_us - request_begin_time_us); const auto& engine_schedule_info = status_info.engine_schedule_info; - response->set_role(status_info.role); + response->set_role(static_cast(status_info.role)); for (const auto& task : engine_schedule_info.running_task_info_list) { TaskInfoPB* task_info = response->add_running_task_info(); @@ -255,7 +255,12 @@ grpc::Status LocalRpcServer::GetWorkerStatus(grpc::ServerContext* context, task_info->set_iterate_count(task.iterate_count); task_info->set_end_time_ms(task.end_time_ms); task_info->set_dp_rank(status_info.dp_rank); - task_info->set_is_waiting(task.is_waiting); + task_info->set_phase(static_cast<::TaskPhase>(task.phase)); + task_info->set_batch_id(task.batch_id); + if (task.error_code != 0) { + task_info->mutable_error_info()->set_error_code(task.error_code); + task_info->mutable_error_info()->set_error_message(task.error_message); + } } for (const auto& task : engine_schedule_info.finished_task_info_list) { @@ -267,40 +272,48 @@ grpc::Status LocalRpcServer::GetWorkerStatus(grpc::ServerContext* context, task_info->set_iterate_count(task.iterate_count); task_info->set_end_time_ms(task.end_time_ms); task_info->set_dp_rank(status_info.dp_rank); - task_info->set_is_waiting(task.is_waiting); + task_info->set_phase(static_cast<::TaskPhase>(task.phase)); + task_info->set_batch_id(task.batch_id); + if (task.error_code != 0) { + task_info->mutable_error_info()->set_error_code(task.error_code); + task_info->mutable_error_info()->set_error_message(task.error_message); + } + } + + // Debug: log finished tasks details + if (!engine_schedule_info.finished_task_info_list.empty()) { + std::string task_details; + for (const auto& task : engine_schedule_info.finished_task_info_list) { + task_details += + " req_id=" + std::to_string(task.request_id) + " batch_id=" + std::to_string(task.batch_id) + "\n"; + } + RTP_LLM_LOG_INFO("GetWorkerStatus response: request_latest_finished_version=%ld, " + "response_latest_finished_version=%ld, " + "finished_tasks_count=%ld\n%s", + latest_finished_version, + status_info.latest_finished_version, + engine_schedule_info.finished_task_info_list.size(), + task_details.c_str()); } + response->set_dp_size(status_info.dp_size); response->set_tp_size(status_info.tp_size); response->set_status_version(status_info.status_version); response->set_latest_finished_version(status_info.latest_finished_version); response->set_alive(status_info.alive); response->set_precision(status_info.precision); + response->set_dp_rank(status_info.dp_rank); + auto kv_info = engine_->getCacheStatusInfo(-1, false); + response->set_available_kv_cache(kv_info.available_kv_cache); + response->set_total_kv_cache(kv_info.total_kv_cache); reportWorkerStatusTime(request_begin_time_us, request_after_ws_time_us); return grpc::Status::OK; } WorkerStatusInfo LocalRpcServer::getWorkerStatusInfo(int64_t latest_finished_version) { WorkerStatusInfo status_info; - status_info.engine_schedule_info = getEngineScheduleInfo(latest_finished_version); - switch (maga_init_params_.pd_sep_config.role_type) { - case RoleType::PDFUSION: - status_info.role = "RoleType.PDFUSION"; - break; - case RoleType::PREFILL: - status_info.role = "RoleType.PREFILL"; - break; - case RoleType::DECODE: - status_info.role = "RoleType.DECODE"; - break; - case RoleType::VIT: - status_info.role = "RoleType.VIT"; - break; - case RoleType::FRONTEND: - status_info.role = "RoleType.FRONTEND"; - break; - default: - status_info.role = "RoleType.UNKNOWN"; - } + status_info.engine_schedule_info = getEngineScheduleInfo(latest_finished_version); + status_info.role = maga_init_params_.pd_sep_config.role_type; status_info.dp_size = maga_init_params_.parallelism_config.dp_size; status_info.tp_size = maga_init_params_.parallelism_config.tp_size; status_info.dp_rank = maga_init_params_.parallelism_config.dp_rank; @@ -359,17 +372,8 @@ size_t LocalRpcServer::onflightRequestNum() { } EngineScheduleInfo LocalRpcServer::getEngineScheduleInfo(int64_t latest_finished_version) { - EngineScheduleInfo info = meta_->getEngineScheduleInfo(latest_finished_version); - std::vector running_task_info_list = engine_->getScheduler().runningTaskList(); - for (auto& task_info : info.running_task_info_list) { - for (auto& running_task : running_task_info_list) { - if (task_info.request_id == running_task.request_id) { - task_info.is_waiting = false; - } - } - } - auto last_schedule_time = engine_->getLastScheduleTime(); - // in case last_schedule_delta is negative + EngineScheduleInfo info = meta_->getEngineScheduleInfo(latest_finished_version); + auto last_schedule_time = engine_->getLastScheduleTime(); info.last_schedule_delta = std::max((int64_t)0, autil::TimeUtility::currentTimeInMilliSeconds() - last_schedule_time); return info; @@ -541,6 +545,18 @@ ::grpc::Status LocalRpcServer::ExecuteFunction(::grpc::ServerContext* contex return grpc::Status::OK; } +::grpc::Status LocalRpcServer::CpuTpBroadcast(::grpc::ServerContext* context, + const ::CpuTpBroadcastRequestPB* request, + ::CpuTpBroadcastResponsePB* response) { + if (context->IsCancelled()) { + response->set_success(false); + response->set_error_message("request is cancelled"); + return grpc::Status(grpc::StatusCode::CANCELLED, "request is cancelled"); + } + (void)RpcCpuTpBroadcaster::instance().handleBroadcastRequest(*request, response); + return grpc::Status::OK; +} + grpc::Status LocalRpcServer::SetPause(grpc::ServerContext* context, const EmptyPB* request, EmptyPB* response) { RTP_LLM_LOG_DEBUG("receive cacheStatus rpc request from client: %s", context->peer().c_str()); engine_->pause(); diff --git a/rtp_llm/cpp/model_rpc/LocalRpcServer.h b/rtp_llm/cpp/model_rpc/LocalRpcServer.h index 473d1ac2a5..88ed6ef6cf 100644 --- a/rtp_llm/cpp/model_rpc/LocalRpcServer.h +++ b/rtp_llm/cpp/model_rpc/LocalRpcServer.h @@ -93,6 +93,9 @@ class LocalRpcServer { ::grpc::Status ExecuteFunction(::grpc::ServerContext* context, const ::FunctionRequestPB* request, ::FunctionResponsePB* response); + ::grpc::Status CpuTpBroadcast(::grpc::ServerContext* context, + const ::CpuTpBroadcastRequestPB* request, + ::CpuTpBroadcastResponsePB* response); public: typedef grpc::internal::WriterInterface WriterInterface; diff --git a/rtp_llm/cpp/model_rpc/LocalRpcServiceImpl.h b/rtp_llm/cpp/model_rpc/LocalRpcServiceImpl.h index 1e9b5aafca..38f57f0654 100644 --- a/rtp_llm/cpp/model_rpc/LocalRpcServiceImpl.h +++ b/rtp_llm/cpp/model_rpc/LocalRpcServiceImpl.h @@ -15,79 +15,155 @@ class LocalRpcServiceImpl: public RpcService::Service { public: LocalRpcServiceImpl() {} virtual ~LocalRpcServiceImpl() {} + void prepareLocalServer() { + if (!local_server_) { + local_server_ = std::make_shared(); + } + } virtual grpc::Status init(const EngineInitParams& maga_init_params, py::object mm_process_engine, std::unique_ptr propose_params) { - local_server_ = std::make_shared(); + prepareLocalServer(); return local_server_->init(maga_init_params, mm_process_engine, std::move(propose_params)); } grpc::Status init(const EngineInitParams& maga_init_params, py::object mm_process_engine, std::unique_ptr propose_params, py::object weight_manager) { - local_server_ = std::make_shared(); + (void)weight_manager; + prepareLocalServer(); return local_server_->init(maga_init_params, mm_process_engine, std::move(propose_params)); } grpc::Status GenerateStreamCall(grpc::ServerContext* context, const GenerateInputPB* request, grpc::ServerWriter* writer) override { + if (!readyForRegularRpc()) { + return notReadyStatus("GenerateStreamCall"); + } return local_server_->GenerateStreamCall(context, request, writer); } + grpc::Status EnqueueBatch(grpc::ServerContext* context, + const EnqueueBatchRequestPB* request, + EnqueueBatchResponsePB* response) override { + (void)context; + (void)request; + (void)response; + return grpc::Status(grpc::StatusCode::UNIMPLEMENTED, "EnqueueBatch not implemented on this role"); + } + + grpc::Status EnqueueGroup(grpc::ServerContext* context, + const EnqueueGroupRequestPB* request, + EnqueueBatchResponsePB* response) override { + (void)context; + (void)request; + (void)response; + return grpc::Status(grpc::StatusCode::UNIMPLEMENTED, "EnqueueGroup not implemented on this role"); + } + + grpc::Status FetchResponse(grpc::ServerContext* context, + const FetchRequestPB* request, + grpc::ServerWriter* writer) override { + (void)context; + (void)request; + (void)writer; + return grpc::Status(grpc::StatusCode::UNIMPLEMENTED, "FetchResponse not implemented on this role"); + } + + grpc::Status Cancel(grpc::ServerContext* context, const CancelRequestPB* request, EmptyPB* response) override { + (void)context; + (void)request; + (void)response; + return grpc::Status(grpc::StatusCode::UNIMPLEMENTED, "Cancel not implemented on this role"); + } + ::grpc::Status GetWorkerStatus(::grpc::ServerContext* context, const StatusVersionPB* request, WorkerStatusPB* response) override { + if (!readyForRegularRpc()) { + return notReadyStatus("GetWorkerStatus"); + } return local_server_->GetWorkerStatus(context, request, response); } ::grpc::Status UpdateWeights(::grpc::ServerContext* context, const UpdateWeightsRequestPB* request, EmptyPB* response) override { + if (!readyForRegularRpc()) { + return notReadyStatus("UpdateWeights"); + } return local_server_->UpdateWeights(context, request, response); } ::grpc::Status GetCacheStatus(::grpc::ServerContext* context, const CacheVersionPB* request, CacheStatusPB* response) override { + if (!readyForRegularRpc()) { + return notReadyStatus("GetCacheStatus"); + } return local_server_->GetCacheStatus(context, request, response); } ::grpc::Status UpdateSchedulerInfo(::grpc::ServerContext* context, const UpdateSchedulerInfoRequestPB* request, EmptyPB* response) override { + if (!readyForRegularRpc()) { + return notReadyStatus("UpdateSchedulerInfo"); + } return local_server_->UpdateSchedulerInfo(context, request, response); } ::grpc::Status SetLogLevel(::grpc::ServerContext* context, const SetLogLevelRequestPB* request, EmptyPB* response) override { + if (!readyForRegularRpc()) { + return notReadyStatus("SetLogLevel"); + } return local_server_->SetLogLevel(context, request, response); } ::grpc::Status StartProfile(::grpc::ServerContext* context, const StartProfileRequestPB* request, EmptyPB* response) override { + if (!readyForRegularRpc()) { + return notReadyStatus("StartProfile"); + } return local_server_->StartProfile(context, request, response); } ::grpc::Status StartProfileInternal(::grpc::ServerContext* context, const StartProfileInternalRequestPB* request, EmptyPB* response) override { + if (!readyForRegularRpc()) { + return notReadyStatus("StartProfileInternal"); + } return local_server_->StartProfileInternal(context, request, response); } ::grpc::Status CheckHealth(::grpc::ServerContext* context, const EmptyPB* request, CheckHealthResponsePB* response) override { + if (!readyForRegularRpc()) { + return notReadyStatus("CheckHealth"); + } return local_server_->CheckHealth(context, request, response); } ::grpc::Status UpdateEplbConfig(::grpc::ServerContext* context, const UpdateEplbConfigRequestPB* request, EmptyPB* response) override { + if (!readyForRegularRpc()) { + return notReadyStatus("UpdateEplbConfig"); + } return local_server_->UpdateEplbConfig(context, request, response); } ::grpc::Status SetPause(::grpc::ServerContext* context, const EmptyPB* request, EmptyPB* response) override { + if (!readyForRegularRpc()) { + return notReadyStatus("SetPause"); + } return local_server_->SetPause(context, request, response); } ::grpc::Status SetRestart(::grpc::ServerContext* context, const EmptyPB* request, EmptyPB* response) override { + if (!readyForRegularRpc()) { + return notReadyStatus("SetRestart"); + } return local_server_->SetRestart(context, request, response); } @@ -124,10 +200,30 @@ class LocalRpcServiceImpl: public RpcService::Service { ::grpc::Status ExecuteFunction(::grpc::ServerContext* context, const ::FunctionRequestPB* request, ::FunctionResponsePB* response) override { + if (!readyForRegularRpc()) { + return notReadyStatus("ExecuteFunction"); + } return local_server_->ExecuteFunction(context, request, response); } + ::grpc::Status CpuTpBroadcast(::grpc::ServerContext* context, + const ::CpuTpBroadcastRequestPB* request, + ::CpuTpBroadcastResponsePB* response) override { + if (!local_server_) { + return grpc::Status(grpc::StatusCode::UNAVAILABLE, "local rpc server is initializing"); + } + return local_server_->CpuTpBroadcast(context, request, response); + } + protected: + bool readyForRegularRpc() const { + return local_server_ && local_server_->getEngine(); + } + + grpc::Status notReadyStatus(const char* method) const { + return grpc::Status(grpc::StatusCode::UNAVAILABLE, std::string(method) + " rejected: engine is initializing"); + } + std::shared_ptr local_server_; }; diff --git a/rtp_llm/cpp/model_rpc/PrefillGenerateContext.h b/rtp_llm/cpp/model_rpc/PrefillGenerateContext.h index 8691db2f48..6c3bbda5a0 100644 --- a/rtp_llm/cpp/model_rpc/PrefillGenerateContext.h +++ b/rtp_llm/cpp/model_rpc/PrefillGenerateContext.h @@ -11,6 +11,8 @@ namespace rtp_llm { +struct AsyncProducerCancelState; + struct PrefillStatInfo { enum ExecuteStage { start = 0, @@ -48,8 +50,8 @@ struct RPCContext { return request->request_id(); } - const GenerateInputPB* request; - grpc::ServerWriter* writer; + const GenerateInputPB* request; + grpc::internal::WriterInterface* writer; }; class PrefillGenerateContext: public GenerateContext { @@ -79,19 +81,21 @@ class PrefillGenerateContext: public GenerateContext { public: typedef grpc::ClientReaderWriter ClientStream; - RemoteServerResource* resource; - RPCContext rpc_context; - std::shared_ptr generate_input; - std::string decode_addr; - std::vector prefill_worker_cache_store_addrs; - GrpcConnection grpc_connection; - std::shared_ptr stub; - std::shared_ptr client_context; - std::shared_ptr client_stream; - bool grpc_stream_closed = false; - grpc::Status last_grpc_stream_closed_status = grpc::Status::OK; - PrefillStatInfo stat_info; - int64_t loading_cache_requests = 0; + RemoteServerResource* resource; + RPCContext rpc_context; + std::shared_ptr generate_input; + std::string decode_addr; + std::vector prefill_worker_cache_store_addrs; + GrpcConnection grpc_connection; + std::shared_ptr stub; + std::shared_ptr client_context; + std::shared_ptr client_stream; + std::shared_ptr cancel_state; + bool grpc_stream_closed = false; + grpc::Status last_grpc_stream_closed_status = grpc::Status::OK; + PrefillStatInfo stat_info; + int64_t loading_cache_requests = 0; + bool recent_cache_key_metric_reported = false; }; } // namespace rtp_llm diff --git a/rtp_llm/cpp/model_rpc/PrefillRpcServer.cc b/rtp_llm/cpp/model_rpc/PrefillRpcServer.cc index 99d15026bb..a3e3377903 100644 --- a/rtp_llm/cpp/model_rpc/PrefillRpcServer.cc +++ b/rtp_llm/cpp/model_rpc/PrefillRpcServer.cc @@ -1,15 +1,38 @@ #include "autil/TimeUtility.h" #include "rtp_llm/cpp/model_rpc/QueryConverter.h" #include "rtp_llm/cpp/model_rpc/PrefillRpcServer.h" +#include "rtp_llm/cpp/cache/PrefillCacheHitMetricsReporter.h" #include "rtp_llm/cpp/utils/DebugUtils.h" +#include "rtp_llm/cpp/utils/HashUtil.h" #include "rtp_llm/cpp/config/ConfigModules.h" #include "rtp_llm/cpp/engine_base/Host.h" #include "rtp_llm/cpp/utils/ProfilingScope.h" #include "rtp_llm/cpp/models/logits_processor/LogitsProcessorFactory.h" -#include +#include +#include +#include +#include +#include #include +#include #include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include #include +#include +#include +#include using namespace std; using namespace autil::legacy; @@ -19,6 +42,23 @@ using grpc::ClientContext; namespace rtp_llm { +PrefillRpcServer::~PrefillRpcServer() { + stopAsyncResponseWorkers(); + stopResponseRegistryGc(); + if (enqueue_worker_pool_) { + enqueue_worker_pool_->stop(); + enqueue_worker_pool_.reset(); + } + if (worker_lambda_pool_) { + worker_lambda_pool_->stop(); + worker_lambda_pool_.reset(); + } + if (slot_worker_pool_) { + slot_worker_pool_->stop(); + slot_worker_pool_.reset(); + } +} + namespace { bool envValueIsTrue(const char* value) { @@ -27,6 +67,12 @@ bool envValueIsTrue(const char* value) { || strcasecmp(value, "yes") == 0); } +bool envValueIsFalse(const char* value) { + return value != nullptr + && (strcmp(value, "0") == 0 || strcasecmp(value, "false") == 0 || strcasecmp(value, "off") == 0 + || strcasecmp(value, "no") == 0); +} + bool prefillTraceLogEnabled() { static const bool enabled = []() { const char* value = std::getenv("PREFILL_TRACE_LOG_ENABLE"); @@ -41,6 +87,437 @@ bool prefillTraceLogEnabled() { return enabled; } +bool prefillCacheDebugLogEnabled() { + const bool enabled = prefillTraceLogEnabled(); + return enabled; +} + +bool prefillTheoryHitLogEnabled() { + static const bool enabled = []() { + const char* value = std::getenv("PREFILL_THEORY_HIT_LOG_ENABLED"); + if (value == nullptr || value[0] == 0) { + value = std::getenv("PREFILL_THEORY_HIT_LOG_ENABLE"); + } + if (value == nullptr || value[0] == 0) { + return true; + } + return !envValueIsFalse(value); + }(); + return enabled; +} + +const char* prefillTheoryHitLogPath() { + const char* value = std::getenv("PREFILL_THEORY_HIT_LOG_PATH"); + if (value == nullptr || value[0] == 0) { + return "/home/admin/logs/prefill_theory_hit.log"; + } + return value; +} + +double theoryHitRatio(int64_t hit_count, int64_t total_count) { + return total_count > 0 ? static_cast(hit_count) / static_cast(total_count) : 0.0; +} + +struct TheoryHitWindowSnapshot { + const char* label = ""; + int64_t window_ms = 0; + int64_t hit_count = 0; + int64_t total_count = 0; + double hit_ratio = 0.0; +}; + +struct TheoryHitStatsSnapshot { + int64_t now_ms = 0; + int64_t request_hit_count = 0; + int64_t request_total_count = 0; + double request_hit_ratio = 0.0; + int64_t all_hit_count = 0; + int64_t all_total_count = 0; + double all_hit_ratio = 0.0; + TheoryHitWindowSnapshot window_1m; + TheoryHitWindowSnapshot window_5m; + TheoryHitWindowSnapshot window_10m; + TheoryHitWindowSnapshot window_15m; +}; + +void markResponseEntryDone(const std::shared_ptr& entry, const grpc::Status& status) { + if (!entry) { + return; + } + { + std::lock_guard lock(entry->mu); + if (!status.ok()) { + entry->error_status = status; + } + entry->done.store(true); + entry->last_activity_us = currentTimeUs(); + entry->cancel_producer = nullptr; + } + entry->cv.notify_all(); +} + +grpc::Status statusFromErrorInfo(const ErrorInfo& error_info) { + if (!error_info.hasError()) { + return grpc::Status::OK; + } + return grpc::Status(grpc::StatusCode::INTERNAL, error_info.ToString()); +} + +void addBatchSuccess(EnqueueBatchResponsePB* response, int64_t request_id) { + auto* success = response->add_successes(); + success->set_request_id(request_id); +} + +void addBatchError(EnqueueBatchResponsePB* response, int64_t request_id, int64_t code, const std::string& msg) { + auto* error = response->add_errors(); + error->set_request_id(request_id); + auto* error_info = error->mutable_error_info(); + error_info->set_error_code(code); + error_info->set_error_message(msg); +} + +// Helper to detect whether a Future (std::future or autil Future) is ready. +// Uses SFINAE: std::future::wait_for returns std::future_status; autil +// Future::wait_for also returns std::future_status. +template +bool futureIsReady(FutureT& f) { + return f.valid() && f.wait_for(std::chrono::milliseconds(0)) == std::future_status::ready; +} + +template +void detachLeftoverFutures(std::vector& futures) { + // Clear leftover futures naturally; no longer creates a throwaway thread. + // Futures that are not yet ready are simply abandoned — the thread pool + // owns the actual work; the future object is just a handle. + futures.clear(); +} + +template +void drainReadyFutures(std::vector& futures, std::chrono::milliseconds timeout) { + auto deadline = std::chrono::steady_clock::now() + timeout; + while (std::chrono::steady_clock::now() < deadline) { + bool all_done = true; + bool any_ready = false; + for (auto& f : futures) { + if (f.valid()) { + if (futureIsReady(f)) { + try { + f.get(); + } catch (...) {} + any_ready = true; + } else { + all_done = false; + } + } + } + if (all_done) + break; + if (!any_ready) { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + } +} + +template +void collectFutures(std::vector& futures, + std::chrono::steady_clock::time_point deadline, + OnReady&& on_ready, + OnTimeout&& on_timeout) { + std::vector collected(futures.size(), false); + size_t remaining = futures.size(); + while (remaining > 0 && std::chrono::steady_clock::now() < deadline) { + bool any_ready = false; + for (size_t i = 0; i < futures.size(); ++i) { + if (!collected[i] && futureIsReady(futures[i])) { + collected[i] = true; + --remaining; + any_ready = true; + on_ready(i); + } + } + if (remaining > 0 && !any_ready) { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } + } + for (size_t i = 0; i < futures.size(); ++i) { + if (!collected[i]) { + if (futureIsReady(futures[i])) { + on_ready(i); + } else { + on_timeout(i); + } + } + } +} + +} // namespace + +struct AsyncProducerCancelState { + std::atomic cancelled{false}; + std::mutex mu; + std::weak_ptr client_context; + std::weak_ptr stream; +}; + +std::function makeAsyncProducerCancelCallback(const std::shared_ptr& state) { + return [state] { + bool expected = false; + if (!state->cancelled.compare_exchange_strong(expected, true)) { + return; + } + + std::shared_ptr client_context; + std::shared_ptr stream; + { + std::lock_guard lock(state->mu); + client_context = state->client_context.lock(); + stream = state->stream.lock(); + } + if (client_context) { + client_context->TryCancel(); + } + if (stream) { + stream->reportError(ErrorCode::CANCELLED, "request cancelled"); + } + }; +} + +void refreshAsyncProducerCancelState(const std::shared_ptr& state, + const std::shared_ptr& client_context, + const std::shared_ptr& stream) { + bool should_cancel = false; + { + std::lock_guard lock(state->mu); + state->client_context = client_context; + state->stream = stream; + should_cancel = state->cancelled.load(); + } + if (should_cancel) { + if (client_context) { + client_context->TryCancel(); + } + if (stream) { + stream->reportError(ErrorCode::CANCELLED, "request cancelled"); + } + } +} + +class ScopeExit { +public: + explicit ScopeExit(std::function fn): fn_(std::move(fn)) {} + ~ScopeExit() { + if (fn_) { + fn_(); + } + } + ScopeExit(const ScopeExit&) = delete; + ScopeExit& operator=(const ScopeExit&) = delete; + +private: + std::function fn_; +}; + +void cancelResponseEntry(const std::shared_ptr& entry) { + if (!entry) { + return; + } + std::function cancel_producer; + { + std::lock_guard lock(entry->mu); + entry->cancelled.store(true); + entry->last_activity_us = currentTimeUs(); + cancel_producer = entry->cancel_producer; + entry->cancel_producer = nullptr; + } + if (cancel_producer) { + cancel_producer(); + } + entry->cv.notify_all(); +} + +namespace { + +class TheoryHitStats { +public: + TheoryHitStats() { + bucket_seconds_.fill(std::numeric_limits::min()); + bucket_hit_counts_.fill(0); + bucket_total_counts_.fill(0); + } + + TheoryHitStatsSnapshot record(int64_t hit_count, int64_t total_count) { + std::lock_guard lock(mutex_); + const int64_t now_ms = autil::TimeUtility::currentTimeInMilliSeconds(); + const int64_t current_second = now_ms / 1000; + const int64_t safe_hit = std::max(0, hit_count); + const int64_t safe_total = std::max(0, total_count); + + if (safe_total > 0) { + const size_t index = static_cast(current_second % kBucketCount); + if (bucket_seconds_[index] != current_second) { + bucket_seconds_[index] = current_second; + bucket_hit_counts_[index] = 0; + bucket_total_counts_[index] = 0; + } + bucket_hit_counts_[index] += safe_hit; + bucket_total_counts_[index] += safe_total; + all_hit_count_ += safe_hit; + all_total_count_ += safe_total; + } + + TheoryHitStatsSnapshot snapshot; + snapshot.now_ms = now_ms; + snapshot.request_hit_count = safe_hit; + snapshot.request_total_count = safe_total; + snapshot.request_hit_ratio = theoryHitRatio(safe_hit, safe_total); + snapshot.all_hit_count = all_hit_count_; + snapshot.all_total_count = all_total_count_; + snapshot.all_hit_ratio = theoryHitRatio(all_hit_count_, all_total_count_); + snapshot.window_1m = windowSnapshot("1m", 60 * 1000, current_second); + snapshot.window_5m = windowSnapshot("5m", 5 * 60 * 1000, current_second); + snapshot.window_10m = windowSnapshot("10m", 10 * 60 * 1000, current_second); + snapshot.window_15m = windowSnapshot("15m", 15 * 60 * 1000, current_second); + return snapshot; + } + +private: + static constexpr size_t kBucketCount = 15 * 60 + 2; + + TheoryHitWindowSnapshot windowSnapshot(const char* label, int64_t window_ms, int64_t current_second) const { + TheoryHitWindowSnapshot snapshot; + snapshot.label = label; + snapshot.window_ms = window_ms; + const int64_t window_seconds = window_ms / 1000; + for (size_t i = 0; i < kBucketCount; ++i) { + const int64_t age_seconds = current_second - bucket_seconds_[i]; + if (age_seconds >= 0 && age_seconds < window_seconds) { + snapshot.hit_count += bucket_hit_counts_[i]; + snapshot.total_count += bucket_total_counts_[i]; + } + } + snapshot.hit_ratio = theoryHitRatio(snapshot.hit_count, snapshot.total_count); + return snapshot; + } + +private: + std::array bucket_seconds_; + std::array bucket_hit_counts_; + std::array bucket_total_counts_; + int64_t all_hit_count_ = 0; + int64_t all_total_count_ = 0; + std::mutex mutex_; +}; + +std::string formatTheoryTimestampMs(int64_t timestamp_ms) { + const time_t seconds = static_cast(timestamp_ms / 1000); + struct tm local_time; + if (localtime_r(&seconds, &local_time) == nullptr) { + return std::to_string(timestamp_ms); + } + + char date_buffer[64]; + if (strftime(date_buffer, sizeof(date_buffer), "%Y-%m-%dT%H:%M:%S", &local_time) == 0) { + return std::to_string(timestamp_ms); + } + + char offset_buffer[16]; + if (strftime(offset_buffer, sizeof(offset_buffer), "%z", &local_time) == 0) { + offset_buffer[0] = '\0'; + } + + char output[96]; + snprintf(output, sizeof(output), "%s.%03ld%s", date_buffer, static_cast(timestamp_ms % 1000), offset_buffer); + return output; +} + +void appendPrefillTheoryHitLogLine(const std::string& line) { + if (!prefillTheoryHitLogEnabled()) { + return; + } + static std::mutex log_mutex; + static std::ofstream log_file; + static bool open_failed = false; + + std::lock_guard lock(log_mutex); + if (!log_file.is_open() && !open_failed) { + const char* path = prefillTheoryHitLogPath(); + log_file.open(path, std::ios::out | std::ios::app); + if (!log_file.is_open()) { + open_failed = true; + RTP_LLM_LOG_WARNING("Failed to open prefill theory hit log path: %s", path); + return; + } + RTP_LLM_LOG_INFO("Prefill theory hit log path: %s", path); + } + if (!log_file.is_open()) { + return; + } + log_file << line << '\n'; + log_file.flush(); +} + +std::string formatPrefillTheoryHitLogLine(PrefillGenerateContext& prefill_context, + int64_t token_num, + int seq_size_per_block, + const TheoryHitStatsSnapshot& snapshot) { + std::ostringstream oss; + oss << std::fixed << std::setprecision(6) << "time=" << formatTheoryTimestampMs(snapshot.now_ms) + << " ts_ms=" << snapshot.now_ms << " source=prefill" + << " request_id=" << prefill_context.request_id << " request_key=" << prefill_context.request_key + << " token_num=" << token_num << " seq_size_per_block=" << seq_size_per_block + << " request_hit=" << snapshot.request_hit_count << " request_total=" << snapshot.request_total_count + << " request_ratio=" << snapshot.request_hit_ratio << " all_hit=" << snapshot.all_hit_count + << " all_total=" << snapshot.all_total_count << " all_ratio=" << snapshot.all_hit_ratio + << " win1m_hit=" << snapshot.window_1m.hit_count << " win1m_total=" << snapshot.window_1m.total_count + << " win1m_ratio=" << snapshot.window_1m.hit_ratio << " win5m_hit=" << snapshot.window_5m.hit_count + << " win5m_total=" << snapshot.window_5m.total_count << " win5m_ratio=" << snapshot.window_5m.hit_ratio + << " win10m_hit=" << snapshot.window_10m.hit_count << " win10m_total=" << snapshot.window_10m.total_count + << " win10m_ratio=" << snapshot.window_10m.hit_ratio << " win15m_hit=" << snapshot.window_15m.hit_count + << " win15m_total=" << snapshot.window_15m.total_count << " win15m_ratio=" << snapshot.window_15m.hit_ratio; + return oss.str(); +} + +std::string cacheKeyPreview(const std::vector& keys, size_t limit = 6) { + std::ostringstream oss; + oss << "["; + for (size_t i = 0; i < keys.size() && i < limit; ++i) { + if (i != 0) { + oss << ","; + } + oss << keys[i]; + } + if (keys.size() > limit) { + oss << ",..."; + } + oss << "]"; + return oss.str(); +} + +std::string cacheKeysToString(const std::vector& keys) { + std::ostringstream oss; + oss << "["; + for (size_t i = 0; i < keys.size(); ++i) { + if (i != 0) { + oss << ","; + } + oss << keys[i]; + } + oss << "]"; + return oss.str(); +} + +std::string cacheKeyDigest(const std::vector& keys) { + uint64_t digest = 14695981039346656037ULL; + for (const auto cache_key : keys) { + uint64_t value = static_cast(cache_key); + digest ^= value; + digest *= 1099511628211ULL; + digest ^= value >> 32; + digest *= 1099511628211ULL; + } + return std::to_string(digest); +} + const char* prefillStageName(PrefillStatInfo::ExecuteStage stage) { switch (stage) { case PrefillStatInfo::start: @@ -91,6 +568,64 @@ void logPrefillFailureTrace(const char* event, PrefillGenerateContext& prefill_c prefill_context.error_info.ToString().c_str()); } +std::vector buildFullBlockCacheKeys(torch::Tensor input_ids, int seq_size_per_block) { + std::vector cache_keys; + if (seq_size_per_block <= 0 || !input_ids.defined() || input_ids.numel() <= 0) { + return cache_keys; + } + + if (!input_ids.device().is_cpu()) { + input_ids = input_ids.cpu(); + } + if (!input_ids.is_contiguous()) { + input_ids = input_ids.contiguous(); + } + if (input_ids.scalar_type() != torch::kInt32) { + input_ids = input_ids.to(torch::kInt32); + } + + const int64_t token_num = input_ids.numel(); + const int64_t block_count = token_num / seq_size_per_block; + if (block_count <= 0) { + return cache_keys; + } + cache_keys.reserve(static_cast(block_count)); + + auto* token_ids = input_ids.data_ptr(); + int64_t rolling_hash = 0; + for (int64_t block_idx = 0; block_idx < block_count; ++block_idx) { + const int64_t pos = block_idx * seq_size_per_block; + rolling_hash = rtp_llm::hashInt64Array( + rolling_hash, token_ids + pos, token_ids + pos + static_cast(seq_size_per_block)); + cache_keys.push_back(static_cast(rolling_hash)); + } + return cache_keys; +} + +void fillPrefillRecentCacheKeyMetricsCollector(PrefillRecentCacheKeyMetricsCollector& collector, + const RecentCacheKeyWindow::Snapshot& snapshot) { + collector.has_value = true; + collector.request_count = true; + collector.empty_request_count = snapshot.request_occurrences == 0; + collector.hit_count = snapshot.request_hit_occurrences; + collector.total_count = snapshot.request_occurrences; + collector.hit_ratio = snapshot.request_hit_ratio; + collector.retained_occurrences = snapshot.retained_occurrences; + collector.retained_unique_cache_keys = static_cast(snapshot.retained_unique_cache_keys); + collector.time_window_ms = snapshot.time_window_ms; +} + +void fillPrefillTheoryHitMetricsCollector(PrefillRecentCacheKeyMetricsCollector& collector, + const TheoryHitStatsSnapshot& snapshot) { + if (snapshot.all_total_count <= 0) { + return; + } + collector.theory_has_value = true; + collector.theory_all_hit_count = snapshot.all_hit_count; + collector.theory_all_total_count = snapshot.all_total_count; + collector.theory_all_hit_ratio = snapshot.all_hit_ratio; +} + } // namespace #define CLIENT_GRPC_RET_IF_ERROR(prefill_context, state, error_code_value) \ @@ -152,6 +687,96 @@ void logPrefillFailureTrace(const char* event, PrefillGenerateContext& prefill_c return; \ } +void PrefillRpcServer::startResponseRegistryGc() { + if (response_gc_thread_.joinable()) { + return; + } + response_gc_stop_.store(false); + response_gc_thread_ = std::thread([this] { + std::unique_lock lock(response_gc_mu_); + int gc_counter = 0; + while (!response_gc_stop_.load()) { + response_gc_cv_.wait_for(lock, std::chrono::seconds(10), [this] { return response_gc_stop_.load(); }); + if (response_gc_stop_.load()) { + break; + } + lock.unlock(); + reportPoolMetrics(); + gc_counter++; + if (gc_counter >= 6) { // GC every 60 seconds + response_registry_.gc(std::chrono::minutes(10)); + gc_counter = 0; + } + lock.lock(); + } + }); +} + +void PrefillRpcServer::stopResponseRegistryGc() { + response_gc_stop_.store(true); + response_gc_cv_.notify_all(); + if (response_gc_thread_.joinable()) { + response_gc_thread_.join(); + } +} + +bool PrefillRpcServer::tryStartAsyncResponseWorker() { + std::lock_guard lock(response_worker_mu_); + if (response_worker_stop_.load()) { + return false; + } + ++response_worker_count_; + return true; +} + +void PrefillRpcServer::finishAsyncResponseWorker() { + { + std::lock_guard lock(response_worker_mu_); + if (response_worker_count_ > 0) { + --response_worker_count_; + } + } + response_worker_cv_.notify_all(); +} + +void PrefillRpcServer::stopAsyncResponseWorkers() { + { + std::lock_guard lock(response_worker_mu_); + response_worker_stop_.store(true); + } + response_registry_.cancelAll(); + + static constexpr auto kStopTimeout = std::chrono::seconds(30); + std::unique_lock lock(response_worker_mu_); + bool all_done = response_worker_cv_.wait_for(lock, kStopTimeout, [this] { return response_worker_count_ == 0; }); + + if (!all_done) { + RTP_LLM_LOG_WARNING("stopAsyncResponseWorkers: timeout after %lds, still %zu workers active. Force resetting.", + kStopTimeout.count(), + response_worker_count_); + response_worker_count_ = 0; + // Notify other waiters that we've force-reset + response_worker_cv_.notify_all(); + } +} + +std::string PrefillRpcServer::batchTargetAddrForDpRank(int dp_rank) const { + if (dp_rank < 0 || dp_rank >= maga_init_params_.parallelism_config.dp_size) { + return ""; + } + const auto& all_workers = maga_init_params_.runtime_config.all_worker_grpc_addrs; + const int64_t tp_size = std::max(1, maga_init_params_.parallelism_config.tp_size); + const int64_t world_rank = static_cast(dp_rank) * tp_size; + if (world_rank >= 0 && world_rank < static_cast(all_workers.size())) { + return all_workers[world_rank]; + } + if (dp_rank == maga_init_params_.parallelism_config.dp_rank + && !maga_init_params_.runtime_config.worker_grpc_addrs.empty()) { + return maga_init_params_.runtime_config.worker_grpc_addrs.front(); + } + return ""; +} + grpc::Status PrefillRpcServer::init(const EngineInitParams& maga_init_params, py::object mm_process_engine, std::unique_ptr propose_params) { @@ -161,9 +786,93 @@ grpc::Status PrefillRpcServer::init(const EngineInitParams& if (!ret.ok()) { return ret; } + initThreadPools(); + if (PrefillCacheHitMetricsReporter::enabled()) { + prefill_recent_cache_key_window_ = std::make_unique(); + } else { + RTP_LLM_LOG_INFO("prefill recent-cache-key metrics disabled by PREFILL_CACHE_HIT_METRIC_ENABLE"); + } + startResponseRegistryGc(); return grpc::Status::OK; } +void PrefillRpcServer::initThreadPools() { + const auto& parallelism_config = maga_init_params_.parallelism_config; + const auto& scheduler_config = maga_init_params_.runtime_config.fifo_scheduler_config; + const auto& pd_sep_config = maga_init_params_.pd_sep_config; + const int dp_size = std::max(1, static_cast(parallelism_config.dp_size)); + const int max_context_batch = std::max(1, static_cast(scheduler_config.max_context_batch_size)); + + // enqueue pool: L1 DP dispatch only (fast, ms-level, must never block) + // Configurable via pd_sep_config.prefill_enqueue_pool_size (0 = use formula default) + const int enqueue_threads = pd_sep_config.prefill_enqueue_pool_size > 0 ? + static_cast(pd_sep_config.prefill_enqueue_pool_size) : + std::max(4, dp_size * (dp_size <= 4 ? 4 : 2)); + const int enqueue_queue = enqueue_threads * 2; + + enqueue_worker_pool_ = + std::make_shared(enqueue_threads, enqueue_queue, nullptr, "PrefillEnqueuePool"); + RTP_LLM_CHECK_WITH_INFO(enqueue_worker_pool_->start(), "PrefillRpcServer enqueue thread pool start failed"); + RTP_LLM_LOG_INFO("PrefillRpcServer enqueue pool started: threads=%d queue=%d", enqueue_threads, enqueue_queue); + + // worker lambda pool: heavy EnqueueGroup coordination (I/O-bound, ~12s per batch) + // Configurable via pd_sep_config.prefill_worker_lambda_pool_size (0 = use formula default) + const int worker_lambda_threads = pd_sep_config.prefill_worker_lambda_pool_size > 0 ? + static_cast(pd_sep_config.prefill_worker_lambda_pool_size) : + std::max(4, dp_size * max_context_batch * 4); + const int worker_lambda_queue = worker_lambda_threads * 4; + + worker_lambda_pool_ = std::make_shared( + worker_lambda_threads, worker_lambda_queue, nullptr, "PrefillWorkerPool"); + RTP_LLM_CHECK_WITH_INFO(worker_lambda_pool_->start(), "PrefillRpcServer worker lambda pool start failed"); + RTP_LLM_LOG_INFO( + "PrefillRpcServer worker lambda pool started: threads=%d queue=%d (dp_size=%d max_context_batch=%d)", + worker_lambda_threads, + worker_lambda_queue, + dp_size, + max_context_batch); + + // slot pool: L2 Prepare + L3 Load + L4 Finish + // Configurable via pd_sep_config.prefill_slot_pool_size (0 = use formula default) + const int slot_threads = pd_sep_config.prefill_slot_pool_size > 0 ? + static_cast(pd_sep_config.prefill_slot_pool_size) : + std::max(16, std::min(max_context_batch * 16, 128)); + const int slot_queue = slot_threads * 8; + + slot_worker_pool_ = + std::make_shared(slot_threads, slot_queue, nullptr, "PrefillSlotPool"); + RTP_LLM_CHECK_WITH_INFO(slot_worker_pool_->start(), "PrefillRpcServer slot thread pool start failed"); + RTP_LLM_LOG_INFO("PrefillRpcServer slot pool started: threads=%d queue=%d (dp_size=%d max_context_batch=%d)", + slot_threads, + slot_queue, + dp_size, + max_context_batch); +} + +void PrefillRpcServer::reportPoolMetrics() { + // Periodically log pool health (called every 10s from GC thread) + RTP_LLM_LOG_INFO("PoolMetrics enqueue: active=%zu queued=%zu completed=%zu rejected=%zu fallback=%zu", + enqueue_pool_metrics_.active.load(), + enqueue_pool_metrics_.queued.load(), + enqueue_pool_metrics_.completed.load(), + enqueue_pool_metrics_.rejected.load(), + enqueue_pool_metrics_.fallback.load()); + RTP_LLM_LOG_INFO("PoolMetrics worker_lambda: active=%zu queued=%zu completed=%zu rejected=%zu fallback=%zu", + worker_lambda_pool_metrics_.active.load(), + worker_lambda_pool_metrics_.queued.load(), + worker_lambda_pool_metrics_.completed.load(), + worker_lambda_pool_metrics_.rejected.load(), + worker_lambda_pool_metrics_.fallback.load()); + RTP_LLM_LOG_INFO( + "PoolMetrics slot: active=%zu queued=%zu completed=%zu rejected=%zu fallback=%zu response_workers=%zu", + slot_pool_metrics_.active.load(), + slot_pool_metrics_.queued.load(), + slot_pool_metrics_.completed.load(), + slot_pool_metrics_.rejected.load(), + slot_pool_metrics_.fallback.load(), + response_worker_count_); +} + ErrorInfo PrefillRpcServer::waitStreamBeforeRun(std::shared_ptr stream) { static int max_wait_timeout_us = maga_init_params_.pd_sep_config.prefill_max_wait_timeout_ms * 1000; auto begin_time_us = currentTimeUs(); @@ -285,6 +994,10 @@ void PrefillRpcServer::remoteAllocateResource(PrefillGenerateContext& prefill_co auto deadline = std::chrono::system_clock::now() + std::chrono::milliseconds(final_timeout_ms); prefill_context.client_context->set_deadline(deadline); } + if (prefill_context.cancel_state) { + refreshAsyncProducerCancelState( + prefill_context.cancel_state, prefill_context.client_context, prefill_context.getStream()); + } // final_timeout_ms <= 0: skip set_deadline; gRPC treats it as no deadline. prefill_context.client_stream = std::move(prefill_context.grpc_connection.stub->RemoteGenerate(prefill_context.client_context.get())); @@ -295,6 +1008,9 @@ void PrefillRpcServer::remoteAllocateResource(PrefillGenerateContext& prefill_co alloc_request.set_request_id(prefill_context.request_id); // TODO(xinfei.sxf) reduce copy GenerateInputPB* new_request = new GenerateInputPB(*prefill_context.rpc_context.request); + new_request->clear_group_size(); + new_request->clear_group_id(); + new_request->mutable_generate_config()->clear_group_timeout(); alloc_request.set_allocated_input(new_request); for (auto& addrs : prefill_context.prefill_worker_cache_store_addrs) { alloc_request.add_peer_addrs(addrs); @@ -329,6 +1045,10 @@ void PrefillRpcServer::enqueueRequest(PrefillGenerateContext& prefill_context) { RTP_LLM_LOG_DEBUG("request [%ld] trans to stream success", prefill_context.request_id); auto stream = engine_->enqueue(prefill_context.generate_input); prefill_context.setStream(stream); + if (prefill_context.cancel_state) { + refreshAsyncProducerCancelState( + prefill_context.cancel_state, prefill_context.client_context, prefill_context.getStream()); + } RTP_LLM_LOG_DEBUG("request [%ld] enqueue success", prefill_context.request_id); } @@ -459,7 +1179,7 @@ void PrefillRpcServer::pollRemoteOutput(PrefillGenerateContext& prefill_context) auto first_token_rt_us = prefill_context.getStream()->getTimeInfo().first_token_rt_us; while (prefill_context.client_stream->Read(&response)) { - if (prefill_context.server_context->IsCancelled()) { + if (prefill_context.server_context && prefill_context.server_context->IsCancelled()) { RTP_LLM_LOG_WARNING("request [%ld] cancel by user", request_id); prefill_context.error_status = grpc::Status(grpc::StatusCode::CANCELLED, "request cancelled"); return; @@ -519,6 +1239,99 @@ grpc::Status PrefillRpcServer::prepareAllocateResource(PrefillGenerateContext& p return grpc::Status::OK; } +void PrefillRpcServer::reportPrefillRecentCacheKeyMetricsOnce(PrefillGenerateContext& prefill_context) { + RTP_LLM_PROFILE_FUNCTION(); + if (prefill_context.recent_cache_key_metric_reported) { + return; + } + if (!PrefillCacheHitMetricsReporter::enabled()) { + return; + } + if (!prefill_recent_cache_key_window_) { + return; + } + if (!prefill_context.generate_input) { + return; + } + prefill_context.recent_cache_key_metric_reported = true; + + const int seq_size_per_block = maga_init_params_.kv_cache_config.seq_size_per_block; + auto cache_keys = buildFullBlockCacheKeys(prefill_context.generate_input->input_ids, seq_size_per_block); + auto snapshot = prefill_recent_cache_key_window_->record(cache_keys); + static TheoryHitStats theory_stats; + auto theory_snapshot = theory_stats.record(snapshot.request_hit_occurrences, snapshot.request_occurrences); + if (theory_snapshot.request_total_count > 0) { + appendPrefillTheoryHitLogLine(formatPrefillTheoryHitLogLine( + prefill_context, prefill_context.generate_input->input_ids.numel(), seq_size_per_block, theory_snapshot)); + } + + if (metrics_reporter_) { + PrefillRecentCacheKeyMetricsCollector collector; + fillPrefillRecentCacheKeyMetricsCollector(collector, snapshot); + fillPrefillTheoryHitMetricsCollector(collector, theory_snapshot); + metrics_reporter_->report(nullptr, + &collector); + } + + if (prefillCacheDebugLogEnabled()) { + auto key_digest = cacheKeyDigest(cache_keys); + auto key_text = cacheKeysToString(cache_keys); + RTP_LLM_LOG_INFO("Prefill cache-key trace: request_id=%ld request_key=%s token_num=%ld seq_size_per_block=%d " + "key_count=%zu hit_count=%ld total_count=%ld hit_ratio=%.6f cache_key_digest=%s " + "retained_occurrences=%ld retained_unique_cache_keys=%zu window_ms=%ld cache_keys=%s", + prefill_context.request_id, + prefill_context.request_key.c_str(), + prefill_context.generate_input->input_ids.numel(), + seq_size_per_block, + cache_keys.size(), + snapshot.request_hit_occurrences, + snapshot.request_occurrences, + snapshot.request_hit_ratio, + key_digest.c_str(), + snapshot.retained_occurrences, + snapshot.retained_unique_cache_keys, + snapshot.time_window_ms, + key_text.c_str()); + RTP_LLM_LOG_INFO("Prefill cache-key preview trace: request_id=%ld cache_key_digest=%s keys_preview=%s", + prefill_context.request_id, + key_digest.c_str(), + cacheKeyPreview(cache_keys).c_str()); + } +} + +grpc::Status PrefillRpcServer::syncPrefix(PrefillGenerateContext& prefill_context) { + auto max_retry_times = maga_init_params_.pd_sep_config.prefill_retry_times; + auto max_retry_timeout_ms = maga_init_params_.pd_sep_config.prefill_retry_timeout_ms; + int retry_interval_ms = 1; + + EXECUTE_WITH_RETRY( + prepareAllocateResource, prefill_context, max_retry_times, max_retry_timeout_ms, retry_interval_ms); + if (prefill_context.hasError()) { + logPrefillFailureTrace("prepare_allocate_failed", prefill_context); + RTP_LLM_LOG_WARNING( + "request [%ld] prepare allocate resource failed after retry [%d] times, cost time ms [%ld], " + "max retry time [%ld], max retry timeout ms [%ld]", + prefill_context.request_id, + prefill_context.retry_times, + prefill_context.retry_cost_time_ms, + max_retry_times + 1, + max_retry_timeout_ms); + return prefill_context.error_status; + } + EXECUTE_STAGE_FUNC(enqueueRequest, prefill_context); + EXECUTE_STAGE_FUNC(remoteLoadCacheStart, prefill_context); + return grpc::Status::OK; +} + +grpc::Status PrefillRpcServer::finishStream(PrefillGenerateContext& prefill_context) { + EXECUTE_STAGE_FUNC(pollLocalOutput, prefill_context); + EXECUTE_STAGE_FUNC(remoteLoadCacheEnd, prefill_context); + EXECUTE_STAGE_FUNC(remoteGenerate, prefill_context); + EXECUTE_STAGE_FUNC(pollRemoteOutput, prefill_context); + prefill_context.stat_info.nextStage(); + return grpc::Status::OK; +} + grpc::Status PrefillRpcServer::GenerateStreamCall(grpc::ServerContext* server_context, const GenerateInputPB* request, grpc::ServerWriter* writer) { @@ -561,32 +1374,15 @@ grpc::Status PrefillRpcServer::GenerateStreamCall(grpc::ServerContext* prefill_context.onflight_requests = onflight_requests_; prefill_context.loading_cache_requests = loading_cache_requests_; - auto max_retry_times = maga_init_params_.pd_sep_config.prefill_retry_times; - auto max_retry_timeout_ms = maga_init_params_.pd_sep_config.prefill_retry_timeout_ms; - int retry_interval_ms = 1; - try { - EXECUTE_WITH_RETRY( - prepareAllocateResource, prefill_context, max_retry_times, max_retry_timeout_ms, retry_interval_ms); - if (prefill_context.hasError()) { - logPrefillFailureTrace("prepare_allocate_failed", prefill_context); - RTP_LLM_LOG_WARNING( - "request [%ld] prepare allocate resource failed after retry [%d] times, cost time ms [%ld], " - "max retry time [%ld], max retry timeout ms [%ld]", - prefill_context.request_id, - prefill_context.retry_times, - prefill_context.retry_cost_time_ms, - max_retry_times + 1, - max_retry_timeout_ms); - return prefill_context.error_status; - } - EXECUTE_STAGE_FUNC(enqueueRequest, prefill_context); - EXECUTE_STAGE_FUNC(remoteLoadCacheStart, prefill_context); - EXECUTE_STAGE_FUNC(pollLocalOutput, prefill_context); - EXECUTE_STAGE_FUNC(remoteLoadCacheEnd, prefill_context); - EXECUTE_STAGE_FUNC(remoteGenerate, prefill_context); - EXECUTE_STAGE_FUNC(pollRemoteOutput, prefill_context); - prefill_context.stat_info.nextStage(); + auto status = syncPrefix(prefill_context); + if (!status.ok()) { + return status; + } + status = finishStream(prefill_context); + if (!status.ok()) { + return status; + } } catch (const std::exception& e) { auto error_msg = "request [" + prefill_context.request_key + "] catch exception [" + e.what() + "]"; prefill_context.error_status = grpc::Status(grpc::StatusCode::INTERNAL, error_msg); @@ -604,6 +1400,796 @@ grpc::Status PrefillRpcServer::GenerateStreamCall(grpc::ServerContext* return grpc::Status::OK; } +grpc::Status PrefillRpcServer::EnqueueBatch(grpc::ServerContext* context, + const EnqueueBatchRequestPB* request, + EnqueueBatchResponsePB* response) { + RTP_LLM_PROFILE_FUNCTION(); + response->set_batch_id(request->batch_id()); + + struct TargetBatch { + int dp_rank = 0; + std::vector inputs; + }; + + std::map targets; + std::vector all_inputs; + std::unordered_set seen_request_ids; + bool duplicate_request_id = false; + + for (const auto& slot : request->dp_slots()) { + auto& target = targets[slot.dp_rank()]; + target.dp_rank = slot.dp_rank(); + for (const auto& external_input : slot.requests()) { + if (!external_input.has_input()) { + addBatchError(response, + /*request_id=*/0, + grpc::StatusCode::INVALID_ARGUMENT, + "EnqueueBatch external request missing input"); + continue; + } + const auto& input = external_input.input(); + all_inputs.push_back(&input); + target.inputs.push_back(&input); + if (!seen_request_ids.insert(input.request_id()).second) { + duplicate_request_id = true; + } + } + } + + response->mutable_successes()->Reserve(static_cast(all_inputs.size())); + response->mutable_errors()->Reserve(static_cast(all_inputs.size())); + + auto add_error_for_inputs = [](EnqueueBatchResponsePB* response, + const std::vector& inputs, + int64_t code, + const std::string& message) { + for (const auto* input : inputs) { + if (input) { + addBatchError(response, input->request_id(), code, message); + } + } + }; + + if (duplicate_request_id) { + response->clear_errors(); + add_error_for_inputs( + response, all_inputs, grpc::StatusCode::ALREADY_EXISTS, "duplicate request_id in EnqueueBatch"); + return grpc::Status::OK; + } + + if (context && context->IsCancelled()) { + add_error_for_inputs(response, all_inputs, grpc::StatusCode::CANCELLED, "EnqueueBatch cancelled by caller"); + return grpc::Status(grpc::StatusCode::CANCELLED, "EnqueueBatch cancelled by caller"); + } + + struct DispatchTarget { + int dp_rank = 0; + std::string addr; + EnqueueGroupRequestPB request; + }; + + const int local_dp_rank = static_cast(maga_init_params_.parallelism_config.dp_rank); + std::vector dispatch_targets; + dispatch_targets.reserve(targets.size()); + for (const auto& pair : targets) { + const auto& target = pair.second; + if (target.inputs.empty()) { + continue; + } + DispatchTarget dispatch_target; + dispatch_target.dp_rank = target.dp_rank; + dispatch_target.request.set_batch_id(request->batch_id()); + dispatch_target.request.set_dp_rank(target.dp_rank); + for (const auto* input : target.inputs) { + auto* dp_input = dispatch_target.request.add_requests(); + dp_input->mutable_input()->CopyFrom(*input); + } + if (target.dp_rank != local_dp_rank) { + dispatch_target.addr = batchTargetAddrForDpRank(target.dp_rank); + if (dispatch_target.addr.empty()) { + add_error_for_inputs(response, + target.inputs, + grpc::StatusCode::INVALID_ARGUMENT, + "invalid EnqueueBatch dp_rank " + std::to_string(target.dp_rank)); + continue; + } + } + dispatch_targets.push_back(std::move(dispatch_target)); + } + + struct DispatchResult { + grpc::Status status; + EnqueueBatchResponsePB dp_response; + }; + + const auto dispatch_timeout_ms = maga_init_params_.pd_sep_config.batch_dispatch_timeout_ms; + const auto dispatch_deadline = std::chrono::steady_clock::now() + std::chrono::milliseconds(dispatch_timeout_ms); + std::vector> dispatch_futures; + dispatch_futures.reserve(dispatch_targets.size()); + for (auto& target : dispatch_targets) { + dispatch_futures.push_back(enqueue_worker_pool_->async([this, target = std::move(target)]() -> DispatchResult { + DispatchResult result; + if (target.dp_rank == static_cast(maga_init_params_.parallelism_config.dp_rank)) { + result.status = EnqueueGroup(/*context=*/nullptr, &target.request, &result.dp_response); + return result; + } + + try { + auto connect_status = resource_.rpc_pool.getConnection(target.addr); + if (!connect_status.ok()) { + result.status = grpc::Status(grpc::StatusCode::UNAVAILABLE, + "get EnqueueGroup connection failed: " + + std::string(connect_status.status().message())); + } else { + grpc::ClientContext client_context; + auto timeout_ms = maga_init_params_.pd_sep_config.max_rpc_timeout_ms; + if (timeout_ms > 0) { + client_context.set_deadline(std::chrono::system_clock::now() + + std::chrono::milliseconds(timeout_ms)); + } + result.status = + connect_status.value().stub->EnqueueGroup(&client_context, target.request, &result.dp_response); + } + } catch (const std::exception& e) { + result.status = grpc::Status(grpc::StatusCode::INTERNAL, + "EnqueueGroup forward exception: " + std::string(e.what())); + } catch (...) { + result.status = grpc::Status(grpc::StatusCode::INTERNAL, "EnqueueGroup forward unknown exception"); + } + return result; + })); + } + + auto merge_response = [&](const EnqueueGroupRequestPB& dp_request, + const grpc::Status& status, + const EnqueueBatchResponsePB& dp_response) { + if (!status.ok()) { + for (const auto& dp_input : dp_request.requests()) { + if (dp_input.has_input()) { + addBatchError(response, dp_input.input().request_id(), status.error_code(), status.error_message()); + } + } + return; + } + + std::unordered_set returned_request_ids; + std::unordered_set error_request_ids; + for (const auto& error : dp_response.errors()) { + addBatchError( + response, error.request_id(), error.error_info().error_code(), error.error_info().error_message()); + returned_request_ids.insert(error.request_id()); + error_request_ids.insert(error.request_id()); + } + for (const auto& success : dp_response.successes()) { + if (error_request_ids.find(success.request_id()) != error_request_ids.end()) { + continue; + } + addBatchSuccess(response, success.request_id()); + returned_request_ids.insert(success.request_id()); + } + for (const auto& dp_input : dp_request.requests()) { + if (!dp_input.has_input()) { + continue; + } + const auto request_id = dp_input.input().request_id(); + if (returned_request_ids.find(request_id) == returned_request_ids.end()) { + addBatchError( + response, request_id, grpc::StatusCode::INTERNAL, "EnqueueGroup missing result for request"); + } + } + }; + + collectFutures( + dispatch_futures, + dispatch_deadline, + [&](size_t i) { + auto result = dispatch_futures[i].get(); + merge_response(dispatch_targets[i].request, result.status, result.dp_response); + }, + [&](size_t i) { + merge_response(dispatch_targets[i].request, + grpc::Status(grpc::StatusCode::DEADLINE_EXCEEDED, + "EnqueueBatch dispatch timeout for dp_rank " + + std::to_string(dispatch_targets[i].dp_rank)), + EnqueueBatchResponsePB()); + for (const auto& dp_input : dispatch_targets[i].request.requests()) { + if (dp_input.has_input()) { + auto entry = response_registry_.get(dp_input.input().request_id()); + cancelResponseEntry(entry); + } + } + }); + detachLeftoverFutures(dispatch_futures); + + response_registry_.gc(std::chrono::minutes(10)); + return grpc::Status::OK; +} + +grpc::Status PrefillRpcServer::EnqueueGroup(grpc::ServerContext* context, + const EnqueueGroupRequestPB* request, + EnqueueBatchResponsePB* response) { + RTP_LLM_PROFILE_FUNCTION(); + response->set_batch_id(request->batch_id()); + + struct LocalSlot { + std::shared_ptr input; + std::shared_ptr entry; + std::shared_ptr rpc_context; + std::shared_ptr prefill_context; + std::shared_ptr cancel_state; + AtomicGuardPtr request_guard; + int64_t request_id = 0; + bool prepared = false; + grpc::Status stage_status = grpc::Status::OK; + }; + + std::vector all_inputs; + all_inputs.reserve(request->requests_size()); + std::unordered_set seen_request_ids; + bool duplicate_request_id = false; + for (const auto& dp_input : request->requests()) { + if (!dp_input.has_input()) { + addBatchError(response, + /*request_id=*/0, + grpc::StatusCode::INVALID_ARGUMENT, + "EnqueueGroup request missing input"); + continue; + } + all_inputs.push_back(&dp_input.input()); + if (!seen_request_ids.insert(dp_input.input().request_id()).second) { + duplicate_request_id = true; + } + } + + response->mutable_successes()->Reserve(static_cast(all_inputs.size())); + response->mutable_errors()->Reserve(static_cast(all_inputs.size())); + + auto add_error_for_all = [&](int64_t code, const std::string& message) { + for (const auto* input : all_inputs) { + addBatchError(response, input->request_id(), code, message); + } + }; + + const int local_dp_rank = static_cast(maga_init_params_.parallelism_config.dp_rank); + if (request->dp_rank() != local_dp_rank) { + add_error_for_all(grpc::StatusCode::INVALID_ARGUMENT, + "EnqueueGroup dp_rank mismatch, request dp_rank " + std::to_string(request->dp_rank()) + + ", local dp_rank " + std::to_string(local_dp_rank)); + return grpc::Status::OK; + } + if (duplicate_request_id) { + response->clear_errors(); + add_error_for_all(grpc::StatusCode::ALREADY_EXISTS, "duplicate request_id in EnqueueGroup"); + return grpc::Status::OK; + } + if (context && context->IsCancelled()) { + add_error_for_all(grpc::StatusCode::CANCELLED, "EnqueueGroup cancelled by caller"); + return grpc::Status(grpc::StatusCode::CANCELLED, "EnqueueGroup cancelled by caller"); + } + + std::vector slots; + slots.reserve(all_inputs.size()); + const int group_size = static_cast(all_inputs.size()); + for (const auto* input : all_inputs) { + auto input_copy = std::make_shared(*input); + input_copy->set_group_size(group_size); + input_copy->mutable_group_id()->set_value(request->batch_id()); + + auto entry = response_registry_.reserve(input_copy->request_id()); + if (!entry) { + addBatchError( + response, input_copy->request_id(), grpc::StatusCode::ALREADY_EXISTS, "request already enqueued"); + continue; + } + slots.push_back({input_copy, entry, nullptr, nullptr, nullptr, nullptr, input_copy->request_id()}); + } + + if (slots.empty()) { + return grpc::Status::OK; + } + + for (const auto& slot : slots) { + int64_t batch_id = (slot.input && slot.input->has_group_id()) ? slot.input->group_id().value() : -1; + RTP_LLM_LOG_DEBUG("request [%ld] EnqueueGroup: has_group_id=%d, batch_id=%ld, request_batch_id=%ld", + slot.request_id, + slot.input ? slot.input->has_group_id() : 0, + batch_id, + request->batch_id()); + meta_->enqueuePending(slot.request_id, slot.input ? slot.input->token_ids_size() : 0, batch_id); + } + + auto erase_reserved_slots = [this](const std::vector& slots) { + for (const auto& slot : slots) { + cancelResponseEntry(slot.entry); + response_registry_.erase(slot.request_id); + } + }; + + auto finish_pending_before_ack = [this](const LocalSlot& slot, const grpc::Status& status) { + meta_->finishTask(slot.request_id, + slot.input ? slot.input->token_ids_size() : 0, + /*prefix_length=*/0, + status.ok() ? 0 : static_cast(status.error_code()), + status.ok() ? "" : std::string(status.error_message())); + }; + + if (!tryStartAsyncResponseWorker()) { + auto status = grpc::Status(grpc::StatusCode::UNAVAILABLE, "EnqueueGroup server is stopping"); + for (const auto& slot : slots) { + finish_pending_before_ack(slot, status); + addBatchError(response, slot.request_id, status.error_code(), status.error_message()); + } + erase_reserved_slots(slots); + return grpc::Status::OK; + } + + std::vector accepted_request_ids; + accepted_request_ids.reserve(slots.size()); + for (const auto& slot : slots) { + accepted_request_ids.push_back(slot.request_id); + } + + auto slots_ptr = std::make_shared>(std::move(slots)); + try { + auto worker_error = worker_lambda_pool_->pushTask( + [this, + slots_ptr, + max_retry_times = maga_init_params_.pd_sep_config.prefill_retry_times, + max_retry_timeout_ms = maga_init_params_.pd_sep_config.prefill_retry_timeout_ms, + batch_prepare_timeout_ms = maga_init_params_.pd_sep_config.batch_prepare_timeout_ms, + batch_load_timeout_ms = maga_init_params_.pd_sep_config.batch_load_timeout_ms]() mutable { + ScopeExit controller_finish_guard([this] { finishAsyncResponseWorker(); }); + auto& slots = *slots_ptr; + + auto entry_cancelled = [](const LocalSlot& slot) { + return !slot.entry || slot.entry->cancelled.load(); + }; + auto grpc_status_to_stream_error = [](const grpc::Status& status) { + return status.error_code() == grpc::StatusCode::CANCELLED ? ErrorCode::CANCELLED : + ErrorCode::UNKNOWN_ERROR; + }; + auto fail_slot = [&](LocalSlot& slot, const grpc::Status& status) { + int64_t input_length = slot.input ? slot.input->token_ids_size() : 0; + int64_t prefix_length = 0; + if (slot.prefill_context && slot.prefill_context->getStream()) { + auto stream = slot.prefill_context->getStream(); + input_length = stream->inputLength(); + prefix_length = stream->prefixLength(); + if (!stream->hasError()) { + stream->reportError(grpc_status_to_stream_error(status), + std::string(status.error_message())); + } + } + meta_->finishTask(slot.request_id, + input_length, + prefix_length, + status.ok() ? 0 : static_cast(status.error_code()), + status.ok() ? "" : std::string(status.error_message())); + markResponseEntryDone(slot.entry, status); + slot.prefill_context.reset(); + slot.rpc_context.reset(); + slot.request_guard.reset(); + slot.cancel_state.reset(); + slot.input.reset(); + slot.entry.reset(); + }; + + auto start_finish_worker = [&](LocalSlot& slot) { + auto entry = slot.entry; + auto writer = std::make_shared(entry); + slot.prefill_context->rpc_context.writer = writer.get(); + + if (!tryStartAsyncResponseWorker()) { + fail_slot(slot, grpc::Status(grpc::StatusCode::UNAVAILABLE, "EnqueueGroup server is stopping")); + return; + } + + try { + auto finish_lambda = [this, + pfx_ctx = slot.prefill_context, + rpc_ctx = slot.rpc_context, + input = slot.input, + writer, + entry, + guard = slot.request_guard, + cancel_state = slot.cancel_state, + request_id = slot.request_id]() mutable { + (void)rpc_ctx; + (void)input; + (void)writer; + (void)guard; + (void)cancel_state; + ScopeExit worker_finish_guard([this] { finishAsyncResponseWorker(); }); + ScopeExit release_captures_guard([&] { + pfx_ctx.reset(); + rpc_ctx.reset(); + input.reset(); + writer.reset(); + entry.reset(); + guard.reset(); + cancel_state.reset(); + }); + grpc::Status finish_status; + try { + finish_status = finishStream(*pfx_ctx); + RTP_LLM_LOG_DEBUG("request [%ld] finishStream returned, ok=%d, has_stream=%d", + request_id, + finish_status.ok(), + pfx_ctx->getStream() ? 1 : 0); + // Record finished task for FlexLB calibration + if (finish_status.ok() && pfx_ctx->getStream()) { + RTP_LLM_LOG_DEBUG("request [%ld] calling dequeue for FlexLB calibration", + request_id); + meta_->dequeue(request_id, pfx_ctx->getStream()); + } else if (!finish_status.ok()) { + RTP_LLM_LOG_DEBUG("request [%ld] calling finishTask due to error, code=%d, msg=%s", + request_id, + static_cast(finish_status.error_code()), + finish_status.error_message().c_str()); + meta_->finishTask(request_id, + input ? input->token_ids_size() : 0, + /*prefix_length=*/0, + static_cast(finish_status.error_code()), + finish_status.error_message()); + } + } catch (const std::exception& e) { + auto error_msg = + "request [" + pfx_ctx->request_key + "] finishStream exception [" + e.what() + "]"; + finish_status = grpc::Status(grpc::StatusCode::INTERNAL, error_msg); + meta_->finishTask(request_id, + input ? input->token_ids_size() : 0, + /*prefix_length=*/0, + static_cast(finish_status.error_code()), + error_msg); + } catch (...) { + finish_status = + grpc::Status(grpc::StatusCode::INTERNAL, "finishStream unknown exception"); + meta_->finishTask(request_id, + input ? input->token_ids_size() : 0, + /*prefix_length=*/0, + static_cast(finish_status.error_code()), + "finishStream unknown exception"); + } + markResponseEntryDone(entry, finish_status); + RTP_LLM_LOG_DEBUG( + "EnqueueGroup request [%ld] finishStream done, ok=%d", request_id, finish_status.ok()); + }; + + // Non-blocking submit: if pool is full, fall back to detached thread + // to avoid deadlock (L3 → L4 on same pool, see plan §3.2). + auto error = slot_worker_pool_->pushTask(std::move(finish_lambda)); + if (error != autil::ThreadPoolBase::ERROR_NONE) { + slot_pool_metrics_.fallback++; + std::thread fallback_thread(std::move(finish_lambda)); + fallback_thread.detach(); + } + } catch (const std::exception& e) { + finishAsyncResponseWorker(); + fail_slot(slot, + grpc::Status(grpc::StatusCode::INTERNAL, + "start async response worker exception: " + std::string(e.what()))); + } catch (...) { + finishAsyncResponseWorker(); + fail_slot( + slot, + grpc::Status(grpc::StatusCode::INTERNAL, "start async response worker unknown exception")); + } + }; + + for (auto& slot : slots) { + auto rpc_ctx = std::make_shared(RPCContext{slot.input.get(), nullptr}); + auto pfx_ctx = std::make_shared(&this->resource(), + *rpc_ctx, + slot.input->generate_config().timeout_ms(), + /*server_context=*/nullptr, + metrics_reporter_, + meta_); + pfx_ctx->onflight_requests = onflight_requests_; + pfx_ctx->loading_cache_requests = loading_cache_requests_; + auto guard = std::make_shared(onflight_requests_); + auto cancel_state = std::make_shared(); + { + std::lock_guard lock(slot.entry->mu); + cancel_state->cancelled.store(slot.entry->cancelled.load()); + slot.entry->cancel_producer = makeAsyncProducerCancelCallback(cancel_state); + } + slot.rpc_context = rpc_ctx; + slot.prefill_context = pfx_ctx; + slot.cancel_state = cancel_state; + slot.request_guard = guard; + pfx_ctx->cancel_state = cancel_state; + } + + const auto prepare_deadline = + std::chrono::steady_clock::now() + std::chrono::milliseconds(batch_prepare_timeout_ms); + std::vector> prepare_futures; + prepare_futures.reserve(slots.size()); + for (auto& slot : slots) { + auto* slot_ptr = &slot; + prepare_futures.push_back(slot_worker_pool_->async( + [this, slot_ptr, slots_ptr, entry_cancelled, max_retry_times, max_retry_timeout_ms] { + auto& slot = *slot_ptr; + if (entry_cancelled(slot)) { + slot.stage_status = + grpc::Status(grpc::StatusCode::CANCELLED, "EnqueueGroup request cancelled"); + return; + } + try { + int64_t begin_time_us = currentTimeUs(); + auto stage = slot.prefill_context->stat_info.saveStage(); + for (int attempt = 0; attempt <= max_retry_times; ++attempt) { + if (entry_cancelled(slot)) { + slot.stage_status = + grpc::Status(grpc::StatusCode::CANCELLED, "EnqueueGroup request cancelled"); + return; + } + slot.prefill_context->reset(); + slot.prefill_context->stat_info.restoreStage(stage); + slot.prefill_context->retry_times++; + prepareAllocateResource(*slot.prefill_context); + if (slot.prefill_context->ok()) { + slot.prepared = true; + return; + } + auto cost_time_us = currentTimeUs() - begin_time_us; + slot.prefill_context->retry_cost_time_ms = cost_time_us / 1000; + if (max_retry_timeout_ms > 0 && cost_time_us >= max_retry_timeout_ms * 1000) { + break; + } + usleep(1000); + } + slot.stage_status = slot.prefill_context->error_status.ok() ? + statusFromErrorInfo(slot.prefill_context->error_info) : + slot.prefill_context->error_status; + if (slot.stage_status.ok()) { + slot.stage_status = + grpc::Status(grpc::StatusCode::INTERNAL, "prepareAllocateResource failed"); + } + } catch (const std::exception& e) { + slot.stage_status = + grpc::Status(grpc::StatusCode::INTERNAL, + "prepareAllocateResource exception: " + std::string(e.what())); + } catch (...) { + slot.stage_status = grpc::Status(grpc::StatusCode::INTERNAL, + "prepareAllocateResource unknown exception"); + } + })); + } + collectFutures( + prepare_futures, + prepare_deadline, + [&](size_t i) { prepare_futures[i].get(); }, + [&](size_t i) { + cancelResponseEntry(slots[i].entry); + slots[i].stage_status = + grpc::Status(grpc::StatusCode::DEADLINE_EXCEEDED, "EnqueueGroup prepare timeout"); + }); + drainReadyFutures(prepare_futures, std::chrono::milliseconds(2000)); + detachLeftoverFutures(prepare_futures); + + std::vector ready_slots; + ready_slots.reserve(slots.size()); + for (auto& slot : slots) { + if (entry_cancelled(slot)) { + fail_slot(slot, grpc::Status(grpc::StatusCode::CANCELLED, "EnqueueGroup request cancelled")); + } else if (!slot.prepared) { + fail_slot(slot, slot.stage_status); + } else { + ready_slots.push_back(&slot); + } + } + if (ready_slots.empty()) { + return; + } + + const int local_group_size = static_cast(ready_slots.size()); + std::vector> generate_inputs; + generate_inputs.reserve(ready_slots.size()); + for (auto* slot : ready_slots) { + slot->input->set_group_size(local_group_size); + slot->prefill_context->generate_input->group_size = local_group_size; + slot->prefill_context->stat_info.nextStage(); + generate_inputs.push_back(slot->prefill_context->generate_input); + } + + std::vector streams; + try { + streams = engine_->enqueueMultiple(generate_inputs); + } catch (const std::exception& e) { + for (auto* slot : ready_slots) { + fail_slot(*slot, + grpc::Status(grpc::StatusCode::INTERNAL, + "enqueueMultiple exception: " + std::string(e.what()))); + } + return; + } catch (...) { + for (auto* slot : ready_slots) { + fail_slot(*slot, grpc::Status(grpc::StatusCode::INTERNAL, "enqueueMultiple unknown exception")); + } + return; + } + + std::unordered_map stream_by_id; + for (auto& stream : streams) { + if (stream) { + stream_by_id[stream->streamId()] = stream; + } + } + std::vector stream_ready_slots; + stream_ready_slots.reserve(ready_slots.size()); + for (auto* slot : ready_slots) { + auto it = stream_by_id.find(slot->request_id); + if (it == stream_by_id.end()) { + fail_slot(*slot, grpc::Status(grpc::StatusCode::INTERNAL, "EnqueueGroup stream not enqueued")); + continue; + } + slot->prefill_context->setStream(it->second); + refreshAsyncProducerCancelState( + slot->cancel_state, slot->prefill_context->client_context, slot->prefill_context->getStream()); + stream_ready_slots.push_back(slot); + } + + const auto load_deadline = + std::chrono::steady_clock::now() + std::chrono::milliseconds(batch_load_timeout_ms); + std::vector> load_futures; + load_futures.reserve(stream_ready_slots.size()); + for (auto* slot : stream_ready_slots) { + load_futures.push_back(slot_worker_pool_->async( + [this, slot, slots_ptr, entry_cancelled, fail_slot, start_finish_worker] { + if (entry_cancelled(*slot)) { + fail_slot(*slot, + grpc::Status(grpc::StatusCode::CANCELLED, "EnqueueGroup request cancelled")); + return; + } + try { + slot->prefill_context->stat_info.nextStage(); + remoteLoadCacheStart(*slot->prefill_context); + refreshAsyncProducerCancelState(slot->cancel_state, + slot->prefill_context->client_context, + slot->prefill_context->getStream()); + if (entry_cancelled(*slot)) { + fail_slot( + *slot, + grpc::Status(grpc::StatusCode::CANCELLED, "EnqueueGroup request cancelled")); + return; + } + if (slot->prefill_context->hasError()) { + auto status = slot->prefill_context->error_status.ok() ? + statusFromErrorInfo(slot->prefill_context->error_info) : + slot->prefill_context->error_status; + fail_slot(*slot, status); + return; + } + start_finish_worker(*slot); + } catch (const std::exception& e) { + fail_slot(*slot, + grpc::Status(grpc::StatusCode::INTERNAL, + "remoteLoadCacheStart exception: " + std::string(e.what()))); + } catch (...) { + fail_slot( + *slot, + grpc::Status(grpc::StatusCode::INTERNAL, "remoteLoadCacheStart unknown exception")); + } + })); + } + collectFutures( + load_futures, + load_deadline, + [&](size_t i) { load_futures[i].get(); }, + [&](size_t i) { cancelResponseEntry(stream_ready_slots[i]->entry); }); + drainReadyFutures(load_futures, std::chrono::milliseconds(2000)); + detachLeftoverFutures(load_futures); + }); + + if (worker_error != autil::ThreadPoolBase::ERROR_NONE) { + worker_lambda_pool_metrics_.rejected++; + // Pool saturated: the lambda was NOT enqueued, so ScopeExit guards + // inside the lambda did not run. We must manually finish the worker. + finishAsyncResponseWorker(); + + auto status = grpc::Status(grpc::StatusCode::UNAVAILABLE, "EnqueueGroup enqueue pool saturated"); + for (auto& slot : *slots_ptr) { + finish_pending_before_ack(slot, status); + addBatchError(response, slot.request_id, status.error_code(), status.error_message()); + } + erase_reserved_slots(*slots_ptr); + return grpc::Status::OK; + } + } catch (const std::exception& e) { + finishAsyncResponseWorker(); + auto status = grpc::Status(grpc::StatusCode::INTERNAL, + "start EnqueueGroup accept worker exception: " + std::string(e.what())); + for (const auto& slot : *slots_ptr) { + finish_pending_before_ack(slot, status); + addBatchError(response, slot.request_id, status.error_code(), status.error_message()); + } + erase_reserved_slots(*slots_ptr); + return grpc::Status::OK; + } catch (...) { + finishAsyncResponseWorker(); + auto status = grpc::Status(grpc::StatusCode::INTERNAL, "start EnqueueGroup accept worker unknown exception"); + for (const auto& slot : *slots_ptr) { + finish_pending_before_ack(slot, status); + addBatchError(response, slot.request_id, status.error_code(), status.error_message()); + } + erase_reserved_slots(*slots_ptr); + return grpc::Status::OK; + } + + for (const auto request_id : accepted_request_ids) { + addBatchSuccess(response, request_id); + } + return grpc::Status::OK; +} + +grpc::Status PrefillRpcServer::FetchResponse(grpc::ServerContext* context, + const FetchRequestPB* request, + grpc::ServerWriter* writer) { + RTP_LLM_PROFILE_FUNCTION(); + const auto request_id = request->request_id(); + auto entry = response_registry_.get(request_id); + if (!entry) { + return grpc::Status(grpc::StatusCode::NOT_FOUND, + "request [" + std::to_string(request_id) + "] not found in response registry"); + } + + while (true) { + if (context && context->IsCancelled()) { + cancelResponseEntry(entry); + response_registry_.erase(request_id); + return grpc::Status(grpc::StatusCode::CANCELLED, "fetch response cancelled by client"); + } + + std::deque drained; + grpc::Status terminal_status = grpc::Status::OK; + bool terminal = false; + { + std::unique_lock lock(entry->mu); + entry->cv.wait_for(lock, std::chrono::milliseconds(100), [&] { + return !entry->queue.empty() || entry->done.load() || entry->cancelled.load() + || entry->error_status.has_value(); + }); + drained.swap(entry->queue); + if (entry->cancelled.load()) { + terminal = true; + terminal_status = grpc::Status(grpc::StatusCode::CANCELLED, "request cancelled"); + } else if (entry->error_status.has_value()) { + terminal = true; + terminal_status = *entry->error_status; + } else if (entry->done.load()) { + terminal = true; + } + } + + for (auto& output : drained) { + if (!writer->Write(output)) { + cancelResponseEntry(entry); + response_registry_.erase(request_id); + return grpc::Status(grpc::StatusCode::CANCELLED, "client writer closed"); + } + } + + if (terminal) { + response_registry_.erase(request_id); + return terminal_status; + } + } +} + +grpc::Status PrefillRpcServer::Cancel(grpc::ServerContext* context, const CancelRequestPB* request, EmptyPB* response) { + (void)context; + (void)response; + RTP_LLM_PROFILE_FUNCTION(); + const auto request_id = request->request_id(); + auto entry = response_registry_.get(request_id); + if (!entry) { + return grpc::Status::OK; + } + cancelResponseEntry(entry); + response_registry_.erase(request_id); + return grpc::Status::OK; +} + grpc::Status PrefillRpcServer::RemoteFinish(grpc::ServerContext* context, const RemoteFinishRequestPB* request, EmptyPB* response) { RTP_LLM_PROFILE_FUNCTION(); diff --git a/rtp_llm/cpp/model_rpc/PrefillRpcServer.h b/rtp_llm/cpp/model_rpc/PrefillRpcServer.h index 55a618eaef..2780a85e7a 100644 --- a/rtp_llm/cpp/model_rpc/PrefillRpcServer.h +++ b/rtp_llm/cpp/model_rpc/PrefillRpcServer.h @@ -1,16 +1,32 @@ #pragma once +#include +#include #include "grpc++/grpc++.h" +#include +#include +#include "autil/LockFreeThreadPool.h" #include "rtp_llm/cpp/model_rpc/RpcServerRuntimeMeta.h" #include "rtp_llm/cpp/model_rpc/RemoteRpcServer.h" #include "rtp_llm/cpp/model_rpc/PrefillGenerateContext.h" +#include "rtp_llm/cpp/cache/RecentCacheKeyWindow.h" +#include "rtp_llm/cpp/model_rpc/ResponseBuffer.h" namespace rtp_llm { +// Pool-level health metrics, reported periodically +struct PoolMetrics { + std::atomic active = 0; // currently executing tasks + std::atomic queued = 0; // tasks waiting in queue + std::atomic completed = 0; // total finished since creation + std::atomic rejected = 0; // pushTask refused (pool full) + std::atomic fallback = 0; // fallback to detached thread +}; + class PrefillRpcServer: public RemoteRpcServer { public: PrefillRpcServer() {} - ~PrefillRpcServer() {} + ~PrefillRpcServer() override; grpc::Status init(const EngineInitParams& maga_init_params, py::object mm_process_engine, std::unique_ptr propose_params) override; @@ -21,7 +37,21 @@ class PrefillRpcServer: public RemoteRpcServer { grpc::Status RemoteFinish(grpc::ServerContext* context, const RemoteFinishRequestPB* request, EmptyPB* response); + grpc::Status + EnqueueBatch(grpc::ServerContext* context, const EnqueueBatchRequestPB* request, EnqueueBatchResponsePB* response); + + grpc::Status + EnqueueGroup(grpc::ServerContext* context, const EnqueueGroupRequestPB* request, EnqueueBatchResponsePB* response); + + grpc::Status FetchResponse(grpc::ServerContext* context, + const FetchRequestPB* request, + grpc::ServerWriter* writer); + + grpc::Status Cancel(grpc::ServerContext* context, const CancelRequestPB* request, EmptyPB* response); + private: + grpc::Status syncPrefix(PrefillGenerateContext& prefill_context); + grpc::Status finishStream(PrefillGenerateContext& prefill_context); ErrorInfo waitStreamBeforeRun(std::shared_ptr stream); grpc::Status prepareAllocateResource(PrefillGenerateContext& prefill_context); void getRpcConnection(PrefillGenerateContext& prefill_context); @@ -33,9 +63,36 @@ class PrefillRpcServer: public RemoteRpcServer { void remoteLoadCacheEnd(PrefillGenerateContext& prefill_context); void remoteGenerate(PrefillGenerateContext& prefill_context); void pollRemoteOutput(PrefillGenerateContext& prefill_context); + void reportPrefillRecentCacheKeyMetricsOnce(PrefillGenerateContext& prefill_context); + void startResponseRegistryGc(); + void stopResponseRegistryGc(); + bool tryStartAsyncResponseWorker(); + void finishAsyncResponseWorker(); + void stopAsyncResponseWorkers(); + void initThreadPools(); + void reportPoolMetrics(); + std::string batchTargetAddrForDpRank(int dp_rank) const; private: - std::string decode_cluster_name_; + std::string decode_cluster_name_; + std::unique_ptr prefill_recent_cache_key_window_; + ResponseBufferRegistry response_registry_; + std::atomic response_gc_stop_{false}; + std::mutex response_gc_mu_; + std::condition_variable response_gc_cv_; + std::thread response_gc_thread_; + std::atomic response_worker_stop_{false}; + std::mutex response_worker_mu_; + std::condition_variable response_worker_cv_; + size_t response_worker_count_{0}; + + // Thread pools replacing std::async / std::thread::detach + autil::ThreadPoolBasePtr enqueue_worker_pool_; // Dispatch only (L1 DP dispatch, fast ms-level) + autil::ThreadPoolBasePtr worker_lambda_pool_; // Heavy worker lambdas (L2/L3 coordination, I/O-bound ~12s) + autil::ThreadPoolBasePtr slot_worker_pool_; // L2 Prep + L3 Load + L4 Finish + PoolMetrics enqueue_pool_metrics_; + PoolMetrics worker_lambda_pool_metrics_; + PoolMetrics slot_pool_metrics_; }; } // namespace rtp_llm diff --git a/rtp_llm/cpp/model_rpc/QueryConverter.cc b/rtp_llm/cpp/model_rpc/QueryConverter.cc index e31cdddcb8..cca7ca9aa4 100644 --- a/rtp_llm/cpp/model_rpc/QueryConverter.cc +++ b/rtp_llm/cpp/model_rpc/QueryConverter.cc @@ -96,8 +96,7 @@ std::shared_ptr QueryConverter::transGenerateConfig(const Genera generate_config->enable_memory_cache = config_proto->enable_memory_cache(); generate_config->enable_remote_cache = config_proto->enable_remote_cache(); TRANS_OPTIONAL(trace_id); - TRANS_OPTIONAL(batch_group_timeout); - TRANS_OPTIONAL(force_batch); + TRANS_OPTIONAL(group_timeout); return generate_config; } @@ -150,9 +149,9 @@ std::shared_ptr QueryConverter::transQuery(const GenerateInputPB* } generate_input->multimodal_inputs = std::move(mm_inputs); } - generate_input->batch_group_size = input->batch_group_size() > 0 ? input->batch_group_size() : 1; - if (input->has_batch_group_id()) { - generate_input->batch_group_id = input->batch_group_id().value(); + generate_input->group_size = input->group_size() > 0 ? input->group_size() : 1; + if (input->has_group_id()) { + generate_input->group_id = input->group_id().value(); } return generate_input; diff --git a/rtp_llm/cpp/model_rpc/RemoteRpcServiceImpl.h b/rtp_llm/cpp/model_rpc/RemoteRpcServiceImpl.h index 1ad3c9cd94..cb4e5e0767 100644 --- a/rtp_llm/cpp/model_rpc/RemoteRpcServiceImpl.h +++ b/rtp_llm/cpp/model_rpc/RemoteRpcServiceImpl.h @@ -36,6 +36,48 @@ class RemoteRpcServiceImpl: public LocalRpcServiceImpl { return prefill_server_->RemoteFinish(context, request, response); } + grpc::Status EnqueueBatch(grpc::ServerContext* context, + const EnqueueBatchRequestPB* request, + EnqueueBatchResponsePB* response) override { + if (!prefill_server_) { + auto error_msg = "server not implement EnqueueBatch"; + RTP_LLM_LOG_ERROR(error_msg); + return grpc::Status(grpc::StatusCode::UNIMPLEMENTED, error_msg); + } + return prefill_server_->EnqueueBatch(context, request, response); + } + + grpc::Status EnqueueGroup(grpc::ServerContext* context, + const EnqueueGroupRequestPB* request, + EnqueueBatchResponsePB* response) override { + if (!prefill_server_) { + auto error_msg = "server not implement EnqueueGroup"; + RTP_LLM_LOG_ERROR(error_msg); + return grpc::Status(grpc::StatusCode::UNIMPLEMENTED, error_msg); + } + return prefill_server_->EnqueueGroup(context, request, response); + } + + grpc::Status FetchResponse(grpc::ServerContext* context, + const FetchRequestPB* request, + grpc::ServerWriter* writer) override { + if (!prefill_server_) { + auto error_msg = "server not implement FetchResponse"; + RTP_LLM_LOG_ERROR(error_msg); + return grpc::Status(grpc::StatusCode::UNIMPLEMENTED, error_msg); + } + return prefill_server_->FetchResponse(context, request, writer); + } + + grpc::Status Cancel(grpc::ServerContext* context, const CancelRequestPB* request, EmptyPB* response) override { + if (!prefill_server_) { + auto error_msg = "server not implement Cancel"; + RTP_LLM_LOG_ERROR(error_msg); + return grpc::Status(grpc::StatusCode::UNIMPLEMENTED, error_msg); + } + return prefill_server_->Cancel(context, request, response); + } + grpc::Status RemoteLoad(grpc::ServerContext* context, const BroadcastLoadRequestPB* request, BroadcastLoadResponsePB* response) override { diff --git a/rtp_llm/cpp/model_rpc/ResponseBuffer.cc b/rtp_llm/cpp/model_rpc/ResponseBuffer.cc new file mode 100644 index 0000000000..76e7132bb3 --- /dev/null +++ b/rtp_llm/cpp/model_rpc/ResponseBuffer.cc @@ -0,0 +1,135 @@ +#include "rtp_llm/cpp/model_rpc/ResponseBuffer.h" + +#include "autil/TimeUtility.h" + +#include +#include + +namespace rtp_llm { + +std::shared_ptr ResponseBufferRegistry::createOrGet(int64_t request_id) { + std::lock_guard lock(mu_); + auto it = map_.find(request_id); + if (it != map_.end()) { + return it->second; + } + auto entry = std::make_shared(); + entry->last_activity_us = autil::TimeUtility::currentTimeInMicroSeconds(); + map_.emplace(request_id, entry); + return entry; +} + +std::shared_ptr ResponseBufferRegistry::reserve(int64_t request_id) { + std::lock_guard lock(mu_); + auto it = map_.find(request_id); + if (it != map_.end()) { + return nullptr; + } + auto entry = std::make_shared(); + entry->last_activity_us = autil::TimeUtility::currentTimeInMicroSeconds(); + map_.emplace(request_id, entry); + return entry; +} + +std::shared_ptr ResponseBufferRegistry::get(int64_t request_id) { + std::lock_guard lock(mu_); + auto it = map_.find(request_id); + if (it == map_.end()) { + return nullptr; + } + return it->second; +} + +void ResponseBufferRegistry::erase(int64_t request_id) { + std::lock_guard lock(mu_); + map_.erase(request_id); +} + +void ResponseBufferRegistry::cancelAll() { + const int64_t now_us = autil::TimeUtility::currentTimeInMicroSeconds(); + std::vector> entries; + std::vector> cancel_producers; + + { + std::lock_guard lock(mu_); + entries.reserve(map_.size()); + for (const auto& kv : map_) { + entries.push_back(kv.second); + } + } + + cancel_producers.reserve(entries.size()); + for (const auto& entry : entries) { + std::function cancel_producer; + { + std::lock_guard entry_lock(entry->mu); + entry->cancelled.store(true); + entry->last_activity_us = now_us; + cancel_producer = entry->cancel_producer; + entry->cancel_producer = nullptr; + } + if (cancel_producer) { + cancel_producers.push_back(std::move(cancel_producer)); + } + entry->cv.notify_all(); + } + + for (const auto& cancel_producer : cancel_producers) { + cancel_producer(); + } +} + +size_t ResponseBufferRegistry::gc(std::chrono::microseconds ttl) { + const int64_t now_us = autil::TimeUtility::currentTimeInMicroSeconds(); + const int64_t ttl_us = ttl.count(); + size_t swept = 0; + + std::lock_guard lock(mu_); + for (auto it = map_.begin(); it != map_.end();) { + const auto& entry = it->second; + bool terminal = false; + bool idle = false; + { + std::lock_guard entry_lock(entry->mu); + terminal = entry->done.load() || entry->cancelled.load() || entry->error_status.has_value(); + idle = (now_us - entry->last_activity_us) >= ttl_us; + if (terminal && idle) { + std::deque().swap(entry->queue); + entry->cancel_producer = nullptr; + } + } + if (terminal && idle) { + it = map_.erase(it); + ++swept; + } else { + ++it; + } + } + return swept; +} + +size_t ResponseBufferRegistry::size() const { + std::lock_guard lock(mu_); + return map_.size(); +} + +bool ResponseBufferWriter::Write(const GenerateOutputsPB& outputs, grpc::WriteOptions /*options*/) { + if (!entry_) { + return false; + } + { + std::lock_guard lock(entry_->mu); + if (entry_->cancelled.load()) { + return false; + } + if (entry_->queue.size() >= ResponseBufferEntry::kMaxQueueSize) { + entry_->queue.pop_front(); + } + entry_->queue.push_back(outputs); + entry_->last_activity_us = autil::TimeUtility::currentTimeInMicroSeconds(); + } + entry_->cv.notify_all(); + return true; +} + +} // namespace rtp_llm diff --git a/rtp_llm/cpp/model_rpc/ResponseBuffer.h b/rtp_llm/cpp/model_rpc/ResponseBuffer.h new file mode 100644 index 0000000000..2ea1a3d433 --- /dev/null +++ b/rtp_llm/cpp/model_rpc/ResponseBuffer.h @@ -0,0 +1,60 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "rtp_llm/cpp/model_rpc/proto/model_rpc_service.pb.h" + +namespace rtp_llm { + +struct ResponseBufferEntry { + static constexpr size_t kMaxQueueSize = 1000; + + std::deque queue; + std::atomic done{false}; + std::atomic cancelled{false}; + std::optional error_status; + std::function cancel_producer; + std::mutex mu; + std::condition_variable cv; + int64_t last_activity_us{0}; +}; + +class ResponseBufferRegistry { +public: + ResponseBufferRegistry() = default; + + std::shared_ptr createOrGet(int64_t request_id); + std::shared_ptr reserve(int64_t request_id); + std::shared_ptr get(int64_t request_id); + void erase(int64_t request_id); + void cancelAll(); + size_t gc(std::chrono::microseconds ttl); + size_t size() const; + +private: + mutable std::mutex mu_; + std::unordered_map> map_; +}; + +class ResponseBufferWriter: public grpc::internal::WriterInterface { +public: + explicit ResponseBufferWriter(std::shared_ptr entry): entry_(std::move(entry)) {} + + bool Write(const GenerateOutputsPB& outputs, grpc::WriteOptions options) override; + +private: + std::shared_ptr entry_; +}; + +} // namespace rtp_llm diff --git a/rtp_llm/cpp/model_rpc/RpcServerRuntimeMeta.h b/rtp_llm/cpp/model_rpc/RpcServerRuntimeMeta.h index e479f581e3..09798df353 100644 --- a/rtp_llm/cpp/model_rpc/RpcServerRuntimeMeta.h +++ b/rtp_llm/cpp/model_rpc/RpcServerRuntimeMeta.h @@ -1,19 +1,38 @@ #pragma once #include +#include +#include +#include +#include #include "rtp_llm/cpp/engine_base/stream/GenerateStream.h" #include "rtp_llm/cpp/engine_base/schedulers/EngineScheduleInfo.h" namespace rtp_llm { + +struct RunningEntry { + EngineScheduleInfo::TaskInfo task_info; + GenerateStreamPtr stream; +}; + class RpcServerRuntimeMeta { public: + static TaskPhase derivePhase(const GenerateStreamPtr& stream) { + if (!stream) + return TaskPhase::PENDING; + if (stream->getStatus() == StreamState::RUNNING) + return TaskPhase::RUNNING; + if (stream->curBlocksNum() > 0) + return TaskPhase::KV_ALLOCATED; + return TaskPhase::RECEIVED; + } + EngineScheduleInfo getEngineScheduleInfo(int64_t latest_finished_version) { std::shared_lock lock(read_write_lock_); EngineScheduleInfo info; - for (auto& iter : running_streams_) { - info.running_task_info_list.push_back(iter.second); + for (auto& [id, entry] : running_streams_) { + entry.task_info.phase = derivePhase(entry.stream); + info.running_task_info_list.push_back(entry.task_info); } - // When no new finished tasks, keep client's version (monotonic). Otherwise client/flexlb - // would overwrite latestFinishedTaskVersion with 0 when finished queue is empty. int64_t version = latest_finished_version; for (auto& iter : finished_streams_) { if (iter.first > latest_finished_version) { @@ -29,8 +48,25 @@ class RpcServerRuntimeMeta { void enqueue(int64_t request_id, const GenerateStreamPtr& stream) { std::unique_lock lock(read_write_lock_); - running_streams_[request_id] = EngineScheduleInfo::TaskInfo( + auto new_task = EngineScheduleInfo::TaskInfo( {request_id, stream->prefixLength(), stream->inputLength(), stream->getTimeInfo().wait_time_us}); + auto it = running_streams_.find(request_id); + if (it != running_streams_.end()) { + new_task.batch_id = it->second.task_info.batch_id; + } + running_streams_[request_id] = RunningEntry{new_task, stream}; + } + + void enqueuePending(int64_t request_id, int64_t input_length, int64_t batch_id = -1) { + std::unique_lock lock(read_write_lock_); + auto task = EngineScheduleInfo::TaskInfo({request_id, + /*prefix_length=*/0, + input_length, + /*waiting_time_ms=*/0, + /*iterate_count=*/0, + /*end_time_ms=*/-1}); + task.batch_id = batch_id; + running_streams_[request_id] = RunningEntry{task, nullptr}; } void dequeue(int64_t request_id, const GenerateStreamPtr& stream) { @@ -39,7 +75,7 @@ class RpcServerRuntimeMeta { if (ptr == running_streams_.end()) { return; } - auto& task_info = ptr->second; + auto& task_info = ptr->second.task_info; if (finished_streams_.size() >= finished_capacity_) { finished_streams_.pop_front(); } @@ -49,17 +85,50 @@ class RpcServerRuntimeMeta { task_info.input_length = stream->inputLength(); task_info.waiting_time_ms = stream->getTimeInfo().wait_time_us / 1000; task_info.iterate_count = stream->iterCount(); + if (stream->hasError()) { + task_info.error_code = static_cast(stream->statusInfo().code()); + task_info.error_message = stream->statusInfo().ToString(); + } int64_t version = version_.fetch_add(1, std::memory_order_relaxed); finished_streams_.push_back(std::make_pair(version, task_info)); running_streams_.erase(ptr); } + void finishTask(int64_t request_id, + int64_t input_length = 0, + int64_t prefix_length = 0, + int64_t error_code = 0, + const std::string& error_message = "") { + std::unique_lock lock(read_write_lock_); + EngineScheduleInfo::TaskInfo task_info{request_id, + prefix_length, + input_length, + /*waiting_time_ms=*/0, + /*iterate_count=*/0, + /*end_time_ms=*/-1}; + auto ptr = running_streams_.find(request_id); + if (ptr != running_streams_.end()) { + task_info = ptr->second.task_info; + if (input_length > 0) { + task_info.input_length = input_length; + } + if (prefix_length > 0) { + task_info.prefix_length = prefix_length; + } + running_streams_.erase(ptr); + } + if (finished_streams_.size() >= finished_capacity_) { + finished_streams_.pop_front(); + } + task_info.end_time_ms = autil::TimeUtility::currentTimeInMilliSeconds(); + task_info.error_code = error_code; + task_info.error_message = error_message; + int64_t version = version_.fetch_add(1, std::memory_order_relaxed); + finished_streams_.push_back(std::make_pair(version, task_info)); + } + protected: - // Note: finished_streams_ pairs are (monotonic_version, TaskInfo). - // The list is ordered by insertion time (append-only via push_back), so end_time_ms - // is approximately monotonic. The break below relies on this approximate time ordering - // to stop early — it does NOT use pair.first (version) for trimming decisions. void trimFinishedStreams() { auto current = autil::TimeUtility::currentTimeInMilliSeconds(); auto iter = finished_streams_.begin(); @@ -78,7 +147,7 @@ class RpcServerRuntimeMeta { } } } - std::unordered_map running_streams_; + std::unordered_map running_streams_; std::list> finished_streams_; std::atomic version_{0}; mutable std::shared_mutex read_write_lock_; diff --git a/rtp_llm/cpp/model_rpc/model_rpc_client.py b/rtp_llm/cpp/model_rpc/model_rpc_client.py index d143addee2..d0542e506b 100644 --- a/rtp_llm/cpp/model_rpc/model_rpc_client.py +++ b/rtp_llm/cpp/model_rpc/model_rpc_client.py @@ -1,6 +1,7 @@ import functools import json import logging +import os from typing import AsyncGenerator import grpc @@ -9,11 +10,14 @@ from rtp_llm.config.exceptions import ExceptionType, FtRuntimeException from rtp_llm.config.generate_config import RoleType from rtp_llm.cpp.model_rpc.proto.model_rpc_service_pb2 import ( + CancelRequestPB, ErrorDetailsPB, + FetchRequestPB, GenerateInputPB, GenerateOutputsPB, MultimodalInputPB, RoleAddrPB, + RoleTypePB, ) from rtp_llm.cpp.model_rpc.proto.model_rpc_service_pb2_grpc import RpcServiceStub from rtp_llm.server.request_headers import ( @@ -36,17 +40,13 @@ def __init__(self): self.cached_logits_dict = {} -def trans_role_type(role_type: RoleType) -> RoleAddrPB.RoleType: - if role_type == RoleType.PDFUSION: - return RoleAddrPB.RoleType.PDFUSION - elif role_type == RoleType.PREFILL: - return RoleAddrPB.RoleType.PREFILL - elif role_type == RoleType.DECODE: - return RoleAddrPB.RoleType.DECODE - elif role_type == RoleType.VIT: - return RoleAddrPB.RoleType.VIT - elif role_type == RoleType.FRONTEND: - return RoleAddrPB.RoleType.FRONTEND +def _is_finished_response(outputs_pb: GenerateOutputsPB) -> bool: + finished = outputs_pb.flatten_output.finished + return bool(finished) and all(finished) + + +def trans_role_type(role_type: RoleType) -> int: + return role_type.value def _trans_jsonable_option(config_pb, config, field_name): @@ -64,9 +64,9 @@ def trans_input(input_py: GenerateInput): input_pb = GenerateInputPB() input_pb.request_id = input_py.request_id input_pb.token_ids.extend(input_py.token_ids.reshape(-1).tolist()) - input_pb.batch_group_size = input_py.batch_group_size - if hasattr(input_py, "batch_group_id") and input_py.batch_group_id != -1: - input_pb.batch_group_id.value = input_py.batch_group_id + input_pb.group_size = input_py.group_size + if hasattr(input_py, "group_id") and input_py.group_id != -1: + input_pb.group_id.value = input_py.group_id request_info = getattr(input_py, "request_info", None) if request_info is not None: @@ -199,8 +199,7 @@ def trans_input(input_py: GenerateInput): trans_option_cast( generate_config_pb, input_py.generate_config, "trace_id", functools.partial(str) ) - trans_option(generate_config_pb, input_py.generate_config, "batch_group_timeout") - trans_option(generate_config_pb, input_py.generate_config, "force_batch") + trans_option(generate_config_pb, input_py.generate_config, "group_timeout") for i in range(len(input_py.generate_config.stop_words_list)): stop_words = generate_config_pb.stop_words_list.rows.add() @@ -445,18 +444,32 @@ async def enqueue( input_pb = trans_input(input_py) response_iterator = None stream_state = StreamState() - - address_list = self._addresses - - for role_addr in input_py.generate_config.role_addrs: - if ( - (self._decode_entrance and role_addr.role == RoleType.DECODE) - or role_addr.role == RoleType.PDFUSION - or (not self._decode_entrance and role_addr.role == RoleType.PREFILL) - ): - if role_addr.ip != "": - address_list = [role_addr.ip + ":" + str(role_addr.grpc_port)] - break + use_fetch_response = bool(getattr(input_py, "enqueued_by_master", False)) + + if use_fetch_response: + address_list = [ + role_addr.ip + ":" + str(role_addr.grpc_port) + for role_addr in input_py.generate_config.role_addrs + if role_addr.role == RoleType.PREFILL and role_addr.ip + ] + if os.environ.get("FLEXLB_EXPECT_FETCH_RESPONSE") == "1": + logging.info( + "FLEXLB_EXPECT_FETCH_RESPONSE request_id=%s using FetchResponse", + input_pb.request_id, + ) + else: + address_list = self._addresses + for role_addr in input_py.generate_config.role_addrs: + if ( + (self._decode_entrance and role_addr.role == RoleType.DECODE) + or role_addr.role == RoleType.PDFUSION + or ( + not self._decode_entrance and role_addr.role == RoleType.PREFILL + ) + ): + if role_addr.ip != "": + address_list = [role_addr.ip + ":" + str(role_addr.grpc_port)] + break if not address_list: raise ValueError(f"No address found for request: {input_pb.request_id}") @@ -467,16 +480,28 @@ async def enqueue( logging.debug( f"request: [{input_pb.request_id}] send to address: {target_address}" ) + stub = None + stream_done = False + terminal_seen = False try: # Get channel from pool channel = await self._channel_pool.get(target_address) stub = RpcServiceStub(channel) grpc_kwargs = {"timeout": effective_ms / 1000.0} if effective_ms > 0 else {} - response_iterator = stub.GenerateStreamCall(input_pb, **grpc_kwargs) + if use_fetch_response: + response_iterator = stub.FetchResponse( + FetchRequestPB(request_id=input_pb.request_id), **grpc_kwargs + ) + else: + response_iterator = stub.GenerateStreamCall(input_pb, **grpc_kwargs) # 调用服务器方法并接收流式响应 async for response in response_iterator.__aiter__(): - yield trans_output(input_py, response, stream_state) + output = trans_output(input_py, response, stream_state) + if use_fetch_response and _is_finished_response(response): + terminal_seen = True + yield output + stream_done = True except grpc.RpcError as e: # TODO(xinfei.sxf) 非流式的请求无法取消了 if response_iterator: @@ -528,5 +553,20 @@ async def enqueue( ) raise e finally: - if response_iterator: + should_cancel = not stream_done and not ( + use_fetch_response and terminal_seen + ) + if response_iterator and should_cancel: response_iterator.cancel() + if use_fetch_response and stub is not None and should_cancel: + try: + await stub.Cancel( + CancelRequestPB(request_id=input_pb.request_id), + timeout=5.0, + ) + except Exception: + logging.debug( + "request: [%s] best-effort Cancel failed", + input_pb.request_id, + exc_info=True, + ) diff --git a/rtp_llm/cpp/model_rpc/proto/model_rpc_service.proto b/rtp_llm/cpp/model_rpc/proto/model_rpc_service.proto index 05db9c1617..f1445482e3 100644 --- a/rtp_llm/cpp/model_rpc/proto/model_rpc_service.proto +++ b/rtp_llm/cpp/model_rpc/proto/model_rpc_service.proto @@ -25,15 +25,16 @@ message IntMatrix { repeated IntVector rows = 1; } +enum RoleTypePB { + ROLE_TYPE_PDFUSION = 0; + ROLE_TYPE_PREFILL = 1; + ROLE_TYPE_DECODE = 2; + ROLE_TYPE_VIT = 3; + ROLE_TYPE_FRONTEND = 4; +} + message RoleAddrPB { - enum RoleType { - PDFUSION = 0; - PREFILL = 1; - DECODE = 2; - VIT = 3; - FRONTEND = 4; - } - RoleType role = 1; + RoleTypePB role = 1; string ip = 2; int32 http_port = 3; int32 grpc_port = 4; @@ -96,8 +97,8 @@ message GenerateConfigPB { bool enable_device_cache = 52; bool enable_remote_cache = 53; string unique_key = 54; - google.protobuf.Int32Value force_batch = 55; - google.protobuf.Int32Value batch_group_timeout = 56; + reserved 55; + google.protobuf.Int32Value group_timeout = 56; string profile_trace_name = 57; repeated int32 begin_think_token_ids = 58; google.protobuf.StringValue json_schema = 59; @@ -152,8 +153,8 @@ message GenerateInputPB { GenerateConfigPB generate_config = 4; string client_id = 5; int64 start_time = 6; - int32 batch_group_size = 7; - google.protobuf.Int64Value batch_group_id = 8; + int32 group_size = 7; + google.protobuf.Int64Value group_id = 8; RequestInfoPB request_info = 9; } @@ -395,6 +396,13 @@ message UpdateEplbConfigRequestPB { int32 update_time = 2; } +enum TaskPhase { + TASK_PHASE_PENDING = 0; + TASK_PHASE_RECEIVED = 1; + TASK_PHASE_KV_ALLOCATED = 2; + TASK_PHASE_RUNNING = 3; +} + message TaskInfoPB { int64 request_id = 1; reserved 2; @@ -405,7 +413,11 @@ message TaskInfoPB { int64 iterate_count = 6; int64 end_time_ms = 7; int64 dp_rank = 8; - bool is_waiting = 9; + reserved 9; + reserved "is_waiting"; + ErrorDetailsPB error_info = 10; + int64 batch_id = 11; + TaskPhase phase = 12; } message UpdateWeightsRequestPB { @@ -415,7 +427,7 @@ message UpdateWeightsRequestPB { } message WorkerStatusPB { - string role = 1; + RoleTypePB role = 1; int32 available_concurrency = 2; repeated TaskInfoPB running_task_info = 3; repeated TaskInfoPB finished_task_list = 4; @@ -429,6 +441,9 @@ message WorkerStatusPB { bool alive = 13; string precision = 14; int64 latest_finished_version = 15; + int64 dp_rank = 16; + int64 available_kv_cache = 17; + int64 total_kv_cache = 18; } message EmbeddingInputPB { @@ -623,6 +638,68 @@ message FunctionResponsePB { } } +message CpuTpBroadcastRequestPB { + string group_key = 1; + uint64 seq = 2; + int32 root = 3; + int32 src_tp_rank = 4; + int32 dst_tp_rank = 5; + uint64 nbytes = 6; + bytes payload = 7; +} + +message CpuTpBroadcastResponsePB { + bool success = 1; + string error_message = 2; +} + +message EnqueueBatchExternalInputPB { + GenerateInputPB input = 1; +} + +message EnqueueBatchDpSlotPB { + int32 dp_rank = 1; + repeated EnqueueBatchExternalInputPB requests = 2; +} + +message EnqueueBatchRequestPB { + int64 batch_id = 1; + repeated EnqueueBatchDpSlotPB dp_slots = 2; +} + +message EnqueueGroupInputPB { + GenerateInputPB input = 1; +} + +message EnqueueGroupRequestPB { + int64 batch_id = 1; + int32 dp_rank = 2; + repeated EnqueueGroupInputPB requests = 3; +} + +message EnqueueBatchSuccessPB { + int64 request_id = 1; +} + +message EnqueueBatchErrorPB { + int64 request_id = 1; + ErrorDetailsPB error_info = 2; +} + +message EnqueueBatchResponsePB { + int64 batch_id = 1; + repeated EnqueueBatchSuccessPB successes = 2; + repeated EnqueueBatchErrorPB errors = 3; +} + +message FetchRequestPB { + int64 request_id = 1; +} + +message CancelRequestPB { + int64 request_id = 1; +} + service RpcService { rpc GetWorkerStatus(StatusVersionPB) returns (WorkerStatusPB); rpc GetCacheStatus(CacheVersionPB) returns (CacheStatusPB); @@ -646,8 +723,13 @@ service RpcService { rpc RemoteStore(RemoteStoreRequestPB) returns (RemoteStoreResponsePB); rpc ExecuteFunction(FunctionRequestPB) returns (FunctionResponsePB); + rpc CpuTpBroadcast(CpuTpBroadcastRequestPB) returns (CpuTpBroadcastResponsePB); rpc StartLoad(P2PConnectorStartLoadRequestPB) returns (P2PConnectorStartLoadResponsePB); rpc GetPeerInfo(GetPeerInfoRequestPB) returns (GetPeerInfoResponsePB); + rpc EnqueueBatch(EnqueueBatchRequestPB) returns (EnqueueBatchResponsePB); + rpc EnqueueGroup(EnqueueGroupRequestPB) returns (EnqueueBatchResponsePB); + rpc FetchResponse(FetchRequestPB) returns (stream GenerateOutputsPB); + rpc Cancel(CancelRequestPB) returns (EmptyPB); } service MultimodalRpcService { @@ -655,3 +737,46 @@ service MultimodalRpcService { rpc GetWorkerStatus(StatusVersionPB) returns (WorkerStatusPB); rpc GetCacheStatus(CacheVersionPB) returns (CacheStatusPB); } + +enum FlexlbScheduleModePB { + FLEXLB_SCHEDULE_AUTO = 0; + FLEXLB_SCHEDULE_BATCH = 1; + FLEXLB_SCHEDULE_DIRECT = 2; +} + +message FlexlbScheduleRequestPB { + int64 request_id = 1; + GenerateInputPB generate_input = 2; + repeated int64 block_cache_keys = 3; + int64 seq_len = 4; + int64 generate_timeout = 5; + int64 request_time_ms = 6; + int32 max_new_tokens = 7; + int32 num_beams = 8; + bool force_disable_sp_run = 9; + string model = 10; + string api_key = 11; + FlexlbScheduleModePB schedule_mode = 12; +} + +message FlexlbServerStatusPB { + RoleTypePB role = 1; + string server_ip = 2; + int32 http_port = 3; + int32 grpc_port = 4; +} + +message FlexlbScheduleResponsePB { + bool success = 1; + int32 code = 2; + string error_message = 3; + repeated FlexlbServerStatusPB server_status = 4; + string real_master_host = 5; + int32 queue_length = 6; + bool enqueued_by_master = 7; +} + +service FlexlbService { + rpc Schedule(FlexlbScheduleRequestPB) returns (FlexlbScheduleResponsePB); + rpc Cancel(CancelRequestPB) returns (EmptyPB); +} diff --git a/rtp_llm/cpp/model_rpc/test/BUILD b/rtp_llm/cpp/model_rpc/test/BUILD index 88f780a10c..1e2a0f6d6d 100644 --- a/rtp_llm/cpp/model_rpc/test/BUILD +++ b/rtp_llm/cpp/model_rpc/test/BUILD @@ -91,3 +91,29 @@ cc_test( "@com_google_googletest//:gtest_main", ], ) + +cc_test( + name = "response_buffer_test", + srcs = [ + "ResponseBufferTest.cc", + ], + copts = test_copts, + deps = test_deps, + env = { + "TEST_USING_DEVICE": "CUDA", + }, + exec_properties = {'gpu':'H20'}, +) + +cc_test( + name = "rpc_server_runtime_meta_test", + srcs = [ + "RpcServerRuntimeMetaTest.cc", + ], + copts = test_copts, + deps = test_deps, + env = { + "TEST_USING_DEVICE": "CUDA", + }, + exec_properties = {'gpu':'H20'}, +) diff --git a/rtp_llm/cpp/model_rpc/test/ResponseBufferTest.cc b/rtp_llm/cpp/model_rpc/test/ResponseBufferTest.cc new file mode 100644 index 0000000000..34c464250d --- /dev/null +++ b/rtp_llm/cpp/model_rpc/test/ResponseBufferTest.cc @@ -0,0 +1,247 @@ +#include +#include +#include +#include + +#include + +#include "rtp_llm/cpp/model_rpc/ResponseBuffer.h" + +namespace rtp_llm::test { + +TEST(ResponseBufferRegistryTest, CreateReturnsSameEntryForDuplicateId) { + ResponseBufferRegistry registry; + auto first = registry.createOrGet(42); + auto second = registry.createOrGet(42); + EXPECT_EQ(first.get(), second.get()); + EXPECT_EQ(registry.size(), 1u); +} + +TEST(ResponseBufferRegistryTest, ReserveReturnsNullForDuplicateId) { + ResponseBufferRegistry registry; + auto first = registry.reserve(42); + auto second = registry.reserve(42); + EXPECT_NE(first, nullptr); + EXPECT_EQ(second, nullptr); + EXPECT_EQ(registry.size(), 1u); +} + +TEST(ResponseBufferRegistryTest, GetReturnsNullWhenMissing) { + ResponseBufferRegistry registry; + EXPECT_EQ(registry.get(99), nullptr); + registry.createOrGet(99); + EXPECT_NE(registry.get(99), nullptr); +} + +TEST(ResponseBufferRegistryTest, EraseRemovesEntry) { + ResponseBufferRegistry registry; + registry.createOrGet(1); + EXPECT_EQ(registry.size(), 1u); + registry.erase(1); + EXPECT_EQ(registry.size(), 0u); + EXPECT_EQ(registry.get(1), nullptr); +} + +TEST(ResponseBufferRegistryTest, GcSkipsLiveAndDrainsTerminalIdle) { + ResponseBufferRegistry registry; + + auto alive = registry.createOrGet(1); + (void)alive; + auto done = registry.createOrGet(2); + done->done.store(true); + auto cancelled = registry.createOrGet(3); + cancelled->cancelled.store(true); + + EXPECT_EQ(registry.gc(std::chrono::hours(1)), 0u); + EXPECT_EQ(registry.size(), 3u); + + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + EXPECT_EQ(registry.gc(std::chrono::microseconds(0)), 2u); + EXPECT_EQ(registry.size(), 1u); + EXPECT_NE(registry.get(1), nullptr); +} + +TEST(ResponseBufferRegistryTest, GcSweepsTerminalEntryWithPendingQueueAfterTtl) { + ResponseBufferRegistry registry; + auto entry = registry.createOrGet(7); + entry->done.store(true); + { + std::lock_guard lock(entry->mu); + entry->queue.emplace_back(); + } + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + EXPECT_EQ(registry.gc(std::chrono::microseconds(0)), 1u); + EXPECT_EQ(registry.size(), 0u); + EXPECT_EQ(registry.get(7), nullptr); +} + +TEST(ResponseBufferRegistryTest, CancelAllMarksEntriesAndInvokesProducers) { + ResponseBufferRegistry registry; + auto first = registry.createOrGet(1); + auto second = registry.createOrGet(2); + int first_cancel_count = 0; + int second_cancel_count = 0; + + { + std::lock_guard lock(first->mu); + first->cancel_producer = [&] { ++first_cancel_count; }; + } + { + std::lock_guard lock(second->mu); + second->cancel_producer = [&] { ++second_cancel_count; }; + } + + registry.cancelAll(); + registry.cancelAll(); + + EXPECT_TRUE(first->cancelled.load()); + EXPECT_TRUE(second->cancelled.load()); + EXPECT_EQ(first_cancel_count, 1); + EXPECT_EQ(second_cancel_count, 1); +} + +TEST(ResponseBufferWriterTest, WritePushesAndNotifies) { + auto entry = std::make_shared(); + ResponseBufferWriter writer(entry); + + GenerateOutputsPB output; + output.set_request_id(123); + EXPECT_TRUE(writer.Write(output, grpc::WriteOptions{})); + + std::lock_guard lock(entry->mu); + ASSERT_EQ(entry->queue.size(), 1u); + EXPECT_EQ(entry->queue.front().request_id(), 123); +} + +TEST(ResponseBufferWriterTest, WriteReturnsFalseWhenCancelled) { + auto entry = std::make_shared(); + entry->cancelled.store(true); + ResponseBufferWriter writer(entry); + + GenerateOutputsPB output; + EXPECT_FALSE(writer.Write(output, grpc::WriteOptions{})); + std::lock_guard lock(entry->mu); + EXPECT_TRUE(entry->queue.empty()); +} + +TEST(ResponseBufferEntryTest, CancelProducerCanBeInvokedByOwner) { + auto entry = std::make_shared(); + bool observed = false; + { + std::lock_guard lock(entry->mu); + entry->cancel_producer = [&] { observed = true; }; + } + + std::function cancel_producer; + { + std::lock_guard lock(entry->mu); + cancel_producer = entry->cancel_producer; + } + ASSERT_TRUE(static_cast(cancel_producer)); + cancel_producer(); + EXPECT_TRUE(observed); +} + +TEST(ResponseBufferWriterTest, WriteWakesBlockedConsumer) { + auto entry = std::make_shared(); + ResponseBufferWriter writer(entry); + + GenerateOutputsPB observed; + std::thread consumer([&] { + std::unique_lock lock(entry->mu); + entry->cv.wait(lock, [&] { return !entry->queue.empty(); }); + observed = entry->queue.front(); + entry->queue.pop_front(); + }); + + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + GenerateOutputsPB output; + output.set_request_id(42); + EXPECT_TRUE(writer.Write(output, grpc::WriteOptions{})); + consumer.join(); + EXPECT_EQ(observed.request_id(), 42); +} + +TEST(AsyncSubmitChainTest, ProducerConsumerDrainOrderAndTerminate) { + ResponseBufferRegistry registry; + const int64_t request_id = 1001; + auto entry = registry.createOrGet(request_id); + + constexpr int kOutputCount = 8; + std::thread producer([entry] { + ResponseBufferWriter writer(entry); + for (int i = 0; i < kOutputCount; ++i) { + GenerateOutputsPB output; + output.set_request_id(i); + std::this_thread::sleep_for(std::chrono::milliseconds(2)); + ASSERT_TRUE(writer.Write(output, grpc::WriteOptions{})); + } + entry->done.store(true); + entry->cv.notify_all(); + }); + + std::vector observed; + bool terminal = false; + while (!terminal) { + std::deque drained; + { + std::unique_lock lock(entry->mu); + entry->cv.wait_for(lock, std::chrono::milliseconds(100), [&] { + return !entry->queue.empty() || entry->done.load() || entry->cancelled.load() + || entry->error_status.has_value(); + }); + drained.swap(entry->queue); + terminal = entry->done.load() || entry->cancelled.load() || entry->error_status.has_value(); + } + for (auto& output : drained) { + observed.push_back(output.request_id()); + } + } + producer.join(); + + ASSERT_EQ(observed.size(), static_cast(kOutputCount)); + for (int i = 0; i < kOutputCount; ++i) { + EXPECT_EQ(observed[i], i); + } + registry.erase(request_id); + EXPECT_EQ(registry.get(request_id), nullptr); +} + +TEST(AsyncSubmitChainTest, ErrorStatusPropagatesToConsumer) { + ResponseBufferRegistry registry; + auto entry = registry.createOrGet(3003); + + std::thread producer([entry] { + std::this_thread::sleep_for(std::chrono::milliseconds(5)); + std::lock_guard lock(entry->mu); + entry->error_status = grpc::Status(grpc::StatusCode::INTERNAL, "simulated finishStream failure"); + entry->done.store(true); + entry->cv.notify_all(); + }); + + grpc::Status terminal_status = grpc::Status::OK; + while (true) { + std::unique_lock lock(entry->mu); + entry->cv.wait_for(lock, std::chrono::milliseconds(100), [&] { + return entry->done.load() || entry->error_status.has_value(); + }); + if (entry->error_status.has_value()) { + terminal_status = *entry->error_status; + break; + } + if (entry->done.load()) { + break; + } + } + producer.join(); + + EXPECT_EQ(terminal_status.error_code(), grpc::StatusCode::INTERNAL); + EXPECT_NE(terminal_status.error_message().find("simulated finishStream failure"), std::string::npos); +} + +} // namespace rtp_llm::test + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/rtp_llm/cpp/model_rpc/test/RpcServerRuntimeMetaTest.cc b/rtp_llm/cpp/model_rpc/test/RpcServerRuntimeMetaTest.cc new file mode 100644 index 0000000000..3eb2cf2d33 --- /dev/null +++ b/rtp_llm/cpp/model_rpc/test/RpcServerRuntimeMetaTest.cc @@ -0,0 +1,61 @@ +#include + +#include "rtp_llm/cpp/model_rpc/RpcServerRuntimeMeta.h" + +namespace rtp_llm::test { + +TEST(RpcServerRuntimeMetaTest, EnqueuePendingReportsPendingPhase) { + RpcServerRuntimeMeta meta; + + meta.enqueuePending(/*request_id=*/101, /*input_length=*/2048); + + auto info = meta.getEngineScheduleInfo(/*latest_finished_version=*/-1); + ASSERT_EQ(info.running_task_info_list.size(), 1); + EXPECT_TRUE(info.finished_task_info_list.empty()); + EXPECT_EQ(info.running_task_info_list[0].request_id, 101); + EXPECT_EQ(info.running_task_info_list[0].input_length, 2048); + EXPECT_EQ(info.running_task_info_list[0].prefix_length, 0); + EXPECT_EQ(info.running_task_info_list[0].phase, TaskPhase::PENDING); +} + +TEST(RpcServerRuntimeMetaTest, FinishTaskMovesPendingToFinishedWithErrorDetails) { + RpcServerRuntimeMeta meta; + + meta.enqueuePending(/*request_id=*/202, /*input_length=*/1024); + meta.finishTask(/*request_id=*/202, + /*input_length=*/1024, + /*prefix_length=*/128, + /*error_code=*/13, + /*error_message=*/"decode alloc failed"); + + auto info = meta.getEngineScheduleInfo(/*latest_finished_version=*/-1); + EXPECT_TRUE(info.running_task_info_list.empty()); + ASSERT_EQ(info.finished_task_info_list.size(), 1); + const auto& finished = info.finished_task_info_list[0]; + EXPECT_EQ(finished.request_id, 202); + EXPECT_EQ(finished.input_length, 1024); + EXPECT_EQ(finished.prefix_length, 128); + EXPECT_EQ(finished.error_code, 13); + EXPECT_EQ(finished.error_message, "decode alloc failed"); + EXPECT_GE(info.latest_finished_version, 0); +} + +TEST(RpcServerRuntimeMetaTest, FinishTaskWithoutPendingStillReportsFailure) { + RpcServerRuntimeMeta meta; + + meta.finishTask(/*request_id=*/303, + /*input_length=*/512, + /*prefix_length=*/0, + /*error_code=*/14, + /*error_message=*/"remote load failed"); + + auto info = meta.getEngineScheduleInfo(/*latest_finished_version=*/-1); + ASSERT_EQ(info.finished_task_info_list.size(), 1); + const auto& finished = info.finished_task_info_list[0]; + EXPECT_EQ(finished.request_id, 303); + EXPECT_EQ(finished.input_length, 512); + EXPECT_EQ(finished.error_code, 14); + EXPECT_EQ(finished.error_message, "remote load failed"); +} + +} // namespace rtp_llm::test diff --git a/rtp_llm/cpp/model_rpc/test/model_rpc_client_test.py b/rtp_llm/cpp/model_rpc/test/model_rpc_client_test.py index 0bcbf61a9f..93c0442d4e 100644 --- a/rtp_llm/cpp/model_rpc/test/model_rpc_client_test.py +++ b/rtp_llm/cpp/model_rpc/test/model_rpc_client_test.py @@ -1,7 +1,8 @@ import asyncio import struct import sys -from unittest.mock import MagicMock +from enum import Enum +from unittest.mock import MagicMock, patch # Mock the ops module to avoid CUDA dependency in this unit test # This MUST be at the very top before any other imports, even before unittest @@ -9,9 +10,20 @@ mock_comm = MagicMock() mock_nccl_op = MagicMock() mock_compute_ops = MagicMock() + + +class _FakeRoleType(Enum): + PDFUSION = 0 + PREFILL = 1 + DECODE = 2 + VIT = 3 + FRONTEND = 4 + + mock_comm.nccl_op = mock_nccl_op mock_ops.comm = mock_comm mock_ops.compute_ops = mock_compute_ops +mock_ops.RoleType = _FakeRoleType sys.modules["rtp_llm.ops"] = mock_ops sys.modules["rtp_llm.ops.comm"] = mock_comm sys.modules["rtp_llm.ops.compute_ops"] = mock_compute_ops @@ -25,7 +37,7 @@ import torch -from rtp_llm.config.generate_config import GenerateConfig +from rtp_llm.config.generate_config import GenerateConfig, RoleAddr, RoleType from rtp_llm.config.log_config import setup_logging from rtp_llm.cpp.model_rpc.model_rpc_client import ( ModelRpcClient, @@ -38,7 +50,11 @@ GenerateOutputsPB, TensorPB, ) -from rtp_llm.utils.base_model_datatypes import GenerateInput, GenerateOutputs, RequestInfo +from rtp_llm.utils.base_model_datatypes import ( + GenerateInput, + GenerateOutputs, + RequestInfo, +) class FakeStub: @@ -87,9 +103,9 @@ class FakeModelRpcClient(ModelRpcClient): def __init__(self): # Call parent __init__ with minimal required parameters super().__init__( - [], # addresses: empty list for fake client - {}, # client_config: empty dict for fake client - 0, # max_rpc_timeout_ms + [], # addresses: empty list for fake client + {}, # client_config: empty dict for fake client + 0, # max_rpc_timeout_ms False, # decode_entrance ) self.stub = FakeStub() @@ -104,6 +120,67 @@ async def enqueue( yield trans_output(input_py, response_pb, stream_state) +class _FakeResponseIterator: + def __init__(self, responses): + self._responses = iter(responses) + self.cancelled = False + + def __aiter__(self): + return self + + async def __anext__(self): + try: + return next(self._responses) + except StopIteration: + raise StopAsyncIteration + + def cancel(self): + self.cancelled = True + + +class _FakeChannelPool: + def __init__(self): + self.targets = [] + + async def get(self, target_address): + self.targets.append(target_address) + return object() + + +class _RoutingStub: + def __init__(self, fetch_responses=None, generate_responses=None): + self.fetch_iterator = _FakeResponseIterator(fetch_responses or []) + self.generate_iterator = _FakeResponseIterator(generate_responses or []) + self.fetch_calls = [] + self.generate_calls = [] + self.cancel_calls = [] + + def FetchResponse(self, request, **kwargs): + self.fetch_calls.append((request, kwargs)) + return self.fetch_iterator + + def GenerateStreamCall(self, request, **kwargs): + self.generate_calls.append((request, kwargs)) + return self.generate_iterator + + async def Cancel(self, request, **kwargs): + self.cancel_calls.append((request, kwargs)) + + +def _make_response(finished=True): + outputs_pb = GenerateOutputsPB() + outputs_pb.flatten_output.finished.extend([finished]) + return outputs_pb + + +def _prefill_role_addr(ip="prefill", grpc_port=9000): + return RoleAddr(role=RoleType.PREFILL, ip=ip, http_port=8000, grpc_port=grpc_port) + + +def _decode_role_addr(ip="decode", grpc_port=9001): + return RoleAddr(role=RoleType.DECODE, ip=ip, http_port=8001, grpc_port=grpc_port) + + class ModelRpcClientTest(TestCase): def __init__(self, methodName: str = "runTest") -> None: @@ -238,6 +315,184 @@ def test_trans_input_request_info_trace_header_fallback(self): input_pb.request_info.request_id, "4bf92f3577b34da6a3ce929d0e0e4736" ) + def test_enqueue_fetches_response_when_master_already_enqueued(self): + client = ModelRpcClient( + addresses=["worker:9000"], + client_config={}, + max_rpc_timeout_ms=0, + decode_entrance=False, + ) + client._channel_pool = _FakeChannelPool() + stub = _RoutingStub(fetch_responses=[_make_response(finished=True)]) + input_py = GenerateInput( + token_ids=torch.tensor([1, 2, 3]), + generate_config=GenerateConfig( + timeout_ms=1000, + role_addrs=[_prefill_role_addr("prefill-worker", 9000)], + ), + request_id=321, + mm_inputs=[], + enqueued_by_master=True, + ) + + with patch( + "rtp_llm.cpp.model_rpc.model_rpc_client.RpcServiceStub", + return_value=stub, + ): + responses = asyncio.run(self._run(client, input_py)) + + self.assertEqual(len(responses), 1) + self.assertEqual(client._channel_pool.targets, ["prefill-worker:9000"]) + self.assertEqual(len(stub.fetch_calls), 1) + self.assertEqual(stub.fetch_calls[0][0].request_id, 321) + self.assertEqual(stub.fetch_calls[0][1]["timeout"], 1.0) + self.assertEqual(stub.generate_calls, []) + self.assertEqual(stub.cancel_calls, []) + + def test_enqueue_uses_generate_stream_without_master_enqueue(self): + client = ModelRpcClient( + addresses=["worker:9000"], + client_config={}, + max_rpc_timeout_ms=0, + decode_entrance=False, + ) + client._channel_pool = _FakeChannelPool() + stub = _RoutingStub(generate_responses=[_make_response(finished=True)]) + input_py = GenerateInput( + token_ids=torch.tensor([1, 2, 3]), + generate_config=GenerateConfig(timeout_ms=1000), + request_id=322, + mm_inputs=[], + ) + + with patch( + "rtp_llm.cpp.model_rpc.model_rpc_client.RpcServiceStub", + return_value=stub, + ): + responses = asyncio.run(self._run(client, input_py)) + + self.assertEqual(len(responses), 1) + self.assertEqual(len(stub.generate_calls), 1) + self.assertEqual(stub.generate_calls[0][0].request_id, 322) + self.assertEqual(stub.fetch_calls, []) + self.assertEqual(stub.cancel_calls, []) + + def test_enqueue_cancels_master_enqueued_fetch_on_early_close(self): + async def run_and_close(): + gen = client.enqueue(input_py) + first = await gen.__anext__() + await gen.aclose() + return first + + client = ModelRpcClient( + addresses=["worker:9000"], + client_config={}, + max_rpc_timeout_ms=0, + decode_entrance=False, + ) + client._channel_pool = _FakeChannelPool() + stub = _RoutingStub( + fetch_responses=[ + _make_response(finished=False), + _make_response(finished=True), + ] + ) + input_py = GenerateInput( + token_ids=torch.tensor([1, 2, 3]), + generate_config=GenerateConfig( + timeout_ms=1000, + role_addrs=[_prefill_role_addr("prefill-worker", 9000)], + ), + request_id=323, + mm_inputs=[], + enqueued_by_master=True, + ) + + with patch( + "rtp_llm.cpp.model_rpc.model_rpc_client.RpcServiceStub", + return_value=stub, + ): + asyncio.run(run_and_close()) + + self.assertTrue(stub.fetch_iterator.cancelled) + self.assertEqual(len(stub.cancel_calls), 1) + self.assertEqual(stub.cancel_calls[0][0].request_id, 323) + self.assertEqual(stub.cancel_calls[0][1]["timeout"], 5.0) + + def test_enqueue_fetch_cancel_uses_prefill_when_decode_entrance(self): + async def run_and_close(): + gen = client.enqueue(input_py) + await gen.__anext__() + await gen.aclose() + + client = ModelRpcClient( + addresses=["worker:9000"], + client_config={}, + max_rpc_timeout_ms=0, + decode_entrance=True, + ) + client._channel_pool = _FakeChannelPool() + stub = _RoutingStub(fetch_responses=[_make_response(finished=False)]) + input_py = GenerateInput( + token_ids=torch.tensor([1, 2, 3]), + generate_config=GenerateConfig( + timeout_ms=1000, + role_addrs=[ + _prefill_role_addr("prefill-worker", 9000), + _decode_role_addr("decode-worker", 9001), + ], + ), + request_id=325, + mm_inputs=[], + enqueued_by_master=True, + ) + + with patch( + "rtp_llm.cpp.model_rpc.model_rpc_client.RpcServiceStub", + return_value=stub, + ): + asyncio.run(run_and_close()) + + self.assertEqual(client._channel_pool.targets, ["prefill-worker:9000"]) + self.assertEqual(len(stub.fetch_calls), 1) + self.assertEqual(len(stub.cancel_calls), 1) + self.assertEqual(stub.cancel_calls[0][0].request_id, 325) + + def test_enqueue_does_not_cancel_after_finished_response_is_seen(self): + async def run_and_close_after_finished(): + gen = client.enqueue(input_py) + first = await gen.__anext__() + self.assertTrue(first.generate_outputs[0].finished) + await gen.aclose() + + client = ModelRpcClient( + addresses=["worker:9000"], + client_config={}, + max_rpc_timeout_ms=0, + decode_entrance=False, + ) + client._channel_pool = _FakeChannelPool() + stub = _RoutingStub(fetch_responses=[_make_response(finished=True)]) + input_py = GenerateInput( + token_ids=torch.tensor([1, 2, 3]), + generate_config=GenerateConfig( + timeout_ms=1000, + role_addrs=[_prefill_role_addr("prefill-worker", 9000)], + ), + request_id=324, + mm_inputs=[], + enqueued_by_master=True, + ) + + with patch( + "rtp_llm.cpp.model_rpc.model_rpc_client.RpcServiceStub", + return_value=stub, + ): + asyncio.run(run_and_close_after_finished()) + + self.assertFalse(stub.fetch_iterator.cancelled) + self.assertEqual(stub.cancel_calls, []) + if __name__ == "__main__": setup_logging() diff --git a/rtp_llm/cpp/normal_engine/BUILD b/rtp_llm/cpp/normal_engine/BUILD index 1f8a4a0ab5..6c0507f32e 100644 --- a/rtp_llm/cpp/normal_engine/BUILD +++ b/rtp_llm/cpp/normal_engine/BUILD @@ -24,6 +24,7 @@ cc_library( "//rtp_llm/cpp/engine_base/system_prompt:system_prompt_constructor", "//rtp_llm/cpp/models:eplb", "//rtp_llm/cpp/cuda_graph:cuda_graph_base", + "//rtp_llm/cpp/distribute:rpc_cpu_tp_broadcaster_hdr", "//rtp_llm/cpp/utils:tensor_debug_utils", "//rtp_llm/cpp/utils:profiling_scope", ] + select({ diff --git a/rtp_llm/cpp/normal_engine/NormalEngine.cc b/rtp_llm/cpp/normal_engine/NormalEngine.cc index a616afcaa8..5df27f386f 100644 --- a/rtp_llm/cpp/normal_engine/NormalEngine.cc +++ b/rtp_llm/cpp/normal_engine/NormalEngine.cc @@ -520,7 +520,7 @@ std::shared_ptr NormalEngine::enqueue(const std::shared_ptr> -NormalEngine::batchEnqueue(const std::vector>& inputs) { +NormalEngine::enqueueMultiple(const std::vector>& inputs) { std::vector> streams; streams.reserve(inputs.size()); for (auto& inp : inputs) { @@ -529,7 +529,7 @@ NormalEngine::batchEnqueue(const std::vector>& in stream->setReserveStep(reserve_step_); streams.push_back(stream); } - return scheduler_->batchEnqueue(streams); + return scheduler_->enqueueGroup(streams); } absl::Status NormalEngine::step() { diff --git a/rtp_llm/cpp/normal_engine/NormalEngine.h b/rtp_llm/cpp/normal_engine/NormalEngine.h index 9354839ebe..142d22eb12 100644 --- a/rtp_llm/cpp/normal_engine/NormalEngine.h +++ b/rtp_llm/cpp/normal_engine/NormalEngine.h @@ -24,10 +24,10 @@ class NormalEngine: public EngineBase { NormalEngine(const EngineInitParams& params, std::unique_ptr propose_params); ~NormalEngine(); - std::shared_ptr makeStream(const std::shared_ptr& input) override; - std::shared_ptr enqueue(const std::shared_ptr& input) override; - std::vector batchEnqueue(const std::vector>& inputs) override; - void enqueue(std::shared_ptr& stream) override; + std::shared_ptr makeStream(const std::shared_ptr& input) override; + std::shared_ptr enqueue(const std::shared_ptr& input) override; + std::vector enqueueMultiple(const std::vector>& inputs) override; + void enqueue(std::shared_ptr& stream) override; absl::StatusOr preRun(const std::shared_ptr& generate_input, preRunMode mode) override; absl::Status stop() override; diff --git a/rtp_llm/cpp/normal_engine/NormalExecutor.cc b/rtp_llm/cpp/normal_engine/NormalExecutor.cc index ff3b15694a..4c0c690b26 100644 --- a/rtp_llm/cpp/normal_engine/NormalExecutor.cc +++ b/rtp_llm/cpp/normal_engine/NormalExecutor.cc @@ -13,6 +13,7 @@ #include "rtp_llm/cpp/models/Sampler.h" #include "rtp_llm/cpp/config/ModelConfig.h" #include "rtp_llm/cpp/models/logits_processor/LogitsProcessorFactory.h" +#include "rtp_llm/cpp/distribute/RpcCpuTpBroadcaster.h" using namespace std; @@ -27,6 +28,19 @@ bool readEnvFlagOnce(const char* env_name, const char* log_tag, const char* labe return on; } +int readEnvIntOnce(const char* env_name, int default_value, const char* log_tag) { + const char* env = std::getenv(env_name); + int value = default_value; + if (env != nullptr) { + value = std::atoi(env); + if (value <= 0) { + value = default_value; + } + } + RTP_LLM_LOG_INFO("[%s] %s=%s -> %d", log_tag, env_name, env ? env : "(unset)", value); + return value; +} + void holdSamplerInputHostBuffers(TensorHolder& holder, const SamplerInputs& inputs) { holder.hold_host(inputs.token_ids); holder.hold_host(inputs.input_lengths); @@ -89,6 +103,19 @@ NormalExecutor::NormalExecutor(const EngineInitParams& params, metrics_reporter_); } + const bool enable_cross_node_cpu_tp_broadcast = + readEnvFlagOnce("RTP_LLM_CROSS_NODE_CPU_TP_BROADCAST", "NormalExecutor", "cross_node_cpu_tp_broadcast"); + if (enable_cross_node_cpu_tp_broadcast && params.parallelism_config.tp_size > 1 + && params.parallelism_config.tp_size > params.parallelism_config.local_world_size) { + const int timeout_ms = readEnvIntOnce("RTP_LLM_CPU_TP_BROADCAST_TIMEOUT_MS", 30000, "NormalExecutor"); + RpcCpuTpBroadcaster::instance().initialize(static_cast(params.parallelism_config.tp_rank), + static_cast(params.parallelism_config.tp_size), + static_cast(params.parallelism_config.dp_rank), + static_cast(params.parallelism_config.world_size), + params.runtime_config.worker_grpc_addrs, + timeout_ms); + } + if (params.eplb_config.enable_eplb() && params.model_config_.moe_style != 0) { // use first moe layer weight as moe weight type int first_moe_layer = params.model_config_.moe_layer_index.front(); @@ -274,9 +301,54 @@ absl::Status NormalExecutor::process(const std::list& streams stream_groups.totalDecodeBatchSize(), stream_groups.modelExecuteTokenSize(), stream_groups.maxSeqLen()); + if (tp_rank_ == 0 && stream_groups.totalContextBatchSize() > 0) { + std::string details; + for (auto& s : stream_groups.contextStreams()) { + char buf[256]; + snprintf(buf, + sizeof(buf), + "{id=%ld trace_id=%s input=%d prefix=%d reuse=%d ctx=%d grp=%ld/%d tokens=%d} ", + s->streamId(), + s->traceId().empty() ? "-" : s->traceId().c_str(), + s->inputLength(), + s->prefixLength(), + s->reuseLength(), + s->contextLength(), + s->groupId(), + s->groupSize(), + s->currentExecuteTokenSize()); + details += buf; + } + RTP_LLM_LOG_INFO( + "prefill_batch_begin: ctx_batch=%zu gen_batch=%zu total_tokens=%zu max_seq=%zu streams=[%s]", + stream_groups.totalContextBatchSize(), + stream_groups.totalDecodeBatchSize(), + stream_groups.modelExecuteTokenSize(), + stream_groups.maxSeqLen(), + details.c_str()); + } int64_t start_time_us = autil::TimeUtility::currentTimeInMicroSeconds(); model_output = std::move(model_->forward(model_input)); executor_collector.model_forward_us = autil::TimeUtility::currentTimeInMicroSeconds() - start_time_us; + if (tp_rank_ == 0 && stream_groups.totalContextBatchSize() > 0) { + RTP_LLM_LOG_INFO("prefill_batch_end: ctx_batch=%zu total_tokens=%zu forward_us=%ld", + stream_groups.totalContextBatchSize(), + stream_groups.modelExecuteTokenSize(), + executor_collector.model_forward_us); + } + if (tp_rank_ == 0 && stream_groups.totalDecodeBatchSize() > 0) { + std::string details; + for (auto& s : stream_groups.decodeStreams()) { + char buf[256]; + snprintf(buf, sizeof(buf), "{id=%ld trace_id=%s seq=%d tokens=%d} ", + s->streamId(), s->traceId().empty() ? "-" : s->traceId().c_str(), + s->seqLength(), s->currentExecuteTokenSize()); + details += buf; + } + RTP_LLM_ACCESS_LOG_INFO("decode_step_begin: gen_batch=%zu total_tokens=%zu streams=[%s]", + stream_groups.totalDecodeBatchSize(), + stream_groups.modelExecuteTokenSize(), details.c_str()); + } } if (expert_balancer_) { int64_t start_time_us = autil::TimeUtility::currentTimeInMicroSeconds(); diff --git a/rtp_llm/cpp/normal_engine/NormalOutputDispatcher.cc b/rtp_llm/cpp/normal_engine/NormalOutputDispatcher.cc index e1192f09f7..6c7720f405 100644 --- a/rtp_llm/cpp/normal_engine/NormalOutputDispatcher.cc +++ b/rtp_llm/cpp/normal_engine/NormalOutputDispatcher.cc @@ -222,6 +222,7 @@ void NormalOutputDispatcher::dispatchSingleStream(GenerateStreamPtr stream, RTP_LLM_LOG_DEBUG("stream [%ld], new_tokens size = [%ld]", stream->streamId(), new_tokens.numel()); + size_t old_token_len = stream->outputTokenLen(); stream->update({has_beam_search ? batch_new_all_token_ids : new_tokens, 1, batch_hidden_states, @@ -232,6 +233,24 @@ void NormalOutputDispatcher::dispatchSingleStream(GenerateStreamPtr stream, loss, src_batch_indices, all_hidden_states}); + + if (old_token_len == 0 && stream->outputTokenLen() > 0) { + auto ti = stream->getTimeInfo(); + RTP_LLM_ACCESS_LOG_INFO("first_token: %s latency_us=%ld", + stream->streamLogTag().c_str(), + ti.first_token_time_us); + } + if (stream->isFinished() || stream->needFinish()) { + auto ti = stream->getTimeInfo(); + int64_t now_us = autil::TimeUtility::currentTimeInMicroSeconds(); + RTP_LLM_ACCESS_LOG_INFO( + "decode_finished: %s output_len=%ld first_token_latency_us=%ld " + "total_latency_us=%ld iter_count=%ld", + stream->streamLogTag().c_str(), stream->outputTokenLen(), + ti.first_token_time_us, + ti.begin_time_us > 0 ? (now_us - ti.begin_time_us) : 0, + stream->iterCount()); + } } } // namespace rtp_llm diff --git a/rtp_llm/cpp/normal_engine/speculative/MtpExecutor.cc b/rtp_llm/cpp/normal_engine/speculative/MtpExecutor.cc index 4dff6e19d1..e4ffe1f687 100644 --- a/rtp_llm/cpp/normal_engine/speculative/MtpExecutor.cc +++ b/rtp_llm/cpp/normal_engine/speculative/MtpExecutor.cc @@ -778,6 +778,31 @@ absl::Status MtpExecutor::prefillStep(const std::list& stream saved_input_lengths = toCudaWithHostHold(model_input.input_lengths, buffer_holder_); } + if (isTpRank0() && stream_groups.totalContextBatchSize() > 0) { + std::string details; + for (auto& s : stream_groups.contextStreams()) { + char buf[256]; + snprintf(buf, + sizeof(buf), + "{id=%ld input=%d prefix=%d reuse=%d ctx=%d grp=%ld/%d tokens=%d} ", + s->streamId(), + s->inputLength(), + s->prefixLength(), + s->reuseLength(), + s->contextLength(), + s->groupId(), + s->groupSize(), + s->currentExecuteTokenSize()); + details += buf; + } + RTP_LLM_LOG_INFO("prefill_batch_begin: ctx_batch=%zu gen_batch=%zu total_tokens=%zu max_seq=%zu streams=[%s]", + stream_groups.totalContextBatchSize(), + stream_groups.totalDecodeBatchSize(), + stream_groups.modelExecuteTokenSize(), + stream_groups.maxSeqLen(), + details.c_str()); + } + // target model prefill { RTP_LLM_PROFILE_SCOPE("executor.mtp.prefill_step(target_model_forward)"); @@ -844,6 +869,13 @@ absl::Status MtpExecutor::prefillStep(const std::list& stream model_forward_us += autil::TimeUtility::currentTimeInMicroSeconds() - start_time_us; } + if (isTpRank0() && stream_groups.totalContextBatchSize() > 0) { + RTP_LLM_LOG_INFO("prefill_batch_end: ctx_batch=%zu total_tokens=%zu forward_us=%ld", + stream_groups.totalContextBatchSize(), + stream_groups.modelExecuteTokenSize(), + model_forward_us); + } + if (!isTpRank0() || warm_up_ || streams.size() == 0 || model_input.is_fake_stream) { cudaSyncAndCheck(); return absl::OkStatus(); diff --git a/rtp_llm/cpp/pybind/BUILD b/rtp_llm/cpp/pybind/BUILD index 835abbe239..2a2ef0fa9b 100644 --- a/rtp_llm/cpp/pybind/BUILD +++ b/rtp_llm/cpp/pybind/BUILD @@ -59,6 +59,7 @@ cc_library( "//rtp_llm/cpp/cache:batch_kv_cache_resource", "//rtp_llm/cpp/cache:kv_cache_transfer_planner", "//rtp_llm/cpp/distribute:cpu_tp_broadcaster", + "//rtp_llm/cpp/distribute:rpc_cpu_tp_broadcaster", "//rtp_llm/cpp/utils:core_utils", "//rtp_llm/cpp/utils:debug_utils", "//rtp_llm/cpp/utils:device_perf_wrapper", diff --git a/rtp_llm/cpp/pybind/ConfigInit.cc b/rtp_llm/cpp/pybind/ConfigInit.cc index 5edac35785..f808593184 100644 --- a/rtp_llm/cpp/pybind/ConfigInit.cc +++ b/rtp_llm/cpp/pybind/ConfigInit.cc @@ -462,9 +462,8 @@ PYBIND11_MODULE(libth_transformer_config, m) { self.enable_dsv4_state_block_independent_eviction); }, [](py::tuple t) { - const bool has_disk_fields = - t.size() >= 50 && py::isinstance(t[9]); - const size_t min_size = has_disk_fields ? 50u : 45u; + const bool has_disk_fields = t.size() >= 50 && py::isinstance(t[9]); + const size_t min_size = has_disk_fields ? 50u : 45u; if (t.size() < min_size) throw std::runtime_error("Invalid state!"); KVCacheConfig c; @@ -486,43 +485,43 @@ PYBIND11_MODULE(libth_transformer_config, m) { c.memory_cache_disk_sync_timeout_ms = t[12].cast(); offset = 5; } - c.linear_step = t[8 + offset].cast(); - c.int8_kv_cache = t[9 + offset].cast(); - c.fp8_kv_cache = t[10 + offset].cast(); - c.kv_cache_mem_mb = t[11 + offset].cast(); - c.seq_size_per_block = t[12 + offset].cast(); - c.kernel_seq_size_per_block = t[13 + offset].cast(); - c.test_block_num = t[14 + offset].cast(); - c.use_block_cache = t[15 + offset].cast(); - c.enable_device_cache = t[16 + offset].cast(); - c.enable_memory_cache = t[17 + offset].cast(); - c.enable_memory_cache_sm_copy = t[18 + offset].cast(); - c.enable_remote_cache = t[19 + offset].cast(); - c.write_cache_sync = t[20 + offset].cast(); - c.enable_tiered_memory_cache = t[21 + offset].cast(); - c.device_cache_min_free_blocks = t[22 + offset].cast(); - c.reco_enable_vipserver = t[23 + offset].cast(); - c.reco_vipserver_domain = t[24 + offset].cast(); - c.reco_server_address = t[25 + offset].cast(); - c.reco_instance_group = t[26 + offset].cast(); - c.reco_meta_channel_retry_time = t[27 + offset].cast(); - c.reco_meta_channel_connection_timeout = t[28 + offset].cast(); - c.reco_meta_channel_call_timeout = t[29 + offset].cast(); - c.reco_storage_thread_num = t[30 + offset].cast(); - c.reco_storage_queue_size = t[31 + offset].cast(); - c.reco_put_timeout_ms = t[32 + offset].cast(); - c.reco_get_timeout_ms = t[33 + offset].cast(); - c.reco_model_sdk_config = t[34 + offset].cast(); - c.reco_model_user_data = t[35 + offset].cast(); - c.reco_model_extra_info = t[36 + offset].cast(); - c.reco_instance_id_salt = t[37 + offset].cast(); - c.reco_asyncwrapper_thread_num = t[38 + offset].cast(); - c.reco_asyncwrapper_queue_size = t[39 + offset].cast(); - c.reco_get_broadcast_timeout = t[40 + offset].cast(); - c.reco_put_broadcast_timeout = t[41 + offset].cast(); - c.reco_client_config = t[42 + offset].cast(); - c.ssm_state_dtype = t[43 + offset].cast(); - c.dsv4_fixed_pool_blocks = t[44 + offset].cast(); + c.linear_step = t[8 + offset].cast(); + c.int8_kv_cache = t[9 + offset].cast(); + c.fp8_kv_cache = t[10 + offset].cast(); + c.kv_cache_mem_mb = t[11 + offset].cast(); + c.seq_size_per_block = t[12 + offset].cast(); + c.kernel_seq_size_per_block = t[13 + offset].cast(); + c.test_block_num = t[14 + offset].cast(); + c.use_block_cache = t[15 + offset].cast(); + c.enable_device_cache = t[16 + offset].cast(); + c.enable_memory_cache = t[17 + offset].cast(); + c.enable_memory_cache_sm_copy = t[18 + offset].cast(); + c.enable_remote_cache = t[19 + offset].cast(); + c.write_cache_sync = t[20 + offset].cast(); + c.enable_tiered_memory_cache = t[21 + offset].cast(); + c.device_cache_min_free_blocks = t[22 + offset].cast(); + c.reco_enable_vipserver = t[23 + offset].cast(); + c.reco_vipserver_domain = t[24 + offset].cast(); + c.reco_server_address = t[25 + offset].cast(); + c.reco_instance_group = t[26 + offset].cast(); + c.reco_meta_channel_retry_time = t[27 + offset].cast(); + c.reco_meta_channel_connection_timeout = t[28 + offset].cast(); + c.reco_meta_channel_call_timeout = t[29 + offset].cast(); + c.reco_storage_thread_num = t[30 + offset].cast(); + c.reco_storage_queue_size = t[31 + offset].cast(); + c.reco_put_timeout_ms = t[32 + offset].cast(); + c.reco_get_timeout_ms = t[33 + offset].cast(); + c.reco_model_sdk_config = t[34 + offset].cast(); + c.reco_model_user_data = t[35 + offset].cast(); + c.reco_model_extra_info = t[36 + offset].cast(); + c.reco_instance_id_salt = t[37 + offset].cast(); + c.reco_asyncwrapper_thread_num = t[38 + offset].cast(); + c.reco_asyncwrapper_queue_size = t[39 + offset].cast(); + c.reco_get_broadcast_timeout = t[40 + offset].cast(); + c.reco_put_broadcast_timeout = t[41 + offset].cast(); + c.reco_client_config = t[42 + offset].cast(); + c.ssm_state_dtype = t[43 + offset].cast(); + c.dsv4_fixed_pool_blocks = t[44 + offset].cast(); const size_t expected_with_fixed_pool_memory = (has_disk_fields ? 51u : 46u); if (t.size() >= expected_with_fixed_pool_memory) { c.dsv4_fixed_pool_use_memory = t[45 + offset].cast(); @@ -542,9 +541,8 @@ PYBIND11_MODULE(libth_transformer_config, m) { c.enable_prefix_tree_memory_cache = t[extra_start + 2].cast(); c.enable_legacy_memory_connector_fallback = t[extra_start + 3].cast(); if (extra_count >= 6) { - c.prefix_tree_memory_state_swa_pool_ratio = t[extra_start + 4].cast(); - c.enable_dsv4_state_block_independent_eviction = - t[extra_start + 5].cast(); + c.prefix_tree_memory_state_swa_pool_ratio = t[extra_start + 4].cast(); + c.enable_dsv4_state_block_independent_eviction = t[extra_start + 5].cast(); } } } @@ -1216,9 +1214,9 @@ PYBIND11_MODULE(libth_transformer_config, m) { throw std::runtime_error("Invalid state!"); FIFOSchedulerConfig c; try { - c.max_context_batch_size = t[0].cast(); - c.max_batch_tokens_size = t[1].cast(); - c.cp_force_single_prefill = t.size() >= 3 ? t[2].cast() : true; + c.max_context_batch_size = t[0].cast(); + c.max_batch_tokens_size = t[1].cast(); + c.cp_force_single_prefill = t.size() >= 3 ? t[2].cast() : true; c.max_inited_kv_cache_streams = t.size() >= 4 ? t[3].cast() : 0; } catch (const std::exception& e) { throw std::runtime_error(std::string("FIFOSchedulerConfig unpickle error: ") + e.what()); @@ -1273,6 +1271,7 @@ PYBIND11_MODULE(libth_transformer_config, m) { .def_readwrite("model_name", &RuntimeConfig::model_name) .def_readwrite("worker_grpc_addrs", &RuntimeConfig::worker_grpc_addrs) .def_readwrite("worker_addrs", &RuntimeConfig::worker_addrs) + .def_readwrite("all_worker_grpc_addrs", &RuntimeConfig::all_worker_grpc_addrs) // Fields merged from PyDeviceResourceConfig .def_readwrite("specify_gpu_arch", &RuntimeConfig::specify_gpu_arch) // Add sub-configs as properties that return references @@ -1299,10 +1298,11 @@ PYBIND11_MODULE(libth_transformer_config, m) { self.model_name, self.worker_grpc_addrs, self.worker_addrs, + self.all_worker_grpc_addrs, self.specify_gpu_arch); }, [](py::tuple t) { - if (t.size() != 13) + if (t.size() < 14) throw std::runtime_error("Invalid state!"); RuntimeConfig c; try { @@ -1318,7 +1318,8 @@ PYBIND11_MODULE(libth_transformer_config, m) { c.model_name = t[9].cast(); c.worker_grpc_addrs = t[10].cast>(); c.worker_addrs = t[11].cast>(); - c.specify_gpu_arch = t[12].cast(); + c.all_worker_grpc_addrs = t[12].cast>(); + c.specify_gpu_arch = t[13].cast(); } catch (const std::exception& e) { throw std::runtime_error(std::string("RuntimeConfig unpickle error: ") + e.what()); } @@ -1715,6 +1716,12 @@ PYBIND11_MODULE(libth_transformer_config, m) { .def_readwrite("max_rpc_timeout_ms", &PDSepConfig::max_rpc_timeout_ms) .def_readwrite("worker_port_offset", &PDSepConfig::worker_port_offset) .def_readwrite("decode_entrance", &PDSepConfig::decode_entrance) + .def_readwrite("batch_dispatch_timeout_ms", &PDSepConfig::batch_dispatch_timeout_ms) + .def_readwrite("batch_prepare_timeout_ms", &PDSepConfig::batch_prepare_timeout_ms) + .def_readwrite("batch_load_timeout_ms", &PDSepConfig::batch_load_timeout_ms) + .def_readwrite("prefill_enqueue_pool_size", &PDSepConfig::prefill_enqueue_pool_size) + .def_readwrite("prefill_worker_lambda_pool_size", &PDSepConfig::prefill_worker_lambda_pool_size) + .def_readwrite("prefill_slot_pool_size", &PDSepConfig::prefill_slot_pool_size) .def("to_string", &PDSepConfig::to_string) .def(py::pickle( [](const PDSepConfig& self) { @@ -1737,10 +1744,16 @@ PYBIND11_MODULE(libth_transformer_config, m) { self.load_cache_timeout_ms, self.max_rpc_timeout_ms, self.worker_port_offset, - self.decode_entrance); + self.decode_entrance, + self.batch_dispatch_timeout_ms, + self.batch_prepare_timeout_ms, + self.batch_load_timeout_ms, + self.prefill_enqueue_pool_size, + self.prefill_worker_lambda_pool_size, + self.prefill_slot_pool_size); }, [](py::tuple t) { - if (t.size() != 20) + if (t.size() < 26) throw std::runtime_error("Invalid state!"); PDSepConfig c; try { @@ -1764,6 +1777,12 @@ PYBIND11_MODULE(libth_transformer_config, m) { c.max_rpc_timeout_ms = t[17].cast(); c.worker_port_offset = t[18].cast(); c.decode_entrance = t[19].cast(); + c.batch_dispatch_timeout_ms = t[20].cast(); + c.batch_prepare_timeout_ms = t[21].cast(); + c.batch_load_timeout_ms = t[22].cast(); + c.prefill_enqueue_pool_size = t[23].cast(); + c.prefill_worker_lambda_pool_size = t[24].cast(); + c.prefill_slot_pool_size = t[25].cast(); } catch (const std::exception& e) { throw std::runtime_error(std::string("PDSepConfig unpickle error: ") + e.what()); } diff --git a/rtp_llm/cpp/pybind/multi_gpu_gpt/RtpLLMOp.cc b/rtp_llm/cpp/pybind/multi_gpu_gpt/RtpLLMOp.cc index a9cb7a99c2..03ae4a35df 100644 --- a/rtp_llm/cpp/pybind/multi_gpu_gpt/RtpLLMOp.cc +++ b/rtp_llm/cpp/pybind/multi_gpu_gpt/RtpLLMOp.cc @@ -6,6 +6,7 @@ #include "c10/util/intrusive_ptr.h" #include #include +#include "autil/EnvUtil.h" #include "rtp_llm/cpp/metrics/RtpLLMMetrics.h" #include "rtp_llm/cpp/utils/AssertUtils.h" #include "rtp_llm/cpp/config/ConfigModules.h" @@ -305,11 +306,18 @@ void RtpLLMOp::initRPCServer(const EngineInitParams maga_ std::unique_ptr propose_params, py::object token_processor) { std::string server_address; + int64_t http_port = 0; + int64_t model_rpc_port = 0; + bool start_grpc_before_init = false; { pybind11::gil_scoped_acquire acquire; - int64_t http_port = maga_init_params.server_config.attr("http_port").cast(); - int64_t model_rpc_port = maga_init_params.server_config.attr("rpc_server_port").cast(); - auto role_type = maga_init_params.pd_sep_config.role_type; + http_port = maga_init_params.server_config.attr("http_port").cast(); + model_rpc_port = maga_init_params.server_config.attr("rpc_server_port").cast(); + auto role_type = maga_init_params.pd_sep_config.role_type; + start_grpc_before_init = model_rpc_port >= 0 + && autil::EnvUtil::getEnv("RTP_LLM_CROSS_NODE_CPU_TP_BROADCAST", false) + && maga_init_params.parallelism_config.tp_size + > maga_init_params.parallelism_config.local_world_size; // NOTE: ip/ip段可自定义为所需范围。 server_address = "0.0.0.0:" + std::to_string(model_rpc_port); if (role_type == RoleType::PREFILL || role_type == RoleType::DECODE) { @@ -317,22 +325,26 @@ void RtpLLMOp::initRPCServer(const EngineInitParams maga_ } else { model_rpc_service_.reset(new LocalRpcServiceImpl()); } - grpc::Status grpc_status = - model_rpc_service_->init(maga_init_params, std::move(mm_process_engine), std::move(propose_params)); - if (!grpc_status.ok()) { - RTP_LLM_FAIL("init rpc server failed, error msg: %s", grpc_status.error_message().c_str()); - } + if (start_grpc_before_init) { + model_rpc_service_->prepareLocalServer(); + } else { + grpc::Status grpc_status = + model_rpc_service_->init(maga_init_params, std::move(mm_process_engine), std::move(propose_params)); + if (!grpc_status.ok()) { + RTP_LLM_FAIL("init rpc server failed, error msg: %s", grpc_status.error_message().c_str()); + } - // NOTE: ip/ip段可自定义为所需范围。 - std::string http_server_address("tcp:0.0.0.0:" + std::to_string(http_port)); - http_server_.reset(new HttpApiServer(model_rpc_service_->getEngine(), - model_rpc_service_->getMultimodalProcessor(), - http_server_address, - maga_init_params, - token_processor)); - if (model_rpc_port < 0) { - is_server_ready_ = true; - return; + // NOTE: ip/ip段可自定义为所需范围。 + std::string http_server_address("tcp:0.0.0.0:" + std::to_string(http_port)); + http_server_.reset(new HttpApiServer(model_rpc_service_->getEngine(), + model_rpc_service_->getMultimodalProcessor(), + http_server_address, + maga_init_params, + token_processor)); + if (model_rpc_port < 0) { + is_server_ready_ = true; + return; + } } } grpc::ServerBuilder builder; @@ -353,6 +365,21 @@ void RtpLLMOp::initRPCServer(const EngineInitParams maga_ RTP_LLM_CHECK_WITH_INFO(grpc_server_ != nullptr, "grpc server start failed at address " + server_address); RTP_LLM_LOG_INFO("Server listening on %s", server_address.c_str()); + if (start_grpc_before_init) { + pybind11::gil_scoped_acquire acquire; + grpc::Status grpc_status = + model_rpc_service_->init(maga_init_params, std::move(mm_process_engine), std::move(propose_params)); + if (!grpc_status.ok()) { + RTP_LLM_FAIL("init rpc server failed, error msg: %s", grpc_status.error_message().c_str()); + } + + std::string http_server_address("tcp:0.0.0.0:" + std::to_string(http_port)); + http_server_.reset(new HttpApiServer(model_rpc_service_->getEngine(), + model_rpc_service_->getMultimodalProcessor(), + http_server_address, + maga_init_params, + token_processor)); + } is_server_ready_ = true; grpc_server_->Wait(); RTP_LLM_LOG_INFO("Server exit on %s", server_address.c_str()); diff --git a/rtp_llm/flexlb/CLAUDE.md b/rtp_llm/flexlb/CLAUDE.md index 5c7d0d5a2c..6782d983a6 100644 --- a/rtp_llm/flexlb/CLAUDE.md +++ b/rtp_llm/flexlb/CLAUDE.md @@ -169,8 +169,8 @@ The `DefaultRouter` orchestrates routing across these stages. If a later stage f Three strategies are available (registered with `LoadBalanceStrategyFactory`): - **RANDOM**: Random worker selection -- **SHORTEST_TTFT**: Select worker with shortest Time-To-First-Token -- **WEIGHTED_CACHE**: Cache-aware selection prioritizing workers with matching KV cache blocks +- **COST_BASED_PREFILL**: Select worker with lowest cost for prefill requests +- **COST_BASED_DECODE**: Select worker with lowest cost for decode requests Each `RoleType` can use a different strategy. See `LoadBalanceStrategyEnum` in flexlb-common. diff --git a/rtp_llm/flexlb/flexlb-api/src/main/java/org/flexlb/httpserver/FlexlbGrpcForwarder.java b/rtp_llm/flexlb/flexlb-api/src/main/java/org/flexlb/httpserver/FlexlbGrpcForwarder.java new file mode 100644 index 0000000000..10748be574 --- /dev/null +++ b/rtp_llm/flexlb/flexlb-api/src/main/java/org/flexlb/httpserver/FlexlbGrpcForwarder.java @@ -0,0 +1,100 @@ +package org.flexlb.httpserver; + +import io.grpc.ManagedChannel; +import io.grpc.StatusRuntimeException; +import io.grpc.netty.NettyChannelBuilder; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.socket.nio.NioSocketChannel; +import org.flexlb.consistency.LBStatusConsistencyService; +import org.flexlb.engine.grpc.EngineRpcService; +import org.flexlb.engine.grpc.FlexlbServiceGrpc; +import org.flexlb.config.ConfigService; +import org.flexlb.service.monitor.EngineHealthReporter; +import org.flexlb.util.Logger; +import org.springframework.beans.factory.annotation.Qualifier; +import org.springframework.stereotype.Component; + +import javax.annotation.PreDestroy; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeUnit; + +@Component +public class FlexlbGrpcForwarder { + + private final LBStatusConsistencyService lbStatusConsistencyService; + private final ConfigService configService; + private final EngineHealthReporter engineHealthReporter; + private final EventLoopGroup eventLoopGroup; + private final ConcurrentHashMap channels = new ConcurrentHashMap<>(); + + public FlexlbGrpcForwarder(LBStatusConsistencyService lbStatusConsistencyService, + ConfigService configService, + EngineHealthReporter engineHealthReporter, + @Qualifier("managedChannelEventLoopGroup") EventLoopGroup eventLoopGroup) { + this.lbStatusConsistencyService = lbStatusConsistencyService; + this.configService = configService; + this.engineHealthReporter = engineHealthReporter; + this.eventLoopGroup = eventLoopGroup; + } + + public EngineRpcService.FlexlbScheduleResponsePB forwardToMaster( + EngineRpcService.FlexlbScheduleRequestPB request) { + String masterHostIpPort = lbStatusConsistencyService.getMasterHostIpPort(); + if (masterHostIpPort == null) { + Logger.error("Master unreachable for gRPC forward, routing locally"); + engineHealthReporter.reportForwardToMasterResult("LOCAL", "MASTER_NULL"); + return null; + } + + int grpcPort = resolveGrpcPort(masterHostIpPort); + String ip = masterHostIpPort.split(":")[0]; + String channelKey = ip + ":" + grpcPort; + + try { + ManagedChannel channel = channels.computeIfAbsent(channelKey, k -> createChannel(ip, grpcPort)); + FlexlbServiceGrpc.FlexlbServiceBlockingStub stub = FlexlbServiceGrpc.newBlockingStub(channel) + .withDeadlineAfter(configService.loadBalanceConfig().getPrefillLbTimeoutMs(), TimeUnit.MILLISECONDS); + EngineRpcService.FlexlbScheduleResponsePB response = stub.schedule(request); + engineHealthReporter.reportForwardToMasterResult(ip, String.valueOf(response.getCode())); + return response; + } catch (StatusRuntimeException e) { + Logger.error("[Fallback] gRPC forward to master failed: {}, routing locally", e.getMessage()); + engineHealthReporter.reportForwardToMasterResult("LOCAL", "GRPC_FAILED"); + channels.remove(channelKey); + return null; + } catch (Exception e) { + Logger.error("[Fallback] gRPC forward to master error, routing locally", e); + engineHealthReporter.reportForwardToMasterResult("LOCAL", "CONNECT_FAILED"); + channels.remove(channelKey); + return null; + } + } + + private int resolveGrpcPort(String masterHostIpPort) { + // Always derive gRPC port from HTTP port using the same offset as FlexlbGrpcServer. + String[] parts = masterHostIpPort.split(":"); + if (parts.length >= 2) { + return Integer.parseInt(parts[1]) + FlexlbGrpcServer.FLEXLB_GRPC_PORT_OFFSET; + } + return 7001 + FlexlbGrpcServer.FLEXLB_GRPC_PORT_OFFSET; + } + + private ManagedChannel createChannel(String ip, int port) { + return NettyChannelBuilder.forAddress(ip, port) + .channelType(NioSocketChannel.class) + .eventLoopGroup(eventLoopGroup) + .usePlaintext() + .keepAliveTime(30, TimeUnit.SECONDS) + .keepAliveTimeout(10, TimeUnit.SECONDS) + .maxInboundMessageSize(16 * 1024 * 1024) + .build(); + } + + @PreDestroy + public void shutdown() { + for (ManagedChannel channel : channels.values()) { + channel.shutdownNow(); + } + channels.clear(); + } +} diff --git a/rtp_llm/flexlb/flexlb-api/src/main/java/org/flexlb/httpserver/FlexlbGrpcServer.java b/rtp_llm/flexlb/flexlb-api/src/main/java/org/flexlb/httpserver/FlexlbGrpcServer.java new file mode 100644 index 0000000000..05a6d774fd --- /dev/null +++ b/rtp_llm/flexlb/flexlb-api/src/main/java/org/flexlb/httpserver/FlexlbGrpcServer.java @@ -0,0 +1,88 @@ +package org.flexlb.httpserver; + +import io.grpc.Server; +import io.grpc.netty.NettyServerBuilder; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.nio.NioServerSocketChannel; +import org.flexlb.config.ConfigService; +import org.flexlb.util.Logger; +import org.springframework.beans.factory.annotation.Qualifier; +import org.springframework.core.env.Environment; +import org.springframework.stereotype.Component; + +import javax.annotation.PostConstruct; +import javax.annotation.PreDestroy; +import java.io.IOException; +import java.util.concurrent.TimeUnit; + +@Component +public class FlexlbGrpcServer { + + /** + * Offset from HTTP port to gRPC port for FlexLB's own servers. + * This is separate from CommonConstants.GRPC_PORT_OFFSET which applies + * to backend inference engine ports (HTTP+1→gRPC). + */ + static final int FLEXLB_GRPC_PORT_OFFSET = 2; + private static final int DEFAULT_HTTP_PORT = 7001; + + private final FlexlbServiceImpl flexlbServiceImpl; + private final ConfigService configService; + private final EventLoopGroup workerGroup; + private final Environment environment; + private Server server; + private NioEventLoopGroup bossGroup; + + public FlexlbGrpcServer(FlexlbServiceImpl flexlbServiceImpl, + ConfigService configService, + @Qualifier("managedChannelEventLoopGroup") EventLoopGroup workerGroup, + Environment environment) { + this.flexlbServiceImpl = flexlbServiceImpl; + this.configService = configService; + this.workerGroup = workerGroup; + this.environment = environment; + } + + @PostConstruct + public void start() throws IOException { + // Always derive gRPC port from HTTP port. + // server.port may come from --server.port CLI arg (Spring Environment only) + // or from -Dserver.port JVM property; check both. + String portStr = environment.getProperty("server.port"); + if (portStr == null) { + portStr = System.getProperty("server.port", String.valueOf(DEFAULT_HTTP_PORT)); + } + int httpPort = Integer.parseInt(portStr); + int port = httpPort + FLEXLB_GRPC_PORT_OFFSET; + + this.bossGroup = new NioEventLoopGroup(1); + + server = NettyServerBuilder.forPort(port) + .channelType(NioServerSocketChannel.class) + .bossEventLoopGroup(bossGroup) + .workerEventLoopGroup(workerGroup) + .addService(flexlbServiceImpl) + .maxInboundMessageSize(16 * 1024 * 1024) + .build() + .start(); + + Logger.info("FlexLB gRPC server started on port {}", port); + } + + @PreDestroy + public void shutdown() { + if (server != null) { + server.shutdown(); + try { + server.awaitTermination(5, TimeUnit.SECONDS); + } catch (InterruptedException e) { + server.shutdownNow(); + Thread.currentThread().interrupt(); + } + } + if (bossGroup != null) { + bossGroup.shutdownGracefully(); + } + } +} diff --git a/rtp_llm/flexlb/flexlb-api/src/main/java/org/flexlb/httpserver/FlexlbServiceImpl.java b/rtp_llm/flexlb/flexlb-api/src/main/java/org/flexlb/httpserver/FlexlbServiceImpl.java new file mode 100644 index 0000000000..e2e80662a2 --- /dev/null +++ b/rtp_llm/flexlb/flexlb-api/src/main/java/org/flexlb/httpserver/FlexlbServiceImpl.java @@ -0,0 +1,165 @@ +package org.flexlb.httpserver; + +import io.grpc.stub.StreamObserver; +import org.flexlb.consistency.LBStatusConsistencyService; +import org.flexlb.dao.BalanceContext; +import org.flexlb.dao.loadbalance.Request; +import org.flexlb.dao.loadbalance.Response; +import org.flexlb.dao.loadbalance.ServerStatus; +import org.flexlb.engine.grpc.EngineRpcService; +import org.flexlb.engine.grpc.FlexlbServiceGrpc; +import org.flexlb.engine.grpc.RoleTypeProtoConverter; +import org.flexlb.enums.ScheduleModeEnum; +import org.flexlb.service.RouteService; +import org.flexlb.service.grace.ActiveRequestCounter; +import org.flexlb.service.monitor.EngineHealthReporter; +import org.flexlb.config.ConfigService; +import org.flexlb.config.FlexlbConfig; +import org.flexlb.util.Logger; +import org.springframework.stereotype.Component; + +@Component +public class FlexlbServiceImpl extends FlexlbServiceGrpc.FlexlbServiceImplBase { + + private final RouteService routeService; + private final LBStatusConsistencyService lbStatusConsistencyService; + private final EngineHealthReporter engineHealthReporter; + private final ActiveRequestCounter activeRequestCounter; + private final FlexlbGrpcForwarder grpcForwarder; + private final ConfigService configService; + + public FlexlbServiceImpl(RouteService routeService, + LBStatusConsistencyService lbStatusConsistencyService, + EngineHealthReporter engineHealthReporter, + ActiveRequestCounter activeRequestCounter, + FlexlbGrpcForwarder grpcForwarder, + ConfigService configService) { + this.routeService = routeService; + this.lbStatusConsistencyService = lbStatusConsistencyService; + this.engineHealthReporter = engineHealthReporter; + this.activeRequestCounter = activeRequestCounter; + this.grpcForwarder = grpcForwarder; + this.configService = configService; + } + + @Override + public void schedule(EngineRpcService.FlexlbScheduleRequestPB request, + StreamObserver responseObserver) { + ActiveRequestCounter.RequestToken token = activeRequestCounter.acquire(); + try { + BalanceContext ctx = buildContext(request); + engineHealthReporter.reportArriveDelayTime(ctx); + + EngineRpcService.FlexlbScheduleResponsePB response; + if (lbStatusConsistencyService.isNeedConsistency() && !lbStatusConsistencyService.isMaster()) { + response = grpcForwarder.forwardToMaster(request); + if (response == null) { + response = routeLocally(ctx); + } + } else { + response = routeLocally(ctx); + } + + responseObserver.onNext(response); + responseObserver.onCompleted(); + ctx.setSuccess(response.getSuccess()); + if (!response.getSuccess()) { + ctx.setErrorMessage(response.getErrorMessage()); + } + engineHealthReporter.reportBalancingService(ctx); + } catch (Exception e) { + Logger.error("FlexlbService.schedule error, request_id={}", request.getRequestId(), e); + EngineRpcService.FlexlbScheduleResponsePB errorResp = EngineRpcService.FlexlbScheduleResponsePB.newBuilder() + .setSuccess(false) + .setCode(500) + .setErrorMessage(e.getMessage() != null ? e.getMessage() : "internal error") + .build(); + responseObserver.onNext(errorResp); + responseObserver.onCompleted(); + } finally { + token.close(); + } + } + + @Override + public void cancel(EngineRpcService.CancelRequestPB request, + StreamObserver responseObserver) { + try { + routeService.cancelByRequestId(request.getRequestId()); + responseObserver.onNext(EngineRpcService.EmptyPB.getDefaultInstance()); + responseObserver.onCompleted(); + } catch (Exception e) { + Logger.error("FlexlbService.cancel error, request_id={}", request.getRequestId(), e); + responseObserver.onError(io.grpc.Status.INTERNAL + .withDescription(e.getMessage()) + .asRuntimeException()); + } + } + + private EngineRpcService.FlexlbScheduleResponsePB routeLocally(BalanceContext ctx) { + Response response = routeService.route(ctx).block(); + return toProtoResponse(response); + } + + private BalanceContext buildContext(EngineRpcService.FlexlbScheduleRequestPB pb) { + BalanceContext ctx = new BalanceContext(); + + Request request = new Request(); + request.setRequestId(pb.getRequestId()); + request.setBlockCacheKeys(pb.getBlockCacheKeysList()); + request.setSeqLen(pb.getSeqLen()); + request.setGenerateTimeout(pb.getGenerateTimeout()); + request.setRequestTimeMs(pb.getRequestTimeMs()); + request.setMaxNewTokens(pb.getMaxNewTokens()); + request.setNumBeams(pb.getNumBeams()); + request.setForceDisableSpRun(pb.getForceDisableSpRun()); + request.setModel(pb.getModel()); + request.setApiKey(pb.getApiKey()); + ctx.setRequest(request); + + if (pb.hasGenerateInput()) { + ctx.setGenerateInputPbBytes(pb.getGenerateInput().toByteArray()); + } + + ctx.setScheduleMode(resolveScheduleMode(pb.getScheduleMode(), configService.loadBalanceConfig())); + return ctx; + } + + private static ScheduleModeEnum resolveScheduleMode(EngineRpcService.FlexlbScheduleModePB mode, + FlexlbConfig config) { + return switch (mode) { + case FLEXLB_SCHEDULE_BATCH -> ScheduleModeEnum.BATCH; + case FLEXLB_SCHEDULE_DIRECT -> ScheduleModeEnum.DIRECT; + default -> config.getDefaultScheduleModeEnum(); + }; + } + + private EngineRpcService.FlexlbScheduleResponsePB toProtoResponse(Response response) { + EngineRpcService.FlexlbScheduleResponsePB.Builder builder = EngineRpcService.FlexlbScheduleResponsePB.newBuilder(); + if (response == null) { + return builder.setSuccess(false).setCode(500).setErrorMessage("null response").build(); + } + builder.setSuccess(response.isSuccess()); + builder.setCode(response.getCode()); + if (response.getErrorMessage() != null) { + builder.setErrorMessage(response.getErrorMessage()); + } + if (response.getRealMasterHost() != null) { + builder.setRealMasterHost(response.getRealMasterHost()); + } + builder.setQueueLength(response.getQueueLength() != null ? response.getQueueLength() : 0); + builder.setEnqueuedByMaster(response.isEnqueuedByMaster()); + + if (response.getServerStatus() != null) { + for (ServerStatus ss : response.getServerStatus()) { + builder.addServerStatus(EngineRpcService.FlexlbServerStatusPB.newBuilder() + .setRole(RoleTypeProtoConverter.toProto(ss.getRole())) + .setServerIp(ss.getServerIp() != null ? ss.getServerIp() : "") + .setHttpPort(ss.getHttpPort()) + .setGrpcPort(ss.getGrpcPort()) + .build()); + } + } + return builder.build(); + } +} diff --git a/rtp_llm/flexlb/flexlb-api/src/main/java/org/flexlb/httpserver/HttpLoadBalanceServer.java b/rtp_llm/flexlb/flexlb-api/src/main/java/org/flexlb/httpserver/HttpLoadBalanceServer.java index 59312e3247..52237da130 100644 --- a/rtp_llm/flexlb/flexlb-api/src/main/java/org/flexlb/httpserver/HttpLoadBalanceServer.java +++ b/rtp_llm/flexlb/flexlb-api/src/main/java/org/flexlb/httpserver/HttpLoadBalanceServer.java @@ -1,29 +1,22 @@ package org.flexlb.httpserver; -import lombok.extern.slf4j.Slf4j; -import org.apache.commons.lang3.StringUtils; import org.flexlb.balance.scheduler.QueueManager; import org.flexlb.config.ConfigService; import org.flexlb.config.TrafficPolicyConfig; import org.flexlb.consistency.LBStatusConsistencyService; -import org.flexlb.dao.BalanceContext; import org.flexlb.dao.loadbalance.LogLevelUpdateRequest; import org.flexlb.dao.loadbalance.QueueSnapshotResponse; import org.flexlb.dao.loadbalance.Request; import org.flexlb.dao.loadbalance.Response; -import org.flexlb.dao.loadbalance.StrategyErrorType; -import org.flexlb.dao.pv.PvLogData; +import org.flexlb.dao.master.WorkerStatus; +import org.flexlb.dao.route.RoleType; import org.flexlb.domain.consistency.MasterChangeNotifyReq; import org.flexlb.domain.consistency.MasterChangeNotifyResp; import org.flexlb.domain.consistency.SyncLBStatusReq; import org.flexlb.domain.consistency.SyncLBStatusResp; -import org.flexlb.service.RouteService; -import org.flexlb.service.grace.ActiveRequestCounter; -import org.flexlb.service.monitor.EngineHealthReporter; -import org.flexlb.transport.GeneralHttpNettyService; -import org.flexlb.util.JsonUtils; +import org.flexlb.sync.status.EngineWorkerStatus; +import org.flexlb.sync.status.ModelWorkerStatus; import org.flexlb.util.Logger; -import org.slf4j.LoggerFactory; import org.springframework.context.annotation.Bean; import org.springframework.http.MediaType; import org.springframework.stereotype.Component; @@ -32,50 +25,30 @@ import org.springframework.web.reactive.function.server.ServerResponse; import reactor.core.publisher.Mono; -import java.net.URI; -import java.util.concurrent.TimeoutException; +import java.util.LinkedHashMap; +import java.util.Map; import java.util.function.Function; import static org.springframework.web.reactive.function.server.RequestPredicates.accept; import static org.springframework.web.reactive.function.server.RouterFunctions.route; -@Slf4j @Component public class HttpLoadBalanceServer { - private static final org.slf4j.Logger pvLogger = LoggerFactory.getLogger("pvLogger"); - private static final String AUTHORIZATION_HEADER = "Authorization"; - private static final String X_API_KEY_HEADER = "X-Api-Key"; - private static final String API_KEY_HEADER = "Api-Key"; - private static final String BEARER_PREFIX = "Bearer "; - private final GeneralHttpNettyService generalHttpNettyService; - private final RouteService routeService; private final LBStatusConsistencyService lbStatusConsistencyService; - private final EngineHealthReporter engineHealthReporter; private final QueueManager queueManager; - private final ActiveRequestCounter activeRequestCounter; private final ConfigService configService; - public HttpLoadBalanceServer(GeneralHttpNettyService generalHttpNettyService, - RouteService routeService, - LBStatusConsistencyService lbStatusConsistencyService, - EngineHealthReporter engineHealthReporter, + public HttpLoadBalanceServer(LBStatusConsistencyService lbStatusConsistencyService, QueueManager queueManager, - ActiveRequestCounter activeRequestCounter, ConfigService configService) { - this.generalHttpNettyService = generalHttpNettyService; - this.routeService = routeService; this.lbStatusConsistencyService = lbStatusConsistencyService; - this.engineHealthReporter = engineHealthReporter; this.queueManager = queueManager; - this.activeRequestCounter = activeRequestCounter; this.configService = configService; } @Bean public RouterFunction loadBalancePrefill() { return route() - .POST("/rtp_llm/schedule", accept(MediaType.APPLICATION_JSON), - this::scheduleRequest) .POST("/rtp_llm/master/info", accept(MediaType.APPLICATION_JSON), this::responseMasterInfo) .POST("/rtp_llm/schedule_snapshot", accept(MediaType.APPLICATION_JSON), @@ -91,81 +64,13 @@ public RouterFunction loadBalancePrefill() { .build(); } - /** - * Handles load balancing request scheduling. - * - * @param request the HTTP request containing the model inference request - * @return a reactive response containing the load balancing result - */ - public Mono scheduleRequest(ServerRequest request) { - BalanceContext ctx = new BalanceContext(); - return request.bodyToMono(Request.class) - .flatMap(req -> { - if (req.getRequestId() == 0) { - throw new IllegalArgumentException("requestId is 0"); - } - populateApiKeyFromHeaders(req, request); - ctx.setRequest(req); - return Mono.using( - activeRequestCounter::acquire, - ignored -> processScheduledRequest(ctx, req), - ActiveRequestCounter.RequestToken::close); - }) - .onErrorResume(e -> handleRequestError(ctx, e)) - .doFinally(signal -> finalizeRequestContext(ctx)); - } - - private void populateApiKeyFromHeaders(Request req, ServerRequest serverRequest) { - if (StringUtils.isNotBlank(req.getApiKey())) { - return; - } - - String apiKey = firstNonBlank( - serverRequest.headers().firstHeader(X_API_KEY_HEADER), - serverRequest.headers().firstHeader(API_KEY_HEADER), - extractBearerToken(serverRequest.headers().firstHeader(AUTHORIZATION_HEADER))); - req.setApiKey(apiKey); - } - - private String extractBearerToken(String authorization) { - if (StringUtils.isBlank(authorization) || !authorization.startsWith(BEARER_PREFIX)) { - return null; - } - return authorization.substring(BEARER_PREFIX.length()).trim(); - } - - private String firstNonBlank(String... values) { - for (String value : values) { - if (StringUtils.isNotBlank(value)) { - return value; - } - } - return null; - } - - private Mono processScheduledRequest(BalanceContext ctx, Request req) { - engineHealthReporter.reportArriveDelayTime(ctx); - - if (lbStatusConsistencyService.isNeedConsistency() && !lbStatusConsistencyService.isMaster()) { - return forwardRequestToMaster(ctx, req); - } - - return routeService.route(ctx) - .flatMap(response -> handleRoutingResult(ctx, response)) - .doOnCancel(() -> { - ctx.setSuccess(false); - ctx.setErrorMessage("REQUEST_CANCELLED"); - routeService.cancel(ctx); - }); - } - private Mono debugMode(ServerRequest serverRequest) { return serverRequest.bodyToMono(LogLevelUpdateRequest.class) .flatMap(logLevelUpdateRequest -> { - Logger.setGlobalLogLevel(logLevelUpdateRequest.getLogLevel()); + Logger.setLevel(logLevelUpdateRequest.getLogLevel()); return ServerResponse.ok() .contentType(MediaType.APPLICATION_JSON) - .body(Mono.just("Success! logLevel=" + Logger.getGlobalLogLevel()), String.class); + .body(Mono.just("Success! logLevel=" + Logger.getLevel()), String.class); }).onErrorResume(e -> { Logger.error("update logLevel error", e); return ServerResponse.status(500) @@ -189,6 +94,26 @@ private Mono updateTrafficPolicy(ServerRequest serverRequest) { }); } + private Map buildWorkerSummary() { + ModelWorkerStatus modelStatus = EngineWorkerStatus.MODEL_ROLE_WORKER_STATUS; + Map summary = new LinkedHashMap<>(); + for (RoleType role : RoleType.values()) { + Map statusMap = modelStatus.getRoleStatusMap(role); + if (statusMap == null || statusMap.isEmpty()) { + continue; + } + Response.WorkerRoleSummary rs = new Response.WorkerRoleSummary(); + rs.setDiscovered(statusMap.size()); + for (WorkerStatus ws : statusMap.values()) { + if (ws.isAlive()) { + rs.setAlive(rs.getAlive() + 1); + } + } + summary.put(role.getCode(), rs); + } + return summary.isEmpty() ? null : summary; + } + private Mono responseMasterInfo(ServerRequest request) { return request.bodyToMono(Request.class) .flatMap((Function>) req -> { @@ -197,6 +122,7 @@ private Mono responseMasterInfo(ServerRequest request) { result.setQueueLength(queueManager.getQueue().size()); result.setCode(200); result.setSuccess(true); + result.setWorkerSummary(buildWorkerSummary()); return ServerResponse.ok() .contentType(MediaType.APPLICATION_JSON) .body(Mono.just(result), Response.class); @@ -255,142 +181,4 @@ public Mono queueSnapshot(ServerRequest request) { .body(Mono.just(e.getMessage()), String.class); } } - - /** - * Forwards the request to the active master node. - * - * @param ctx the balance context - * @param request the master request to forward - * @return response from the master node - */ - private Mono forwardRequestToMaster(BalanceContext ctx, Request request) { - String master = lbStatusConsistencyService.getMasterHostIpPort(); - if (master == null) { - Logger.error("Master unreachable, routing locally"); - engineHealthReporter.reportForwardToMasterResult("LOCAL", "MASTER_NULL"); - return fallbackToLocalRouting(ctx); - } - Logger.info("Forwarding request to master: {}, request: {}", master, request); - URI uri = URI.create("http://" + master); - return generalHttpNettyService.request(request, uri, "/rtp_llm/schedule", Response.class) - .flatMap(resp -> { - engineHealthReporter.reportForwardToMasterResult(uri.getHost(), String.valueOf(resp.getCode())); - return ServerResponse.ok() - .contentType(MediaType.APPLICATION_JSON) - .bodyValue(resp); - } - ) - .onErrorResume(e -> { - String errorCode = e instanceof TimeoutException ? "TIMEOUT" : "CONNECT_FAILED"; - Logger.error("[Fallback] Master unreachable, routing locally: {}, errorCode: {}", e.getMessage(), errorCode); - engineHealthReporter.reportForwardToMasterResult("LOCAL", errorCode); - return fallbackToLocalRouting(ctx); - }); - } - - private Mono fallbackToLocalRouting(BalanceContext ctx) { - return routeService.route(ctx) - .flatMap(response -> handleRoutingResult(ctx, response)) - .onErrorResume(e -> { - Logger.error("[Fallback] Local routing failed", e); - Response errorResponse = Response.error(StrategyErrorType.NO_AVAILABLE_WORKER); - return ServerResponse.status(500) - .contentType(MediaType.APPLICATION_JSON) - .bodyValue(errorResponse); - }); - } - - /** - * Processes the routing response and builds the appropriate HTTP response. - * - * @param ctx the balance context - * @param response the routing response - * @return HTTP response based on routing success or failure - */ - private Mono handleRoutingResult(BalanceContext ctx, Response response) { - - response.setRealMasterHost(lbStatusConsistencyService.getMasterHostIpPort()); - - if (response.isSuccess()) { - return buildSuccessResponse(response); - } else { - Logger.error("Routing failed with error code: {}", response.getErrorMessage()); - ctx.setSuccess(false); - ctx.setErrorMessage("error_code:" + response.getErrorMessage()); - return buildErrorResponse(response); - } - } - - /** - * Builds a successful HTTP response. - * - * @param result the master response containing the result - * @return successful HTTP response - */ - private Mono buildSuccessResponse(Response result) { - return ServerResponse.ok() - .contentType(MediaType.APPLICATION_JSON) - .body(Mono.just(result), Response.class); - } - - /** - * Builds an error HTTP response. - * - * @param result the master response containing the error - * @return error HTTP response - */ - private Mono buildErrorResponse(Response result) { - return ServerResponse.status(500) - .contentType(MediaType.APPLICATION_JSON) - .body(Mono.just(result), Response.class); - } - - /** - * Handles global request errors. - * - * @param ctx the balance context - * @param throwable the error that occurred - * @return error response - */ - private Mono handleRequestError(BalanceContext ctx, Throwable throwable) { - Logger.error("Request processing error", throwable); - ctx.setSuccess(false); - ctx.setErrorMessage(throwable.getMessage()); - - return ServerResponse.status(500) - .contentType(MediaType.APPLICATION_JSON) - .body(Mono.just(throwable.getMessage()), String.class); - } - - /** - * Finalizes the request context by reporting metrics. - * - * @param ctx the balance context to finalize - */ - private void finalizeRequestContext(BalanceContext ctx) { - engineHealthReporter.reportBalancingService(ctx); - logPvRecord(ctx); - } - - /** - * Logs the PV record with appropriate log level based on success status. - * - * @param ctx the balance context containing PV log data - */ - private void logPvRecord(BalanceContext ctx) { - - PvLogData pvLogData = new PvLogData(ctx); - - try { - String jsonLog = JsonUtils.toStringOrEmpty(pvLogData); - if (pvLogData.isSuccess()) { - pvLogger.info(jsonLog); - } else { - pvLogger.error(jsonLog); - } - } catch (Exception ex) { - Logger.error("Failed to serialize PV log data", ex); - } - } - } diff --git a/rtp_llm/flexlb/flexlb-api/src/main/resources/logback-spring.xml b/rtp_llm/flexlb/flexlb-api/src/main/resources/logback-spring.xml index 294c40eaee..7e9483a874 100644 --- a/rtp_llm/flexlb/flexlb-api/src/main/resources/logback-spring.xml +++ b/rtp_llm/flexlb/flexlb-api/src/main/resources/logback-spring.xml @@ -4,7 +4,7 @@ - + @@ -141,6 +141,7 @@ + @@ -165,4 +166,4 @@ - \ No newline at end of file + diff --git a/rtp_llm/flexlb/flexlb-api/src/test/java/org/flexlb/ReuseSpringContextIntegrationTest.java b/rtp_llm/flexlb/flexlb-api/src/test/java/org/flexlb/ReuseSpringContextIntegrationTest.java index 2c52d345b8..fb9590a375 100644 --- a/rtp_llm/flexlb/flexlb-api/src/test/java/org/flexlb/ReuseSpringContextIntegrationTest.java +++ b/rtp_llm/flexlb/flexlb-api/src/test/java/org/flexlb/ReuseSpringContextIntegrationTest.java @@ -1,28 +1,21 @@ package org.flexlb; import lombok.extern.slf4j.Slf4j; -import org.flexlb.balance.scheduler.QueueManager; -import org.flexlb.cases.QueueStressTest; -import org.flexlb.cases.RequestCancelTest; import org.flexlb.config.ConfigService; -import org.flexlb.service.RouteService; import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.DisplayName; -import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; -import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.test.autoconfigure.web.reactive.AutoConfigureWebTestClient; import org.springframework.boot.test.autoconfigure.web.servlet.AutoConfigureMockMvc; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.boot.test.mock.mockito.SpyBean; import org.springframework.test.context.ActiveProfiles; -import org.springframework.test.web.reactive.server.WebTestClient; import uk.org.webcompere.systemstubs.environment.EnvironmentVariables; import uk.org.webcompere.systemstubs.jupiter.SystemStub; import uk.org.webcompere.systemstubs.jupiter.SystemStubsExtension; /** - * Integration test with reusable Spring context + * Integration test with reusable Spring context. + * Schedule/cancel tests moved to gRPC layer (FlexlbServiceImpl). */ @Slf4j @ActiveProfiles("test") @@ -38,14 +31,6 @@ public class ReuseSpringContextIntegrationTest { @SpyBean private ConfigService configService; - @SpyBean - private RouteService routeService; - @Autowired - private QueueManager queueManager; - @Autowired - private WebTestClient webTestClient; - - //========================= Integration Test Class ======================// @BeforeAll public static void setUp() { @@ -74,32 +59,4 @@ public static void setUp() { environmentVariables.set("HIPPO_ROLE", "TEST_HIPPO_ROLE"); environmentVariables.set("OTEL_EXPORTER_OTLP_ENDPOINT", "http://search-uniagent-trace-na61.vip.tbsite.net:4317"); } - - private WebTestClient createWebClient() { - return webTestClient.mutate() - .baseUrl("http://localhost:7001") - .build(); - } - - @Test - @DisplayName("Request cancellation test") - public void requestCancelTest() { - RequestCancelTest.init(environmentVariables, configService, routeService).run(); - } - - @Test - @DisplayName("Queue full rejection test") - public void queueFullRejectionTest() { - QueueStressTest.init(createWebClient(), environmentVariables, configService) - .resetQueue(queueManager, 10) - .testQueueFullRejection(); - } - - @Test - @DisplayName("Concurrent enqueue thread safety test") - public void concurrentEnqueueTest() { - QueueStressTest.init(createWebClient(), environmentVariables, configService) - .resetQueue(queueManager, 500) - .testConcurrentEnqueue(); - } } diff --git a/rtp_llm/flexlb/flexlb-api/src/test/java/org/flexlb/cases/QueueStressTest.java b/rtp_llm/flexlb/flexlb-api/src/test/java/org/flexlb/cases/QueueStressTest.java index 069b22d7c3..0d50489cdb 100644 --- a/rtp_llm/flexlb/flexlb-api/src/test/java/org/flexlb/cases/QueueStressTest.java +++ b/rtp_llm/flexlb/flexlb-api/src/test/java/org/flexlb/cases/QueueStressTest.java @@ -269,7 +269,6 @@ private void setupLimitedWorkerResources() { WorkerStatus workerStatus = new WorkerStatus(); workerStatus.setAlive(true); - workerStatus.setUsedKvCacheTokens(new AtomicLong(990L)); // High usage, simulating resource constraints workerStatus.setAvailableKvCacheTokens(new AtomicLong(10L)); // Very small resources, force queuing // Configure multiple Prefill Workers diff --git a/rtp_llm/flexlb/flexlb-api/src/test/java/org/flexlb/cases/RequestCancelTest.java b/rtp_llm/flexlb/flexlb-api/src/test/java/org/flexlb/cases/RequestCancelTest.java index e4d3a56f3e..6d367a129c 100644 --- a/rtp_llm/flexlb/flexlb-api/src/test/java/org/flexlb/cases/RequestCancelTest.java +++ b/rtp_llm/flexlb/flexlb-api/src/test/java/org/flexlb/cases/RequestCancelTest.java @@ -68,7 +68,6 @@ public void run() { WorkerStatus workerStatus = new WorkerStatus(); workerStatus.setAlive(true); - workerStatus.setUsedKvCacheTokens(new AtomicLong(990L)); // High usage, simulating resource constraints workerStatus.setAvailableKvCacheTokens(new AtomicLong(10L)); // Set very small remaining memory, simulating decode resource shortage EngineWorkerStatus.MODEL_ROLE_WORKER_STATUS.getPrefillStatusMap().put("127.0.0.100:8080", workerStatus); diff --git a/rtp_llm/flexlb/flexlb-api/src/test/java/org/flexlb/httpserver/FlexlbServiceImplTest.java b/rtp_llm/flexlb/flexlb-api/src/test/java/org/flexlb/httpserver/FlexlbServiceImplTest.java new file mode 100644 index 0000000000..bdf6ee2929 --- /dev/null +++ b/rtp_llm/flexlb/flexlb-api/src/test/java/org/flexlb/httpserver/FlexlbServiceImplTest.java @@ -0,0 +1,223 @@ +package org.flexlb.httpserver; + +import io.grpc.stub.StreamObserver; +import org.flexlb.consistency.LBStatusConsistencyService; +import org.flexlb.dao.BalanceContext; +import org.flexlb.dao.loadbalance.Response; +import org.flexlb.engine.grpc.EngineRpcService; +import org.flexlb.service.RouteService; +import org.flexlb.service.grace.ActiveRequestCounter; +import org.flexlb.service.monitor.EngineHealthReporter; +import org.flexlb.config.ConfigService; +import org.flexlb.config.FlexlbConfig; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import reactor.core.publisher.Mono; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.*; + +class FlexlbServiceImplTest { + + private RouteService routeService; + private LBStatusConsistencyService lbStatusConsistencyService; + private EngineHealthReporter engineHealthReporter; + private ActiveRequestCounter activeRequestCounter; + private FlexlbGrpcForwarder grpcForwarder; + private ConfigService configService; + private FlexlbServiceImpl service; + + @BeforeEach + void setUp() { + routeService = mock(RouteService.class); + lbStatusConsistencyService = mock(LBStatusConsistencyService.class); + engineHealthReporter = mock(EngineHealthReporter.class); + activeRequestCounter = mock(ActiveRequestCounter.class); + grpcForwarder = mock(FlexlbGrpcForwarder.class); + + configService = mock(ConfigService.class); + FlexlbConfig flexlbConfig = new FlexlbConfig(); + when(configService.loadBalanceConfig()).thenReturn(flexlbConfig); + + ActiveRequestCounter.RequestToken token = mock(ActiveRequestCounter.RequestToken.class); + when(activeRequestCounter.acquire()).thenReturn(token); + + service = new FlexlbServiceImpl( + routeService, + lbStatusConsistencyService, + engineHealthReporter, + activeRequestCounter, + grpcForwarder, + configService + ); + } + + @Test + void testSchedule_localRouting() { + // Given: not master, no consistency needed + when(lbStatusConsistencyService.isNeedConsistency()).thenReturn(false); + + Response response = new Response(); + response.setSuccess(true); + response.setCode(200); + when(routeService.route(any(BalanceContext.class))).thenReturn(Mono.just(response)); + + EngineRpcService.FlexlbScheduleRequestPB request = EngineRpcService.FlexlbScheduleRequestPB.newBuilder() + .setRequestId(12345L) + .setSeqLen(100) + .build(); + + StreamObserver observer = mock(StreamObserver.class); + + // When + service.schedule(request, observer); + + // Then + ArgumentCaptor captor = + ArgumentCaptor.forClass(EngineRpcService.FlexlbScheduleResponsePB.class); + verify(observer).onNext(captor.capture()); + verify(observer).onCompleted(); + verify(observer, never()).onError(any()); + + EngineRpcService.FlexlbScheduleResponsePB resp = captor.getValue(); + assertTrue(resp.getSuccess()); + assertEquals(200, resp.getCode()); + } + + @Test + void testSchedule_forwardToMaster_success() { + // Given: consistency needed, not master, forward succeeds + when(lbStatusConsistencyService.isNeedConsistency()).thenReturn(true); + when(lbStatusConsistencyService.isMaster()).thenReturn(false); + + EngineRpcService.FlexlbScheduleResponsePB masterResponse = EngineRpcService.FlexlbScheduleResponsePB.newBuilder() + .setSuccess(true) + .setCode(200) + .build(); + when(grpcForwarder.forwardToMaster(any())).thenReturn(masterResponse); + + EngineRpcService.FlexlbScheduleRequestPB request = EngineRpcService.FlexlbScheduleRequestPB.newBuilder() + .setRequestId(12345L) + .build(); + + StreamObserver observer = mock(StreamObserver.class); + + // When + service.schedule(request, observer); + + // Then + verify(grpcForwarder).forwardToMaster(request); + verify(routeService, never()).route(any()); + + ArgumentCaptor captor = + ArgumentCaptor.forClass(EngineRpcService.FlexlbScheduleResponsePB.class); + verify(observer).onNext(captor.capture()); + + EngineRpcService.FlexlbScheduleResponsePB resp = captor.getValue(); + assertTrue(resp.getSuccess()); + } + + @Test + void testSchedule_forwardToMaster_fallbackToLocal() { + // Given: consistency needed, not master, forward fails (returns null) + when(lbStatusConsistencyService.isNeedConsistency()).thenReturn(true); + when(lbStatusConsistencyService.isMaster()).thenReturn(false); + when(grpcForwarder.forwardToMaster(any())).thenReturn(null); + + Response localResponse = new Response(); + localResponse.setSuccess(true); + localResponse.setCode(200); + when(routeService.route(any(BalanceContext.class))).thenReturn(Mono.just(localResponse)); + + EngineRpcService.FlexlbScheduleRequestPB request = EngineRpcService.FlexlbScheduleRequestPB.newBuilder() + .setRequestId(12345L) + .build(); + + StreamObserver observer = mock(StreamObserver.class); + + // When + service.schedule(request, observer); + + // Then + verify(grpcForwarder).forwardToMaster(request); + verify(routeService).route(any(BalanceContext.class)); + + ArgumentCaptor captor = + ArgumentCaptor.forClass(EngineRpcService.FlexlbScheduleResponsePB.class); + verify(observer).onNext(captor.capture()); + + EngineRpcService.FlexlbScheduleResponsePB resp = captor.getValue(); + assertTrue(resp.getSuccess()); + } + + @Test + void testSchedule_exceptionHandling() { + // Given: route throws exception + when(lbStatusConsistencyService.isNeedConsistency()).thenReturn(false); + when(routeService.route(any(BalanceContext.class))).thenThrow(new RuntimeException("test error")); + + EngineRpcService.FlexlbScheduleRequestPB request = EngineRpcService.FlexlbScheduleRequestPB.newBuilder() + .setRequestId(12345L) + .build(); + + StreamObserver observer = mock(StreamObserver.class); + + // When + service.schedule(request, observer); + + // Then + ArgumentCaptor captor = + ArgumentCaptor.forClass(EngineRpcService.FlexlbScheduleResponsePB.class); + verify(observer).onNext(captor.capture()); + verify(observer).onCompleted(); + + EngineRpcService.FlexlbScheduleResponsePB resp = captor.getValue(); + assertFalse(resp.getSuccess()); + assertEquals(500, resp.getCode()); + assertTrue(resp.getErrorMessage().contains("test error")); + } + + @Test + void testCancel_success() { + // Given + doNothing().when(routeService).cancelByRequestId(12345L); + + EngineRpcService.CancelRequestPB request = EngineRpcService.CancelRequestPB.newBuilder() + .setRequestId(12345L) + .build(); + + StreamObserver observer = mock(StreamObserver.class); + + // When + service.cancel(request, observer); + + // Then + verify(routeService).cancelByRequestId(12345L); + verify(observer).onNext(any(EngineRpcService.EmptyPB.class)); + verify(observer).onCompleted(); + verify(observer, never()).onError(any()); + } + + @Test + void testCancel_exceptionHandling() { + // Given + doThrow(new RuntimeException("cancel error")).when(routeService).cancelByRequestId(12345L); + + EngineRpcService.CancelRequestPB request = EngineRpcService.CancelRequestPB.newBuilder() + .setRequestId(12345L) + .build(); + + StreamObserver observer = mock(StreamObserver.class); + + // When + service.cancel(request, observer); + + // Then + verify(routeService).cancelByRequestId(12345L); + verify(observer).onError(any()); + verify(observer, never()).onNext(any()); + verify(observer, never()).onCompleted(); + } +} diff --git a/rtp_llm/flexlb/flexlb-cache/src/main/java/org/flexlb/cache/service/impl/DefaultCacheAwareService.java b/rtp_llm/flexlb/flexlb-cache/src/main/java/org/flexlb/cache/service/impl/DefaultCacheAwareService.java index f1a9ecfa68..422121eb5d 100644 --- a/rtp_llm/flexlb/flexlb-cache/src/main/java/org/flexlb/cache/service/impl/DefaultCacheAwareService.java +++ b/rtp_llm/flexlb/flexlb-cache/src/main/java/org/flexlb/cache/service/impl/DefaultCacheAwareService.java @@ -25,13 +25,13 @@ @Slf4j @Service public class DefaultCacheAwareService implements CacheAwareService { - + @Autowired private KvCacheManager kvCacheManager; - + @Autowired private CacheMetricsReporter cacheMetricsReporter; - + @Override public Map findMatchingEngines(List blockCacheKeys, RoleType roleType, String group) { @@ -55,12 +55,12 @@ public Map findMatchingEngines(List blockCacheKeys, return Collections.emptyMap(); } } - + @Override public WorkerCacheUpdateResult updateEngineBlockCache(WorkerStatus workerStatus) { long startTime = System.nanoTime() / 1000; String engineIpPort = workerStatus.getIpPort(); - String role = workerStatus.getRole(); + String role = workerStatus.getRole().getCode(); try { if (workerStatus.getCacheStatus() == null) { @@ -78,23 +78,23 @@ public WorkerCacheUpdateResult updateEngineBlockCache(WorkerStatus workerStatus) } Set cachedKeys = cacheStatus.getCachedKeys(); - + // Update cache kvCacheManager.updateEngineCache(ipPort, role, cachedKeys); - + WorkerCacheUpdateResult result = buildSuccessResult(workerStatus, cacheStatus); cacheMetricsReporter.reportUpdateEngineBlockCacheRT(ipPort, role, startTime, "1"); - + return result; - + } catch (Throwable e) { log.error("Error updating worker cache for: {}", engineIpPort, e); - + WorkerCacheUpdateResult result = buildFailureResult(engineIpPort, e.getMessage()); cacheMetricsReporter.reportUpdateEngineBlockCacheRT(engineIpPort, role, startTime, "0"); - + return result; } } @@ -112,7 +112,7 @@ private WorkerCacheUpdateResult buildSuccessResult(WorkerStatus workerStatus, Ca .cacheVersion(cacheStatus.getVersion()) .build(); } - + /** * Build failure result */ diff --git a/rtp_llm/flexlb/flexlb-common/src/main/java/org/flexlb/config/ConfigService.java b/rtp_llm/flexlb/flexlb-common/src/main/java/org/flexlb/config/ConfigService.java index 462b335768..34a6694e17 100644 --- a/rtp_llm/flexlb/flexlb-common/src/main/java/org/flexlb/config/ConfigService.java +++ b/rtp_llm/flexlb/flexlb-common/src/main/java/org/flexlb/config/ConfigService.java @@ -19,12 +19,11 @@ public class ConfigService { private static final String FLEXLB_CONFIG_ENV = "FLEXLB_CONFIG"; - private static final String STRATEGY_CONFIGS_ENV = "STRATEGY_CONFIGS"; + private static final String PREFILL_COEFFICIENTS_ENV = "PREFILL_COEFFICIENTS"; private static final String TRAFFIC_POLICY_CONFIG_ENV = "TRAFFIC_POLICY_CONFIG"; private static final String TRAFFIC_POLICY_CONFIG_FILE_ENV = "TRAFFIC_POLICY_CONFIG_FILE"; private final FlexlbConfig flexlbConfig; - private final StrategyConfigs strategyConfigs; public ConfigService() { this(System.getenv()); @@ -35,7 +34,12 @@ public ConfigService() { log.warn("FLEXLB_CONFIG = {}", lbConfigStr); FlexlbConfig config; if (lbConfigStr != null) { - config = JsonUtils.toObject(lbConfigStr, FlexlbConfig.class); + try { + config = JsonUtils.toObject(lbConfigStr, FlexlbConfig.class); + } catch (Exception e) { + log.error("Failed to parse FLEXLB_CONFIG, use default config", e); + config = new FlexlbConfig(); + } } else { config = new FlexlbConfig(); } @@ -43,35 +47,15 @@ public ConfigService() { // If corresponding advanced environment variables exist, override and update applyEnvironmentOverrides(config, environment); applyTrafficPolicyOverride(config, environment); + applyPrefillCoefficientsOverride(config, environment); this.flexlbConfig = config; - this.strategyConfigs = loadStrategyConfigs(environment); } public FlexlbConfig loadBalanceConfig() { return flexlbConfig; } - private StrategyConfigs loadStrategyConfigs(Map environment) { - String strategyConfigsStr = environment.get(STRATEGY_CONFIGS_ENV); - - StrategyConfigs configs; - if (StringUtils.isNotBlank(strategyConfigsStr)) { - log.warn("STRATEGY_CONFIGS = {}", strategyConfigsStr); - configs = JsonUtils.toObjectOrNull(strategyConfigsStr, StrategyConfigs.class); - if (configs == null) { - log.warn("Failed to parse STRATEGY_CONFIGS, use default strategy configs"); - configs = new StrategyConfigs(); - } - } else { - log.debug("STRATEGY_CONFIGS is not set, use default strategy configs"); - configs = new StrategyConfigs(); - } - - configs.normalize(); - return configs; - } - public synchronized void updateTrafficPolicy(TrafficPolicyConfig trafficPolicy) { if (trafficPolicy == null) { throw new IllegalArgumentException("trafficPolicy cannot be null"); @@ -136,9 +120,22 @@ private void applyTrafficPolicyOverride(FlexlbConfig config, Map return; } - TrafficPolicyConfig trafficPolicy = JsonUtils.toObject(trafficPolicyConfig, TrafficPolicyConfig.class); - config.setTrafficPolicy(trafficPolicy); - log.warn("Traffic policy loaded from standalone config: {}", JsonUtils.toStringOrEmpty(trafficPolicy)); + try { + TrafficPolicyConfig trafficPolicy = JsonUtils.toObject(trafficPolicyConfig, TrafficPolicyConfig.class); + config.setTrafficPolicy(trafficPolicy); + log.warn("Traffic policy loaded from standalone config: {}", JsonUtils.toStringOrEmpty(trafficPolicy)); + } catch (Exception e) { + log.error("Failed to parse traffic policy config, skipping.", e); + } + } + + private void applyPrefillCoefficientsOverride(FlexlbConfig config, Map environment) { + String csv = environment.get(PREFILL_COEFFICIENTS_ENV); + if (StringUtils.isBlank(csv)) { + return; + } + config.setPrefillCoefficients(csv); + log.warn("Prefill coefficients loaded from {}: {}", PREFILL_COEFFICIENTS_ENV, csv); } private String readConfigFile(String filePath) { @@ -161,6 +158,7 @@ private boolean isSupportedType(Class type) { || type == Double.class || type == boolean.class || type == Boolean.class + || type == String.class || type.isEnum(); } @@ -185,7 +183,9 @@ private String camelToUpperSnakeCase(String camelCase) { */ @SuppressWarnings({"unchecked", "rawtypes"}) private Object parseValue(String value, Class targetType) { - if (targetType == int.class || targetType == Integer.class) { + if (targetType == String.class) { + return value; + } else if (targetType == int.class || targetType == Integer.class) { return Integer.parseInt(value); } else if (targetType == long.class || targetType == Long.class) { return Long.parseLong(value); @@ -194,7 +194,7 @@ private Object parseValue(String value, Class targetType) { } else if (targetType == boolean.class || targetType == Boolean.class) { return Boolean.parseBoolean(value); } else if (targetType.isEnum()) { - return Enum.valueOf((Class) targetType, value); + return JsonUtils.toObject("\"" + value + "\"", targetType); } throw new IllegalArgumentException("Unsupported type: " + targetType); } diff --git a/rtp_llm/flexlb/flexlb-common/src/main/java/org/flexlb/config/FlexlbConfig.java b/rtp_llm/flexlb/flexlb-common/src/main/java/org/flexlb/config/FlexlbConfig.java index 6ecb09920e..9e03381677 100644 --- a/rtp_llm/flexlb/flexlb-common/src/main/java/org/flexlb/config/FlexlbConfig.java +++ b/rtp_llm/flexlb/flexlb-common/src/main/java/org/flexlb/config/FlexlbConfig.java @@ -5,10 +5,15 @@ import org.flexlb.dao.route.RoleType; import org.flexlb.enums.LoadBalanceStrategyEnum; import org.flexlb.enums.ResourceMeasureIndicatorEnum; +import org.flexlb.enums.ScheduleModeEnum; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; + +import static org.flexlb.enums.LoadBalanceStrategyEnum.COST_BASED_DECODE; +import static org.flexlb.enums.LoadBalanceStrategyEnum.COST_BASED_PREFILL; import static org.flexlb.enums.LoadBalanceStrategyEnum.RANDOM; -import static org.flexlb.enums.LoadBalanceStrategyEnum.SHORTEST_TTFT; -import static org.flexlb.enums.LoadBalanceStrategyEnum.WEIGHTED_CACHE; import static org.flexlb.enums.ResourceMeasureIndicatorEnum.REMAINING_KV_CACHE; import static org.flexlb.enums.ResourceMeasureIndicatorEnum.WAIT_TIME; @@ -24,12 +29,12 @@ public class FlexlbConfig { /** * Load balancing strategy */ - private LoadBalanceStrategyEnum loadBalanceStrategy = LoadBalanceStrategyEnum.SHORTEST_TTFT; + private LoadBalanceStrategyEnum loadBalanceStrategy = LoadBalanceStrategyEnum.COST_BASED_PREFILL; /** * Load balancing strategy for DECODE role */ - private LoadBalanceStrategyEnum decodeLoadBalanceStrategy = LoadBalanceStrategyEnum.WEIGHTED_CACHE; + private LoadBalanceStrategyEnum decodeLoadBalanceStrategy = LoadBalanceStrategyEnum.COST_BASED_DECODE; /** * Load balancing strategy for VIT role @@ -101,7 +106,7 @@ public class FlexlbConfig { * Prefill role queuing threshold * When below this threshold, the Worker is considered available */ - private long prefillQueueSizeThreshold = 3; + private long prefillQueueSizeThreshold = 64; /** * KV cache available threshold for DECODE role (percentage) @@ -176,6 +181,240 @@ public class FlexlbConfig { */ private volatile TrafficPolicyConfig trafficPolicy = new TrafficPolicyConfig(); + // ========== FlexLB Batch Configuration ========== + + /** + * Enables master-side request coalescing. Requests carrying a full + * GenerateInputPB are routed once, grouped by Prefill worker, + * and submitted through EnqueueBatch. + */ + private boolean flexlbBatchEnabled = true; + + /** + * Default schedule mode when the frontend doesn't specify one in the gRPC request. + * Environment variable: DEFAULT_SCHEDULE_MODE (values: AUTO, BATCH, DIRECT). + */ + private String defaultScheduleMode = "AUTO"; + + /** + * Maximum real requests in one EnqueueBatch request. + */ + private int flexlbBatchSizeMax = 8; + + /** + * Remaining-budget window in milliseconds. Outside this window the batcher + * keeps collecting unless the batch reaches flexlbBatchSizeMax. Inside this + * window it can dispatch once the batch has enough requests and another + * arrival is unlikely before the latest safe dispatch point. + */ + private long flexlbBatchWindowMs = 300; + + /** + * Minimum useful batch size. This is not a hard immediate-dispatch trigger: + * the batcher may keep waiting if the remaining SLO slack can likely buy + * one more request. + */ + private int flexlbBatchMinSize = 3; + + /** + * Upper bound for deadline-protection dispatch. The effective guard is + * min(flexlbBatchEmergencyBudgetMs, incrementalBatchCost + flexlbBatchDispatchGuardMs). + */ + private long flexlbBatchEmergencyBudgetMs = 150; + + /** + * Safety guard left before the computed SLO deadline when dispatching a batch. + * Covers master loop jitter, gRPC enqueue overhead, and predictor error. + */ + private long flexlbBatchDispatchGuardMs = 40; + + /** + * EMA alpha used to estimate per-worker request inter-arrival time for batching. + */ + private double flexlbBatchArrivalEmaAlpha = 0.2; + + /** + * Extra slack that must remain after the next expected request arrival before + * the latest safe dispatch point. Larger values dispatch earlier and reduce + * deadline pressure; smaller values favor bigger batches. + */ + private long flexlbBatchArrivalWaitGuardMs = 20; + + /** + * Maximum in-flight prefill batches allowed per worker before the batcher + * stops dispatching new batches and keeps requests in the master-side queue. + * Values <= 0 disable this backpressure gate. + */ + private int flexlbBatchSloMaxInflightBatches = 2; + + /** + * Maximum in-flight prefill batches per worker for the fixed_window batcher. + * When the engine already has this many batches inflight, the batcher parks + * instead of dispatching new batches. Default 0 disables backpressure — + * the fixed_window batcher dispatches regardless of engine load. + * + *

Set to a small value (e.g. 2–3) to prevent engine overload when + * using fixed_window; set to 0 to keep the original always-dispatch behavior. + */ + private int flexlbBatchFixedMaxInflightBatches = 0; + + /** + * Deadline in milliseconds for EnqueueBatch. + */ + private long flexlbBatchEnqueueDeadlineMs = 5000; + + /** + * TTL for inflight entries before eviction (used by all routing paths). + * Only a safety net — calibrate() cleans up normally. This catches stale + * entries left by engine crashes, lost status reports, or bugs. + * 5 min is generous for network/engine-report jitter but short enough + * that stale inflight won't distort realWaitTimeMs for long. + */ + private long flexlbInflightTtlMs = 300_000L; + + /** + * Maximum threads in the batch dispatch executor pool. + */ + private int flexlbBatchDispatchPoolSize = 64; + + /** + * Maximum pending tasks in the batch dispatch executor queue. + * Tasks submitted when both the pool and queue are full are rejected + * and fail immediately with QUEUE_FULL. + */ + private int flexlbBatchDispatchQueueSize = 256; + + // ========== CostBasedPrefill Strategy Configuration ========== + + /** + * Whether to enable SLO time-budget hard filter during prefill worker selection. + * When enabled, workers whose (waitMs + predictedPrefillMs) exceeds + * (SLO - riskMargin) are excluded. Default false because the filter is + * too aggressive in practice. + */ + private boolean costSloFilterEnabled = false; + + private long costSloMs = 500; + + private long costSloRiskMarginMs = 100; + + private String costSloBuckets = ""; + + private transient volatile List parsedSloBuckets; + + public void setCostSloBuckets(String costSloBuckets) { + this.costSloBuckets = costSloBuckets; + this.parsedSloBuckets = null; + } + + private double costHotspotMultiplier = 3.0; + + private double costImbalanceMultiplier = 3.0; + + private double costAlpha0 = 0; + private double costAlpha1 = 1.0; + private double costAlpha2 = 0; + private double costAlpha3 = 0; + private double costAlpha4 = 0.3; + private double costAlpha5 = 0; + + /** + * Comma-separated shorthand for the 6 predictor coefficients. + * Accepts 3 values (α₀,α₁,α₂) or 6 values (α₀–α₅). + * Overrides the individual costAlpha* fields when set. + * Example: "290,0.0116,1.21e-8" or "290,0.0116,1.21e-8,1.21e-8,0,0" + */ + public void setPrefillCoefficients(String csv) { + if (csv == null || csv.isBlank()) { + return; + } + String[] parts = csv.split(","); + try { + if (parts.length >= 3) { + costAlpha0 = Double.parseDouble(parts[0].trim()); + costAlpha1 = Double.parseDouble(parts[1].trim()); + costAlpha2 = Double.parseDouble(parts[2].trim()); + } + if (parts.length >= 6) { + costAlpha3 = Double.parseDouble(parts[3].trim()); + costAlpha4 = Double.parseDouble(parts[4].trim()); + costAlpha5 = Double.parseDouble(parts[5].trim()); + } else if (parts.length >= 3) { + costAlpha3 = 0; + costAlpha4 = 0; + costAlpha5 = 0; + } + } catch (NumberFormatException e) { + // Keep existing default values on parse failure. + } + } + + // ========== SLO-Budget Batcher Configuration ========== + + private double flexlbBatchFillThreshold = 0.5; + + private int flexlbBatchMaxCapacity = 1048576; + + private int flexlbBatchSearchIter = 10; + + private int flexlbBatchScanAhead = 64; + + /** + * Maximum queue depth per WorkerBatcher. Requests beyond this limit are + * rejected with QUEUE_FULL. + */ + private int flexlbBatchQueueMaxSize = 1024; + + /** + * Maximum total in-flight requests across all batchers. Acts as a global + * admission control gate at the FlexlbBatchScheduler entry. + */ + private int flexlbBatchMaxInflight = 100000; + + // ========== Batcher Algorithm Selection ========== + + /** + * Batcher algorithm name. Supported values: + *

    + *
  • {@code fixed_window} — Fixed time window batching with optional + * predictor-based early dispatch. No SLO deadline tracking, no EMA, + * no request dropping (default).
  • + *
  • {@code slo_budget} — SLO-deadline-aware batching with EMA arrival + * rate estimation, budget-based greedy fill, and deadline-gated dispatch.
  • + *
+ */ + private String flexlbBatchAlgorithm = "fixed_window"; + + /** + * Fixed wait time in milliseconds for the {@code fixed_window} batcher + * algorithm. After a request has waited this long, the batcher dispatches + * whatever has accumulated regardless of batch size. + * + *

Only used when {@link #flexlbBatchAlgorithm} is {@code fixed_window}. + */ + private long flexlbBatchFixedWaitMs = 300; + + /** + * Predicted batch execution time threshold in milliseconds for the + * {@code fixed_window} batcher algorithm. If the predictor estimates + * the accumulated batch will take at least this long, the batcher + * dispatches immediately rather than waiting for {@link #flexlbBatchFixedWaitMs}. + * + *

Set to 0 to disable predictor-based early dispatch (default). + * Only used when {@link #flexlbBatchAlgorithm} is {@code fixed_window}. + */ + private long flexlbBatchPredictThresholdMs = 0; + + // ========== gRPC Configuration ========== + + private long prefillLbTimeoutMs = 5000; + + // ========== Decode Load Balance Hard Filter Configuration ========== + + private double decodeHotspotMultiplier = 3.0; + + private double decodeImbalanceMultiplier = 3.0; + /** * Get load balancing strategy for a role type * This method handles the logic of selecting the appropriate strategy based on role type and configuration @@ -186,13 +425,13 @@ public class FlexlbConfig { public LoadBalanceStrategyEnum getStrategyForRoleType(RoleType roleType) { switch (roleType) { case PDFUSION -> { - return this.loadBalanceStrategy != null ? loadBalanceStrategy : SHORTEST_TTFT; + return this.loadBalanceStrategy != null ? loadBalanceStrategy : COST_BASED_PREFILL; } case PREFILL -> { - return this.loadBalanceStrategy != null ? loadBalanceStrategy : SHORTEST_TTFT; + return this.loadBalanceStrategy != null ? loadBalanceStrategy : COST_BASED_PREFILL; } case DECODE -> { - return this.decodeLoadBalanceStrategy != null ? decodeLoadBalanceStrategy : WEIGHTED_CACHE; + return this.decodeLoadBalanceStrategy != null ? decodeLoadBalanceStrategy : COST_BASED_DECODE; } case VIT -> { return this.vitLoadBalanceStrategy != null ? vitLoadBalanceStrategy : RANDOM; @@ -229,4 +468,50 @@ public ResourceMeasureIndicatorEnum getResourceMeasureIndicator(RoleType roleTyp } } } + + public long resolveSloMs(long seqLen) { + List buckets = getParsedSloBuckets(); + if (buckets == null || buckets.isEmpty()) { + return costSloMs; + } + for (long[] bucket : buckets) { + if (seqLen <= bucket[0]) { + return bucket[1]; + } + } + return buckets.get(buckets.size() - 1)[1]; + } + + private List getParsedSloBuckets() { + if (parsedSloBuckets != null) { + return parsedSloBuckets; + } + if (costSloBuckets == null || costSloBuckets.isBlank()) { + return null; + } + List result = new ArrayList<>(); + for (String entry : costSloBuckets.split(",")) { + String[] kv = entry.trim().split(":"); + if (kv.length == 2) { + try { + result.add(new long[]{Long.parseLong(kv[0].trim()), Long.parseLong(kv[1].trim())}); + } catch (NumberFormatException ignored) { + } + } + } + result.sort(Comparator.comparingLong(a -> a[0])); + parsedSloBuckets = result; + return result; + } + + /** + * Returns the configured default schedule mode as an enum. + */ + public ScheduleModeEnum getDefaultScheduleModeEnum() { + try { + return ScheduleModeEnum.valueOf(defaultScheduleMode.toUpperCase()); + } catch (IllegalArgumentException e) { + return ScheduleModeEnum.AUTO; + } + } } diff --git a/rtp_llm/flexlb/flexlb-common/src/main/java/org/flexlb/config/StrategyConfigs.java b/rtp_llm/flexlb/flexlb-common/src/main/java/org/flexlb/config/StrategyConfigs.java deleted file mode 100644 index 8f6055298c..0000000000 --- a/rtp_llm/flexlb/flexlb-common/src/main/java/org/flexlb/config/StrategyConfigs.java +++ /dev/null @@ -1,114 +0,0 @@ -package org.flexlb.config; - -import com.fasterxml.jackson.annotation.JsonCreator; -import com.fasterxml.jackson.annotation.JsonIgnoreProperties; -import lombok.Data; -import lombok.extern.slf4j.Slf4j; - -import java.util.Locale; - -@Data -@Slf4j -@JsonIgnoreProperties(ignoreUnknown = true) -public class StrategyConfigs { - - private ShortestTtftStrategyConfig shortestTtft = new ShortestTtftStrategyConfig(); - - public void normalize() { - if (shortestTtft == null) { - shortestTtft = new ShortestTtftStrategyConfig(); - } - shortestTtft.normalize(); - } - - @Data - @JsonIgnoreProperties(ignoreUnknown = true) - public static class ShortestTtftStrategyConfig { - - private CandidatePoolConfig candidatePool = new CandidatePoolConfig(); - - private void normalize() { - if (candidatePool == null) { - candidatePool = new CandidatePoolConfig(); - } - candidatePool.normalize(); - } - } - - @Data - @JsonIgnoreProperties(ignoreUnknown = true) - public static class CandidatePoolConfig { - - public static final CandidatePoolMode DEFAULT_MODE = CandidatePoolMode.RATIO; - public static final double DEFAULT_RATIO = 0.3; - public static final int DEFAULT_MIN_SIZE = 1; - public static final int DEFAULT_SIZE = 1; - - private CandidatePoolMode mode = DEFAULT_MODE; - private double ratio = DEFAULT_RATIO; - private int minSize = DEFAULT_MIN_SIZE; - private int size = DEFAULT_SIZE; - - public int resolveCandidateCount(int workerCount) { - if (workerCount <= 0) { - return 0; - } - - CandidatePoolMode resolvedMode = mode != null ? mode : DEFAULT_MODE; - double resolvedRatio = isValidRatio(ratio) ? ratio : DEFAULT_RATIO; - int resolvedMinSize = Math.max(DEFAULT_MIN_SIZE, minSize); - int resolvedSize = Math.max(DEFAULT_SIZE, size); - - int candidateCount; - if (resolvedMode == CandidatePoolMode.FIXED) { - candidateCount = resolvedSize; - } else { - candidateCount = Math.max(resolvedMinSize, (int) (workerCount * resolvedRatio)); - } - return Math.min(workerCount, Math.max(DEFAULT_MIN_SIZE, candidateCount)); - } - - private void normalize() { - if (mode == null) { - log.warn("Invalid shortestTtft candidatePool mode: null, fallback to default: {}", DEFAULT_MODE); - mode = DEFAULT_MODE; - } - - if (!isValidRatio(ratio)) { - log.warn("Invalid shortestTtft candidatePool ratio: {}, fallback to default: {}", ratio, DEFAULT_RATIO); - ratio = DEFAULT_RATIO; - } - - if (minSize < DEFAULT_MIN_SIZE) { - log.warn("Invalid shortestTtft candidatePool minSize: {}, fallback to default: {}", minSize, DEFAULT_MIN_SIZE); - minSize = DEFAULT_MIN_SIZE; - } - - if (size < DEFAULT_MIN_SIZE) { - log.warn("Invalid shortestTtft candidatePool size: {}, fallback to default: {}", size, DEFAULT_SIZE); - size = DEFAULT_SIZE; - } - } - - private boolean isValidRatio(double value) { - return Double.isFinite(value) && value > 0.0 && value <= 1.0; - } - } - - public enum CandidatePoolMode { - RATIO, - FIXED; - - @JsonCreator - public static CandidatePoolMode fromString(String value) { - if (value == null || value.trim().isEmpty()) { - return null; - } - try { - return CandidatePoolMode.valueOf(value.trim().toUpperCase(Locale.ROOT)); - } catch (IllegalArgumentException e) { - return null; - } - } - } -} diff --git a/rtp_llm/flexlb/flexlb-common/src/main/java/org/flexlb/constant/MetricConstant.java b/rtp_llm/flexlb/flexlb-common/src/main/java/org/flexlb/constant/MetricConstant.java index 8cbaebb8c4..4584a0ee27 100644 --- a/rtp_llm/flexlb/flexlb-common/src/main/java/org/flexlb/constant/MetricConstant.java +++ b/rtp_llm/flexlb/flexlb-common/src/main/java/org/flexlb/constant/MetricConstant.java @@ -51,12 +51,12 @@ public class MetricConstant { public static final String ENGINE_BALANCING_MASTER_SELECT_DETAIL = "app.engine.balancing.master.select.detail"; /** - * Engine queue wait time + * Engine running queue time (from EP authoritative value) */ public static final String ENGINE_RUNNING_QUEUE_TIME = "app.engine.health.check.running.queue.time"; /** - * Engine local task map size + * Engine local task map size (from EP authoritative value) */ public static final String ENGINE_LOCAL_TASK_MAP_SIZE = "app.engine.health.check.local.task.map.size"; @@ -95,6 +95,9 @@ public class MetricConstant { */ public static final String ENGINE_WORKER_INFO_STEP_LATENCY_VAR = "app.engine.worker.info.step.latency.var"; + /** + * Engine worker info running query length variance + */ public static final String ENGINE_WORKER_INFO_RUNNING_QUERY_LEN_VAR = "app.engine.worker.info.running.query.len.var"; /* ------------------------ Cache Health Monitoring -------------------------- */ diff --git a/rtp_llm/flexlb/flexlb-common/src/main/java/org/flexlb/dao/BalanceContext.java b/rtp_llm/flexlb/flexlb-common/src/main/java/org/flexlb/dao/BalanceContext.java index 863982d8dd..38062fd524 100644 --- a/rtp_llm/flexlb/flexlb-common/src/main/java/org/flexlb/dao/BalanceContext.java +++ b/rtp_llm/flexlb/flexlb-common/src/main/java/org/flexlb/dao/BalanceContext.java @@ -5,6 +5,7 @@ import org.flexlb.config.FlexlbConfig; import org.flexlb.dao.loadbalance.Request; import org.flexlb.dao.loadbalance.Response; +import org.flexlb.enums.ScheduleModeEnum; import java.util.concurrent.CompletableFuture; import java.util.concurrent.atomic.AtomicBoolean; @@ -27,6 +28,11 @@ public class BalanceContext { private Response response; + @ToString.Exclude + private byte[] generateInputPbBytes; + + private ScheduleModeEnum scheduleMode = ScheduleModeEnum.AUTO; + //======================== Queue ========================// private CompletableFuture future; diff --git a/rtp_llm/flexlb/flexlb-common/src/main/java/org/flexlb/dao/loadbalance/Request.java b/rtp_llm/flexlb/flexlb-common/src/main/java/org/flexlb/dao/loadbalance/Request.java index 8f493f563a..b733f9b767 100644 --- a/rtp_llm/flexlb/flexlb-common/src/main/java/org/flexlb/dao/loadbalance/Request.java +++ b/rtp_llm/flexlb/flexlb-common/src/main/java/org/flexlb/dao/loadbalance/Request.java @@ -37,4 +37,16 @@ public class Request { @JsonAlias({"apikey", "apiKey"}) @ToString.Exclude private String apiKey; + + @JsonProperty("max_new_tokens") + private int maxNewTokens = 1; + + @JsonProperty("num_beams") + private int numBeams = 1; + + @JsonProperty("force_disable_sp_run") + private boolean forceDisableSpRun = false; + + @JsonProperty("model") + private String model = ""; } diff --git a/rtp_llm/flexlb/flexlb-common/src/main/java/org/flexlb/dao/loadbalance/Response.java b/rtp_llm/flexlb/flexlb-common/src/main/java/org/flexlb/dao/loadbalance/Response.java index 6c967bd7cd..23867f56cb 100644 --- a/rtp_llm/flexlb/flexlb-common/src/main/java/org/flexlb/dao/loadbalance/Response.java +++ b/rtp_llm/flexlb/flexlb-common/src/main/java/org/flexlb/dao/loadbalance/Response.java @@ -5,6 +5,7 @@ import lombok.Data; import java.util.List; +import java.util.Map; @JsonIgnoreProperties(ignoreUnknown = true) @Data @@ -28,6 +29,12 @@ public class Response { @JsonProperty("queue_length") private Integer queueLength; + @JsonProperty("enqueued_by_master") + private boolean enqueuedByMaster = false; + + @JsonProperty("worker_summary") + private Map workerSummary; + public static Response error(StrategyErrorType strategyErrorType) { Response result = new Response(); result.setSuccess(false); @@ -35,4 +42,11 @@ public static Response error(StrategyErrorType strategyErrorType) { result.setErrorMessage(strategyErrorType.getErrorMsg()); return result; } + + @Data + public static class WorkerRoleSummary { + private int discovered; + private int alive; + private long maxQueueTokens; + } } diff --git a/rtp_llm/flexlb/flexlb-common/src/main/java/org/flexlb/dao/loadbalance/ServerStatus.java b/rtp_llm/flexlb/flexlb-common/src/main/java/org/flexlb/dao/loadbalance/ServerStatus.java index c6cb4ca579..44ecf3c935 100644 --- a/rtp_llm/flexlb/flexlb-common/src/main/java/org/flexlb/dao/loadbalance/ServerStatus.java +++ b/rtp_llm/flexlb/flexlb-common/src/main/java/org/flexlb/dao/loadbalance/ServerStatus.java @@ -20,6 +20,9 @@ public class ServerStatus { @JsonProperty("grpc_port") private int grpcPort; + @JsonProperty("dp_rank") + private long dpRank; + @JsonProperty("prefill_time") private long prefillTime; diff --git a/rtp_llm/flexlb/flexlb-common/src/main/java/org/flexlb/dao/loadbalance/StrategyErrorType.java b/rtp_llm/flexlb/flexlb-common/src/main/java/org/flexlb/dao/loadbalance/StrategyErrorType.java index 10be29c676..f9cee8e5db 100644 --- a/rtp_llm/flexlb/flexlb-common/src/main/java/org/flexlb/dao/loadbalance/StrategyErrorType.java +++ b/rtp_llm/flexlb/flexlb-common/src/main/java/org/flexlb/dao/loadbalance/StrategyErrorType.java @@ -22,7 +22,12 @@ public enum StrategyErrorType { // queue error QUEUE_FULL(8502, false), QUEUE_TIMEOUT(8503, false), - REQUEST_CANCELLED(8504, false); + REQUEST_CANCELLED(8504, false), + + // batch dispatch error + BATCH_DISPATCH_FAILED(8510, true), + BATCH_SLO_EXPIRED(8511, false), + BATCH_BUILD_FAILED(8512, false); private final int errorCode; private final String errorMsg; diff --git a/rtp_llm/flexlb/flexlb-common/src/main/java/org/flexlb/dao/master/TaskInfo.java b/rtp_llm/flexlb/flexlb-common/src/main/java/org/flexlb/dao/master/TaskInfo.java index 656f16a014..5b975529bc 100644 --- a/rtp_llm/flexlb/flexlb-common/src/main/java/org/flexlb/dao/master/TaskInfo.java +++ b/rtp_llm/flexlb/flexlb-common/src/main/java/org/flexlb/dao/master/TaskInfo.java @@ -4,6 +4,7 @@ import com.fasterxml.jackson.annotation.JsonProperty; import lombok.Data; import lombok.extern.slf4j.Slf4j; +import org.flexlb.enums.TaskPhase; import org.flexlb.enums.TaskStateEnum; import java.util.Map; @@ -32,6 +33,16 @@ public class TaskInfo { private long endTimeMs; @JsonProperty("dp_rank") private long dpRank; + @JsonProperty("error_code") + private long errorCode; + @JsonProperty("error_message") + private String errorMessage; + @JsonProperty("batch_id") + private long batchId = -1; + @JsonProperty("phase") + private TaskPhase phase; + + private long predictedMs; // Task state related fields private TaskStateEnum taskState = TaskStateEnum.CREATED; diff --git a/rtp_llm/flexlb/flexlb-common/src/main/java/org/flexlb/dao/master/WorkerStatus.java b/rtp_llm/flexlb/flexlb-common/src/main/java/org/flexlb/dao/master/WorkerStatus.java index 3dc2ab8325..9916147087 100644 --- a/rtp_llm/flexlb/flexlb-common/src/main/java/org/flexlb/dao/master/WorkerStatus.java +++ b/rtp_llm/flexlb/flexlb-common/src/main/java/org/flexlb/dao/master/WorkerStatus.java @@ -2,16 +2,10 @@ import lombok.Data; import lombok.extern.slf4j.Slf4j; -import org.apache.commons.collections4.MapUtils; import org.flexlb.dao.route.RoleType; -import org.flexlb.enums.TaskStateEnum; -import org.flexlb.util.Logger; import org.slf4j.LoggerFactory; -import java.util.Iterator; import java.util.Map; -import java.util.Map.Entry; -import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; @@ -22,261 +16,68 @@ public class WorkerStatus { private static final org.slf4j.Logger logger = LoggerFactory.getLogger("syncLogger"); public final transient ReentrantLock lock = new ReentrantLock(); - private String role; + private RoleType role; private String group; private String ip; private int port; + private int grpcPort; private String site; private Long availableConcurrency; private boolean alive; private AtomicLong availableKvCacheTokens = new AtomicLong(); - private AtomicLong usedKvCacheTokens = new AtomicLong(); + private AtomicLong totalKvCacheTokens = new AtomicLong(); private CacheStatus cacheStatus; - private AtomicLong runningQueueTime = new AtomicLong(); - private Map waitingTaskList; private Map runningTaskList; private AtomicLong latestFinishedTaskVersion = new AtomicLong(-1L); - private ConcurrentHashMap localTaskMap = new ConcurrentHashMap<>(); private double stepLatencyMs; private long iterateCount; private long dpSize; private long tpSize; - - private AtomicLong statusLastUpdateTime = new AtomicLong(-1); // Last status update time (microseconds) - private AtomicLong statusUpdateIntervalUs = new AtomicLong(0); // Actual interval between last two status updates (microseconds) - private AtomicLong cacheLastUpdateTime = new AtomicLong(-1); // Last cache status update time - private AtomicLong lastSelectedTime = new AtomicLong(-1); // Last selection time - private AtomicBoolean resourceAvailable = new AtomicBoolean(true); // Resource availability state - private AtomicBoolean statusCheckInProgress = new AtomicBoolean(false); // Status check in progress flag - private AtomicBoolean cacheCheckInProgress = new AtomicBoolean(false); // Cache check in progress flag + private long dpRank; + + private AtomicLong statusLastUpdateTime = new AtomicLong(-1); + private AtomicLong statusUpdateIntervalUs = new AtomicLong(0); + private AtomicLong cacheLastUpdateTime = new AtomicLong(-1); + private AtomicLong lastSelectedTime = new AtomicLong(-1); + private AtomicBoolean resourceAvailable = new AtomicBoolean(true); + private AtomicBoolean statusCheckInProgress = new AtomicBoolean(false); + private AtomicBoolean cacheCheckInProgress = new AtomicBoolean(false); private AtomicLong statusVersion = new AtomicLong(-1L); + private AtomicLong consecutiveFailures = new AtomicLong(0); /** - * Add task to local running queue - * @param requestId Request ID - * @param taskInfo Task information - */ - public void putLocalTask(Long requestId, TaskInfo taskInfo) { - taskInfo.updateTaskState(TaskStateEnum.IN_TRANSIT); - AtomicReference replacedTaskRef = new AtomicReference<>(); - localTaskMap.compute(requestId, (id, previousTask) -> { - if (previousTask != null) { - releaseLocalTaskResources(previousTask); - replacedTaskRef.set(previousTask); - } - reserveLocalTaskResources(taskInfo); - return taskInfo; - }); - - lastSelectedTime.set(System.nanoTime() / 1000); - Logger.debug("Task {} added to local queue with state: {}", requestId, TaskStateEnum.IN_TRANSIT); - } - - /** - * Remove task from local running queue - * @param requestId Request ID - */ - public void removeLocalTask(Long requestId) { - localTaskMap.computeIfPresent(requestId, (id, taskInfo) -> { - releaseLocalTaskResources(taskInfo); - return null; - }); - } - - private void reserveLocalTaskResources(TaskInfo taskInfo) { - this.addRunningQueueTime(taskInfo.estimatePrefillTime()); - long needNewKvCacheLen = taskInfo.getInputLength() - taskInfo.getPrefixLength(); - this.decKvCacheFree(needNewKvCacheLen); - this.addKvCacheUsed(needNewKvCacheLen); - } - - private void releaseLocalTaskResources(TaskInfo taskInfo) { - safeDecrementQueueTime(runningQueueTime, taskInfo.estimatePrefillTime()); - long needNewKvCacheLen = taskInfo.getInputLength() - taskInfo.getPrefixLength(); - decKvCacheFree(-needNewKvCacheLen); - addKvCacheUsed(-needNewKvCacheLen); - } - - /** - * Add estimated execution time to running queue - * @param len Estimated execution time to add - */ - public void addRunningQueueTime(long len) { - runningQueueTime.addAndGet(len); - } - - public void addKvCacheUsed(long len) { - usedKvCacheTokens.addAndGet(len); - } - - public void decKvCacheFree(long len) { - availableKvCacheTokens.accumulateAndGet(len, (current, decrement) -> - Math.max(0, current - decrement)); - } - - /** - * Update task states - * Check for lost tasks, update running/waiting tasks, and clean up finished tasks - */ - public void updateTaskStates(Map waitingTaskInfo, Map runningTaskInfo, Map finishedTaskInfo) { - Iterator> iterator = localTaskMap.entrySet().iterator(); - while (iterator.hasNext()) { - Map.Entry entry = iterator.next(); - Long requestId = entry.getKey(); - TaskInfo localTask = entry.getValue(); - String requestIdStr = String.valueOf(requestId); - - TaskInfo finishedTask = finishedTaskInfo != null ? finishedTaskInfo.get(requestIdStr) : null; - if (finishedTask != null) { - if (localTask.getTaskState() == TaskStateEnum.IN_TRANSIT) { - localTask.updateTaskState(TaskStateEnum.CONFIRMED); - Logger.debug("Task {} first confirmed by worker", requestId); - } - localTask.updateTaskState(TaskStateEnum.FINISHED); - - if (RoleType.PREFILL.matches(role) || RoleType.PDFUSION.matches(role)) { - long delta = finishedTask.estimatePrefillTime(); - safeDecrementQueueTime(runningQueueTime, delta); - } - Logger.debug("Task {} finished and removed", requestId); - iterator.remove(); - continue; - } - - TaskInfo runningTask = runningTaskInfo != null ? runningTaskInfo.get(requestIdStr) : null; - if (runningTask != null) { - localTask.setLastActiveTimeUs(System.nanoTime() / 1000); - - if (localTask.getTaskState() == TaskStateEnum.IN_TRANSIT) { - localTask.updateTaskState(TaskStateEnum.CONFIRMED); - Logger.debug("Task {} first confirmed by worker", requestId); - } - if (localTask.getTaskState() != TaskStateEnum.RUNNING) { - localTask.updateTaskState(TaskStateEnum.RUNNING); - } - - localTask.setPrefixLength(runningTask.getPrefixLength()); - localTask.setPrefillTime(runningTask.getPrefillTime()); - localTask.setInputLength(runningTask.getInputLength()); - localTask.setWaitingTime(runningTask.getWaitingTime()); - localTask.setIterateCount(runningTask.getIterateCount()); - localTask.setEndTimeMs(runningTask.getEndTimeMs()); - localTask.setDpRank(runningTask.getDpRank()); - - continue; - } - - TaskInfo waitingTask = waitingTaskInfo != null ? waitingTaskInfo.get(requestIdStr) : null; - if (waitingTask != null) { - localTask.setLastActiveTimeUs(System.nanoTime() / 1000); - - if (localTask.getTaskState() == TaskStateEnum.IN_TRANSIT) { - localTask.updateTaskState(TaskStateEnum.CONFIRMED); - Logger.debug("Task {} first confirmed by worker (waiting)", requestId); - } - - localTask.setPrefixLength(waitingTask.getPrefixLength()); - localTask.setInputLength(waitingTask.getInputLength()); - localTask.setWaitingTime(waitingTask.getWaitingTime()); - localTask.setDpRank(waitingTask.getDpRank()); - - continue; - } - - if (localTask.getTaskState() == TaskStateEnum.CONFIRMED || localTask.getTaskState() == TaskStateEnum.RUNNING) { - localTask.updateTaskState(TaskStateEnum.LOST); - logger.warn("Task {} marked as LOST - not in waiting, running or finished list", requestId); - } - } - } - - /** - * Update total queue time for running queue + * Absorb all dynamic engine fields from a gRPC status response. + * Topology labels ({@code site}, {@code group}) are NOT set here — + * they are managed externally by the sync runner. */ - public void updateRunningQueueTime() { - int localTaskMapSize = localTaskMap.size(); - if (localTaskMapSize == 0) { - runningQueueTime.getAndSet(0); + public void updateFromResponse(WorkerStatusResponse resp) { + if (resp == null) { return; } - long rectifiedEstimateRunningTime = 0; - for (Entry entry : localTaskMap.entrySet()) { - TaskInfo taskInfo = entry.getValue(); - // Recalculate based on accurate cache hit count, rectify local task running queue time - rectifiedEstimateRunningTime += taskInfo.estimatePrefillTime(); - } - if (RoleType.PREFILL.matches(role) || RoleType.PDFUSION.matches(role)) { - // Only update when rectified time is less than estimated time, because engine layer returned running_list may include queuing tasks where prefixLength=0 - if (runningQueueTime.get() > rectifiedEstimateRunningTime) { - runningQueueTime.getAndSet(rectifiedEstimateRunningTime); - } + this.role = resp.getRole(); + this.alive = resp.isAlive(); + this.availableConcurrency = resp.getAvailableConcurrency(); + this.stepLatencyMs = resp.getStepLatencyMs(); + this.iterateCount = resp.getIterateCount(); + this.dpSize = resp.getDpSize(); + this.tpSize = resp.getTpSize(); + this.dpRank = resp.getDpRank(); + this.availableKvCacheTokens.set(resp.getAvailableKvCacheTokens()); + this.totalKvCacheTokens.set(resp.getTotalKvCacheTokens()); + this.cacheStatus = resp.getCacheStatus(); + this.runningTaskList = resp.getRunningTaskInfo(); + this.statusVersion.set(resp.getStatusVersion()); + this.latestFinishedTaskVersion.set(resp.getLatestFinishedVersion()); + + long nowUs = System.nanoTime() / 1000; + long prev = this.statusLastUpdateTime.get(); + if (prev > 0) { + this.statusUpdateIntervalUs.set(nowUs - prev); } + this.statusLastUpdateTime.set(nowUs); } - public void updateKvCacheTokens(long latestUsedKvCacheTokens, long latestAvailableKvCacheTokens) { - - int localTaskMapSize = localTaskMap.size(); - if (localTaskMapSize == 0) { - usedKvCacheTokens.getAndSet(latestUsedKvCacheTokens); - availableKvCacheTokens.getAndSet(latestAvailableKvCacheTokens); - return; - } - - long inTransitTaskCacheUsed = 0; - for (Map.Entry entry : localTaskMap.entrySet()) { - TaskInfo taskInfo = entry.getValue(); - // Calculate tokens occupied by in-transit task cache miss portion - if (taskInfo.getTaskState() == TaskStateEnum.IN_TRANSIT) { - inTransitTaskCacheUsed = inTransitTaskCacheUsed + taskInfo.getInputLength() - taskInfo.getPrefixLength(); - } - } - // Rectify KV cache tokens affected by in-transit tasks - latestUsedKvCacheTokens += inTransitTaskCacheUsed; - latestAvailableKvCacheTokens -= inTransitTaskCacheUsed; - - usedKvCacheTokens.getAndSet(latestUsedKvCacheTokens); - availableKvCacheTokens.getAndSet(latestAvailableKvCacheTokens); - - } - - public long getLocalPendingTaskCount() { - long pendingTaskCount = 0; - for (TaskInfo taskInfo : localTaskMap.values()) { - if (isLocalPendingTask(taskInfo)) { - pendingTaskCount++; - } - } - return pendingTaskCount; - } - - private boolean isLocalPendingTask(TaskInfo taskInfo) { - if (taskInfo == null) { - return false; - } - TaskStateEnum taskState = taskInfo.getTaskState(); - return taskState == TaskStateEnum.IN_TRANSIT || taskState == TaskStateEnum.CONFIRMED; - } - - /** - * Safely decrement total queue time for running queue, ensuring it never becomes negative - * - * @param runningQueueTime Total queue time for running queue - * @param timeToReduce Time to reduce - */ - public static void safeDecrementQueueTime(AtomicLong runningQueueTime, long timeToReduce) { - if (timeToReduce <= 0) { - logger.warn("Invalid tokens to reduce: {}", timeToReduce); - return; - } - runningQueueTime.accumulateAndGet(timeToReduce, (currentRunningQueueTime, reductionAmount) -> { - // Ensure reduction amount is positive, calculate new value, but not less than 0 - long newRunningQueueTime = currentRunningQueueTime - reductionAmount; - - // If result is negative, set to 0, ensuring token count never goes below 0 - return Math.max(newRunningQueueTime, 0L); - }); - } /** * Update resource availability with hysteresis to prevent state oscillation. diff --git a/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/domain/worker/WorkerStatusResponse.java b/rtp_llm/flexlb/flexlb-common/src/main/java/org/flexlb/dao/master/WorkerStatusResponse.java similarity index 82% rename from rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/domain/worker/WorkerStatusResponse.java rename to rtp_llm/flexlb/flexlb-common/src/main/java/org/flexlb/dao/master/WorkerStatusResponse.java index 25570ee4de..abf8fde13e 100644 --- a/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/domain/worker/WorkerStatusResponse.java +++ b/rtp_llm/flexlb/flexlb-common/src/main/java/org/flexlb/dao/master/WorkerStatusResponse.java @@ -1,10 +1,9 @@ -package org.flexlb.domain.worker; +package org.flexlb.dao.master; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonProperty; import lombok.Data; -import org.flexlb.dao.master.CacheStatus; -import org.flexlb.dao.master.TaskInfo; +import org.flexlb.dao.route.RoleType; import java.util.Map; @@ -18,7 +17,7 @@ public class WorkerStatusResponse { @JsonProperty("role") - private String role; + private RoleType role; @JsonProperty("available_concurrency") private long availableConcurrency; @@ -32,9 +31,6 @@ public class WorkerStatusResponse { @JsonProperty("running_task_info") private Map runningTaskInfo; - @JsonProperty("waiting_task_info") - private Map waitingTaskInfo; - @JsonProperty("finished_task_info") private Map finishedTaskInfo; @@ -59,13 +55,22 @@ public class WorkerStatusResponse { @JsonProperty("tpSize") private long tpSize; + @JsonProperty("dpRank") + private long dpRank; + @JsonProperty("alive") private boolean alive; + @JsonProperty("available_kv_cache") + private long availableKvCacheTokens; + + @JsonProperty("total_kv_cache") + private long totalKvCacheTokens; + @JsonProperty("version") private long version; @JsonProperty("message") private String message; -} \ No newline at end of file +} diff --git a/rtp_llm/flexlb/flexlb-common/src/main/java/org/flexlb/dao/route/RoleType.java b/rtp_llm/flexlb/flexlb-common/src/main/java/org/flexlb/dao/route/RoleType.java index 13090f76b2..97acfdf874 100644 --- a/rtp_llm/flexlb/flexlb-common/src/main/java/org/flexlb/dao/route/RoleType.java +++ b/rtp_llm/flexlb/flexlb-common/src/main/java/org/flexlb/dao/route/RoleType.java @@ -1,5 +1,7 @@ package org.flexlb.dao.route; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonValue; import lombok.Getter; import org.flexlb.dao.loadbalance.StrategyErrorType; @@ -8,11 +10,12 @@ @Getter public enum RoleType { - PDFUSION("RoleType.PDFUSION", "Prefill-Decode Fusion"), - PREFILL("RoleType.PREFILL", "Prefill"), - DECODE("RoleType.DECODE", "Decode"), - VIT("RoleType.VIT", "Vision Transformer"); + PDFUSION("PDFUSION", "Prefill-Decode Fusion"), + PREFILL("PREFILL", "Prefill"), + DECODE("DECODE", "Decode"), + VIT("VIT", "Vision Transformer"); + @JsonValue private final String code; private final String description; @@ -30,14 +33,38 @@ public enum RoleType { } /** - * Check if string matches current role type + * Deserialize from JSON string. Accepts short name ("PREFILL") or proto-prefixed name ("ROLE_TYPE_PREFILL"). */ + @JsonCreator + public static RoleType fromString(String value) { + if (value == null) { + return null; + } + // Try code first ("PREFILL") + RoleType byCode = CODE_MAP.get(value); + if (byCode != null) { + return byCode; + } + // Compat: strip proto prefix ("ROLE_TYPE_PREFILL" -> "PREFILL") + if (value.startsWith("ROLE_TYPE_")) { + return RoleType.valueOf(value.substring(10)); + } + // Fallback: try enum name + return RoleType.valueOf(value); + } + + /** + * Check if string matches current role type. + * + * @deprecated Use {@code roleType == RoleType.PREFILL} or enum comparison instead. + */ + @Deprecated public boolean matches(String code) { return this.code.equals(code); } /** - * Get corresponding error type based on role type + * Get corresponding error type based on role type. * * @return Corresponding error type */ @@ -49,4 +76,15 @@ public StrategyErrorType getErrorType() { case VIT -> StrategyErrorType.NO_VIT_WORKER; }; } + + /** + * Get the proto enum constant name (ROLE_TYPE_XXX). + * + * @deprecated Use {@link org.flexlb.engine.grpc.RoleTypeProtoConverter#toProto(RoleType)} + * for direct proto enum mapping. + */ + @Deprecated + public String getProtoName() { + return "ROLE_TYPE_" + this.name(); + } } diff --git a/rtp_llm/flexlb/flexlb-common/src/main/java/org/flexlb/enums/LoadBalanceStrategyEnum.java b/rtp_llm/flexlb/flexlb-common/src/main/java/org/flexlb/enums/LoadBalanceStrategyEnum.java index 4c6b763d1a..5f9558f517 100644 --- a/rtp_llm/flexlb/flexlb-common/src/main/java/org/flexlb/enums/LoadBalanceStrategyEnum.java +++ b/rtp_llm/flexlb/flexlb-common/src/main/java/org/flexlb/enums/LoadBalanceStrategyEnum.java @@ -1,15 +1,17 @@ package org.flexlb.enums; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonValue; import lombok.Getter; @Getter public enum LoadBalanceStrategyEnum { - RANDOM("Random"), // Random assignment + RANDOM("Random"), - SHORTEST_TTFT("ShortestTTFT"), // Shortest Time-To-First-Token + COST_BASED_PREFILL("CostBasedPrefill"), - WEIGHTED_CACHE("WeightedCache") // Lowest cache usage strategy + COST_BASED_DECODE("CostBasedDecode") ; private final String name; @@ -18,4 +20,18 @@ public enum LoadBalanceStrategyEnum { this.name = name; } + @JsonValue + public String getName() { + return name; + } + + @JsonCreator + public static LoadBalanceStrategyEnum fromName(String value) { + for (LoadBalanceStrategyEnum e : values()) { + if (e.name.equals(value) || e.name().equals(value)) { + return e; + } + } + throw new IllegalArgumentException("Unknown strategy: " + value); + } } diff --git a/rtp_llm/flexlb/flexlb-common/src/main/java/org/flexlb/enums/ScheduleModeEnum.java b/rtp_llm/flexlb/flexlb-common/src/main/java/org/flexlb/enums/ScheduleModeEnum.java new file mode 100644 index 0000000000..d3ab6100aa --- /dev/null +++ b/rtp_llm/flexlb/flexlb-common/src/main/java/org/flexlb/enums/ScheduleModeEnum.java @@ -0,0 +1,7 @@ +package org.flexlb.enums; + +public enum ScheduleModeEnum { + AUTO, + BATCH, + DIRECT +} diff --git a/rtp_llm/flexlb/flexlb-common/src/main/java/org/flexlb/enums/TaskPhase.java b/rtp_llm/flexlb/flexlb-common/src/main/java/org/flexlb/enums/TaskPhase.java new file mode 100644 index 0000000000..969bb51e8a --- /dev/null +++ b/rtp_llm/flexlb/flexlb-common/src/main/java/org/flexlb/enums/TaskPhase.java @@ -0,0 +1,34 @@ +package org.flexlb.enums; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonValue; +import lombok.Getter; + +@Getter +public enum TaskPhase { + + PENDING("pending"), + RECEIVED("received"), + KV_ALLOCATED("kv_allocated"), + RUNNING("running"); + + @JsonValue + private final String value; + + TaskPhase(String value) { + this.value = value; + } + + @JsonCreator + public static TaskPhase fromValue(String value) { + if (value == null) { + return null; + } + for (TaskPhase phase : values()) { + if (phase.value.equalsIgnoreCase(value)) { + return phase; + } + } + return null; + } +} diff --git a/rtp_llm/flexlb/flexlb-common/src/main/java/org/flexlb/util/Logger.java b/rtp_llm/flexlb/flexlb-common/src/main/java/org/flexlb/util/Logger.java index 683beed421..47cda3a358 100644 --- a/rtp_llm/flexlb/flexlb-common/src/main/java/org/flexlb/util/Logger.java +++ b/rtp_llm/flexlb/flexlb-common/src/main/java/org/flexlb/util/Logger.java @@ -1,76 +1,86 @@ package org.flexlb.util; -import lombok.Getter; -import lombok.Setter; +import ch.qos.logback.classic.Level; import org.flexlb.enums.LogLevel; import org.slf4j.LoggerFactory; /** - * Logging utility class, in order to log when enable global switch or set log level in master request + * Logging utility wrapping SLF4J with runtime log-level control via logback. * - *

The {@code info} {@code warn} and {@code error} level in enabled by default.

+ *

All filtering is delegated to logback — there is no custom gating. + * {@link #setLevel(LogLevel)} directly updates the {@code flexlbLogger} logback logger, + * so the {@code update_log_level} API and the logback configuration stay in sync.

* - * @see LogLevel + *

{@code INFO}, {@code WARN} and {@code ERROR} are enabled by default (logback + * {@code flexlbLogger} starts at {@code INFO}).

*/ public class Logger { private static final org.slf4j.Logger log = LoggerFactory.getLogger("flexlbLogger"); - @Getter - @Setter - private static LogLevel globalLogLevel; - static { String logLevelStr = System.getenv("LOG_LEVEL"); if (logLevelStr != null) { try { - globalLogLevel = LogLevel.valueOf(logLevelStr.toUpperCase().trim()); + setLevel(LogLevel.valueOf(logLevelStr.toUpperCase().trim())); } catch (IllegalArgumentException e) { log.warn("Invalid LOG_LEVEL value: '{}'. Valid values are: TRACE, DEBUG, INFO, WARN, ERROR.", logLevelStr); } } } + // ---- Logging methods (delegate directly to SLF4J) ---- + public static void trace(String format, Object... args) { - log(LogLevel.TRACE, () -> log.trace(format, args)); + log.trace(format, args); } public static void debug(String format, Object... args) { - log(LogLevel.DEBUG, () -> log.debug(format, args)); + log.debug(format, args); } public static void info(String format, Object... args) { - log(LogLevel.INFO, () -> log.info(format, args), false); + log.info(format, args); } public static void warn(String format, Object... args) { - log(LogLevel.WARN, () -> log.warn(format, args), false); + log.warn(format, args); } public static void error(String format, Object... args) { - log(LogLevel.ERROR, () -> log.error(format, args), false); - } - - private static void log(LogLevel targetLevel, Runnable logAction) { - log(targetLevel, logAction, true); + log.error(format, args); } - private static void log(LogLevel targetLevel, Runnable logAction, boolean checkGlobalLevel) { - boolean shouldLog = shouldLog(targetLevel, checkGlobalLevel); + // ---- Runtime level control ---- - if (shouldLog) { - logAction.run(); + /** + * Returns the effective log level of the underlying logback logger. + */ + public static LogLevel getLevel() { + ch.qos.logback.classic.Logger lbLogger = logbackLogger(); + Level level = lbLogger.getLevel(); + if (level == null) { + return null; + } + try { + return LogLevel.valueOf(level.levelStr.toUpperCase()); + } catch (IllegalArgumentException e) { + return null; } } - private static boolean shouldLog(LogLevel targetLevel, boolean checkGlobalLevel) { - boolean shouldLog; - if (checkGlobalLevel) { - shouldLog = globalLogLevel != null && globalLogLevel.compareTo(targetLevel) <= 0; - } else { - // warn and error are enabled by default - shouldLog = globalLogLevel == null || globalLogLevel.compareTo(targetLevel) <= 0; - } - return shouldLog; + /** + * Sets the log level of the underlying {@code flexlbLogger}. + * A {@code null} level resets the logger to {@code INFO} (the safe production default). + */ + public static void setLevel(LogLevel level) { + Level lbLevel = level != null + ? Level.toLevel(level.name(), Level.INFO) + : Level.INFO; + logbackLogger().setLevel(lbLevel); + } + + private static ch.qos.logback.classic.Logger logbackLogger() { + return (ch.qos.logback.classic.Logger) LoggerFactory.getLogger("flexlbLogger"); } } diff --git a/rtp_llm/flexlb/flexlb-common/src/test/java/org/flexlb/config/ConfigServiceTest.java b/rtp_llm/flexlb/flexlb-common/src/test/java/org/flexlb/config/ConfigServiceTest.java index fee0b865a7..cc873c4dab 100644 --- a/rtp_llm/flexlb/flexlb-common/src/test/java/org/flexlb/config/ConfigServiceTest.java +++ b/rtp_llm/flexlb/flexlb-common/src/test/java/org/flexlb/config/ConfigServiceTest.java @@ -108,147 +108,6 @@ void should_override_cache_hit_switches_with_environment() { assertFalse(configService.loadBalanceConfig().isCacheHitTheoryLogEnabled()); } - @Test - void should_use_default_strategy_configs_without_environment() { - ConfigService configService = new ConfigService(Map.of()); - - StrategyConfigs.CandidatePoolConfig candidatePool = configService.getStrategyConfigs() - .getShortestTtft() - .getCandidatePool(); - assertEquals(StrategyConfigs.CandidatePoolMode.RATIO, candidatePool.getMode()); - assertEquals(0.3, candidatePool.getRatio()); - assertEquals(1, candidatePool.getMinSize()); - assertEquals(1, candidatePool.getSize()); - assertEquals(3, candidatePool.resolveCandidateCount(10)); - } - - @Test - void should_load_shortest_ttft_strategy_configs_from_environment() { - ConfigService configService = new ConfigService(Map.of( - "STRATEGY_CONFIGS", """ - { - "shortestTtft": { - "candidatePool": { - "mode": "FIXED", - "size": 1 - } - } - } - """)); - - StrategyConfigs.CandidatePoolConfig candidatePool = configService.getStrategyConfigs() - .getShortestTtft() - .getCandidatePool(); - assertEquals(StrategyConfigs.CandidatePoolMode.FIXED, candidatePool.getMode()); - assertEquals(1, candidatePool.getSize()); - assertEquals(0.3, candidatePool.getRatio()); - assertEquals(1, candidatePool.getMinSize()); - } - - @Test - void should_load_strategy_config_enum_case_insensitively() { - ConfigService configService = new ConfigService(Map.of( - "STRATEGY_CONFIGS", """ - { - "shortestTtft": { - "candidatePool": { - "mode": "fixed", - "size": 2 - } - } - } - """)); - - StrategyConfigs.CandidatePoolConfig candidatePool = configService.getStrategyConfigs() - .getShortestTtft() - .getCandidatePool(); - assertEquals(StrategyConfigs.CandidatePoolMode.FIXED, candidatePool.getMode()); - assertEquals(2, candidatePool.getSize()); - } - - @Test - void should_keep_strategy_config_defaults_for_missing_fields() { - ConfigService configService = new ConfigService(Map.of( - "STRATEGY_CONFIGS", """ - { - "shortestTtft": { - "candidatePool": { - "ratio": 0.5 - } - } - } - """)); - - StrategyConfigs.CandidatePoolConfig candidatePool = configService.getStrategyConfigs() - .getShortestTtft() - .getCandidatePool(); - assertEquals(StrategyConfigs.CandidatePoolMode.RATIO, candidatePool.getMode()); - assertEquals(0.5, candidatePool.getRatio()); - assertEquals(1, candidatePool.getMinSize()); - assertEquals(1, candidatePool.getSize()); - assertEquals(2, candidatePool.resolveCandidateCount(4)); - } - - @Test - void should_normalize_invalid_strategy_candidate_pool_values() { - ConfigService configService = new ConfigService(Map.of( - "STRATEGY_CONFIGS", """ - { - "shortestTtft": { - "candidatePool": { - "mode": "FIXED", - "size": 0, - "ratio": 2.0, - "minSize": 0 - } - } - } - """)); - - StrategyConfigs.CandidatePoolConfig candidatePool = configService.getStrategyConfigs() - .getShortestTtft() - .getCandidatePool(); - assertEquals(StrategyConfigs.CandidatePoolMode.FIXED, candidatePool.getMode()); - assertEquals(0.3, candidatePool.getRatio()); - assertEquals(1, candidatePool.getMinSize()); - assertEquals(1, candidatePool.getSize()); - } - - @Test - void should_normalize_invalid_strategy_candidate_pool_mode() { - ConfigService configService = new ConfigService(Map.of( - "STRATEGY_CONFIGS", """ - { - "shortestTtft": { - "candidatePool": { - "mode": "BAD", - "size": 2 - } - } - } - """)); - - StrategyConfigs.CandidatePoolConfig candidatePool = configService.getStrategyConfigs() - .getShortestTtft() - .getCandidatePool(); - assertEquals(StrategyConfigs.CandidatePoolMode.RATIO, candidatePool.getMode()); - assertEquals(2, candidatePool.getSize()); - } - - @Test - void should_fallback_default_strategy_configs_when_environment_json_is_malformed() { - ConfigService configService = new ConfigService(Map.of( - "STRATEGY_CONFIGS", "{\"shortestTtft\":")); - - StrategyConfigs.CandidatePoolConfig candidatePool = configService.getStrategyConfigs() - .getShortestTtft() - .getCandidatePool(); - assertEquals(StrategyConfigs.CandidatePoolMode.RATIO, candidatePool.getMode()); - assertEquals(0.3, candidatePool.getRatio()); - assertEquals(1, candidatePool.getMinSize()); - assertEquals(1, candidatePool.getSize()); - } - private Request request() { Request request = new Request(); request.setRequestId(12345L); diff --git a/rtp_llm/flexlb/flexlb-common/src/test/java/org/flexlb/dao/loadbalance/RequestTest.java b/rtp_llm/flexlb/flexlb-common/src/test/java/org/flexlb/dao/loadbalance/RequestTest.java index 879dc585aa..ed36565b4c 100644 --- a/rtp_llm/flexlb/flexlb-common/src/test/java/org/flexlb/dao/loadbalance/RequestTest.java +++ b/rtp_llm/flexlb/flexlb-common/src/test/java/org/flexlb/dao/loadbalance/RequestTest.java @@ -18,6 +18,9 @@ void should_deserialize_frontend_schedule_payload() throws Exception { "block_cache_keys": [1, 2, 3], "cache_key_block_size": 1024, "seq_len": 8192, + "max_new_tokens": 64, + "num_beams": 1, + "force_disable_sp_run": false, "debug": false, "request_priority": 100, "generate_timeout": 5000, @@ -31,6 +34,9 @@ void should_deserialize_frontend_schedule_payload() throws Exception { assertEquals(1024L, request.getCacheKeyBlockSize()); assertEquals(3, request.getBlockCacheKeys().size()); assertEquals(5000L, request.getGenerateTimeout()); + assertEquals(64, request.getMaxNewTokens()); + assertEquals(1, request.getNumBeams()); + assertEquals("engine_service", request.getModel()); } @Test diff --git a/rtp_llm/flexlb/flexlb-common/src/test/java/org/flexlb/dao/master/WorkerStatusTest.java b/rtp_llm/flexlb/flexlb-common/src/test/java/org/flexlb/dao/master/WorkerStatusTest.java index 95f7c6705a..abaf62bd3f 100644 --- a/rtp_llm/flexlb/flexlb-common/src/test/java/org/flexlb/dao/master/WorkerStatusTest.java +++ b/rtp_llm/flexlb/flexlb-common/src/test/java/org/flexlb/dao/master/WorkerStatusTest.java @@ -1,7 +1,5 @@ package org.flexlb.dao.master; -import org.flexlb.dao.route.RoleType; -import org.flexlb.enums.TaskStateEnum; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Nested; @@ -9,14 +7,8 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.CsvSource; -import java.util.HashMap; -import java.util.Map; - import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertNull; -import static org.junit.jupiter.api.Assertions.assertSame; import static org.junit.jupiter.api.Assertions.assertTrue; @DisplayName("WorkerStatus Hysteresis Tests") @@ -31,71 +23,6 @@ void setUp() { workerStatus.getResourceAvailable().set(true); } - @Test - @DisplayName("Duplicate putLocalTask should not double count local resources") - void duplicatePutLocalTask_shouldNotDoubleCountLocalResources() { - workerStatus.setRole(RoleType.PREFILL.getCode()); - workerStatus.getAvailableKvCacheTokens().set(1000L); - TaskInfo firstTask = createTaskInfo(1000L, 100L, 20L); - TaskInfo duplicateTask = createTaskInfo(1000L, 100L, 20L); - - workerStatus.putLocalTask(1000L, firstTask); - workerStatus.putLocalTask(1000L, duplicateTask); - - long expectedMissTokens = duplicateTask.getInputLength() - duplicateTask.getPrefixLength(); - assertEquals(1, workerStatus.getLocalTaskMap().size()); - assertSame(duplicateTask, workerStatus.getLocalTaskMap().get(1000L)); - assertEquals(duplicateTask.estimatePrefillTime(), workerStatus.getRunningQueueTime().get()); - assertEquals(1000L - expectedMissTokens, workerStatus.getAvailableKvCacheTokens().get()); - assertEquals(expectedMissTokens, workerStatus.getUsedKvCacheTokens().get()); - } - - @Test - @DisplayName("Replacing a local task should adjust resources to the new task") - void replacingLocalTask_shouldAdjustResourcesToNewTask() { - workerStatus.setRole(RoleType.PREFILL.getCode()); - workerStatus.getAvailableKvCacheTokens().set(1000L); - TaskInfo firstTask = createTaskInfo(1000L, 100L, 20L); - TaskInfo replacementTask = createTaskInfo(1000L, 200L, 50L); - - workerStatus.putLocalTask(1000L, firstTask); - workerStatus.putLocalTask(1000L, replacementTask); - - long replacementMissTokens = replacementTask.getInputLength() - replacementTask.getPrefixLength(); - assertEquals(1, workerStatus.getLocalTaskMap().size()); - assertSame(replacementTask, workerStatus.getLocalTaskMap().get(1000L)); - assertEquals(replacementTask.estimatePrefillTime(), workerStatus.getRunningQueueTime().get()); - assertEquals(1000L - replacementMissTokens, workerStatus.getAvailableKvCacheTokens().get()); - assertEquals(replacementMissTokens, workerStatus.getUsedKvCacheTokens().get()); - - workerStatus.removeLocalTask(1000L); - - assertEquals(0, workerStatus.getLocalTaskMap().size()); - assertEquals(0L, workerStatus.getRunningQueueTime().get()); - assertEquals(1000L, workerStatus.getAvailableKvCacheTokens().get()); - assertEquals(0L, workerStatus.getUsedKvCacheTokens().get()); - } - - @Test - @DisplayName("Local pending task count should include in-transit and confirmed tasks only") - void localPendingTaskCount_shouldIncludeOnlyPendingStates() { - TaskInfo inTransitTask = createTaskInfo(1000L, 100L, 0L); - inTransitTask.updateTaskState(TaskStateEnum.IN_TRANSIT); - TaskInfo confirmedTask = createTaskInfo(1001L, 100L, 0L); - confirmedTask.updateTaskState(TaskStateEnum.CONFIRMED); - TaskInfo runningTask = createTaskInfo(1002L, 100L, 0L); - runningTask.updateTaskState(TaskStateEnum.RUNNING); - TaskInfo lostTask = createTaskInfo(1003L, 100L, 0L); - lostTask.updateTaskState(TaskStateEnum.LOST); - - workerStatus.getLocalTaskMap().put(1000L, inTransitTask); - workerStatus.getLocalTaskMap().put(1001L, confirmedTask); - workerStatus.getLocalTaskMap().put(1002L, runningTask); - workerStatus.getLocalTaskMap().put(1003L, lostTask); - - assertEquals(2L, workerStatus.getLocalPendingTaskCount()); - } - @ParameterizedTest @CsvSource({ // currentState, currentMetric, upperThreshold, hysteresisBias, expectedResult @@ -407,177 +334,4 @@ void zeroCurrentMetric() { } } - @Nested - @DisplayName("updateTaskStates - waiting task handling") - class UpdateTaskStatesTests { - - private static final Long REQUEST_ID = 1000L; - - @BeforeEach - void setUpWorkerStatus() { - workerStatus.setRole(RoleType.PREFILL.getCode()); - } - - @Test - @DisplayName("Task in waiting list only: IN_TRANSIT becomes CONFIRMED and fields updated from waiting task") - void taskInWaitingOnly_shouldBecomeConfirmedAndSyncFields() { - TaskInfo localTask = new TaskInfo(); - localTask.setRequestId(REQUEST_ID); - localTask.setInputLength(200); - localTask.setPrefixLength(0); - workerStatus.putLocalTask(REQUEST_ID, localTask); - - TaskInfo waitingTask = new TaskInfo(); - waitingTask.setRequestId(REQUEST_ID); - waitingTask.setPrefixLength(50); - waitingTask.setInputLength(200); - waitingTask.setWaitingTime(100); - waitingTask.setDpRank(1); - Map waitingTaskInfo = new HashMap<>(); - waitingTaskInfo.put(String.valueOf(REQUEST_ID), waitingTask); - - workerStatus.updateTaskStates(waitingTaskInfo, new HashMap<>(), new HashMap<>()); - - TaskInfo updated = workerStatus.getLocalTaskMap().get(REQUEST_ID); - assertNotNull(updated, "Task should remain in local map"); - assertEquals(TaskStateEnum.CONFIRMED, updated.getTaskState()); - assertEquals(50, updated.getPrefixLength()); - assertEquals(200, updated.getInputLength()); - assertEquals(100, updated.getWaitingTime()); - assertEquals(1, updated.getDpRank()); - } - - @Test - @DisplayName("Task in waiting list with null running and finished maps should not NPE") - void taskInWaitingWithNullMaps_shouldNotThrow() { - TaskInfo localTask = new TaskInfo(); - localTask.setRequestId(REQUEST_ID); - workerStatus.putLocalTask(REQUEST_ID, localTask); - - Map waitingTaskInfo = new HashMap<>(); - waitingTaskInfo.put(String.valueOf(REQUEST_ID), new TaskInfo()); - - workerStatus.updateTaskStates(waitingTaskInfo, null, null); - - TaskInfo updated = workerStatus.getLocalTaskMap().get(REQUEST_ID); - assertNotNull(updated); - assertEquals(TaskStateEnum.CONFIRMED, updated.getTaskState()); - } - - @Test - @DisplayName("Task CONFIRMED but not in waiting/running/finished should be marked LOST") - void taskConfirmedButNotInAnyList_shouldBeMarkedLost() { - TaskInfo localTask = new TaskInfo(); - localTask.setRequestId(REQUEST_ID); - localTask.updateTaskState(TaskStateEnum.CONFIRMED); - workerStatus.getLocalTaskMap().put(REQUEST_ID, localTask); - - workerStatus.updateTaskStates(new HashMap<>(), new HashMap<>(), new HashMap<>()); - - TaskInfo updated = workerStatus.getLocalTaskMap().get(REQUEST_ID); - assertNotNull(updated); - assertTrue(updated.isLost()); - } - - @Test - @DisplayName("Task in finished list should be removed from local map") - void taskInFinishedList_shouldBeRemoved() { - TaskInfo localTask = new TaskInfo(); - localTask.setRequestId(REQUEST_ID); - localTask.setInputLength(100); - localTask.setPrefixLength(0); - workerStatus.putLocalTask(REQUEST_ID, localTask); - - TaskInfo finishedTask = new TaskInfo(); - finishedTask.setRequestId(REQUEST_ID); - finishedTask.setEndTimeMs(System.currentTimeMillis()); - Map finishedTaskInfo = new HashMap<>(); - finishedTaskInfo.put(String.valueOf(REQUEST_ID), finishedTask); - - workerStatus.updateTaskStates(new HashMap<>(), new HashMap<>(), finishedTaskInfo); - - assertNull(workerStatus.getLocalTaskMap().get(REQUEST_ID)); - } - - @Test - @DisplayName("Task in running list should become RUNNING and sync fields") - void taskInRunningList_shouldBecomeRunningAndSyncFields() { - TaskInfo localTask = new TaskInfo(); - localTask.setRequestId(REQUEST_ID); - workerStatus.putLocalTask(REQUEST_ID, localTask); - - TaskInfo runningTask = new TaskInfo(); - runningTask.setRequestId(REQUEST_ID); - runningTask.setPrefixLength(100); - runningTask.setInputLength(200); - runningTask.setPrefillTime(50); - runningTask.setIterateCount(2); - runningTask.setEndTimeMs(12345L); - runningTask.setDpRank(0); - Map runningTaskInfo = new HashMap<>(); - runningTaskInfo.put(String.valueOf(REQUEST_ID), runningTask); - - workerStatus.updateTaskStates(new HashMap<>(), runningTaskInfo, new HashMap<>()); - - TaskInfo updated = workerStatus.getLocalTaskMap().get(REQUEST_ID); - assertNotNull(updated); - assertEquals(TaskStateEnum.RUNNING, updated.getTaskState()); - assertEquals(100, updated.getPrefixLength()); - assertEquals(200, updated.getInputLength()); - assertEquals(50, updated.getPrefillTime()); - assertEquals(2, updated.getIterateCount()); - assertEquals(12345L, updated.getEndTimeMs()); - } - - @Test - @DisplayName("Task in waiting then in running on next call should be RUNNING") - void taskInWaitingThenInRunning_shouldBeRunning() { - TaskInfo localTask = new TaskInfo(); - localTask.setRequestId(REQUEST_ID); - workerStatus.putLocalTask(REQUEST_ID, localTask); - - Map waitingTaskInfo = new HashMap<>(); - waitingTaskInfo.put(String.valueOf(REQUEST_ID), new TaskInfo()); - workerStatus.updateTaskStates(waitingTaskInfo, new HashMap<>(), new HashMap<>()); - assertEquals(TaskStateEnum.CONFIRMED, workerStatus.getLocalTaskMap().get(REQUEST_ID).getTaskState()); - - Map runningTaskInfo = new HashMap<>(); - TaskInfo runningTask = new TaskInfo(); - runningTask.setRequestId(REQUEST_ID); - runningTaskInfo.put(String.valueOf(REQUEST_ID), runningTask); - workerStatus.updateTaskStates(new HashMap<>(), runningTaskInfo, new HashMap<>()); - - assertEquals(TaskStateEnum.RUNNING, workerStatus.getLocalTaskMap().get(REQUEST_ID).getTaskState()); - } - - @Test - @DisplayName("Finished takes precedence over waiting when task in both") - void taskInFinishedAndWaiting_shouldBeRemovedAsFinished() { - TaskInfo localTask = new TaskInfo(); - localTask.setRequestId(REQUEST_ID); - workerStatus.putLocalTask(REQUEST_ID, localTask); - - TaskInfo finishedTask = new TaskInfo(); - finishedTask.setRequestId(REQUEST_ID); - finishedTask.setEndTimeMs(1); - TaskInfo waitingTask = new TaskInfo(); - waitingTask.setRequestId(REQUEST_ID); - Map finishedTaskInfo = new HashMap<>(); - finishedTaskInfo.put(String.valueOf(REQUEST_ID), finishedTask); - Map waitingTaskInfo = new HashMap<>(); - waitingTaskInfo.put(String.valueOf(REQUEST_ID), waitingTask); - - workerStatus.updateTaskStates(waitingTaskInfo, new HashMap<>(), finishedTaskInfo); - - assertNull(workerStatus.getLocalTaskMap().get(REQUEST_ID)); - } - } - - private TaskInfo createTaskInfo(long requestId, long inputLength, long prefixLength) { - TaskInfo taskInfo = new TaskInfo(); - taskInfo.setRequestId(requestId); - taskInfo.setInputLength(inputLength); - taskInfo.setPrefixLength(prefixLength); - return taskInfo; - } } diff --git a/rtp_llm/flexlb/flexlb-common/src/test/java/org/flexlb/util/LoggerTest.java b/rtp_llm/flexlb/flexlb-common/src/test/java/org/flexlb/util/LoggerTest.java index 34cb38b463..25fe9ec70d 100644 --- a/rtp_llm/flexlb/flexlb-common/src/test/java/org/flexlb/util/LoggerTest.java +++ b/rtp_llm/flexlb/flexlb-common/src/test/java/org/flexlb/util/LoggerTest.java @@ -1,15 +1,15 @@ package org.flexlb.util; +import ch.qos.logback.classic.Level; import org.flexlb.enums.LogLevel; -import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; +import org.slf4j.LoggerFactory; -import java.lang.reflect.Method; import java.util.stream.Stream; import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; @@ -23,197 +23,75 @@ class LoggerTest { @BeforeEach void setUp() { - // Reset globalLogLevel before each test - Logger.setGlobalLogLevel(null); + Logger.setLevel(null); } - @Test - @DisplayName("shouldLog - when checkGlobalLevel=true and globalLogLevel is null") - void shouldLog_checkGlobalLevel_true_globalLogLevel_null() throws Exception { - // Arrange - Logger.setGlobalLogLevel(null); - Method shouldLogMethod = Logger.class.getDeclaredMethod("shouldLog", LogLevel.class, boolean.class); - shouldLogMethod.setAccessible(true); - - // Act & Assert - // When globalLogLevel is null and checkGlobalLevel is true, should return false for any level - for (LogLevel level : LogLevel.values()) { - boolean result = (boolean) shouldLogMethod.invoke(null, level, true); - Assertions.assertFalse(result, "Should return false for " + level + " when globalLogLevel is null and checkGlobalLevel=true"); - } - } + // ---- setLevel / getLevel sync tests ---- @Test - @DisplayName("shouldLog - when checkGlobalLevel=false and globalLogLevel is null") - void shouldLog_checkGlobalLevel_false_globalLogLevel_null() throws Exception { - // Arrange - Logger.setGlobalLogLevel(null); - Method shouldLogMethod = Logger.class.getDeclaredMethod("shouldLog", LogLevel.class, boolean.class); - shouldLogMethod.setAccessible(true); - - // Act & Assert - // When globalLogLevel is null and checkGlobalLevel is false, should return true for any level (warn and error enabled by default) - for (LogLevel level : LogLevel.values()) { - boolean result = (boolean) shouldLogMethod.invoke(null, level, false); - assertTrue(result, "Should return true for " + level + " when globalLogLevel is null and checkGlobalLevel=false"); - } - } - - @ParameterizedTest - @MethodSource("provideLogLevelCombinationsForCheckGlobalLevelTrue") - @DisplayName("shouldLog - checkGlobalLevel=true with different log level combinations") - void shouldLog_checkGlobalLevel_true_with_levels(LogLevel globalLevel, LogLevel targetLevel, boolean expected) throws Exception { - // Arrange - Logger.setGlobalLogLevel(globalLevel); - Method shouldLogMethod = Logger.class.getDeclaredMethod("shouldLog", LogLevel.class, boolean.class); - shouldLogMethod.setAccessible(true); + @DisplayName("setLevel(null) resets logback level to INFO") + void setLevel_null_resetsLogbackToInfo() { + Logger.setLevel(LogLevel.DEBUG); + assertEquals(Level.DEBUG, logbackLevel()); - // Act - boolean result = (boolean) shouldLogMethod.invoke(null, targetLevel, true); + Logger.setLevel(null); - // Assert - assertEquals(expected, result, - String.format("For globalLevel=%s, targetLevel=%s, expected=%s", - globalLevel, targetLevel, expected)); + assertEquals(LogLevel.INFO, Logger.getLevel()); + assertEquals(Level.INFO, logbackLevel(), + "logback level should reset to INFO when passed null"); } @ParameterizedTest - @MethodSource("provideLogLevelCombinationsForCheckGlobalLevelFalse") - @DisplayName("shouldLog - checkGlobalLevel=false with different log level combinations") - void shouldLog_checkGlobalLevel_false_with_levels(LogLevel globalLevel, LogLevel targetLevel, boolean expected) throws Exception { - // Arrange - Logger.setGlobalLogLevel(globalLevel); - Method shouldLogMethod = Logger.class.getDeclaredMethod("shouldLog", LogLevel.class, boolean.class); - shouldLogMethod.setAccessible(true); - - // Act - boolean result = (boolean) shouldLogMethod.invoke(null, targetLevel, false); - - // Assert - assertEquals(expected, result, - String.format("For globalLevel=%s, targetLevel=%s, expected=%s", - globalLevel, targetLevel, expected)); - } - - static Stream provideLogLevelCombinationsForCheckGlobalLevelTrue() { - return Stream.of( - // globalLogLevel=TRACE should allow all levels - Arguments.of(LogLevel.TRACE, LogLevel.TRACE, true), - Arguments.of(LogLevel.TRACE, LogLevel.DEBUG, true), - Arguments.of(LogLevel.TRACE, LogLevel.INFO, true), - Arguments.of(LogLevel.TRACE, LogLevel.WARN, true), - Arguments.of(LogLevel.TRACE, LogLevel.ERROR, true), - - // globalLogLevel=DEBUG should allow DEBUG and above - Arguments.of(LogLevel.DEBUG, LogLevel.TRACE, false), - Arguments.of(LogLevel.DEBUG, LogLevel.DEBUG, true), - Arguments.of(LogLevel.DEBUG, LogLevel.INFO, true), - Arguments.of(LogLevel.DEBUG, LogLevel.WARN, true), - Arguments.of(LogLevel.DEBUG, LogLevel.ERROR, true), - - // globalLogLevel=INFO should allow INFO and above - Arguments.of(LogLevel.INFO, LogLevel.TRACE, false), - Arguments.of(LogLevel.INFO, LogLevel.DEBUG, false), - Arguments.of(LogLevel.INFO, LogLevel.INFO, true), - Arguments.of(LogLevel.INFO, LogLevel.WARN, true), - Arguments.of(LogLevel.INFO, LogLevel.ERROR, true), - - // globalLogLevel=WARN should allow TO WARN and above - Arguments.of(LogLevel.WARN, LogLevel.TRACE, false), - Arguments.of(LogLevel.WARN, LogLevel.DEBUG, false), - Arguments.of(LogLevel.WARN, LogLevel.INFO, false), - Arguments.of(LogLevel.WARN, LogLevel.WARN, true), - Arguments.of(LogLevel.WARN, LogLevel.ERROR, true), - - // globalLogLevel=ERROR should only allow ERROR - Arguments.of(LogLevel.ERROR, LogLevel.TRACE, false), - Arguments.of(LogLevel.ERROR, LogLevel.DEBUG, false), - Arguments.of(LogLevel.ERROR, LogLevel.INFO, false), - Arguments.of(LogLevel.ERROR, LogLevel.WARN, false), - Arguments.of(LogLevel.ERROR, LogLevel.ERROR, true) - ); + @MethodSource("provideLogLevelToLogbackMappings") + @DisplayName("setLevel syncs logback level correctly") + void setLevel_syncsLogbackLevel(LogLevel inputLevel, Level expectedLogbackLevel) { + Logger.setLevel(inputLevel); + + assertEquals(inputLevel, Logger.getLevel()); + assertEquals(expectedLogbackLevel, logbackLevel(), + "logback level should match after setLevel(" + inputLevel + ")"); } - static Stream provideLogLevelCombinationsForCheckGlobalLevelFalse() { + static Stream provideLogLevelToLogbackMappings() { return Stream.of( - // When checkGlobalLevel=false, behavior should be same as checkGlobalLevel=true when globalLogLevel is set - // globalLogLevel=TRACE should allow all levels - Arguments.of(LogLevel.TRACE, LogLevel.TRACE, true), - Arguments.of(LogLevel.TRACE, LogLevel.DEBUG, true), - Arguments.of(LogLevel.TRACE, LogLevel.INFO, true), - Arguments.of(LogLevel.TRACE, LogLevel.WARN, true), - Arguments.of(LogLevel.TRACE, LogLevel.ERROR, true), - - // globalLogLevel=DEBUG should allow DEBUG and above - Arguments.of(LogLevel.DEBUG, LogLevel.TRACE, false), - Arguments.of(LogLevel.DEBUG, LogLevel.DEBUG, true), - Arguments.of(LogLevel.DEBUG, LogLevel.INFO, true), - Arguments.of(LogLevel.DEBUG, LogLevel.WARN, true), - Arguments.of(LogLevel.DEBUG, LogLevel.ERROR, true), - - // globalLogLevel=INFO should allow INFO and above - Arguments.of(LogLevel.INFO, LogLevel.TRACE, false), - Arguments.of(LogLevel.INFO, LogLevel.DEBUG, false), - Arguments.of(LogLevel.INFO, LogLevel.INFO, true), - Arguments.of(LogLevel.INFO, LogLevel.WARN, true), - Arguments.of(LogLevel.INFO, LogLevel.ERROR, true), - - // globalLogLevel=WARN should allow TO WARN and above - Arguments.of(LogLevel.WARN, LogLevel.TRACE, false), - Arguments.of(LogLevel.WARN, LogLevel.DEBUG, false), - Arguments.of(LogLevel.WARN, LogLevel.INFO, false), - Arguments.of(LogLevel.WARN, LogLevel.WARN, true), - Arguments.of(LogLevel.WARN, LogLevel.ERROR, true), - - // globalLogLevel=ERROR should only allow ERROR - Arguments.of(LogLevel.ERROR, LogLevel.TRACE, false), - Arguments.of(LogLevel.ERROR, LogLevel.DEBUG, false), - Arguments.of(LogLevel.ERROR, LogLevel.INFO, false), - Arguments.of(LogLevel.ERROR, LogLevel.WARN, false), - Arguments.of(LogLevel.ERROR, LogLevel.ERROR, true) + Arguments.of(LogLevel.TRACE, Level.TRACE), + Arguments.of(LogLevel.DEBUG, Level.DEBUG), + Arguments.of(LogLevel.INFO, Level.INFO), + Arguments.of(LogLevel.WARN, Level.WARN), + Arguments.of(LogLevel.ERROR, Level.ERROR) ); } @Test - @DisplayName("globalLogLevel getter and setter work correctly") - void globalLogLevel_getterSetter() { - // Test setter and getter - assertNull(Logger.getGlobalLogLevel()); + @DisplayName("getLevel reads back the logback level") + void getLevel_readsLogbackLevel() { + assertEquals(LogLevel.INFO, Logger.getLevel()); - Logger.setGlobalLogLevel(LogLevel.INFO); - assertEquals(LogLevel.INFO, Logger.getGlobalLogLevel()); + Logger.setLevel(LogLevel.INFO); + assertEquals(LogLevel.INFO, Logger.getLevel()); - Logger.setGlobalLogLevel(LogLevel.DEBUG); - assertEquals(LogLevel.DEBUG, Logger.getGlobalLogLevel()); + Logger.setLevel(LogLevel.DEBUG); + assertEquals(LogLevel.DEBUG, Logger.getLevel()); - Logger.setGlobalLogLevel(null); - assertNull(Logger.getGlobalLogLevel()); + Logger.setLevel(null); + assertEquals(LogLevel.INFO, Logger.getLevel()); } @Test @DisplayName("Static block - reads LOG_LEVEL environment variable on class loading") void staticBlock_readsEnvVar() { - // This test verifies that the static block has processed the LOG_LEVEL environment variable - // Since the static block runs when the class is first loaded, we can only verify the current state - String currentLogLevel = System.getenv("LOG_LEVEL"); if (currentLogLevel == null) { - // If no LOG_LEVEL is set, globalLogLevel should be null (default) - // Note: We can't easily test this since the class is already loaded when test runs - // But we can verify the current behavior assertTrue(true, "No LOG_LEVEL environment variable set"); } else { - // If LOG_LEVEL is set, verify it was processed correctly try { LogLevel expectedLevel = LogLevel.valueOf(currentLogLevel.toUpperCase().trim()); - assertEquals(expectedLevel, Logger.getGlobalLogLevel(), - "Static block should have processed LOG_LEVEL environment variable: " + currentLogLevel); + assertEquals(expectedLevel, Logger.getLevel(), + "Static block should have processed LOG_LEVEL: " + currentLogLevel); } catch (IllegalArgumentException e) { - // If current LOG_LEVEL is invalid, globalLogLevel should be null - assertNull( - Logger.getGlobalLogLevel(), - "Invalid LOG_LEVEL should result in null globalLogLevel: " + currentLogLevel); + assertNull(Logger.getLevel(), + "Invalid LOG_LEVEL should result in null: " + currentLogLevel); } } } @@ -222,9 +100,6 @@ void staticBlock_readsEnvVar() { @MethodSource("provideCaseInsensitiveLogLevels") @DisplayName("LogLevel.valueOf with case-insensitive processing") void logLevelValueOf_caseInsensitive(String input, LogLevel expected) { - // Test the logic that the constructor uses internally - // This verifies that toUpperCase().trim() works correctly for LogLevel.valueOf() - LogLevel result = LogLevel.valueOf(input.toUpperCase().trim()); assertEquals(expected, result); } @@ -232,49 +107,39 @@ void logLevelValueOf_caseInsensitive(String input, LogLevel expected) { @Test @DisplayName("LogLevel.valueOf throws IllegalArgumentException for invalid values") void logLevelValueOf_invalidValue() { - assertThrows(IllegalArgumentException.class, () -> LogLevel.valueOf("INVALID".toUpperCase().trim()), "Should throw IllegalArgumentException for invalid log level"); + assertThrows(IllegalArgumentException.class, + () -> LogLevel.valueOf("INVALID".toUpperCase().trim()), + "Should throw for invalid log level"); - assertThrows(IllegalArgumentException.class, () -> LogLevel.valueOf("".toUpperCase().trim()), "Should throw IllegalArgumentException for empty string"); + assertThrows(IllegalArgumentException.class, + () -> LogLevel.valueOf("".toUpperCase().trim()), + "Should throw for empty string"); } @Test @DisplayName("Static method calls work without creating instances") void staticMethods_workWithoutInstances() { - // Verify that static methods can be called without creating instances - // This demonstrates that the static block initialization works correctly - - // These calls should work fine and use the globalLogLevel set by static block assertDoesNotThrow(() -> { Logger.info("Test info message"); Logger.debug("Test debug message"); Logger.warn("Test warn message"); Logger.error("Test error message"); Logger.trace("Test trace message"); - }, "Static logging methods should work without creating Logger instances"); + }, "Static logging methods should work"); - // Verify that globalLogLevel is accessible assertNotNull(Logger.class, "Logger class should be loaded"); - // The getter should work (globalLogLevel might be null, which is fine) - assertDoesNotThrow( - Logger::getGlobalLogLevel, - "getGlobalLogLevel should work without creating instances"); + assertDoesNotThrow(Logger::getLevel, "getLevel should work"); } @Test @DisplayName("Static block logic handles case-insensitive and whitespace correctly") void staticBlock_logicVerification() { - // Test the internal logic that static block uses - // We'll test the case-insensitive and trim logic directly - - // Test valid values with different cases and whitespace String[] testInputs = {"debug", "DEBUG", "Debug", " INFO ", "warn", "ERROR"}; LogLevel[] expectedOutputs = {LogLevel.DEBUG, LogLevel.DEBUG, LogLevel.DEBUG, LogLevel.INFO, LogLevel.WARN, LogLevel.ERROR}; for (int i = 0; i < testInputs.length; i++) { String input = testInputs[i]; LogLevel expected = expectedOutputs[i]; - - // This tests the exact logic used in the static block LogLevel result = LogLevel.valueOf(input.toUpperCase().trim()); assertEquals(expected, result, "Failed for input: '" + input + "'"); } @@ -282,36 +147,41 @@ void staticBlock_logicVerification() { static Stream provideCaseInsensitiveLogLevels() { return Stream.of( - // Test different cases for each log level - Arguments.of("trace", LogLevel.TRACE), - Arguments.of("TRACE", LogLevel.TRACE), - Arguments.of("Trace", LogLevel.TRACE), - Arguments.of("TrAcE", LogLevel.TRACE), - Arguments.of(" trace ", LogLevel.TRACE), - - Arguments.of("debug", LogLevel.DEBUG), - Arguments.of("DEBUG", LogLevel.DEBUG), - Arguments.of("Debug", LogLevel.DEBUG), - Arguments.of("DeBuG", LogLevel.DEBUG), - Arguments.of(" DEBUG ", LogLevel.DEBUG), - - Arguments.of("info", LogLevel.INFO), - Arguments.of("INFO", LogLevel.INFO), - Arguments.of("Info", LogLevel.INFO), - Arguments.of("InFo", LogLevel.INFO), - Arguments.of(" info ", LogLevel.INFO), - - Arguments.of("warn", LogLevel.WARN), - Arguments.of("WARN", LogLevel.WARN), - Arguments.of("Warn", LogLevel.WARN), - Arguments.of("WaRn", LogLevel.WARN), - Arguments.of(" WARN ", LogLevel.WARN), - - Arguments.of("error", LogLevel.ERROR), - Arguments.of("ERROR", LogLevel.ERROR), - Arguments.of("Error", LogLevel.ERROR), - Arguments.of("ErRoR", LogLevel.ERROR), - Arguments.of(" ERROR ", LogLevel.ERROR) + Arguments.of("trace", LogLevel.TRACE), + Arguments.of("TRACE", LogLevel.TRACE), + Arguments.of("Trace", LogLevel.TRACE), + Arguments.of("TrAcE", LogLevel.TRACE), + Arguments.of(" trace ", LogLevel.TRACE), + + Arguments.of("debug", LogLevel.DEBUG), + Arguments.of("DEBUG", LogLevel.DEBUG), + Arguments.of("Debug", LogLevel.DEBUG), + Arguments.of("DeBuG", LogLevel.DEBUG), + Arguments.of(" DEBUG ", LogLevel.DEBUG), + + Arguments.of("info", LogLevel.INFO), + Arguments.of("INFO", LogLevel.INFO), + Arguments.of("Info", LogLevel.INFO), + Arguments.of("InFo", LogLevel.INFO), + Arguments.of(" info ", LogLevel.INFO), + + Arguments.of("warn", LogLevel.WARN), + Arguments.of("WARN", LogLevel.WARN), + Arguments.of("Warn", LogLevel.WARN), + Arguments.of("WaRn", LogLevel.WARN), + Arguments.of(" WARN ", LogLevel.WARN), + + Arguments.of("error", LogLevel.ERROR), + Arguments.of("ERROR", LogLevel.ERROR), + Arguments.of("Error", LogLevel.ERROR), + Arguments.of("ErRoR", LogLevel.ERROR), + Arguments.of(" ERROR ", LogLevel.ERROR) ); } -} \ No newline at end of file + + private static ch.qos.logback.classic.Level logbackLevel() { + ch.qos.logback.classic.Logger lbLogger = + (ch.qos.logback.classic.Logger) LoggerFactory.getLogger("flexlbLogger"); + return lbLogger.getLevel(); + } +} diff --git a/rtp_llm/flexlb/flexlb-grpc/src/main/java/org/flexlb/engine/grpc/AbstractGrpcClient.java b/rtp_llm/flexlb/flexlb-grpc/src/main/java/org/flexlb/engine/grpc/AbstractGrpcClient.java index 972dc822d1..d45a2c0025 100644 --- a/rtp_llm/flexlb/flexlb-grpc/src/main/java/org/flexlb/engine/grpc/AbstractGrpcClient.java +++ b/rtp_llm/flexlb/flexlb-grpc/src/main/java/org/flexlb/engine/grpc/AbstractGrpcClient.java @@ -85,18 +85,12 @@ private void updateGrpcChannelPool(List ipPortList) { int httpPort = Integer.parseInt(parts[1]); int grpcPort = CommonUtils.toGrpcPort(httpPort); - String workerStatusKey = createKey(ip, grpcPort, ServiceType.WORKER_STATUS); - String cacheStatusKey = createKey(ip, grpcPort, ServiceType.CACHE_STATUS); - String multimodalWorkerStatusKey = createKey(ip, grpcPort, ServiceType.MULTIMODAL_WORKER_STATUS); - String multimodalCacheStatusKey = createKey(ip, grpcPort, ServiceType.MULTIMODAL_CACHE_STATUS); - boolean contained = currentKeys.remove(workerStatusKey) && currentKeys.remove(cacheStatusKey) - && currentKeys.remove(multimodalWorkerStatusKey) && currentKeys.remove(multimodalCacheStatusKey); - - if (!contained) { - addedKeys.add(workerStatusKey); - addedKeys.add(cacheStatusKey); - addedKeys.add(multimodalWorkerStatusKey); - addedKeys.add(multimodalCacheStatusKey); + for (ServiceType serviceType : ServiceType.values()) { + String serviceKey = createKey(ip, grpcPort, serviceType); + boolean contained = currentKeys.remove(serviceKey); + if (!contained) { + addedKeys.add(serviceKey); + } } } @@ -327,7 +321,9 @@ public enum ServiceType { WORKER_STATUS("worker", "GetWorkerStatus"), CACHE_STATUS("cache", "GetCacheStatus"), MULTIMODAL_WORKER_STATUS("multimodal_worker", "GetWorkerStatus"), - MULTIMODAL_CACHE_STATUS("multimodal_cache", "GetCacheStatus"); + MULTIMODAL_CACHE_STATUS("multimodal_cache", "GetCacheStatus"), + BATCH_ENQUEUE("batch_enqueue", "EnqueueBatch"), + CANCEL("cancel", "Cancel"); @Getter private final String suffix; diff --git a/rtp_llm/flexlb/flexlb-grpc/src/main/java/org/flexlb/engine/grpc/EngineGrpcClient.java b/rtp_llm/flexlb/flexlb-grpc/src/main/java/org/flexlb/engine/grpc/EngineGrpcClient.java index 2fb3ba5f3e..a25900e482 100644 --- a/rtp_llm/flexlb/flexlb-grpc/src/main/java/org/flexlb/engine/grpc/EngineGrpcClient.java +++ b/rtp_llm/flexlb/flexlb-grpc/src/main/java/org/flexlb/engine/grpc/EngineGrpcClient.java @@ -177,6 +177,26 @@ public EngineRpcService.CacheStatusPB getMultimodalCacheStatus(String ip, int po return executeGrpcCall(ip, port, stub -> stub.getMultimodalRpcServiceStub().getCacheStatus(request), requestTimeoutMs, ServiceType.MULTIMODAL_CACHE_STATUS); } + /** + * Submit a batch of already-routed requests to a Prefill worker. + */ + public EngineRpcService.EnqueueBatchResponsePB batchEnqueue(String ip, + int port, + EngineRpcService.EnqueueBatchRequestPB request, + long requestTimeoutMs) { + return executeGrpcCall(ip, port, stub -> stub.getRpcServiceStub().enqueueBatch(request), requestTimeoutMs, ServiceType.BATCH_ENQUEUE); + } + + /** + * Cancel a request previously submitted through EnqueueBatch. + */ + public EngineRpcService.EmptyPB cancel(String ip, int port, long requestId, long requestTimeoutMs) { + EngineRpcService.CancelRequestPB request = EngineRpcService.CancelRequestPB.newBuilder() + .setRequestId(requestId) + .build(); + return executeGrpcCall(ip, port, stub -> stub.getRpcServiceStub().cancel(request), requestTimeoutMs, ServiceType.CANCEL); + } + @Override protected ManagedChannel createChannel(String channelKey) { String[] parts = parseServiceKey(channelKey); @@ -218,4 +238,4 @@ protected GrpcStubWrapper createStub(ManagedChannel channel) { MultimodalRpcServiceGrpc.newBlockingStub(channel) ); } -} \ No newline at end of file +} diff --git a/rtp_llm/flexlb/flexlb-grpc/src/main/java/org/flexlb/engine/grpc/RoleTypeProtoConverter.java b/rtp_llm/flexlb/flexlb-grpc/src/main/java/org/flexlb/engine/grpc/RoleTypeProtoConverter.java new file mode 100644 index 0000000000..ad41387a80 --- /dev/null +++ b/rtp_llm/flexlb/flexlb-grpc/src/main/java/org/flexlb/engine/grpc/RoleTypeProtoConverter.java @@ -0,0 +1,41 @@ +package org.flexlb.engine.grpc; + +import org.flexlb.dao.route.RoleType; + +/** + * Bidirectional converter between {@link RoleType} (domain enum) and + * {@link EngineRpcService.RoleTypePB} (proto-generated enum). + * + *

Lives in flexlb-grpc (not flexlb-common) to avoid a reverse dependency: + * flexlb-common must not depend on flexlb-grpc.

+ */ +public final class RoleTypeProtoConverter { + + private RoleTypeProtoConverter() { + } + + /** + * Convert proto enum to domain {@link RoleType}. + */ + public static RoleType fromProto(EngineRpcService.RoleTypePB proto) { + return switch (proto) { + case ROLE_TYPE_PDFUSION -> RoleType.PDFUSION; + case ROLE_TYPE_PREFILL -> RoleType.PREFILL; + case ROLE_TYPE_DECODE -> RoleType.DECODE; + case ROLE_TYPE_VIT -> RoleType.VIT; + default -> null; + }; + } + + /** + * Convert domain {@link RoleType} to proto enum. + */ + public static EngineRpcService.RoleTypePB toProto(RoleType role) { + return switch (role) { + case PDFUSION -> EngineRpcService.RoleTypePB.ROLE_TYPE_PDFUSION; + case PREFILL -> EngineRpcService.RoleTypePB.ROLE_TYPE_PREFILL; + case DECODE -> EngineRpcService.RoleTypePB.ROLE_TYPE_DECODE; + case VIT -> EngineRpcService.RoleTypePB.ROLE_TYPE_VIT; + }; + } +} diff --git a/rtp_llm/flexlb/flexlb-grpc/src/test/java/org/flexlb/engine/grpc/AbstractGrpcClientTest.java b/rtp_llm/flexlb/flexlb-grpc/src/test/java/org/flexlb/engine/grpc/AbstractGrpcClientTest.java new file mode 100644 index 0000000000..9b372695d4 --- /dev/null +++ b/rtp_llm/flexlb/flexlb-grpc/src/test/java/org/flexlb/engine/grpc/AbstractGrpcClientTest.java @@ -0,0 +1,57 @@ +package org.flexlb.engine.grpc; + +import io.grpc.ManagedChannel; +import org.flexlb.cache.core.EngineLocalView; +import org.flexlb.cache.core.GlobalCacheIndex; +import org.flexlb.engine.grpc.monitor.GrpcReporter; +import org.junit.jupiter.api.Test; + +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.mock; + +class AbstractGrpcClientTest { + + @Test + void address_update_manages_all_rpc_service_types() { + TestGrpcClient client = new TestGrpcClient(); + + client.onAddressUpdate(List.of("10.0.0.1:8080")); + + for (AbstractGrpcClient.ServiceType serviceType : AbstractGrpcClient.ServiceType.values()) { + assertTrue(client.hasChannel("10.0.0.1", 8081, serviceType)); + } + assertEquals(AbstractGrpcClient.ServiceType.values().length, client.channelCount()); + + client.onAddressUpdate(List.of("10.0.0.1:8080")); + + assertEquals(AbstractGrpcClient.ServiceType.values().length, client.channelCount()); + } + + private static final class TestGrpcClient extends AbstractGrpcClient { + + private TestGrpcClient() { + super(mock(EngineLocalView.class), mock(GlobalCacheIndex.class), mock(GrpcReporter.class)); + } + + @Override + protected ManagedChannel createChannel(String channelKey) { + return mock(ManagedChannel.class); + } + + @Override + protected AbstractGrpcClient.GrpcStubWrapper createStub(ManagedChannel channel) { + return new AbstractGrpcClient.GrpcStubWrapper(null, null); + } + + private boolean hasChannel(String ip, int port, ServiceType serviceType) { + return channelPool.containsKey(createKey(ip, port, serviceType)); + } + + private int channelCount() { + return channelPool.size(); + } + } +} diff --git a/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/endpoint/BatchInflight.java b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/endpoint/BatchInflight.java new file mode 100644 index 0000000000..45cf1951a0 --- /dev/null +++ b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/endpoint/BatchInflight.java @@ -0,0 +1,91 @@ +package org.flexlb.balance.endpoint; + +import org.flexlb.balance.scheduler.BatchItem; +import org.flexlb.balance.scheduler.InflightEvictor; + +import java.util.List; +import java.util.concurrent.atomic.AtomicLong; + +public final class BatchInflight implements InflightEvictor.TtlTracked { + + private final long batchId; + private final long predictTimeMs; + private final List requests; + private final long createdAtMs; + private final AtomicLong progressBaseMs; + private volatile boolean running; + private volatile long lastSeenMs; + + public BatchInflight(long batchId, long predictTimeMs, List requests) { + this(batchId, predictTimeMs, requests, System.currentTimeMillis()); + } + + private BatchInflight(long batchId, long predictTimeMs, List requests, long nowMs) { + this(batchId, predictTimeMs, requests, nowMs, nowMs, false, nowMs); + } + + private BatchInflight(long batchId, + long predictTimeMs, + List requests, + long createdAtMs, + long progressBaseMs, + boolean running, + long lastSeenMs) { + this.batchId = batchId; + this.predictTimeMs = predictTimeMs; + this.requests = requests; + this.createdAtMs = createdAtMs; + this.progressBaseMs = new AtomicLong(progressBaseMs); + this.running = running; + this.lastSeenMs = lastSeenMs; + } + + public long batchId() { + return batchId; + } + + public long predictTimeMs() { + return predictTimeMs; + } + + public List requests() { + return requests; + } + + @Override + public long createdAtMs() { + return createdAtMs; + } + + public long progressBaseMs() { + return progressBaseMs.get(); + } + + public boolean running() { + return running; + } + + public long lastSeenMs() { + return lastSeenMs; + } + + public void markQueued(long statusMs) { + if (!running) { + progressBaseMs.updateAndGet(base -> Math.max(base, statusMs)); + } + lastSeenMs = Math.max(lastSeenMs, statusMs); + } + + public void markRunning(long statusMs) { + if (!running) { + progressBaseMs.updateAndGet(base -> Math.max(base, statusMs)); + running = true; + } + lastSeenMs = Math.max(lastSeenMs, statusMs); + } + + public BatchInflight repack(long newPredictTimeMs, List newRequests) { + return new BatchInflight(batchId, newPredictTimeMs, newRequests, + createdAtMs, progressBaseMs(), running, lastSeenMs); + } +} diff --git a/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/endpoint/DecodeEndpoint.java b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/endpoint/DecodeEndpoint.java new file mode 100644 index 0000000000..5cd7081b84 --- /dev/null +++ b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/endpoint/DecodeEndpoint.java @@ -0,0 +1,169 @@ +package org.flexlb.balance.endpoint; + +import org.flexlb.balance.scheduler.InflightEvictor; +import org.flexlb.dao.master.TaskInfo; +import org.flexlb.dao.master.WorkerStatus; +import org.flexlb.dao.master.WorkerStatusResponse; +import org.flexlb.enums.TaskPhase; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicLong; + +public class DecodeEndpoint extends WorkerEndpoint { + + private static final Logger logger = LoggerFactory.getLogger("syncLogger"); + + private final ConcurrentHashMap inflightRequests = new ConcurrentHashMap<>(); + private final AtomicLong reportedKvAvailable = new AtomicLong(); + private volatile int confirmedRunningCount; + private final InflightEvictor requestEvictor; + + public DecodeEndpoint(WorkerStatus status) { + super(status); + this.requestEvictor = new InflightEvictor<>(inflightRequests, req -> {}); + } + + public void reserve(long requestId, long kvTokens) { + inflightRequests.put(requestId, new RequestInflight(requestId, kvTokens)); + } + + public void release(long requestId) { + inflightRequests.remove(requestId); + } + + @Override + public void onWorkerStatusUpdate(WorkerStatus ws, WorkerStatusResponse resp) { + super.onWorkerStatusUpdate(ws, resp); + calibrate(resp.getRunningTaskInfo(), resp.getFinishedTaskInfo(), + status.getAvailableKvCacheTokens().get()); + } + + /** + * Full calibration against worker status report. + */ + public void calibrate(Map runningTaskInfo, Map finishedTaskInfo, + long latestAvailableKvCacheTokens) { + this.reportedKvAvailable.set(latestAvailableKvCacheTokens); + + // Phase 1: process running requests — KV_ALLOCATED or RUNNING means the engine + // has taken ownership, so we can release our inflight reservation. + int kvAllocatedRequests = 0; + if (runningTaskInfo != null) { + for (TaskInfo task : runningTaskInfo.values()) { + TaskPhase phase = task.getPhase(); + if (phase == TaskPhase.KV_ALLOCATED || phase == TaskPhase.RUNNING) { + inflightRequests.remove(task.getRequestId()); + kvAllocatedRequests++; + } + } + } + + // Phase 2: process finished non-success requests + if (finishedTaskInfo != null) { + for (TaskInfo task : finishedTaskInfo.values()) { + if (task.getErrorCode() != 0) { + RequestInflight removed = inflightRequests.remove(task.getRequestId()); + if (removed == null && !isCancelError(task)) { + logger.warn("Decode calibrate: finished failed request reqId={} not in inflight, error={}", + task.getRequestId(), task.getErrorMessage()); + } + } + } + + // Phase 3: process finished success requests + for (TaskInfo task : finishedTaskInfo.values()) { + if (task.getErrorCode() == 0) { + RequestInflight removed = inflightRequests.remove(task.getRequestId()); + if (removed != null) { + logger.debug("Decode calibrate: success request reqId={} still in inflight, " + + "KV_ALLOCATED detection may have been skipped", task.getRequestId()); + } + } + } + } + + this.confirmedRunningCount = kvAllocatedRequests; + } + + // ==================== KV Cache 三视图 ==================== + + /** + * Local inflight KV reservation not yet confirmed by the engine. + * Computed on demand from the inflight map — no separate counter needed. + */ + public long inflightKvReserved() { + long sum = 0; + for (RequestInflight ri : inflightRequests.values()) { + sum += ri.kvTokens(); + } + return sum; + } + + /** + * Real KV used: engine-reported used (total - available) + local inflight reservations. + */ + public long realKvUsed() { + long totalCap = status.getTotalKvCacheTokens().get(); + long avail = status.getAvailableKvCacheTokens().get(); + long reportedUsed = totalCap > 0 ? Math.max(0, totalCap - avail) : 0; + return reportedUsed + inflightKvReserved(); + } + + /** + * Real KV available: engine-reported available - local inflight reservations. + * + *

Approximate: reads {@code reportedKvAvailable} and + * computes {@code inflightKvReserved()} non-atomically — the returned value may reflect a + * slightly inconsistent snapshot. This is acceptable for scheduling decisions. + */ + public long realKvAvailable() { + return Math.max(0, reportedKvAvailable.get() - inflightKvReserved()); + } + + /** + * Real KV total capacity reported by the engine. + */ + public long realKvTotal() { + return status.getTotalKvCacheTokens().get(); + } + + public int getInflightCount() { + return inflightRequests.size(); + } + + /** + * Evict inflight requests older than {@code ttlMs}. + * Called periodically by the scheduler to clean up stale decode entries. + * + * @return number of entries evicted + */ + public int evictExpiredRequests(long ttlMs) { + return requestEvictor.evictExpired(ttlMs); + } + + public int getTotalLoad() { + return confirmedRunningCount + inflightRequests.size(); + } + + @Override + public long getLoadMetric() { + return getTotalLoad(); + } + + @Override + public int getLocalTaskCount() { + return getInflightCount(); + } + + ConcurrentHashMap getInflightRequests() { + return inflightRequests; + } + + private static boolean isCancelError(TaskInfo task) { + return task.getErrorMessage() != null && task.getErrorMessage().toLowerCase().contains("cancel"); + } + +} diff --git a/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/endpoint/EndpointRegistry.java b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/endpoint/EndpointRegistry.java new file mode 100644 index 0000000000..033fdc6522 --- /dev/null +++ b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/endpoint/EndpointRegistry.java @@ -0,0 +1,112 @@ +package org.flexlb.balance.endpoint; + +import org.flexlb.balance.scheduler.FlexlbBatchScheduler; +import org.flexlb.config.ConfigService; +import org.flexlb.dao.master.WorkerStatus; +import org.flexlb.service.monitor.BatchSchedulerReporter; +import org.springframework.context.annotation.Lazy; +import org.springframework.scheduling.annotation.Scheduled; +import org.springframework.stereotype.Component; + +import java.util.concurrent.ConcurrentHashMap; + +@Component +public class EndpointRegistry { + + private final ConcurrentHashMap prefillEndpoints = new ConcurrentHashMap<>(); + private final ConcurrentHashMap decodeEndpoints = new ConcurrentHashMap<>(); + private final ConfigService configService; + private final FlexlbBatchScheduler batchScheduler; + private final BatchSchedulerReporter reporter; + + public EndpointRegistry(ConfigService configService, + @Lazy FlexlbBatchScheduler batchScheduler, + BatchSchedulerReporter reporter) { + this.configService = configService; + this.batchScheduler = batchScheduler; + this.reporter = reporter; + } + + public WorkerEndpoint get(String ipPort) { + WorkerEndpoint ep = prefillEndpoints.get(ipPort); + if (ep != null) { + return ep; + } + return decodeEndpoints.get(ipPort); + } + + public PrefillEndpoint getPrefill(String ipPort) { + return prefillEndpoints.get(ipPort); + } + + public DecodeEndpoint getDecode(String ipPort) { + return decodeEndpoints.get(ipPort); + } + + public PrefillEndpoint ensurePrefillEndpoint(String ipPort, WorkerStatus status) { + return prefillEndpoints.computeIfAbsent(ipPort, + k -> new PrefillEndpoint(status, configService.loadBalanceConfig(), batchScheduler, reporter)); + } + + public DecodeEndpoint ensureDecodeEndpoint(String ipPort, WorkerStatus status) { + return decodeEndpoints.computeIfAbsent(ipPort, + k -> new DecodeEndpoint(status)); + } + + /** + * Replace prefill endpoint at given key. Closes old endpoint if present. + * Note: This is primarily used in tests. Production code should use ensurePrefillEndpoint(). + */ + public void putPrefill(String ipPort, PrefillEndpoint endpoint) { + PrefillEndpoint old = prefillEndpoints.put(ipPort, endpoint); + if (old != null && old != endpoint) { + old.close(); + } + } + + /** + * Replace decode endpoint at given key. Closes old endpoint if present. + * Note: This is primarily used in tests. Production code should use ensureDecodeEndpoint(). + */ + public void putDecode(String ipPort, DecodeEndpoint endpoint) { + DecodeEndpoint old = decodeEndpoints.put(ipPort, endpoint); + if (old != null && old != endpoint) { + old.close(); + } + } + + public void close() { + prefillEndpoints.values().forEach(WorkerEndpoint::close); + decodeEndpoints.values().forEach(WorkerEndpoint::close); + } + + /** + * Expose all prefill endpoints for per-worker metrics reporting. + */ + public ConcurrentHashMap getPrefillEndpoints() { + return prefillEndpoints; + } + + /** + * Trigger TTL eviction on all prefill and decode endpoints. + * + * @param ttlMs max age before eviction + */ + public void evictExpiredAll(long ttlMs) { + prefillEndpoints.values().forEach(ep -> ep.evictExpiredBatches(ttlMs)); + decodeEndpoints.values().forEach(ep -> ep.evictExpiredRequests(ttlMs)); + } + + /** + * Periodic TTL eviction for all endpoints. + *

Each endpoint is responsible for its own inflight lifecycle. + * This scheduled method provides a safety-net fallback for entries + * that were not cleaned up by {@code calibrate()} (e.g., engine crash, + * network partition, status report delay). + */ + @Scheduled(fixedRate = 60000L) + public void scheduledEviction() { + long ttlMs = configService.loadBalanceConfig().getFlexlbInflightTtlMs(); + evictExpiredAll(ttlMs); + } +} diff --git a/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/endpoint/PrefillEndpoint.java b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/endpoint/PrefillEndpoint.java new file mode 100644 index 0000000000..c13e4ba977 --- /dev/null +++ b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/endpoint/PrefillEndpoint.java @@ -0,0 +1,333 @@ +package org.flexlb.balance.endpoint; + +import org.flexlb.balance.scheduler.BatchDecisionHandler; +import org.flexlb.balance.scheduler.BatchItem; +import org.flexlb.balance.scheduler.InflightEvictor; +import org.flexlb.balance.scheduler.WorkerBatcher; +import org.flexlb.balance.strategy.PrefillTimePredictor; +import org.flexlb.config.FlexlbConfig; +import org.flexlb.dao.master.TaskInfo; +import org.flexlb.dao.master.WorkerStatus; +import org.flexlb.dao.master.WorkerStatusResponse; +import org.flexlb.enums.TaskPhase; +import org.flexlb.service.monitor.BatchSchedulerReporter; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicReference; + +public class PrefillEndpoint extends WorkerEndpoint { + + private static final Logger logger = LoggerFactory.getLogger("syncLogger"); + + private volatile PrefillTimePredictor predictor; + private final ConcurrentHashMap inflightBatches = new ConcurrentHashMap<>(); + private final AtomicReference estimatedWaitingTimeMs = new AtomicReference<>(0L); + private volatile WorkerBatcher batcher; + private final InflightEvictor batchEvictor; + + public PrefillEndpoint(WorkerStatus status, FlexlbConfig config, + BatchDecisionHandler handler, + BatchSchedulerReporter reporter) { + super(status); + this.predictor = createPredictor(config); + this.batcher = createBatcher(config, handler, reporter); + this.batchEvictor = new InflightEvictor<>(inflightBatches, + batch -> refreshEstimatedWaitingTime()); + this.batcher.start(); + } + + private WorkerBatcher createBatcher(FlexlbConfig config, BatchDecisionHandler handler, + BatchSchedulerReporter reporter) { + return new WorkerBatcher(status.getIpPort(), this, config, handler, reporter); + } + + public WorkerBatcher getBatcher() { + return batcher; + } + + @Override + public void close() { + batcher.shutdown(); + } + + public int getBatcherQueueSize() { + return batcher.queueSize(); + } + + public long getBatcherHeadSortKey() { + return batcher.headSortKey(); + } + + public long getBatcherHeadWaitMs() { + return batcher.headWaitMs(); + } + + public long batcherWaitMs() { + return batcher.queueWaitMs(); + } + + private static PrefillTimePredictor createPredictor(FlexlbConfig cfg) { + return new PrefillTimePredictor( + cfg.getCostAlpha0(), cfg.getCostAlpha1(), cfg.getCostAlpha2(), + cfg.getCostAlpha3(), cfg.getCostAlpha4(), cfg.getCostAlpha5()); + } + + public void commitBatch(long batchId, long predictMs, List requests) { + inflightBatches.put(batchId, new BatchInflight(batchId, predictMs, requests)); + refreshEstimatedWaitingTime(); + } + + public void releaseBatch(long batchId) { + BatchInflight removed = inflightBatches.remove(batchId); + if (removed != null) { + refreshEstimatedWaitingTime(); + } + } + + /** + * Handle partial batch failure: remove failed requests from a batch and recompute prediction. + * + * @return the new BatchInflight if survivors remain, null if the entire batch was removed + */ + public BatchInflight repackBatch(long batchId, Set failedRequestIds) { + BatchInflight result = inflightBatches.computeIfPresent(batchId, (id, old) -> { + List survivors = old.requests().stream() + .filter(r -> !failedRequestIds.contains(r.requestId())) + .toList(); + if (survivors.isEmpty()) { + return null; // removes entry from map + } + long newPredMs = predictor != null ? predictor.predictBatchMs(survivors) : 0; + return old.repack(newPredMs, survivors); + }); + refreshEstimatedWaitingTime(); + return result; + } + + @Override + public void onWorkerStatusUpdate(WorkerStatus ws, WorkerStatusResponse resp) { + super.onWorkerStatusUpdate(ws, resp); + calibrate(resp.getFinishedTaskInfo(), resp.getRunningTaskInfo()); + } + + /** + * Full calibration against worker status report. + */ + public void calibrate(Map finishedTaskInfo, Map runningTaskInfo) { + if (predictor == null) { + return; + } + long statusMs = System.currentTimeMillis(); + + int finishedSize = finishedTaskInfo != null ? finishedTaskInfo.size() : 0; + int runningSize = runningTaskInfo != null ? runningTaskInfo.size() : 0; + if (finishedSize > 0 || !inflightBatches.isEmpty()) { + logger.info("Prefill calibrate: finishedTasks={}, runningTasks={}, inflightBatches={}", + finishedSize, runningSize, inflightBatches.size()); + } + + // Phase 1: classify finished requests and clean up non-batch inflight. + // Non-batch requests use requestId as the inflight key (engine reports + // them with batch_id=-1). Remove them immediately to keep + // realWaitTimeMs() accurate; warn if a finished non-batch request was + // not tracked in inflight (indicates a bug or stale engine report). + Set batchesWithSuccess = new HashSet<>(); + Map> failedByBatch = new HashMap<>(); + + if (finishedTaskInfo != null) { + for (TaskInfo task : finishedTaskInfo.values()) { + long batchId = task.getBatchId(); + if (batchId < 0) { + BatchInflight removed = inflightBatches.remove(task.getRequestId()); + if (removed == null) { + logger.warn("Prefill calibrate: finished non-batch request reqId={} not in inflight", task.getRequestId()); + } + continue; + } + if (task.getErrorCode() == 0) { + batchesWithSuccess.add(batchId); + } else { + failedByBatch.computeIfAbsent(batchId, k -> new ArrayList<>()).add(task); + } + } + } + + // Phase 2: any success request → remove entire batch + for (long batchId : batchesWithSuccess) { + inflightBatches.remove(batchId); + } + + // Phase 3: fail-only batches → repack survivors + for (Map.Entry> entry : failedByBatch.entrySet()) { + long batchId = entry.getKey(); + if (batchesWithSuccess.contains(batchId)) { + continue; + } + if (!inflightBatches.containsKey(batchId)) { + continue; + } + + List failedTasks = entry.getValue(); + Set failedIds = new HashSet<>(); + for (TaskInfo t : failedTasks) { + if (!isCancelError(t)) { + logger.warn("Prefill calibrate: batch failure batchId={} reqId={} error={}", + batchId, t.getRequestId(), t.getErrorMessage()); + } + failedIds.add(t.getRequestId()); + } + repackBatch(batchId, failedIds); + } + + // Phase 4: update progress anchors. A queued batch cannot spend + // predicted forward time until the worker reports it as RUNNING. + Map activeBatchRunning = new HashMap<>(); + if (runningTaskInfo != null) { + for (TaskInfo task : runningTaskInfo.values()) { + long batchId = task.getBatchId(); + if (batchId < 0 || !inflightBatches.containsKey(batchId)) { + continue; + } + boolean running = task.getPhase() == TaskPhase.RUNNING; + activeBatchRunning.merge(batchId, running, Boolean::logicalOr); + } + } + for (Map.Entry entry : activeBatchRunning.entrySet()) { + BatchInflight batch = inflightBatches.get(entry.getKey()); + if (batch == null) { + continue; + } + if (Boolean.TRUE.equals(entry.getValue())) { + batch.markRunning(statusMs); + } else { + batch.markQueued(statusMs); + } + } + + // Phase 5: check running requests for anomalies + if (runningTaskInfo != null) { + for (TaskInfo task : runningTaskInfo.values()) { + long batchId = task.getBatchId(); + if (batchId < 0) { + continue; + } + if (!inflightBatches.containsKey(batchId)) { + logger.warn("Prefill calibrate: running request reqId={} batchId={} not in inflight", + task.getRequestId(), batchId); + } + } + } + + // Phase 6: refresh waiting time snapshot + refreshEstimatedWaitingTime(); + } + + // ==================== Pending Count ==================== + + /** + * Real pending count: total requests the engine will face. + */ + public long realPendingCount() { + return getInflightRequestCount() + batcher.queueSize(); + } + + // ==================== Wait Time ==================== + + /** + * Real wait time: estimated time to drain current inflight batches. + */ + public long realWaitTimeMs() { + long waitMs = estimateWaitingTimeMs(System.currentTimeMillis()); + estimatedWaitingTimeMs.set(waitMs); + return waitMs; + } + + public int getInflightBatchCount() { + return inflightBatches.size(); + } + + public int getInflightRequestCount() { + int count = 0; + for (BatchInflight batch : inflightBatches.values()) { + count += batch.requests().size(); + } + return count; + } + + /** + * Evict inflight batches older than {@code ttlMs}. + * Called periodically by the scheduler to clean up stale prefill entries. + * + * @return number of batches evicted + */ + public int evictExpiredBatches(long ttlMs) { + int evicted = batchEvictor.evictExpired(ttlMs); + if (evicted > 0) { + refreshEstimatedWaitingTime(); + } + return evicted; + } + + @Override + public long getLoadMetric() { + return realWaitTimeMs(); + } + + @Override + public int getLocalTaskCount() { + return getInflightRequestCount(); + } + + public PrefillTimePredictor getPredictor() { + return predictor; + } + + ConcurrentHashMap getInflightBatches() { + return inflightBatches; + } + + // ==================== Metrics ==================== + + /** + * Report per-worker batch metrics via the given reporter. + * Called periodically by {@link org.flexlb.balance.scheduler.FlexlbBatchScheduler}. + */ + public void reportBatchMetrics(BatchSchedulerReporter reporter) { + reporter.reportBatcherQueueDepth("prefill", getIp(), getBatcherQueueSize()); + reporter.reportPrefillInflightBatchCount("prefill", getIp(), getInflightBatchCount()); + } + + private void refreshEstimatedWaitingTime() { + estimatedWaitingTimeMs.set(estimateWaitingTimeMs(System.currentTimeMillis())); + } + + private long estimateWaitingTimeMs(long nowMs) { + if (inflightBatches.isEmpty()) { + return 0; + } + long totalPredMs = 0; + long earliestProgressBaseMs = Long.MAX_VALUE; + for (BatchInflight batch : inflightBatches.values()) { + totalPredMs += Math.max(0, batch.predictTimeMs()); + earliestProgressBaseMs = Math.min(earliestProgressBaseMs, batch.progressBaseMs()); + } + if (earliestProgressBaseMs == Long.MAX_VALUE) { + return 0; + } + long elapsedMs = Math.max(0, nowMs - earliestProgressBaseMs); + return Math.max(0, totalPredMs - elapsedMs); + } + + private static boolean isCancelError(TaskInfo task) { + return task.getErrorMessage() != null && task.getErrorMessage().toLowerCase().contains("cancel"); + } + +} diff --git a/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/endpoint/RequestInflight.java b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/endpoint/RequestInflight.java new file mode 100644 index 0000000000..3a437e58a1 --- /dev/null +++ b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/endpoint/RequestInflight.java @@ -0,0 +1,13 @@ +package org.flexlb.balance.endpoint; + +import org.flexlb.balance.scheduler.InflightEvictor; + +public record RequestInflight( + long requestId, + long kvTokens, + long createdAtMs +) implements InflightEvictor.TtlTracked { + public RequestInflight(long requestId, long kvTokens) { + this(requestId, kvTokens, System.currentTimeMillis()); + } +} diff --git a/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/endpoint/WorkerEndpoint.java b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/endpoint/WorkerEndpoint.java new file mode 100644 index 0000000000..b58093d57e --- /dev/null +++ b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/endpoint/WorkerEndpoint.java @@ -0,0 +1,88 @@ +package org.flexlb.balance.endpoint; + +import org.flexlb.dao.master.WorkerStatus; +import org.flexlb.dao.master.WorkerStatusResponse; + +/** + * Primary abstraction for a remote inference worker. + * Holds only a mutable {@link WorkerStatus} reference — all state + * (identity, engine metrics, topology) is carried by the status object. + * + *

Callers read dynamic engine state via {@link #getStatus()} and + * operate on it directly. + */ +public abstract class WorkerEndpoint { + + // ---- the sole mutable holding ---- + protected volatile WorkerStatus status; + + protected WorkerEndpoint(WorkerStatus status) { + this.status = status; + } + + // ==================== identity (delegated to status) ==================== + + public String ipPort() { + return status.getIpPort(); + } + + public String getIp() { + return status.getIp(); + } + + public int getHttpPort() { + return status.getPort(); + } + + public int getGrpcPort() { + return status.getGrpcPort(); + } + + // ==================== status ==================== + + /** + * Returns the underlying {@link WorkerStatus} reference. + * Callers read dynamic engine state from it; sync logic mutates + * it in-place via {@link WorkerStatus#updateFromResponse}. + */ + public WorkerStatus getStatus() { + return status; + } + + // ==================== gRPC sync entry point ==================== + + /** + * Replaces the endpoint's internal {@link #status} reference with + * the one already updated by + * {@link WorkerStatus#updateFromResponse(WorkerStatusResponse)}. + * Triggers role-specific calibration (inflight reconciliation) via + * subclass overrides. + * + *

Topology labels ({@code site}, {@code group}) are already + * part of the incoming status — they belong to + * {@link WorkerStatus}, not to {@link WorkerEndpoint}. + * + * @param ws the updated status (replaces {@link #status}) + * @param resp the raw gRPC response (used by subclasses for task info) + */ + public void onWorkerStatusUpdate(WorkerStatus ws, WorkerStatusResponse resp) { + this.status = ws; + } + + public void close() { + } + + // ==================== monitoring (EP-authoritative) ==================== + + /** + * Role-specific load metric for monitoring. + *

Prefill: estimated queue wait time (ms). + *

Decode: total active task count (confirmed running + inflight). + */ + public abstract long getLoadMetric(); + + /** + * EP-authoritative local task count, replacing raw gRPC fields. + */ + public abstract int getLocalTaskCount(); +} diff --git a/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/resource/DecodeResourceMeasure.java b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/resource/DecodeResourceMeasure.java index 7484ee2eee..c4bf788b5e 100644 --- a/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/resource/DecodeResourceMeasure.java +++ b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/resource/DecodeResourceMeasure.java @@ -1,15 +1,16 @@ package org.flexlb.balance.resource; import org.apache.commons.collections4.MapUtils; +import org.flexlb.balance.endpoint.DecodeEndpoint; +import org.flexlb.balance.endpoint.WorkerEndpoint; import org.flexlb.config.ConfigService; import org.flexlb.config.FlexlbConfig; import org.flexlb.dao.master.WorkerStatus; import org.flexlb.enums.ResourceMeasureIndicatorEnum; +import org.flexlb.util.Logger; import org.springframework.stereotype.Component; -import java.util.HashSet; import java.util.Map; -import java.util.Set; /** * Decode role resource measure @@ -36,26 +37,36 @@ public DecodeResourceMeasure(ConfigService configService) { } @Override - public boolean isResourceAvailable(WorkerStatus workerStatus) { - if (workerStatus == null || !workerStatus.isAlive()) { - return false; + public boolean isResourceAvailable(WorkerEndpoint endpoint) { + if (endpoint instanceof DecodeEndpoint) { + return isResourceAvailable((DecodeEndpoint) endpoint); } + return ResourceMeasure.super.isResourceAvailable(endpoint); + } - if (isConcurrencyLimitReached(workerStatus)) { + public boolean isResourceAvailable(DecodeEndpoint endpoint) { + if (endpoint == null || !endpoint.getStatus().isAlive()) { return false; } - - long used = workerStatus.getUsedKvCacheTokens().get(); - long available = workerStatus.getAvailableKvCacheTokens().get(); - long total = used + available; - + long concurrency = calculateDecodeConcurrency(endpoint.getStatus()); + if (concurrencyLimit > 0 && concurrency >= concurrencyLimit) { + Logger.warn("Decode worker {} resource unavailable: concurrency={}, limit={}", + endpoint.getIp(), concurrency, concurrencyLimit); + return false; + } + long used = endpoint.realKvUsed(); + long total = endpoint.realKvTotal(); if (total == 0) { - workerStatus.getResourceAvailable().set(true); + endpoint.getStatus().getResourceAvailable().set(true); return true; } - long usagePercentage = (long) ((used * 100.0) / total); - return workerStatus.updateResourceAvailabilityWithHysteresis(usagePercentage, availableThreshold, hysteresisBiasPercent); + boolean available = endpoint.getStatus().updateResourceAvailabilityWithHysteresis(usagePercentage, availableThreshold, hysteresisBiasPercent); + if (!available) { + Logger.warn("Decode worker {} resource unavailable: kvUsage={}%, threshold={}%, used={}, total={}", + endpoint.getIp(), usagePercentage, availableThreshold, used, total); + } + return available; } @Override @@ -90,9 +101,9 @@ private double calculateWaterLevel(WorkerStatus workerStatus) { } private double calculateKvCacheWaterLevel(WorkerStatus workerStatus) { - long used = workerStatus.getUsedKvCacheTokens().get(); + long total = workerStatus.getTotalKvCacheTokens().get(); long available = workerStatus.getAvailableKvCacheTokens().get(); - long total = used + available; + long used = total - available; if (total == 0) { return 0.0; @@ -122,23 +133,10 @@ private double calculateConcurrencyWaterLevel(WorkerStatus workerStatus) { return Math.min(100.0, currentConcurrency * 100.0 / concurrencyLimit); } - private boolean isConcurrencyLimitReached(WorkerStatus workerStatus) { - return concurrencyLimit > 0 && calculateDecodeConcurrency(workerStatus) >= concurrencyLimit; - } - private long calculateDecodeConcurrency(WorkerStatus workerStatus) { - Set requestIds = new HashSet<>(); - if (MapUtils.isNotEmpty(workerStatus.getWaitingTaskList())) { - requestIds.addAll(workerStatus.getWaitingTaskList().keySet()); - } if (MapUtils.isNotEmpty(workerStatus.getRunningTaskList())) { - requestIds.addAll(workerStatus.getRunningTaskList().keySet()); - } - if (MapUtils.isNotEmpty(workerStatus.getLocalTaskMap())) { - workerStatus.getLocalTaskMap().keySet().stream() - .map(String::valueOf) - .forEach(requestIds::add); + return workerStatus.getRunningTaskList().size(); } - return requestIds.size(); + return 0; } } diff --git a/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/resource/PrefillResourceMeasure.java b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/resource/PrefillResourceMeasure.java index 577a9e8873..876518af19 100644 --- a/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/resource/PrefillResourceMeasure.java +++ b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/resource/PrefillResourceMeasure.java @@ -1,13 +1,17 @@ package org.flexlb.balance.resource; import org.apache.commons.collections4.MapUtils; +import org.flexlb.balance.endpoint.PrefillEndpoint; +import org.flexlb.balance.endpoint.WorkerEndpoint; import org.flexlb.config.ConfigService; import org.flexlb.config.FlexlbConfig; import org.flexlb.dao.master.WorkerStatus; import org.flexlb.enums.ResourceMeasureIndicatorEnum; -import org.flexlb.sync.status.EngineWorkerStatus; +import org.flexlb.enums.TaskPhase; import org.springframework.stereotype.Component; +import org.flexlb.util.Logger; + import java.util.Map; /** @@ -31,13 +35,24 @@ public PrefillResourceMeasure(ConfigService configService) { } @Override - public boolean isResourceAvailable(WorkerStatus workerStatus) { - if (workerStatus == null || !workerStatus.isAlive()) { - return false; + public boolean isResourceAvailable(WorkerEndpoint endpoint) { + if (endpoint instanceof PrefillEndpoint) { + return isResourceAvailable((PrefillEndpoint) endpoint); } + return ResourceMeasure.super.isResourceAvailable(endpoint); + } - long queueSize = effectiveQueueSize(workerStatus); - return workerStatus.updateResourceAvailabilityWithHysteresis(queueSize, queueSizeThreshold, hysteresisBiasPercent); + public boolean isResourceAvailable(PrefillEndpoint endpoint) { + if (endpoint == null || !endpoint.getStatus().isAlive()) { + return false; + } + long queueSize = endpoint.realPendingCount(); + boolean available = endpoint.getStatus().updateResourceAvailabilityWithHysteresis(queueSize, queueSizeThreshold, hysteresisBiasPercent); + if (!available) { + Logger.warn("Prefill worker {} resource unavailable: queueSize={}, threshold={}, alive={}", + endpoint.getIp(), queueSize, queueSizeThreshold, endpoint.getStatus().isAlive()); + } + return available; } @Override @@ -68,7 +83,7 @@ private double calculateWaterLevel(WorkerStatus workerStatus) { return 0.0; } - long queueSize = effectiveQueueSize(workerStatus); + long queueSize = countWaitingTasks(workerStatus); if (queueSize <= 0) { return 0.0; @@ -79,8 +94,12 @@ private double calculateWaterLevel(WorkerStatus workerStatus) { } } - private long effectiveQueueSize(WorkerStatus workerStatus) { - long waitingTaskCount = workerStatus.getWaitingTaskList() == null ? 0 : workerStatus.getWaitingTaskList().size(); - return Math.max(waitingTaskCount, workerStatus.getLocalPendingTaskCount()); + private static long countWaitingTasks(WorkerStatus workerStatus) { + if (MapUtils.isEmpty(workerStatus.getRunningTaskList())) { + return 0; + } + return workerStatus.getRunningTaskList().values().stream() + .filter(t -> t.getPhase() != TaskPhase.RUNNING).count(); } -} + +} \ No newline at end of file diff --git a/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/resource/ResourceMeasure.java b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/resource/ResourceMeasure.java index 5cab1f2bdc..93f3f9a9ed 100644 --- a/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/resource/ResourceMeasure.java +++ b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/resource/ResourceMeasure.java @@ -1,5 +1,6 @@ package org.flexlb.balance.resource; +import org.flexlb.balance.endpoint.WorkerEndpoint; import org.flexlb.dao.master.WorkerStatus; import org.flexlb.enums.ResourceMeasureIndicatorEnum; @@ -15,12 +16,15 @@ public interface ResourceMeasure { /** - * Check if specified worker has available resources - * - * @param workerStatus Individual worker status - * @return true if worker has available resources, false otherwise + * Check if specified endpoint has available resources. + * Concrete implementations should override with type-specific logic. */ - boolean isResourceAvailable(WorkerStatus workerStatus); + default boolean isResourceAvailable(WorkerEndpoint endpoint) { + if (endpoint == null) { + return false; + } + return endpoint.getStatus().isAlive(); + } /** * Get resource evaluation indicator diff --git a/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/scheduler/BatchDecisionHandler.java b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/scheduler/BatchDecisionHandler.java new file mode 100644 index 0000000000..8470f37836 --- /dev/null +++ b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/scheduler/BatchDecisionHandler.java @@ -0,0 +1,41 @@ +package org.flexlb.balance.scheduler; + +import java.util.List; + +/** + * Callback interface for {@link WorkerBatcher} to notify the scheduler of batching decisions. + *

+ * Each method corresponds to a decision the batcher makes during its run loop: + *

    + *
  • {@link #onExpired} — head item's deadline has passed, must be dropped
  • + *
  • {@link #onUrgent} — head item is within the risk margin, dispatch alone immediately
  • + *
  • {@link #onBatchReady} — a batch has been assembled and is ready for gRPC dispatch
  • + *
  • {@link #onOfferFailure} — a new item could not be enqueued (batcher stopped or queue full)
  • + *
+ */ +public interface BatchDecisionHandler { + + /** + * Called when the head item's SLO deadline has expired. + * The scheduler removes it from inflight, rolls back the route, and fails the future. + */ + void onExpired(BatchItem head); + + /** + * Called when the head item is within the risk margin — must be dispatched alone immediately. + */ + void onUrgent(BatchItem head, DispatchMeta meta); + + /** + * Called when the batcher has assembled a batch ready for dispatch. + */ + void onBatchReady(List items, DispatchMeta meta); + + /** + * Called when {@link WorkerBatcher#offer} fails — batcher is stopped or queue is full. + * + * @param item the item that could not be enqueued + * @param error non-null if the batcher is stopped; null if the queue is full + */ + void onOfferFailure(BatchItem item, Throwable error); +} diff --git a/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/scheduler/BatchDispatcher.java b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/scheduler/BatchDispatcher.java new file mode 100644 index 0000000000..a47115e3e2 --- /dev/null +++ b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/scheduler/BatchDispatcher.java @@ -0,0 +1,43 @@ +package org.flexlb.balance.scheduler; + +import org.flexlb.balance.endpoint.PrefillEndpoint; + +import java.util.List; + +/** + * Dispatches pre-assembled batches of items to the prefill engine via gRPC. + *

+ * The dispatcher is a pure network layer — it does NOT manage inflight + * state, cancellation, or rollback. Those concerns belong to the scheduler. + * Per-item results are reported through {@link DispatchCallback}. + * + *

Contract

+ *
    + *
  • All items in {@code items} are guaranteed non-cancelled and + * have incomplete futures when {@code dispatch} is called.
  • + *
  • On batch-level failure (build error, network error), the + * dispatcher releases the PrefillEndpoint batch before calling + * {@link DispatchCallback#onFailure} for each item.
  • + *
  • Exactly one callback method is invoked per item.
  • + *
+ */ +public interface BatchDispatcher { + + /** + * Dispatch a batch of active items to the prefill engine. + * + * @param items non-empty list of non-cancelled items + * @param prefillEp target prefill endpoint (already resolved by caller) + * @param batchId pre-allocated unique batch identifier + * @param predMs predicted batch execution time (for logging) + * @param reason dispatch trigger reason (for logging): "fill_rate", "emergency", + * "batch_size_max", etc. + * @param callback receives per-item success or failure + */ + void dispatch(List items, + PrefillEndpoint prefillEp, + long batchId, + long predMs, + String reason, + DispatchCallback callback); +} diff --git a/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/scheduler/BatchItem.java b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/scheduler/BatchItem.java new file mode 100644 index 0000000000..ccaf41d0da --- /dev/null +++ b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/scheduler/BatchItem.java @@ -0,0 +1,135 @@ +package org.flexlb.balance.scheduler; + +import org.flexlb.balance.endpoint.DecodeEndpoint; +import org.flexlb.balance.endpoint.PrefillEndpoint; +import org.flexlb.dao.BalanceContext; +import org.flexlb.dao.loadbalance.Response; +import org.flexlb.dao.loadbalance.ServerStatus; + +import java.util.Objects; +import java.util.concurrent.CompletableFuture; + +/** + * A single inference request queued for batch dispatch. + * + *

Extracted from {@link FlexlbBatchScheduler} to reduce coupling + * with {@link WorkerBatcher}. + * + *

Carries direct {@link PrefillEndpoint} / {@link DecodeEndpoint} references + * so downstream operations (commit, rollback, ack, cancel) avoid repeated + * {@code EndpointRegistry} lookups by ip+port. + * + *

{@link #sortKey} is mutable — the {@link BatcherAlgorithm} computes it + * inside {@link WorkerBatcher#offer(BatchItem)} via {@link BatcherAlgorithm#computeSortKey}. + */ +public final class BatchItem { + + private final BalanceContext ctx; + private final CompletableFuture future; + private final Response routeResponse; + private final ServerStatus prefill; + private final ServerStatus decode; + private final PrefillEndpoint prefillEp; + private final DecodeEndpoint decodeEp; + private final long enqueuedAtMs; + + /** Mutable sort key set by the batcher algorithm at offer time. */ + private volatile long sortKey; + + public BatchItem(BalanceContext ctx, + CompletableFuture future, + Response routeResponse, + ServerStatus prefill, + ServerStatus decode, + PrefillEndpoint prefillEp, + DecodeEndpoint decodeEp, + long sortKey, + long enqueuedAtMs) { + this.ctx = ctx; + this.future = future; + this.routeResponse = routeResponse; + this.prefill = prefill; + this.decode = decode; + this.prefillEp = prefillEp; + this.decodeEp = decodeEp; + this.sortKey = sortKey; + this.enqueuedAtMs = enqueuedAtMs; + } + + // -- accessors -- + + public BalanceContext ctx() { return ctx; } + public CompletableFuture future() { return future; } + public Response routeResponse() { return routeResponse; } + public ServerStatus prefill() { return prefill; } + public ServerStatus decode() { return decode; } + public PrefillEndpoint prefillEp() { return prefillEp; } + public DecodeEndpoint decodeEp() { return decodeEp; } + public long enqueuedAtMs() { return enqueuedAtMs; } + + /** Priority queue sort key. */ + public long sortKey() { return sortKey; } + + /** Set by {@link WorkerBatcher#offer} after {@link BatcherAlgorithm#computeSortKey}. */ + public void setSortKey(long sortKey) { this.sortKey = sortKey; } + + /** @deprecated use {@link #sortKey()} instead; kept for SLO-budget references. */ + @Deprecated + public long deadlineMs() { return sortKey; } + + // -- derived accessors -- + + public long requestId() { + return ctx != null && ctx.getRequest() != null + ? ctx.getRequest().getRequestId() : 0; + } + + /** Total sequence length of this request. */ + public long seqLen() { + return ctx != null && ctx.getRequest() != null + ? ctx.getRequest().getSeqLen() : 0; + } + + /** Cache-hit tokens on the assigned prefill endpoint. */ + public long hitCache() { + return hitCacheOf(prefill); + } + + /** Compute tokens = seqLen - hitCache (floor at 0). */ + public long computeTokens() { + return Math.max(0, seqLen() - hitCache()); + } + + /** Extract cache-hit length from a {@link ServerStatus} debug info. */ + public static long hitCacheOf(ServerStatus ss) { + return ss != null && ss.getDebugInfo() != null + ? ss.getDebugInfo().getHitCacheLen() : 0; + } + + // -- Object -- + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (!(o instanceof BatchItem that)) return false; + return sortKey == that.sortKey && enqueuedAtMs == that.enqueuedAtMs + && Objects.equals(ctx, that.ctx) && Objects.equals(future, that.future) + && Objects.equals(routeResponse, that.routeResponse) + && Objects.equals(prefill, that.prefill) + && Objects.equals(decode, that.decode) + && Objects.equals(prefillEp, that.prefillEp) + && Objects.equals(decodeEp, that.decodeEp); + } + + @Override + public int hashCode() { + return Objects.hash(ctx, future, routeResponse, prefill, decode, + prefillEp, decodeEp, sortKey, enqueuedAtMs); + } + + @Override + public String toString() { + return "BatchItem{requestId=" + requestId() + ", seqLen=" + seqLen() + + ", sortKey=" + sortKey + ", enqueuedAtMs=" + enqueuedAtMs + '}'; + } +} diff --git a/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/scheduler/BatcherAlgorithm.java b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/scheduler/BatcherAlgorithm.java new file mode 100644 index 0000000000..5a4fa193d1 --- /dev/null +++ b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/scheduler/BatcherAlgorithm.java @@ -0,0 +1,101 @@ +package org.flexlb.balance.scheduler; + +/** + * Batching algorithm contract. One instance per {@link WorkerBatcher}. + * + *

Implementations encapsulate dispatch decision logic — when to + * assemble a batch, how many items to pick, and when to wait. + */ +public interface BatcherAlgorithm { + + /** + * Core decision loop. Called by {@link WorkerBatcher#runLoop()} each + * iteration when the queue is non-empty. + * + *

On each call the implementation should make one of: + *

    + *
  • Dispatch items via {@link BatcherContext#dispatch}
  • + *
  • Drop the head item via {@link BatcherContext#dropHead} + * (only for algorithms that support expiry)
  • + *
  • Park briefly (e.g. {@code TimeUnit.MILLISECONDS.sleep(1)}) + * and return, letting the caller re-invoke
  • + *
+ */ + void processQueue(BatcherContext ctx) throws InterruptedException; + + /** + * Compute the sort key used to order items in the per-worker + * priority queue. Called by {@link WorkerBatcher#offer} before the + * item is enqueued; the result is stored via + * {@link BatchItem#setSortKey(long)}. + * + *

Different algorithms compute this differently: + *

    + *
  • SLO-budget: SLO deadline = now + (sloMs - predMs - workerQueueMs)
  • + *
  • Fixed-window: FIFO (arrival timestamp)
  • + *
+ * + *

The implementation can resolve SLO, predictor coefficients, + * and worker queue depth from {@link BatcherContext#cfg()} and + * {@link BatcherContext#prefillEp()}. + */ + long computeSortKey(BatcherContext ctx, BatchItem item); + + /** + * Hook called by {@link WorkerBatcher#offer} after the sort key is + * computed and set. Gives the algorithm a chance to update arrival + * statistics or perform lightweight bookkeeping. + */ + default void onOffer(BatcherContext ctx, BatchItem item, long nowMs) { + } + + /** + * Estimated remaining wait time of the head request in the batcher + * queue. Used by load-balancing strategies to compare workers without + * leaking sort-key semantics. + * + *

Each algorithm computes this according to its own dispatch model: + *

    + *
  • SLO-budget: remaining SLO slack = {@code sortKey - now}
  • + *
  • Fixed-window: remaining fixed window = {@code fixedWaitMs - elapsedMs}
  • + *
+ * + *

The default implementation treats {@link BatchItem#sortKey()} as + * a future deadline, which is correct for deadline-based algorithms. + */ + default long headWaitMs(BatcherContext ctx) { + BatchItem head = ctx.peek(); + if (head == null) { + return 0; + } + return Math.max(0, head.sortKey() - ctx.now()); + } + + /** + * Estimated time a new request would wait in the batcher queue before + * its batch is dispatched to the engine. Used by load-balancing + * strategies for worker selection scoring. + * + *

Unlike {@link #headWaitMs}, this accounts for the empty-queue + * case: when the queue is empty, a new request starts a fresh batch + * cycle and must wait for the dispatch trigger (e.g. fixed window + * timeout). + * + *

Each algorithm defines this according to its dispatch model: + *

    + *
  • Fixed-window: queue non-empty → remaining window; + * empty → full {@code fixedWaitMs} (new cycle).
  • + *
  • SLO-budget: remaining SLO slack of the head request.
  • + *
+ */ + default long queueWaitMs(BatcherContext ctx) { + return headWaitMs(ctx); + } + + /** + * Hook called by {@link WorkerBatcher#shutdown} before the queue is drained. + * Gives the algorithm a chance to clean up internal state. + */ + default void onShutdown(BatcherContext ctx) { + } +} diff --git a/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/scheduler/BatcherContext.java b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/scheduler/BatcherContext.java new file mode 100644 index 0000000000..b1afd418d7 --- /dev/null +++ b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/scheduler/BatcherContext.java @@ -0,0 +1,126 @@ +package org.flexlb.balance.scheduler; + +import org.flexlb.balance.endpoint.PrefillEndpoint; +import org.flexlb.config.FlexlbConfig; +import org.flexlb.service.monitor.BatchSchedulerReporter; + +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import java.util.concurrent.PriorityBlockingQueue; + +/** + * Controlled access to shared {@link WorkerBatcher} infrastructure. + * + *

Passed to {@link BatcherAlgorithm} methods so algorithms can + * inspect and mutate the queue, read config, and invoke callbacks + * without directly depending on WorkerBatcher internals. + */ +public class BatcherContext { + + private final String key; + private final PrefillEndpoint prefillEp; + private final FlexlbConfig cfg; + private final BatchDecisionHandler handler; + private final PriorityBlockingQueue queue; + private final BatchSchedulerReporter reporter; + BatcherContext(String key, PrefillEndpoint prefillEp, FlexlbConfig cfg, + BatchDecisionHandler handler, + PriorityBlockingQueue queue, + BatchSchedulerReporter reporter) { + this.key = key; + this.prefillEp = prefillEp; + this.cfg = cfg; + this.handler = handler; + this.queue = queue; + this.reporter = reporter; + } + + // ---- accessors ---- + + String key() { + return key; + } + + PrefillEndpoint prefillEp() { + return prefillEp; + } + + FlexlbConfig cfg() { + return cfg; + } + + BatchDecisionHandler handler() { + return handler; + } + + BatchSchedulerReporter reporter() { + return reporter; + } + + long now() { + return System.currentTimeMillis(); + } + + // ---- queue inspection ---- + + BatchItem peek() { + return queue.peek(); + } + + boolean isEmpty() { + return queue.isEmpty(); + } + + int size() { + return queue.size(); + } + + // ---- queue mutation ---- + + BatchItem poll() { + return queue.poll(); + } + + boolean remove(BatchItem item) { + return queue.remove(item); + } + + void drainTo(List dst) { + queue.drainTo(dst); + } + + /** + * Items sorted by {@link BatchItem#sortKey()}, suitable for + * greedy-fill iteration in dispatch algorithms. + */ + List sortedItems() { + List candidates = new ArrayList<>(queue); + candidates.sort(Comparator.comparingLong(BatchItem::sortKey)); + return candidates; + } + + // ---- dispatch helpers (shared infrastructure) ---- + + /** + * Remove items from queue and notify handler. + * Caller is responsible for algorithm-specific logging and state cleanup + * (e.g. {@code lastParkByRequest.remove()}) before calling this. + */ + void dispatch(List items, DispatchMeta meta) { + for (BatchItem item : items) { + remove(item); + } + handler.onBatchReady(items, meta); + } + + /** + * Remove head from queue and notify handler of expiry. + * Only called by algorithms that support deadline-based expiry. + * Caller is responsible for algorithm-specific logging and state cleanup. + */ + void dropHead(BatchItem head) { + remove(head); + handler.onExpired(head); + } +} diff --git a/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/scheduler/DefaultBatchDispatcher.java b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/scheduler/DefaultBatchDispatcher.java new file mode 100644 index 0000000000..00014a1fa2 --- /dev/null +++ b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/scheduler/DefaultBatchDispatcher.java @@ -0,0 +1,300 @@ +package org.flexlb.balance.scheduler; + +import com.google.protobuf.Int64Value; +import com.google.protobuf.InvalidProtocolBufferException; +import org.flexlb.balance.endpoint.PrefillEndpoint; +import org.flexlb.config.ConfigService; +import org.flexlb.dao.loadbalance.ServerStatus; +import org.flexlb.engine.grpc.EngineGrpcClient; +import org.flexlb.engine.grpc.EngineRpcService; +import org.flexlb.util.Logger; + +import org.springframework.stereotype.Component; + +import javax.annotation.PreDestroy; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.RejectedExecutionException; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; + +/** + * Default implementation of {@link BatchDispatcher}. + *

+ * Owns its own thread pool for asynchronous gRPC dispatch. + * Handles the full pipeline: build request → send → parse response → callback. + * Does NOT manage inflight state — results are reported via {@link DispatchCallback}. + */ +@Component +public class DefaultBatchDispatcher implements BatchDispatcher { + + private final EngineGrpcClient grpcClient; + private final ConfigService configService; + private final ExecutorService dispatchExecutor; + + public DefaultBatchDispatcher(EngineGrpcClient grpcClient, ConfigService configService) { + this.grpcClient = grpcClient; + this.configService = configService; + int poolSize = configService.loadBalanceConfig().getFlexlbBatchDispatchPoolSize(); + int queueSize = configService.loadBalanceConfig().getFlexlbBatchDispatchQueueSize(); + this.dispatchExecutor = new ThreadPoolExecutor( + poolSize, poolSize, + 60L, TimeUnit.SECONDS, + new LinkedBlockingQueue<>(queueSize), + new ThreadPoolExecutor.AbortPolicy()); + } + + @Override + public void dispatch(List items, PrefillEndpoint prefillEp, + long batchId, long predMs, String reason, DispatchCallback callback) { + try { + dispatchExecutor.execute(() -> doDispatch(items, prefillEp, batchId, predMs, reason, callback)); + } catch (RejectedExecutionException e) { + Logger.warn("FlexLB batch dispatch rejected (executor shutdown), failing {} items", items.size()); + prefillEp.releaseBatch(batchId); + for (BatchItem item : items) { + callback.onFailure(item, e); + } + } + } + + @PreDestroy + public void shutdown() { + dispatchExecutor.shutdownNow(); + } + + // ==================== Internal: dispatch pipeline (runs on executor thread) ==================== + + private void doDispatch(List items, PrefillEndpoint prefillEp, + long batchId, long predMs, String reason, DispatchCallback callback) { + try { + doDispatchInternal(items, prefillEp, batchId, predMs, reason, callback); + } catch (Throwable t) { + // Safety net: ensure callbacks are always invoked even for unexpected errors + Logger.error("Unexpected error in doDispatch batchId={}", batchId, t); + for (BatchItem item : items) { + try { + callback.onFailure(item, t); + } catch (Throwable ignored) { + // best-effort + } + } + } + } + + private void doDispatchInternal(List items, PrefillEndpoint prefillEp, + long batchId, long predMs, String reason, DispatchCallback callback) { + // Filter out items that were cancelled before dispatch + List active = new ArrayList<>(); + for (BatchItem item : items) { + if (!item.future().isDone() && !item.ctx().isCancelled()) { + active.add(item); + } else { + Logger.debug("Skipping cancelled item in dispatch: request_id={}, batch_id={}", + item.requestId(), batchId); + } + } + + if (active.isEmpty()) { + Logger.debug("All items cancelled before dispatch, batch_id={}", batchId); + prefillEp.releaseBatch(batchId); + return; + } + + // 1. Build gRPC request + EngineRpcService.EnqueueBatchRequestPB request; + try { + request = buildBatchRequest(batchId, active); + } catch (Exception e) { + Logger.error("Failed to build FlexLB batch request batchId: {}", batchId, e); + failItems(active, prefillEp, batchId, "Batch request build failed: " + e.getMessage(), callback); + return; + } + + // 2. Log dispatch + logDispatch(batchId, active, prefillEp, predMs, reason); + + // 3. Send gRPC + try { + long deadlineMs = configService.loadBalanceConfig().getFlexlbBatchEnqueueDeadlineMs(); + EngineRpcService.EnqueueBatchResponsePB response = + grpcClient.batchEnqueue(prefillEp.getIp(), prefillEp.getGrpcPort(), + request, deadlineMs); + if (response == null) { + failItems(active, prefillEp, batchId, "EnqueueBatch returned null response", callback); + return; + } + handleResponse(batchId, active, response, callback); + } catch (Throwable t) { + Logger.warn("EnqueueBatch failed batchId: {}, entrypoint: {}:{}, err: {}", + batchId, prefillEp.getIp(), prefillEp.getGrpcPort(), t.getMessage()); + failItems(active, prefillEp, batchId, "gRPC dispatch failed: " + t.getMessage(), callback); + } + } + + private void failItems(List items, PrefillEndpoint prefillEp, + long batchId, String message, DispatchCallback callback) { + prefillEp.releaseBatch(batchId); + RuntimeException error = new RuntimeException(message); + for (BatchItem item : items) { + callback.onFailure(item, error); + } + } + + // ==================== Response parsing ==================== + + private void handleResponse(long batchId, List items, + EngineRpcService.EnqueueBatchResponsePB response, + DispatchCallback callback) { + Map errorByRequestId = new HashMap<>(); + for (EngineRpcService.EnqueueBatchErrorPB error : response.getErrorsList()) { + errorByRequestId.put(error.getRequestId(), error); + } + Set successIds = new HashSet<>(); + for (EngineRpcService.EnqueueBatchSuccessPB success : response.getSuccessesList()) { + successIds.add(success.getRequestId()); + } + + for (BatchItem item : items) { + if (successIds.contains(item.requestId())) { + callback.onSuccess(item, batchId); + } else if (errorByRequestId.containsKey(item.requestId())) { + EngineRpcService.EnqueueBatchErrorPB error = errorByRequestId.get(item.requestId()); + String errorMessage = error.hasErrorInfo() + ? error.getErrorInfo().getErrorMessage() + : "missing error_info"; + callback.onFailure(item, new RuntimeException( + "EnqueueBatch rejected request " + item.requestId() + ": " + errorMessage)); + } else { + callback.onFailure(item, new RuntimeException( + "EnqueueBatch missing ack for request " + item.requestId())); + } + } + } + + // ==================== gRPC request building ==================== + + private EngineRpcService.EnqueueBatchRequestPB buildBatchRequest(long batchId, List items) + throws InvalidProtocolBufferException { + EngineRpcService.EnqueueBatchRequestPB.Builder builder = + EngineRpcService.EnqueueBatchRequestPB.newBuilder().setBatchId(batchId); + Map> byDpRank = new HashMap<>(); + for (BatchItem item : items) { + byDpRank.computeIfAbsent(item.prefill().getDpRank(), ignored -> new ArrayList<>()).add(item); + } + try { + byDpRank.entrySet().stream() + .sorted(Map.Entry.comparingByKey()) + .forEach(entry -> { + EngineRpcService.EnqueueBatchDpSlotPB.Builder slot = + EngineRpcService.EnqueueBatchDpSlotPB.newBuilder() + .setDpRank(entry.getKey().intValue()); + int groupSize = entry.getValue().size(); + for (BatchItem item : entry.getValue()) { + try { + slot.addRequests(EngineRpcService.EnqueueBatchExternalInputPB.newBuilder() + .setInput(buildInput(batchId, groupSize, item)) + .build()); + } catch (InvalidProtocolBufferException e) { + throw new BatchRequestBuildException(e); + } + } + builder.addDpSlots(slot.build()); + }); + } catch (BatchRequestBuildException e) { + throw (InvalidProtocolBufferException) e.getCause(); + } + return builder.build(); + } + + private EngineRpcService.GenerateInputPB buildInput(long batchId, int groupSize, BatchItem item) + throws InvalidProtocolBufferException { + byte[] bytes = item.ctx().getGenerateInputPbBytes(); + if (bytes == null || bytes.length == 0) { + throw new IllegalArgumentException("generateInputPbBytes is missing for request " + item.requestId()); + } + EngineRpcService.GenerateInputPB.Builder input = + EngineRpcService.GenerateInputPB.parseFrom(bytes).toBuilder(); + if (input.getRequestId() != item.requestId()) { + throw new IllegalArgumentException("request_id mismatch between schedule request and GenerateInputPB"); + } + input.setGroupId(Int64Value.of(batchId)); + input.setGroupSize(groupSize); + + EngineRpcService.GenerateConfigPB.Builder config = input.getGenerateConfigBuilder(); + config.clearRoleAddrs(); + addRoleAddr(config, item.prefill()); + addRoleAddr(config, item.decode()); + return input.build(); + } + + private void addRoleAddr(EngineRpcService.GenerateConfigPB.Builder config, ServerStatus serverStatus) { + if (serverStatus == null) { + return; + } + EngineRpcService.RoleTypePB role = switch (serverStatus.getRole()) { + case PREFILL -> EngineRpcService.RoleTypePB.ROLE_TYPE_PREFILL; + case DECODE -> EngineRpcService.RoleTypePB.ROLE_TYPE_DECODE; + case PDFUSION -> EngineRpcService.RoleTypePB.ROLE_TYPE_PDFUSION; + case VIT -> EngineRpcService.RoleTypePB.ROLE_TYPE_VIT; + }; + config.addRoleAddrs(EngineRpcService.RoleAddrPB.newBuilder() + .setRole(role) + .setIp(serverStatus.getServerIp()) + .setHttpPort(serverStatus.getHttpPort()) + .setGrpcPort(serverStatus.getGrpcPort()) + .build()); + } + + // ==================== Logging ==================== + + private void logDispatch(long batchId, List items, PrefillEndpoint prefillEp, long predMs, String reason) { + long totalTokens = 0; + long totalHit = 0; + StringBuilder itemDetail = new StringBuilder(); + for (int i = 0; i < items.size(); i++) { + BatchItem item = items.get(i); + long seqLen = item.seqLen(); + long hitCache = item.hitCache(); + totalTokens += seqLen; + totalHit += hitCache; + if (i > 0) { + itemDetail.append(", "); + } + itemDetail.append("{req_id=").append(item.requestId()) + .append(" seq_len=").append(seqLen) + .append(" hit_cache=").append(hitCache).append('}'); + } + + BatchItem head = items.get(0); + long now = System.currentTimeMillis(); + long waitMs = now - head.enqueuedAtMs(); + long budgetMs = head.deadlineMs() - now; + + Logger.info("flexlb_batch_dispatch batch_id={} batch_size={} total_tokens={} total_hit={} " + + "pred_ms={} reason={} wait_ms={} budget_ms={} " + + "prefill={}:{} items=[{}]", + batchId, items.size(), totalTokens, totalHit, predMs, reason, + waitMs, budgetMs, + prefillEp.getIp(), prefillEp.getHttpPort(), + itemDetail); + } + + // ==================== Internal exception wrapper ==================== + + /** + * Wraps checked {@link InvalidProtocolBufferException} to propagate through + * stream lambdas in {@link #buildBatchRequest}. + */ + private static final class BatchRequestBuildException extends RuntimeException { + private BatchRequestBuildException(Throwable cause) { + super(cause); + } + } +} diff --git a/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/scheduler/DefaultRouter.java b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/scheduler/DefaultRouter.java index 1244470152..3101ee22c2 100644 --- a/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/scheduler/DefaultRouter.java +++ b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/scheduler/DefaultRouter.java @@ -29,7 +29,7 @@ import static org.flexlb.dao.loadbalance.StrategyErrorType.NO_AVAILABLE_WORKER; @Component -@DependsOn({"randomStrategy", "weightedCacheStrategy", "shortestTTFTStrategy"}) +@DependsOn({"randomStrategy", "costBasedDecodeStrategy", "costBasedPrefillStrategy"}) public class DefaultRouter implements Router { private final Map loadBalancerMap; @@ -69,6 +69,7 @@ public Response route(BalanceContext balanceContext) { ModelWorkerStatus workerStatus = EngineWorkerStatus.MODEL_ROLE_WORKER_STATUS; List roleTypeList = workerStatus.getRoleTypeList(); if (CollectionUtils.isEmpty(roleTypeList)) { + Logger.warn("No worker roles registered yet (total workers: {})", workerStatus.getWorkerTotalCount()); return Response.error(NO_AVAILABLE_WORKER); } diff --git a/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/scheduler/DispatchCallback.java b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/scheduler/DispatchCallback.java new file mode 100644 index 0000000000..5e3d4767a6 --- /dev/null +++ b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/scheduler/DispatchCallback.java @@ -0,0 +1,36 @@ +package org.flexlb.balance.scheduler; + +/** + * Receives per-item dispatch results from {@link BatchDispatcher}. + *

+ * Implemented by the scheduler to manage inflight state in response to + * engine acknowledgements. The dispatcher guarantees exactly one call + * per item — either {@link #onSuccess} or {@link #onFailure}. + */ +public interface DispatchCallback { + + /** + * Engine successfully accepted this item. + * Called once per item that appears in the gRPC success list. + * + * @param item the dispatched item + * @param batchId the batch it was dispatched in + */ + void onSuccess(BatchItem item, long batchId); + + /** + * Item failed to be enqueued. Possible causes: + *

    + *
  • gRPC request build failure (protobuf parsing)
  • + *
  • Engine rejected via error list in response
  • + *
  • Item missing from ack response (protocol error)
  • + *
  • Network error on the entire batch call
  • + *
+ * When called due to a batch-level failure, the dispatcher has + * already released the PrefillEndpoint batch before calling this. + * + * @param item the failed item + * @param error the underlying error + */ + void onFailure(BatchItem item, Throwable error); +} diff --git a/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/scheduler/DispatchMeta.java b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/scheduler/DispatchMeta.java new file mode 100644 index 0000000000..c1e4b1fa3a --- /dev/null +++ b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/scheduler/DispatchMeta.java @@ -0,0 +1,10 @@ +package org.flexlb.balance.scheduler; + +/** + * Metadata describing why and how a batch was dispatched. + * + *

Extracted from {@link FlexlbBatchScheduler} to reduce coupling + * with {@link WorkerBatcher}. + */ +public record DispatchMeta(String reason, double fillRatio, long batchMaxTokens, int queueDepth) { +} diff --git a/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/scheduler/FixedWindowBatcherAlgorithm.java b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/scheduler/FixedWindowBatcherAlgorithm.java new file mode 100644 index 0000000000..746b95ff1f --- /dev/null +++ b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/scheduler/FixedWindowBatcherAlgorithm.java @@ -0,0 +1,153 @@ +package org.flexlb.balance.scheduler; + +import org.flexlb.balance.strategy.PrefillTimePredictor; +import org.flexlb.util.Logger; +import org.springframework.stereotype.Component; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.TimeUnit; + +/** + * Fixed-window batching algorithm with optional predictor-based early dispatch. + * + *

Algorithm

+ *
    + *
  1. If the head request has waited {@code flexlbBatchFixedWaitMs} or longer, + * dispatch whatever has accumulated (up to batch size / capacity limits).
  2. + *
  3. Otherwise, if {@code flexlbBatchPredictThresholdMs > 0} and the + * predictor estimates the accumulated batch will take at least that long, + * dispatch immediately ("early dispatch").
  4. + *
  5. Otherwise park briefly and retry.
  6. + *
+ * + *

Key differences from {@link SloBudgetBatcherAlgorithm}

+ *
    + *
  • No SLO deadline tracking — does not read {@code BatchItem.deadlineMs()}.
  • + *
  • No EMA arrival rate estimation.
  • + *
  • No request dropping — oversized requests are skipped, never expired.
  • + *
  • No inflight-batch backpressure check.
  • + *
+ */ +@Component +public class FixedWindowBatcherAlgorithm implements BatcherAlgorithm { + + @Override + public long computeSortKey(BatcherContext ctx, BatchItem item) { + // FIFO: arrival timestamp as sort key; no SLO deadline tracking + return item.enqueuedAtMs(); + } + + @Override + public long headWaitMs(BatcherContext ctx) { + BatchItem head = ctx.peek(); + if (head == null) { + return 0; + } + long elapsedMs = ctx.now() - head.enqueuedAtMs(); + return Math.max(0, ctx.cfg().getFlexlbBatchFixedWaitMs() - elapsedMs); + } + + @Override + public long queueWaitMs(BatcherContext ctx) { + if (!ctx.isEmpty()) { + return headWaitMs(ctx); + } + return ctx.cfg().getFlexlbBatchFixedWaitMs(); + } + + @Override + public void processQueue(BatcherContext ctx) throws InterruptedException { + if (ctx.isEmpty()) { + return; + } + + BatchItem head = ctx.peek(); + if (head == null) { + return; + } + + long elapsedMs = ctx.now() - head.enqueuedAtMs(); + long fixedWaitMs = ctx.cfg().getFlexlbBatchFixedWaitMs(); + int batchMaxCount = Math.max(1, ctx.cfg().getFlexlbBatchSizeMax()); + long batchMaxTokens = ctx.cfg().getFlexlbBatchMaxCapacity(); + long predictThresholdMs = ctx.cfg().getFlexlbBatchPredictThresholdMs(); + + // 0. Engine backpressure: park if the prefill worker already has too + // many batches inflight, to prevent overloading the engine. + // Default 0 disables this gate — the batcher always dispatches. + int maxInflightBatches = ctx.cfg().getFlexlbBatchFixedMaxInflightBatches(); + if (maxInflightBatches > 0 && ctx.prefillEp().getInflightBatchCount() >= maxInflightBatches) { + TimeUnit.MILLISECONDS.sleep(1); + return; + } + + // 1. Fixed window timeout → must dispatch + if (elapsedMs >= fixedWaitMs) { + List picked = pickUpTo(ctx, batchMaxCount, batchMaxTokens); + if (picked.isEmpty() && !ctx.isEmpty()) { + // All items exceed maxTokens — force-dispatch the head to avoid busy-wait + BatchItem forced = ctx.peek(); + if (forced != null) { + picked = List.of(forced); + } + } + if (!picked.isEmpty()) { + dispatch(ctx, picked, "fixed_window_timeout"); + } + return; + } + + // 2. Predictor-based early dispatch + if (predictThresholdMs > 0) { + PrefillTimePredictor predictor = ctx.prefillEp().getPredictor(); + List candidates = pickUpTo(ctx, batchMaxCount, batchMaxTokens); + if (!candidates.isEmpty() && predictor.predictBatchMs(candidates) >= predictThresholdMs) { + dispatch(ctx, candidates, "predict_threshold"); + return; + } + } + + // 3. Park + TimeUnit.MILLISECONDS.sleep(1); + } + + // ==================== Internal helpers ==================== + + /** + * Pick up to {@code maxCount} items from the queue, respecting + * {@code maxTokens} (total token) limit. Items that would exceed + * the capacity are skipped, not dropped. + */ + private static List pickUpTo(BatcherContext ctx, int maxCount, long maxTokens) { + List picked = new ArrayList<>(); + long sumTokens = 0; + + for (BatchItem item : ctx.sortedItems()) { + if (picked.size() >= maxCount) { + break; + } + long nextTokens = sumTokens + item.seqLen(); + if (nextTokens > maxTokens) { + continue; // skip, don't drop + } + picked.add(item); + sumTokens = nextTokens; + } + return picked; + } + + private static void dispatch(BatcherContext ctx, List picked, String reason) { + BatchItem head = picked.get(0); + long waitMs = ctx.now() - head.enqueuedAtMs(); + + ctx.reporter().reportDispatchReason("prefill", ctx.prefillEp().getIp(), reason); + + Logger.info("flexlb_batch_decision reason={} picked_size={} " + + "wait_ms={} queue_before={} worker={} head_req_id={}", + reason, picked.size(), waitMs, ctx.size(), ctx.key(), head.requestId()); + + ctx.dispatch(picked, + new DispatchMeta(reason, 1.0, ctx.cfg().getFlexlbBatchMaxCapacity(), ctx.size() - picked.size())); + } +} diff --git a/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/scheduler/FlexlbBatchScheduler.java b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/scheduler/FlexlbBatchScheduler.java new file mode 100644 index 0000000000..5ff5846840 --- /dev/null +++ b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/scheduler/FlexlbBatchScheduler.java @@ -0,0 +1,588 @@ +package org.flexlb.balance.scheduler; + +import org.flexlb.balance.endpoint.DecodeEndpoint; +import org.flexlb.balance.endpoint.EndpointRegistry; +import org.flexlb.balance.endpoint.PrefillEndpoint; +import org.flexlb.balance.strategy.PrefillTimePredictor; +import org.flexlb.config.ConfigService; +import org.flexlb.dao.BalanceContext; +import org.flexlb.dao.loadbalance.DebugInfo; +import org.flexlb.dao.loadbalance.Response; +import org.flexlb.dao.loadbalance.ServerStatus; +import org.flexlb.dao.loadbalance.StrategyErrorType; +import org.flexlb.dao.master.TaskInfo; +import org.flexlb.dao.master.WorkerStatus; +import org.flexlb.dao.route.RoleType; +import org.flexlb.dao.master.WorkerStatusResponse; +import org.flexlb.engine.grpc.EngineGrpcClient; +import org.flexlb.service.monitor.BatchSchedulerReporter; +import org.flexlb.sync.status.EngineWorkerStatus; +import org.flexlb.util.Logger; +import org.springframework.context.annotation.Lazy; +import org.springframework.scheduling.annotation.Scheduled; +import org.springframework.stereotype.Component; + +import javax.annotation.PreDestroy; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicLong; + +/** + * Coordinates batch scheduling for FlexLB disaggregated inference. + * + *

Responsibilities: + *

    + *
  • Request admission and routing (submit, cancel)
  • + *
  • Inflight lifecycle management (inflight map, TTL cleanup)
  • + *
  • Batch assembly coordination — commits to PrefillEndpoint, + * filters cancelled items, delegates gRPC dispatch to {@link BatchDispatcher}
  • + *
  • Resource rollback on failure or completion
  • + *
+ * + *

The actual gRPC dispatch (build protobuf, send, parse response) is + * delegated to {@link BatchDispatcher}. Per-item results come back through + * {@link DispatchCallback} which this class implements. + */ +@Component +public class FlexlbBatchScheduler implements BatchDecisionHandler, DispatchCallback { + + public final ConfigService configService; + private final Router router; + final EngineGrpcClient grpcClient; + final EngineWorkerStatus engineWorkerStatus; + final EndpointRegistry endpointRegistry; + final BatchDispatcher dispatcher; + final BatchSchedulerReporter reporter; + final Map inflight = new ConcurrentHashMap<>(); + final AtomicLong batchIdGenerator = new AtomicLong(0); + private final InflightEvictor inflightEvictor + = new InflightEvictor<>(inflight, entry -> { + synchronized (entry) { + rollbackOnce(entry); + completeCancelled(entry); + } + }); + + public FlexlbBatchScheduler(ConfigService configService, + @Lazy Router router, + EngineGrpcClient grpcClient, + EngineWorkerStatus engineWorkerStatus, + EndpointRegistry endpointRegistry, + BatchDispatcher dispatcher, + BatchSchedulerReporter reporter) { + this.configService = configService; + this.router = router; + this.grpcClient = grpcClient; + this.engineWorkerStatus = engineWorkerStatus; + this.endpointRegistry = endpointRegistry; + this.dispatcher = dispatcher; + this.reporter = reporter; + } + + // ==================== Request submission ==================== + + public CompletableFuture submit(BalanceContext ctx) { + CompletableFuture future = new CompletableFuture<>(); + try { + if (ctx == null || ctx.getRequest() == null) { + future.complete(Response.error(StrategyErrorType.INVALID_REQUEST)); + return future; + } + + int maxInflight = configService.loadBalanceConfig().getFlexlbBatchMaxInflight(); + if (maxInflight > 0 && inflight.size() >= maxInflight) { + Response resp = Response.error(StrategyErrorType.QUEUE_FULL); + future.complete(resp); + return future; + } + + Response routeResponse = router.route(ctx); + if (routeResponse == null || !routeResponse.isSuccess()) { + future.complete(routeResponse != null + ? routeResponse + : Response.error(StrategyErrorType.NO_AVAILABLE_WORKER)); + return future; + } + + ServerStatus prefill = findServer(routeResponse, RoleType.PREFILL); + ServerStatus decode = findServer(routeResponse, RoleType.DECODE); + if (prefill == null) { + rollback(routeResponse); + Response resp = Response.error(StrategyErrorType.NO_PREFILL_WORKER); + future.complete(resp); + return future; + } + + String prefillIpPort = prefill.getServerIp() + ":" + prefill.getHttpPort(); + PrefillEndpoint prefillEp = endpointRegistry.getPrefill(prefillIpPort); + if (prefillEp == null) { + rollback(routeResponse); + Response resp = Response.error(StrategyErrorType.NO_PREFILL_WORKER); + future.complete(resp); + return future; + } + + DecodeEndpoint decodeEp = null; + if (decode != null) { + String decodeIpPort = decode.getServerIp() + ":" + decode.getHttpPort(); + decodeEp = endpointRegistry.getDecode(decodeIpPort); + } + + BatchItem item = new BatchItem(ctx, future, routeResponse, copyOf(prefill), copyOf(decode), + prefillEp, decodeEp, /* sortKey set by batcher */ 0, System.currentTimeMillis()); + inflight.put(ctx.getRequestId(), new InflightEntry(item)); + WorkerBatcher batcher = prefillEp.getBatcher(); + batcher.offer(item); + } catch (Throwable t) { + if (ctx != null) { + inflight.remove(ctx.getRequestId()); + } + Logger.error("FlexlbBatchScheduler submit failed for request id: {}", + ctx == null ? null : ctx.getRequestId(), t); + Response errorResp = new Response(); + errorResp.setSuccess(false); + errorResp.setCode(StrategyErrorType.BATCH_DISPATCH_FAILED.getErrorCode()); + errorResp.setErrorMessage("Submit failed: " + t.getMessage()); + future.complete(errorResp); + } + return future; + } + + // ==================== Cancellation ==================== + + public void cancel(long requestId) { + InflightEntry entry = inflight.remove(requestId); + if (entry == null) { + Logger.debug("flexlb batch cancel ignored; request {} not found in inflight", requestId); + return; + } + + synchronized (entry) { + entry.cancelled.set(true); + if (!entry.ackFinished) { + completeCancelled(entry); + } else if (!entry.item.future().isDone()) { + rollbackOnce(entry); + } + // If ackFinished and future already done (success), skip rollback + } + cancelPrefill(entry); + } + + // ==================== Completion from worker status ==================== + + public void onWorkerStatusUpdate(WorkerStatus ws, WorkerStatusResponse response) { + if (response == null) { + return; + } + Map finishedTaskInfo = response.getFinishedTaskInfo(); + if (finishedTaskInfo == null || finishedTaskInfo.isEmpty()) { + return; + } + + boolean isPrefill = response.getRole() == RoleType.PREFILL; + + for (TaskInfo task : finishedTaskInfo.values()) { + long requestId = task.getRequestId(); + + // Prefill success: decode is still running, keep scheduler inflight entry + if (isPrefill && task.getErrorCode() == 0) { + continue; + } + + // Remove from scheduler inflight (prefill error, or any decode completion) + InflightEntry entry = inflight.remove(requestId); + + // Prefill error: rollback decode KV reservation since decode will never run + if (isPrefill && entry != null) { + rollbackOnce(entry); + } + // Decode completion (success or error): scheduler only cleans its own map. + // DecodeEndpoint.calibrate() independently handles its own inflightRequests cleanup. + } + } + + public void removeInflight(long requestId) { + inflight.remove(requestId); + } + + // ==================== Inflight TTL cleanup ==================== + + @Scheduled(fixedRate = 60000L) + public void cleanupInflight() { + long ttlMs = configService.loadBalanceConfig().getFlexlbInflightTtlMs(); + inflightEvictor.evictExpired(ttlMs); + } + + // ==================== BatchDecisionHandler callbacks (from WorkerBatcher) ==================== + + @Override + public void onExpired(BatchItem head) { + removeInflight(head.requestId()); + rollback(head); + if (!head.future().isDone()) { + Response errorResp = new Response(); + errorResp.setSuccess(false); + errorResp.setCode(StrategyErrorType.BATCH_SLO_EXPIRED.getErrorCode()); + errorResp.setErrorMessage("FlexLB request deadline expired — cannot meet TTFT SLO"); + head.future().complete(errorResp); + } + } + + @Override + public void onUrgent(BatchItem head, DispatchMeta meta) { + flushItems(List.of(head), meta.reason()); + } + + @Override + public void onBatchReady(List items, DispatchMeta meta) { + flushItems(items, meta.reason()); + } + + @Override + public void onOfferFailure(BatchItem item, Throwable error) { + removeInflight(item.requestId()); + rollback(item); + if (!item.future().isDone()) { + Response errorResp = new Response(); + errorResp.setSuccess(false); + errorResp.setCode(StrategyErrorType.BATCH_DISPATCH_FAILED.getErrorCode()); + errorResp.setErrorMessage("Batcher offer failed: " + error.getMessage()); + item.future().complete(errorResp); + } + } + + // ==================== Dispatch pipeline ==================== + + /** + * Commit batch to PrefillEndpoint, filter cancelled items, then delegate + * to {@link BatchDispatcher} for asynchronous gRPC dispatch. + *

+ * Filtering is done synchronously — it only reads inflight (ConcurrentHashMap) + * and performs fast in-memory operations. The heavy gRPC I/O is handled + * asynchronously by the dispatcher's own thread pool. + */ + private void flushItems(List items, String reason) { + PrefillEndpoint prefillEp = items.get(0).prefillEp(); + + // [SYNC] Filter cancelled/done items first — avoid committing them to the endpoint + List active = items.stream() + .filter(item -> !isCancelled(item) && !item.future().isDone()) + .toList(); + + // Complete items that were cancelled before dispatch + for (BatchItem item : items) { + if (!active.contains(item)) { + completeCancelled(item); + } + } + + if (active.isEmpty()) { + return; + } + + // [SYNC] Compute prediction and commit only active items to endpoint + long predMs = 0; + long batchId = batchIdGenerator.incrementAndGet(); + if (prefillEp != null) { + PrefillTimePredictor predictor = prefillEp.getPredictor(); + predMs = predictor.predictBatchMs(active); + prefillEp.commitBatch(batchId, predMs, active); + } + + // [ASYNC] Delegate gRPC dispatch — dispatcher owns its own thread pool + long waitMs = System.currentTimeMillis() - items.get(0).enqueuedAtMs(); + reporter.reportBatchWaitTimeMs("prefill", prefillEp != null ? prefillEp.getIp() : "", waitMs); + dispatcher.dispatch(active, prefillEp, batchId, predMs, reason, this); + } + + // ==================== DispatchCallback implementation ==================== + + @Override + public void onSuccess(BatchItem item, long batchId) { + InflightEntry entry = inflight.get(item.requestId()); + if (entry == null) { + // cancel() already removed entry and handled cleanup + return; + } + + boolean cancelAfterAck = false; + synchronized (entry) { + entry.ackFinished = true; + if (entry.cancelled.get()) { + cancelAfterAck = true; + } else if (!item.future().isDone()) { + Response success = copyResponse(item.routeResponse()); + success.setSuccess(true); + success.setCode(200); + success.setEnqueuedByMaster(true); + success.setQueueLength(inflight.size()); + item.future().complete(success); + Logger.debug("FlexLB batch enqueued request {} in batch {}", item.requestId(), batchId); + } + } + + if (cancelAfterAck) { + inflight.remove(item.requestId()); + cancelPrefill(entry); + completeCancelled(entry); + } + } + + @Override + public void onFailure(BatchItem item, Throwable error) { + failAck(item, error); + } + + // ==================== Internal: inflight state management ==================== + + void failAck(BatchItem item, Throwable error) { + InflightEntry entry = inflight.remove(item.requestId()); + if (entry != null) { + synchronized (entry) { + entry.ackFinished = true; + rollbackOnce(entry); + if (!item.future().isDone()) { + Response errorResp = new Response(); + errorResp.setSuccess(false); + errorResp.setCode(StrategyErrorType.BATCH_DISPATCH_FAILED.getErrorCode()); + errorResp.setErrorMessage(error.getMessage()); + item.future().complete(errorResp); + } + } + return; + } + rollback(item); + if (!item.future().isDone()) { + Response errorResp = new Response(); + errorResp.setSuccess(false); + errorResp.setCode(StrategyErrorType.BATCH_DISPATCH_FAILED.getErrorCode()); + errorResp.setErrorMessage(error.getMessage()); + item.future().complete(errorResp); + } + } + + private void completeCancelled(BatchItem item) { + InflightEntry entry = inflight.remove(item.requestId()); + if (entry != null) { + synchronized (entry) { + completeCancelled(entry); + } + return; + } + item.ctx().cancel(); + rollback(item); + if (!item.future().isDone()) { + Response errorResp = new Response(); + errorResp.setSuccess(false); + errorResp.setCode(StrategyErrorType.REQUEST_CANCELLED.getErrorCode()); + errorResp.setErrorMessage("Request cancelled by client"); + item.future().complete(errorResp); + } + } + + private void completeCancelled(InflightEntry entry) { + entry.item.ctx().cancel(); + rollbackOnce(entry); + if (!entry.item.future().isDone()) { + Response errorResp = new Response(); + errorResp.setSuccess(false); + errorResp.setCode(StrategyErrorType.REQUEST_CANCELLED.getErrorCode()); + errorResp.setErrorMessage("Request cancelled by client"); + entry.item.future().complete(errorResp); + } + } + + // ==================== Internal: resource rollback ==================== + + private void rollbackOnce(InflightEntry entry) { + if (entry.rolledBack.compareAndSet(false, true)) { + rollback(entry.item); + } + } + + /** Rollback using endpoint references already held by the item (no registry lookup). */ + private void rollback(BatchItem item) { + DecodeEndpoint decodeEp = item.decodeEp(); + if (decodeEp != null && item.decode() != null) { + decodeEp.release(item.decode().getRequestId()); + } + } + + /** + * Rollback using route response — used only in submit() early-return paths + * where BatchItem has not been created yet. + */ + private void rollback(Response routeResponse) { + if (routeResponse == null || routeResponse.getServerStatus() == null) { + return; + } + for (ServerStatus serverStatus : routeResponse.getServerStatus()) { + rollback(serverStatus); + } + } + + private void rollback(ServerStatus serverStatus) { + if (serverStatus == null) { + return; + } + if (serverStatus.getRole() == RoleType.DECODE) { + String ipPort = serverStatus.getServerIp() + ":" + serverStatus.getHttpPort(); + DecodeEndpoint ep = endpointRegistry.getDecode(ipPort); + if (ep != null) { + ep.release(serverStatus.getRequestId()); + } + } + } + + // ==================== Internal: inflight queries ==================== + + private boolean isCancelled(BatchItem item) { + InflightEntry entry = inflight.get(item.requestId()); + return item.ctx().isCancelled() || (entry != null && entry.cancelled.get()); + } + + // ==================== Internal: engine cancel ==================== + + /** + * Cancel request on the prefill engine via gRPC. + *

+ * Only prefill needs an explicit cancel — there is no symmetric {@code cancelDecode()}. + * The prefill engine owns the full request lifecycle in PD-separated architecture: + * {@code PrefillRpcServer::Cancel()} cancels the response entry, which cascades + * internally to interrupt the prefill→decode flow. + */ + private void cancelPrefill(InflightEntry entry) { + PrefillEndpoint prefillEp = entry.item.prefillEp(); + if (prefillEp == null) { + return; + } + try { + long deadlineMs = configService.loadBalanceConfig().getFlexlbBatchEnqueueDeadlineMs(); + grpcClient.cancel(prefillEp.getIp(), + prefillEp.getGrpcPort(), + entry.item.requestId(), + deadlineMs); + } catch (RuntimeException e) { + Logger.warn("FlexLB batch cancel failed for request {}", entry.item.requestId(), e); + } + } + + // ==================== Internal: static utilities ==================== + + private static ServerStatus findServer(Response response, RoleType roleType) { + if (response.getServerStatus() == null) { + return null; + } + for (ServerStatus serverStatus : response.getServerStatus()) { + if (serverStatus != null && roleType == serverStatus.getRole()) { + return serverStatus; + } + } + return null; + } + + private static Response copyResponse(Response src) { + Response response = new Response(); + response.setServerStatus(copyServerList(src.getServerStatus())); + response.setSuccess(src.isSuccess()); + response.setCode(src.getCode()); + response.setErrorMessage(src.getErrorMessage()); + response.setRealMasterHost(src.getRealMasterHost()); + response.setQueueLength(src.getQueueLength()); + response.setEnqueuedByMaster(src.isEnqueuedByMaster()); + return response; + } + + private static List copyServerList(List src) { + if (src == null) { + return null; + } + List result = new ArrayList<>(src.size()); + for (ServerStatus serverStatus : src) { + result.add(copyOf(serverStatus)); + } + return result; + } + + private static ServerStatus copyOf(ServerStatus src) { + if (src == null) { + return null; + } + ServerStatus status = new ServerStatus(); + status.setRole(src.getRole()); + status.setServerIp(src.getServerIp()); + status.setHttpPort(src.getHttpPort()); + status.setGrpcPort(src.getGrpcPort()); + status.setDpRank(src.getDpRank()); + status.setPrefillTime(src.getPrefillTime()); + status.setGroup(src.getGroup()); + status.setDebugInfo(copyOf(src.getDebugInfo())); + status.setRequestId(src.getRequestId()); + status.setSuccess(src.isSuccess()); + status.setCode(src.getCode()); + status.setMessage(src.getMessage()); + return status; + } + + private static DebugInfo copyOf(DebugInfo src) { + if (src == null) { + return null; + } + DebugInfo info = new DebugInfo(); + info.setRunningBatchSize(src.getRunningBatchSize()); + info.setQueueSize(src.getQueueSize()); + info.setWaitingTimeMs(src.getWaitingTimeMs()); + info.setAvailableKvCacheLen(src.getAvailableKvCacheLen()); + info.setEstimateTtftMs(src.getEstimateTtftMs()); + info.setEstimateTpotMs(src.getEstimateTpotMs()); + info.setHitCacheLen(src.getHitCacheLen()); + return info; + } + + // ==================== Lifecycle ==================== + + public BatchSchedulerReporter getReporter() { + return reporter; + } + + @Scheduled(fixedRate = 20000L) + public void reportBatchMetrics() { + reporter.reportSchedulerInflightSize(inflight.size()); + + // Per-worker metrics: delegated to each PrefillEndpoint + for (Map.Entry entry : endpointRegistry.getPrefillEndpoints().entrySet()) { + entry.getValue().reportBatchMetrics(reporter); + } + } + + @PreDestroy + public void shutdown() { + endpointRegistry.close(); + } + + // ==================== Inflight entry ==================== + + static final class InflightEntry implements InflightEvictor.TtlTracked { + final BatchItem item; + private final long createdAtMs = System.currentTimeMillis(); + final AtomicBoolean cancelled = new AtomicBoolean(false); + final AtomicBoolean rolledBack = new AtomicBoolean(false); + boolean ackFinished; + + InflightEntry(BatchItem item) { + this.item = Objects.requireNonNull(item); + Objects.requireNonNull(item.prefill(), "BatchItem.prefill must not be null"); + } + + @Override + public long createdAtMs() { + return createdAtMs; + } + } +} diff --git a/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/scheduler/InflightEvictor.java b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/scheduler/InflightEvictor.java new file mode 100644 index 0000000000..b04705fc51 --- /dev/null +++ b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/scheduler/InflightEvictor.java @@ -0,0 +1,59 @@ +package org.flexlb.balance.scheduler; + +import java.util.Iterator; +import java.util.Map; +import java.util.function.Consumer; + +/** + * Generic TTL eviction manager for inflight maps across all scheduling layers. + * + *

Does NOT own the map — works on any {@link Map} whose values + * implement {@link TtlTracked}. Callers invoke {@link #evictExpired(long)} from + * their own {@code @Scheduled} cleanup methods. + * + * @param key type + * @param value type, must implement {@link TtlTracked} + */ +public class InflightEvictor { + + /** Interface required for inflight entries to be evictable by age. */ + public interface TtlTracked { + /** @return epoch-millis timestamp when this entry was created */ + long createdAtMs(); + } + + private final Map map; + private final Consumer onEvict; + + /** + * @param map the map to evict from (not owned by this evictor) + * @param onEvict called for each evicted entry (e.g. to adjust counters); + * may be null if no side effects are needed + */ + public InflightEvictor(Map map, Consumer onEvict) { + this.map = map; + this.onEvict = onEvict; + } + + /** + * Remove all entries older than {@code ttlMs} milliseconds. + * + * @param ttlMs max age before eviction + * @return number of entries evicted + */ + public int evictExpired(long ttlMs) { + long now = System.currentTimeMillis(); + int count = 0; + for (Iterator> it = map.entrySet().iterator(); it.hasNext(); ) { + Map.Entry entry = it.next(); + if (now - entry.getValue().createdAtMs() > ttlMs) { + it.remove(); + count++; + if (onEvict != null) { + onEvict.accept(entry.getValue()); + } + } + } + return count; + } +} diff --git a/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/scheduler/SloBudgetBatcherAlgorithm.java b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/scheduler/SloBudgetBatcherAlgorithm.java new file mode 100644 index 0000000000..ab4c8ea560 --- /dev/null +++ b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/scheduler/SloBudgetBatcherAlgorithm.java @@ -0,0 +1,396 @@ +package org.flexlb.balance.scheduler; + +import org.flexlb.balance.endpoint.PrefillEndpoint; +import org.flexlb.balance.strategy.PrefillTimePredictor; +import org.flexlb.util.Logger; + +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeUnit; + +/** + * SLO-deadline-aware batching algorithm with EMA arrival rate estimation, + * budget-based greedy fill, and deadline-gated dispatch. + * + *

This is the original algorithm migrated from the now-refactored + * {@link WorkerBatcher}. All mutable algorithm-specific state lives here. + */ +public class SloBudgetBatcherAlgorithm implements BatcherAlgorithm { + + // ==================== Algorithm-specific mutable state ==================== + + private volatile long lastOfferMs; + private volatile double interArrivalEmaMs; + private final Map lastParkByRequest = new ConcurrentHashMap<>(); + + // ==================== BatcherAlgorithm implementation ==================== + + @Override + public long computeSortKey(BatcherContext ctx, BatchItem item) { + long sloMs = ctx.cfg().resolveSloMs(item.seqLen()); + PrefillTimePredictor predictor = ctx.prefillEp().getPredictor(); + long predMs = predictor.estimateMs(item.seqLen(), item.hitCache()); + long workerQueueMs = ctx.prefillEp().realWaitTimeMs(); + return System.currentTimeMillis() + Math.max(0, sloMs - predMs - workerQueueMs); + } + + @Override + public void processQueue(BatcherContext ctx) throws InterruptedException { + if (ctx.isEmpty()) { + return; + } + + long windowMs = ctx.cfg().getFlexlbBatchWindowMs(); + int minBatchSize = ctx.cfg().getFlexlbBatchMinSize(); + long emergencyBudgetMs = ctx.cfg().getFlexlbBatchEmergencyBudgetMs(); + int maxScan = ctx.cfg().getFlexlbBatchScanAhead(); + int batchMaxTokens = ctx.cfg().getFlexlbBatchMaxCapacity(); + int batchMaxCount = Math.max(1, ctx.cfg().getFlexlbBatchSizeMax()); + + BatchItem head = ctx.peek(); + if (head == null) { + return; + } + long now = ctx.now(); + long budgetMs = head.deadlineMs() - now; + + // 1. expired → drop + if (budgetMs < 0) { + dropHead(ctx, head, now, budgetMs, "deadline_expired"); + return; + } + + int maxInflightBatches = ctx.cfg().getFlexlbBatchSloMaxInflightBatches(); + if (maxInflightBatches > 0 && ctx.prefillEp().getInflightBatchCount() >= maxInflightBatches) { + long inflightGuardMs = dispatchGuardMs(ctx, emergencyBudgetMs); + if (budgetMs <= inflightGuardMs) { + dropHead(ctx, head, now, budgetMs, "inflight_full_guard"); + return; + } + recordPark(ctx, head, "inflight_full", budgetMs, now); + parkBriefly(); + return; + } + + PrefillTimePredictor predictor = ctx.prefillEp().getPredictor(); + long effectiveTokens = Math.max(head.seqLen(), batchMaxTokens); + long baseGuardMs = dispatchGuardMs(ctx, emergencyBudgetMs); + BatchPick pick = pickWithinIncrementalBudget( + ctx, head, predictor, Math.max(0, budgetMs - baseGuardMs), maxScan, batchMaxCount, effectiveTokens); + List picked = pick.items(); + long incrementalCostMs = Math.max(0, pick.predMs() - pick.headPredMs()); + long latestDispatchBudgetMs = latestDispatchBudgetMs(baseGuardMs, emergencyBudgetMs, incrementalCostMs); + boolean insideWindow = budgetMs <= windowMs; + int targetBatchSize = insideWindow + ? targetBatchSize(ctx, minBatchSize, batchMaxCount, budgetMs, latestDispatchBudgetMs, now) + : batchMaxCount; + double fillRatio = targetBatchSize > 0 ? (double) picked.size() / targetBatchSize : 1.0; + boolean reachesMaxSize = picked.size() >= batchMaxCount; + boolean reachesTarget = picked.size() >= targetBatchSize; + boolean mustDispatch = budgetMs <= latestDispatchBudgetMs; + boolean shouldWaitForMore = shouldWaitForMore(ctx, + picked.size(), minBatchSize, batchMaxCount, targetBatchSize, budgetMs, latestDispatchBudgetMs, now); + DecisionTrace trace = new DecisionTrace( + targetBatchSize, + budgetMs, + latestDispatchBudgetMs, + Math.max(0, budgetMs - latestDispatchBudgetMs), + estimatedInterArrivalMs(ctx), + estimatedTimeToNextArrivalMs(ctx, now), + arrivalWaitGuardMs(ctx), + ctx.prefillEp().getInflightBatchCount(), + now); + + // 2. Dispatch decision. Predictor is used for admission and deadline + // protection; request count and arrival rate decide whether to keep + // waiting for a more efficient batch. + if (reachesMaxSize) { + dispatchBatch(ctx, picked, "batch_size_max", fillRatio, batchMaxTokens, trace); + } else if (mustDispatch) { + dispatchBatch(ctx, picked, "deadline_guard", fillRatio, batchMaxTokens, trace); + } else if (insideWindow && reachesTarget && !shouldWaitForMore) { + dispatchBatch(ctx, picked, "target_batch_size", fillRatio, batchMaxTokens, trace); + } else if (insideWindow && !shouldWaitForMore) { + dispatchBatch(ctx, picked, "arrival_guard", fillRatio, batchMaxTokens, trace); + } else { + recordPark(ctx, head, parkReason(insideWindow, picked.size(), minBatchSize, batchMaxCount, + targetBatchSize, shouldWaitForMore), budgetMs, now); + parkBriefly(); + } + } + + @Override + public void onOffer(BatcherContext ctx, BatchItem item, long nowMs) { + recordArrival(ctx, nowMs); + } + + @Override + public void onShutdown(BatcherContext ctx) { + lastParkByRequest.clear(); + } + + // ==================== Batch pick ==================== + + private BatchPick pickWithinIncrementalBudget(BatcherContext ctx, + BatchItem head, + PrefillTimePredictor predictor, + long budgetMs, + int maxScan, + int batchMaxCount, + long batchMaxTokens) { + List picked = new ArrayList<>(); + picked.add(head); + + long sumTokens = head.seqLen(); + long headPredMs = Math.max(0, predictor.predictBatchMs(picked)); + long maxPredMs = headPredMs + Math.max(0, budgetMs); + int scanned = 0; + + for (BatchItem c : ctx.sortedItems()) { + if (c == head) { + continue; + } + if (scanned >= maxScan || picked.size() >= batchMaxCount) { + break; + } + scanned++; + + long nextTokens = sumTokens + c.seqLen(); + if (nextTokens > batchMaxTokens) { + continue; + } + + List trial = new ArrayList<>(picked.size() + 1); + trial.addAll(picked); + trial.add(c); + long trialPredMs = Math.max(0, predictor.predictBatchMs(trial)); + if (trialPredMs <= maxPredMs) { + picked.add(c); + sumTokens = nextTokens; + } + } + return new BatchPick(picked, headPredMs, Math.max(headPredMs, predictor.predictBatchMs(picked))); + } + + // ==================== Target batch size ==================== + + private static int minTargetBatchSize(int minBatchSize, int batchMaxCount) { + return Math.max(1, Math.min(minBatchSize, batchMaxCount)); + } + + private int targetBatchSize(BatcherContext ctx, + int minBatchSize, + int batchMaxCount, + long budgetMs, + long latestDispatchBudgetMs, + long nowMs) { + int minTarget = minTargetBatchSize(minBatchSize, batchMaxCount); + if (batchMaxCount <= minTarget) { + return batchMaxCount; + } + long slackMs = Math.max(0, budgetMs - latestDispatchBudgetMs); + long usableSlackMs = Math.max(0, slackMs - arrivalWaitGuardMs(ctx)); + long arrivalMs = estimatedInterArrivalMs(ctx); + long nextArrivalMs = estimatedTimeToNextArrivalMs(ctx, nowMs); + if (arrivalMs <= 0 || nextArrivalMs > usableSlackMs) { + return minTarget; + } + long expectedMore = 1 + (usableSlackMs - nextArrivalMs) / Math.max(1, arrivalMs); + long target = (long) minTarget + expectedMore; + return (int) Math.max(minTarget, Math.min(batchMaxCount, target)); + } + + // ==================== Wait decision ==================== + + private boolean shouldWaitForMore(BatcherContext ctx, + int pickedSize, + int minBatchSize, + int batchMaxCount, + int targetBatchSize, + long budgetMs, + long latestDispatchBudgetMs, + long nowMs) { + if (pickedSize >= batchMaxCount) { + return false; + } + long slackMs = budgetMs - latestDispatchBudgetMs; + if (slackMs <= 1) { + return false; + } + long nextArrivalMs = estimatedTimeToNextArrivalMs(ctx, nowMs); + if (nextArrivalMs + arrivalWaitGuardMs(ctx) > slackMs) { + return false; + } + if (pickedSize < minTargetBatchSize(minBatchSize, batchMaxCount)) { + return true; + } + return pickedSize < targetBatchSize; + } + + // ==================== Budget guards ==================== + + private static long dispatchGuardMs(BatcherContext ctx, long emergencyBudgetMs) { + long configured = Math.max(1, ctx.cfg().getFlexlbBatchDispatchGuardMs()); + return emergencyBudgetMs > 0 ? Math.min(configured, emergencyBudgetMs) : configured; + } + + private static long latestDispatchBudgetMs(long baseGuardMs, long emergencyBudgetMs, long incrementalCostMs) { + long latest = Math.max(baseGuardMs, baseGuardMs + incrementalCostMs); + return emergencyBudgetMs > 0 ? Math.min(latest, emergencyBudgetMs) : latest; + } + + // ==================== Arrival rate estimation (EMA) ==================== + + private synchronized void recordArrival(BatcherContext ctx, long nowMs) { + if (lastOfferMs > 0 && nowMs > lastOfferMs) { + long intervalMs = Math.min(nowMs - lastOfferMs, + Math.max(1, ctx.cfg().getFlexlbBatchWindowMs())); + double alpha = Math.max(0.01, Math.min(1.0, ctx.cfg().getFlexlbBatchArrivalEmaAlpha())); + interArrivalEmaMs = interArrivalEmaMs <= 0 + ? intervalMs + : alpha * intervalMs + (1.0 - alpha) * interArrivalEmaMs; + } + lastOfferMs = nowMs; + } + + private long estimatedInterArrivalMs(BatcherContext ctx) { + double ema = interArrivalEmaMs; + if (ema > 0) { + return Math.max(1, Math.round(ema)); + } + long windowMs = Math.max(1, ctx.cfg().getFlexlbBatchWindowMs()); + int minBatchSize = Math.max(1, ctx.cfg().getFlexlbBatchMinSize()); + return Math.max(1, Math.round((double) windowMs / minBatchSize)); + } + + private long estimatedTimeToNextArrivalMs(BatcherContext ctx, long nowMs) { + long intervalMs = estimatedInterArrivalMs(ctx); + long lastMs = lastOfferMs; + if (lastMs <= 0 || nowMs <= lastMs) { + return intervalMs; + } + long elapsedMs = nowMs - lastMs; + if (interArrivalEmaMs <= 0 || elapsedMs >= intervalMs * 2) { + return intervalMs; + } + long remainderMs = elapsedMs % intervalMs; + return remainderMs == 0 ? 1 : Math.max(1, intervalMs - remainderMs); + } + + private static long arrivalWaitGuardMs(BatcherContext ctx) { + return Math.max(0, ctx.cfg().getFlexlbBatchArrivalWaitGuardMs()); + } + + // ==================== Park tracking ==================== + + private static String parkReason(boolean insideWindow, + int pickedSize, + int minBatchSize, + int batchMaxCount, + int targetBatchSize, + boolean shouldWaitForMore) { + if (!insideWindow) { + return "outside_window"; + } + if (!shouldWaitForMore) { + return "unknown"; + } + int minTarget = minTargetBatchSize(minBatchSize, batchMaxCount); + if (pickedSize < minTarget) { + return "wait_for_min_batch"; + } + if (pickedSize < targetBatchSize) { + return "wait_for_target_batch"; + } + return "wait_for_more"; + } + + private void recordPark(BatcherContext ctx, BatchItem head, String reason, long budgetMs, long nowMs) { + lastParkByRequest.put(head.requestId(), new ParkTrace( + reason, + budgetMs, + nowMs - head.enqueuedAtMs(), + ctx.size(), + ctx.prefillEp().getInflightBatchCount())); + } + + // ==================== Drop ==================== + + private void dropHead(BatcherContext ctx, BatchItem head, long nowMs, long budgetMs, String dropReason) { + int queueBefore = ctx.size(); + int inflightBatches = ctx.prefillEp().getInflightBatchCount(); + long waitMs = nowMs - head.enqueuedAtMs(); + long initialBudgetMs = head.deadlineMs() - head.enqueuedAtMs(); + ParkTrace parkTrace = lastParkByRequest.remove(head.requestId()); + if (parkTrace == null) { + parkTrace = ParkTrace.EMPTY; + } + Logger.warn("flexlb_batch_drop req_id={} seq_len={} wait_ms={} budget_ms={} worker={} " + + "drop_reason={} initial_budget_ms={} deadline_ms={} enqueued_at_ms={} queue_size={} " + + "inflight_batches={} last_park_reason={} last_park_budget_ms={} " + + "last_park_wait_ms={} last_park_queue_size={} last_park_inflight_batches={}", + head.requestId(), head.seqLen(), waitMs, budgetMs, ctx.key(), + dropReason, initialBudgetMs, head.deadlineMs(), head.enqueuedAtMs(), queueBefore, + inflightBatches, parkTrace.reason(), parkTrace.budgetMs(), + parkTrace.waitMs(), parkTrace.queueSize(), parkTrace.inflightBatches()); + ctx.dropHead(head); + } + + // ==================== Dispatch ==================== + + private void dispatchBatch(BatcherContext ctx, + List picked, + String reason, + double fillRatio, + long batchMaxTokens, + DecisionTrace trace) { + BatchItem head = picked.get(0); + Logger.info("flexlb_batch_decision reason={} picked_size={} target_batch_size={} " + + "fill_ratio={} wait_ms={} budget_ms={} slack_ms={} latest_dispatch_budget_ms={} " + + "arrival_ema_ms={} next_arrival_ms={} arrival_wait_guard_ms={} " + + "inflight_batches={} queue_before={} worker={} head_req_id={}", + reason, picked.size(), trace.targetBatchSize(), fillRatio, + trace.nowMs() - head.enqueuedAtMs(), trace.budgetMs(), trace.slackMs(), + trace.latestDispatchBudgetMs(), trace.arrivalEmaMs(), trace.nextArrivalMs(), + trace.arrivalWaitGuardMs(), trace.inflightBatches(), ctx.size(), ctx.key(), + head.requestId()); + for (BatchItem item : picked) { + lastParkByRequest.remove(item.requestId()); + } + ctx.dispatch(picked, + new DispatchMeta(reason, fillRatio, batchMaxTokens, ctx.size() - picked.size())); + } + + // ==================== Park ==================== + + private static void parkBriefly() throws InterruptedException { + TimeUnit.MILLISECONDS.sleep(1); + } + + // ==================== Inner records ==================== + + private record BatchPick(List items, long headPredMs, long predMs) { + } + + private record DecisionTrace(int targetBatchSize, + long budgetMs, + long latestDispatchBudgetMs, + long slackMs, + long arrivalEmaMs, + long nextArrivalMs, + long arrivalWaitGuardMs, + int inflightBatches, + long nowMs) { + } + + private record ParkTrace(String reason, + long budgetMs, + long waitMs, + int queueSize, + int inflightBatches) { + private static final ParkTrace EMPTY = new ParkTrace("none", -1, -1, -1, -1); + } +} diff --git a/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/scheduler/WorkerBatcher.java b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/scheduler/WorkerBatcher.java new file mode 100644 index 0000000000..a83e70760a --- /dev/null +++ b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/scheduler/WorkerBatcher.java @@ -0,0 +1,136 @@ +package org.flexlb.balance.scheduler; + +import org.flexlb.balance.endpoint.PrefillEndpoint; +import org.flexlb.config.FlexlbConfig; +import org.flexlb.service.monitor.BatchSchedulerReporter; +import org.flexlb.util.Logger; + +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import java.util.concurrent.CancellationException; +import java.util.concurrent.PriorityBlockingQueue; + +/** + * Per-worker request batcher that owns the queue and lifecycle, delegating + * dispatch decision logic to a pluggable {@link BatcherAlgorithm}. + * + *

One instance per Prefill worker. Requests are submitted via + * {@link #offer(BatchItem)} and batched by the configured algorithm. + */ +public class WorkerBatcher { + + private final String key; + private final PrefillEndpoint prefillEp; + private final FlexlbConfig cfg; + private final BatchDecisionHandler handler; + private final PriorityBlockingQueue queue = + new PriorityBlockingQueue<>(11, Comparator.comparingLong(BatchItem::sortKey)); + private final Thread workerThread; + private volatile boolean stopped; + private final BatcherAlgorithm algorithm; + private final BatcherContext ctx; + + public WorkerBatcher(String key, PrefillEndpoint prefillEp, FlexlbConfig cfg, + BatchDecisionHandler handler, + BatchSchedulerReporter reporter) { + this.key = key; + this.prefillEp = prefillEp; + this.cfg = cfg; + this.handler = handler; + this.algorithm = createAlgorithm(cfg); + this.ctx = new BatcherContext(key, prefillEp, cfg, handler, queue, reporter); + this.workerThread = new Thread(this::runLoop, "flexlb-batcher-" + key); + this.workerThread.setDaemon(true); + this.workerThread.setUncaughtExceptionHandler((t, e) -> + Logger.error("WorkerBatcher[{}] thread died unexpectedly", key, e)); + } + + private static BatcherAlgorithm createAlgorithm(FlexlbConfig config) { + String algoName = config.getFlexlbBatchAlgorithm(); + if ("fixed_window".equalsIgnoreCase(algoName)) { + return new FixedWindowBatcherAlgorithm(); + } + // Fallback: slo_budget for any unrecognized value + return new SloBudgetBatcherAlgorithm(); + } + + public void start() { + workerThread.start(); + } + + public void offer(BatchItem item) { + if (stopped) { + handler.onOfferFailure(item, new IllegalStateException("FlexLB batcher stopped")); + return; + } + int maxSize = cfg.getFlexlbBatchQueueMaxSize(); + if (maxSize > 0 && queue.size() >= maxSize) { + handler.onOfferFailure(item, + new IllegalStateException("FlexLB batcher queue full, maxSize=" + maxSize)); + return; + } + long sortKey = algorithm.computeSortKey(ctx, item); + item.setSortKey(sortKey); + algorithm.onOffer(ctx, item, System.currentTimeMillis()); + queue.add(item); + } + + public int queueSize() { + return queue.size(); + } + + public long headSortKey() { + BatchItem head = queue.peek(); + return head != null ? head.sortKey() : 0; + } + + /** + * Estimated remaining wait time of the head request. + * Delegates to the algorithm-specific {@link BatcherAlgorithm#headWaitMs}. + */ + public long headWaitMs() { + return algorithm.headWaitMs(ctx); + } + + /** + * Estimated time a new request would wait in the queue before dispatch. + * Delegates to the algorithm-specific {@link BatcherAlgorithm#queueWaitMs}. + */ + public long queueWaitMs() { + return algorithm.queueWaitMs(ctx); + } + + public void shutdown() { + stopped = true; + workerThread.interrupt(); + algorithm.onShutdown(ctx); + List remaining = new ArrayList<>(); + queue.drainTo(remaining); + for (BatchItem item : remaining) { + handler.onOfferFailure(item, + new CancellationException("FlexLB batcher stopped: " + key)); + } + } + + // ==================== Internal: Run loop ==================== + + private void runLoop() { + while (!stopped && !Thread.currentThread().isInterrupted()) { + try { + waitForNonEmpty(); + algorithm.processQueue(ctx); + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + return; + } catch (Throwable t) { + Logger.error("WorkerBatcher[{}] loop failed", key, t); + } + } + } + + private void waitForNonEmpty() throws InterruptedException { + BatchItem item = queue.take(); + queue.put(item); + } +} diff --git a/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/strategy/CostBasedDecodeStrategy.java b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/strategy/CostBasedDecodeStrategy.java new file mode 100644 index 0000000000..0966f89f8e --- /dev/null +++ b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/strategy/CostBasedDecodeStrategy.java @@ -0,0 +1,276 @@ +package org.flexlb.balance.strategy; + +import org.apache.commons.collections4.CollectionUtils; +import org.apache.commons.collections4.MapUtils; +import org.flexlb.balance.endpoint.DecodeEndpoint; +import org.flexlb.balance.endpoint.EndpointRegistry; +import org.flexlb.balance.resource.DecodeResourceMeasure; +import org.flexlb.balance.resource.ResourceMeasureFactory; +import org.flexlb.config.ConfigService; +import org.flexlb.config.FlexlbConfig; +import org.flexlb.dao.BalanceContext; +import org.flexlb.dao.loadbalance.Request; +import org.flexlb.dao.loadbalance.ServerStatus; +import org.flexlb.dao.loadbalance.StrategyErrorType; +import org.flexlb.balance.endpoint.WorkerEndpoint; +import org.flexlb.dao.master.CacheStatus; +import org.flexlb.dao.route.RoleType; +import org.flexlb.enums.LoadBalanceStrategyEnum; +import org.flexlb.enums.ResourceMeasureIndicatorEnum; +import org.flexlb.sync.status.EngineWorkerStatus; +import org.flexlb.util.CommonUtils; +import org.flexlb.util.Logger; +import org.springframework.stereotype.Component; + +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ThreadLocalRandom; + +@Component("costBasedDecodeStrategy") +public class CostBasedDecodeStrategy implements LoadBalancer { + + private final EngineWorkerStatus engineWorkerStatus; + private final double decayFactor; + private final ResourceMeasureFactory resourceMeasureFactory; + private final EndpointRegistry endpointRegistry; + + public CostBasedDecodeStrategy(ConfigService configService, + EngineWorkerStatus engineWorkerStatus, + ResourceMeasureFactory resourceMeasureFactory, + EndpointRegistry endpointRegistry) { + this.engineWorkerStatus = engineWorkerStatus; + FlexlbConfig config = configService.loadBalanceConfig(); + this.decayFactor = config.getWeightedCacheDecayFactor(); + this.resourceMeasureFactory = resourceMeasureFactory; + this.endpointRegistry = endpointRegistry; + LoadBalanceStrategyFactory.register(LoadBalanceStrategyEnum.COST_BASED_DECODE, this); + } + + private record WeightedWorker(DecodeEndpoint endpoint, long normalizedCacheUsed, double weight) { + } + + @Override + public ServerStatus select(BalanceContext balanceContext, RoleType roleType, String group) { + Request request = balanceContext.getRequest(); + long seqLen = request.getSeqLen(); + FlexlbConfig config = balanceContext.getConfig(); + + EndpointFilterResult filterResult = getAvailableEndpoints(roleType, group, config.getResourceMeasureIndicator(roleType)); + List eligible = filterResult.endpoints(); + if (CollectionUtils.isEmpty(eligible)) { + Logger.warn("Decode select failed: no available endpoints, request_id={}, rejections={}", + balanceContext.getRequestId(), filterResult.rejections()); + return ServerStatus.code(StrategyErrorType.NO_AVAILABLE_WORKER); + } + + FilterResult hardFilterResult = applyHardFilters(eligible, seqLen, config); + List survivors = hardFilterResult.endpoints(); + + DecodeEndpoint selectedEndpoint = weightedRandomSelection(survivors); + + if (selectedEndpoint != null) { + long prefixLength = calcPrefixMatchLength(selectedEndpoint.getStatus().getCacheStatus(), balanceContext.getRequest().getBlockCacheKeys()); + return buildServerStatus(selectedEndpoint, seqLen, prefixLength, roleType, balanceContext.getRequestId()); + } + + Map merged = new java.util.HashMap<>(filterResult.rejections()); + hardFilterResult.rejections().forEach((k, v) -> merged.merge(k, v, Integer::sum)); + Logger.warn("Decode select failed: all filtered out, request_id={}, rejections={}", + balanceContext.getRequestId(), merged); + return ServerStatus.code(StrategyErrorType.NO_AVAILABLE_WORKER); + } + + private record EndpointFilterResult(List endpoints, Map rejections) {} + private record FilterResult(List endpoints, Map rejections) {} + + private EndpointFilterResult getAvailableEndpoints(RoleType roleType, String group, ResourceMeasureIndicatorEnum indicator) { + Map workerEndpointMap = engineWorkerStatus.selectModelWorkerStatus(roleType, group); + if (MapUtils.isEmpty(workerEndpointMap)) { + return new EndpointFilterResult(new ArrayList<>(), Map.of("NO_REGISTERED", 1)); + } + DecodeResourceMeasure measure = (DecodeResourceMeasure) resourceMeasureFactory.getMeasure(indicator); + if (measure == null) { + return new EndpointFilterResult(new ArrayList<>(), Map.of("NO_REGISTERED", 1)); + } + List result = new ArrayList<>(); + Map rejections = new java.util.HashMap<>(); + for (WorkerEndpoint ep : workerEndpointMap.values()) { + if (!(ep instanceof DecodeEndpoint de)) { + continue; + } + if (!de.getStatus().isAlive()) { + rejections.merge("NOT_ALIVE", 1, Integer::sum); + continue; + } + if (!measure.isResourceAvailable(de)) { + rejections.merge("RESOURCE_UNAVAILABLE", 1, Integer::sum); + continue; + } + result.add(de); + } + return new EndpointFilterResult(result, rejections); + } + + @Override + public void rollBack(String ipPort, long requestId) { + Logger.debug("Decode rollBack - ip: {}, requestId: {}", ipPort, requestId); + + DecodeEndpoint ep = endpointRegistry.getDecode(ipPort); + if (ep != null) { + ep.release(requestId); + } + } + + private FilterResult applyHardFilters(List eligible, long seqLen, FlexlbConfig config) { + double hotspotMultiplier = config.getDecodeHotspotMultiplier(); + double imbalanceMultiplier = config.getDecodeImbalanceMultiplier(); + + long sumLoad = 0; + long sumCacheUsed = 0; + for (DecodeEndpoint ep : eligible) { + sumLoad += ep.getTotalLoad(); + sumCacheUsed += ep.realKvUsed(); + } + long avgLoad = sumLoad / eligible.size(); + long avgCacheUsed = sumCacheUsed / eligible.size(); + + List survivors = new ArrayList<>(eligible.size()); + Map rejections = new java.util.HashMap<>(); + for (DecodeEndpoint ep : eligible) { + long availableKv = ep.realKvAvailable(); + long totalKv = ep.realKvTotal(); + if (totalKv > 0 && availableKv < seqLen) { + rejections.merge("KV_CAPACITY", 1, Integer::sum); + continue; + } + long load = ep.getTotalLoad(); + if (hotspotMultiplier > 0 && avgLoad > 0 + && load > avgLoad * hotspotMultiplier) { + rejections.merge("HOTSPOT_FILTERED", 1, Integer::sum); + continue; + } + if (imbalanceMultiplier > 0 && avgCacheUsed > 0 + && ep.realKvUsed() > avgCacheUsed * imbalanceMultiplier) { + rejections.merge("IMBALANCE_FILTERED", 1, Integer::sum); + continue; + } + survivors.add(ep); + } + + if (survivors.isEmpty()) { + DecodeEndpoint leastUsed = eligible.stream() + .min(Comparator.comparingLong(DecodeEndpoint::realKvUsed)) + .orElse(null); + if (leastUsed != null) { + survivors.add(leastUsed); + } + } + return new FilterResult(survivors, rejections); + } + + private long calcPrefixMatchLength(CacheStatus cacheStatus, List promptCacheKeys) { + + if (cacheStatus == null || promptCacheKeys == null) { + return 0; + } + long blockSize = cacheStatus.getBlockSize(); + Set cachePrefixHash = cacheStatus.getCachedKeys(); + if (cachePrefixHash == null) { + return 0; + } + + for (int index = 0; index < promptCacheKeys.size(); index++) { + long hash = promptCacheKeys.get(index); + if (!cachePrefixHash.contains(hash)) { + return blockSize * index; + } + } + + return blockSize * promptCacheKeys.size(); + } + + private DecodeEndpoint weightedRandomSelection(List candidateEndpoints) { + int workerCount = candidateEndpoints.size(); + if (workerCount == 0) { + return null; + } + + long totalCacheUsed = 0; + for (DecodeEndpoint ep : candidateEndpoints) { + totalCacheUsed += ep.realKvUsed(); + } + double avgCacheUsed = (double) totalCacheUsed / workerCount; + + List weightedEndpoints = new ArrayList<>(); + boolean allSameUsage = true; + double totalWeight = 0; + Long firstCacheUsed = null; + + for (DecodeEndpoint ep : candidateEndpoints) { + long cacheUsed = ep.realKvUsed(); + double normalizedValue = cacheUsed - avgCacheUsed; + + if (firstCacheUsed == null) { + firstCacheUsed = cacheUsed; + } else if (cacheUsed != firstCacheUsed) { + allSameUsage = false; + } + + double weight = Math.exp(-decayFactor * normalizedValue); + + weightedEndpoints.add(new WeightedWorker(ep, (long) normalizedValue, weight)); + totalWeight += weight; + } + + if (totalWeight <= 0) { + Logger.warn("Total weight is zero or negative: {}, using uniform random selection", totalWeight); + int randomIndex = ThreadLocalRandom.current().nextInt(workerCount); + return candidateEndpoints.get(randomIndex); + } + + if (allSameUsage) { + int randomIndex = ThreadLocalRandom.current().nextInt(workerCount); + return candidateEndpoints.get(randomIndex); + } + + double randomValue = ThreadLocalRandom.current().nextDouble() * totalWeight; + double cumulativeWeight = 0; + + for (WeightedWorker weightedEndpoint : weightedEndpoints) { + cumulativeWeight += weightedEndpoint.weight; + if (Double.compare(randomValue, cumulativeWeight) <= 0) { + return weightedEndpoint.endpoint; + } + } + + return weightedEndpoints.stream() + .min(Comparator.comparingLong(w -> w.endpoint.realKvUsed())) + .map(w -> w.endpoint) + .orElse(null); + } + + private ServerStatus buildServerStatus(DecodeEndpoint optimalEndpoint, long seqLen, long prefixLength, RoleType roleType, long requestId) { + ServerStatus result = new ServerStatus(); + try { + optimalEndpoint.reserve(requestId, seqLen); + + result.setSuccess(true); + result.setRole(roleType); + result.setServerIp(optimalEndpoint.getIp()); + result.setHttpPort(optimalEndpoint.getHttpPort()); + result.setGrpcPort(CommonUtils.toGrpcPort(optimalEndpoint.getHttpPort())); + result.setDpRank(optimalEndpoint.getStatus().getDpRank()); + result.setGroup(optimalEndpoint.getStatus().getGroup()); + result.setRequestId(requestId); + } catch (Exception e) { + Logger.error("buildServerStatus error", e); + result.setSuccess(false); + result.setCode(StrategyErrorType.NO_AVAILABLE_WORKER.getErrorCode()); + result.setMessage(StrategyErrorType.NO_AVAILABLE_WORKER.getErrorMsg()); + } + return result; + } +} diff --git a/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/strategy/CostBasedPrefillStrategy.java b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/strategy/CostBasedPrefillStrategy.java new file mode 100644 index 0000000000..327896ac35 --- /dev/null +++ b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/strategy/CostBasedPrefillStrategy.java @@ -0,0 +1,277 @@ +package org.flexlb.balance.strategy; + +import org.apache.commons.collections4.CollectionUtils; +import org.apache.commons.collections4.MapUtils; +import org.flexlb.balance.endpoint.EndpointRegistry; +import org.flexlb.balance.endpoint.PrefillEndpoint; +import org.flexlb.balance.resource.PrefillResourceMeasure; +import org.flexlb.balance.resource.ResourceMeasureFactory; +import org.flexlb.cache.service.CacheAwareService; +import org.flexlb.config.FlexlbConfig; +import org.flexlb.dao.BalanceContext; +import org.flexlb.dao.loadbalance.ServerStatus; +import org.flexlb.dao.loadbalance.StrategyErrorType; +import org.flexlb.balance.endpoint.WorkerEndpoint; +import org.flexlb.dao.route.RoleType; +import org.flexlb.enums.LoadBalanceStrategyEnum; +import org.flexlb.enums.ResourceMeasureIndicatorEnum; + +import org.flexlb.service.monitor.EngineHealthReporter; +import org.flexlb.sync.status.EngineWorkerStatus; +import org.flexlb.util.CommonUtils; +import org.flexlb.util.Logger; +import org.springframework.stereotype.Component; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.Map; + +@Component("costBasedPrefillStrategy") +public class CostBasedPrefillStrategy implements LoadBalancer { + + private final EngineWorkerStatus engineWorkerStatus; + private final CacheAwareService cacheAwareService; + private final ResourceMeasureFactory resourceMeasureFactory; + private final EngineHealthReporter engineHealthReporter; + private final EndpointRegistry endpointRegistry; + + public CostBasedPrefillStrategy(EngineWorkerStatus engineWorkerStatus, + CacheAwareService cacheAwareService, + ResourceMeasureFactory resourceMeasureFactory, + EngineHealthReporter engineHealthReporter, + EndpointRegistry endpointRegistry) { + this.engineWorkerStatus = engineWorkerStatus; + this.cacheAwareService = cacheAwareService; + this.resourceMeasureFactory = resourceMeasureFactory; + this.engineHealthReporter = engineHealthReporter; + this.endpointRegistry = endpointRegistry; + LoadBalanceStrategyFactory.register(LoadBalanceStrategyEnum.COST_BASED_PREFILL, this); + } + + @Override + public ServerStatus select(BalanceContext balanceContext, RoleType roleType, String group) { + try { + return doSelect(balanceContext, roleType, group); + } catch (Exception e) { + Logger.warn("CostBasedPrefillStrategy select failed", e); + return ServerStatus.code(StrategyErrorType.NO_AVAILABLE_WORKER); + } + } + + @Override + public void rollBack(String ipPort, long requestId) { + // Release non-batch prefill inflight reservation on routing failure. + // Batch path inflight is managed by FlexlbBatchScheduler — no-op here. + PrefillEndpoint ep = endpointRegistry.getPrefill(ipPort); + if (ep != null) { + ep.releaseBatch(requestId); + } + } + + private ServerStatus doSelect(BalanceContext balanceContext, RoleType roleType, String group) { + long requestId = balanceContext.getRequestId(); + long seqLen = balanceContext.getRequest().getSeqLen(); + FlexlbConfig config = balanceContext.getConfig(); + + EndpointFilterResult filterResult = getAvailableEndpoints(roleType, group, config.getResourceMeasureIndicator(roleType)); + List eligible = filterResult.endpoints(); + if (CollectionUtils.isEmpty(eligible)) { + Logger.warn("Prefill select failed: no available endpoints, request_id={}, rejections={}", + requestId, filterResult.rejections()); + return ServerStatus.code(StrategyErrorType.NO_AVAILABLE_WORKER); + } + + Map cacheMatchResults = getCacheMatchResults(balanceContext, roleType, group); + + FilterResult hardFilterResult = applyHardFilters(eligible, seqLen, config, cacheMatchResults); + List survivors = hardFilterResult.endpoints(); + + PrefillEndpoint best = null; + long bestScore = Long.MAX_VALUE; + long bestCacheHit = 0; + + for (PrefillEndpoint ep : survivors) { + long cacheHit = calculateCacheHit(ep, cacheMatchResults); + long score = computeScore(ep); + + if (score < bestScore) { + bestScore = score; + best = ep; + bestCacheHit = cacheHit; + } + } + + if (best == null) { + Map merged = new java.util.HashMap<>(filterResult.rejections()); + hardFilterResult.rejections().forEach((k, v) -> merged.merge(k, v, Integer::sum)); + Logger.warn("Prefill select failed: all filtered out, request_id={}, rejections={}", + requestId, merged); + return ServerStatus.code(StrategyErrorType.NO_AVAILABLE_WORKER); + } + + reportCacheHitMetrics(roleType, best.getIp(), bestCacheHit, seqLen); + + return buildServerStatus(best, roleType, requestId, bestScore, config, balanceContext); + } + + private record EndpointFilterResult(List endpoints, Map rejections) {} + private record FilterResult(List endpoints, Map rejections) {} + + private FilterResult applyHardFilters(List eligible, long seqLen, + FlexlbConfig config, Map cacheMatchResults) { + long sloMs = config.resolveSloMs(seqLen); + long sloRiskMarginMs = config.getCostSloRiskMarginMs(); + boolean sloFilterEnabled = config.isCostSloFilterEnabled(); + double hotspotMultiplier = config.getCostHotspotMultiplier(); + double imbalanceMultiplier = config.getCostImbalanceMultiplier(); + + List feasible = new ArrayList<>(eligible.size()); + Map rejections = new java.util.HashMap<>(); + for (PrefillEndpoint ep : eligible) { + PrefillTimePredictor predictor = ep.getPredictor(); + if (predictor == null) { + rejections.merge("PREDICTOR_MISSING", 1, Integer::sum); + continue; + } + + long cacheHit = calculateCacheHit(ep, cacheMatchResults); + long singlePrefillMs = predictor.estimateMs(seqLen, cacheHit); + + long endpointWaitMs = ep.realWaitTimeMs(); + + if (sloFilterEnabled && endpointWaitMs + singlePrefillMs > sloMs - sloRiskMarginMs) { + rejections.merge("SLO_VIOLATION", 1, Integer::sum); + continue; + } + + feasible.add(ep); + } + + if (feasible.isEmpty()) { + return new FilterResult(feasible, rejections); + } + + long sumWaitMs = 0; + long sumPendingCount = 0; + for (PrefillEndpoint ep : feasible) { + sumWaitMs += ep.realWaitTimeMs(); + sumPendingCount += ep.realPendingCount(); + } + long avgWaitMs = sumWaitMs / feasible.size(); + long avgPendingCount = sumPendingCount / feasible.size(); + + List survivors = new ArrayList<>(feasible.size()); + for (PrefillEndpoint ep : feasible) { + long endpointWaitMs = ep.realWaitTimeMs(); + long pendingCount = ep.realPendingCount(); + + if (hotspotMultiplier > 0 && avgPendingCount > 0 && pendingCount > avgPendingCount * hotspotMultiplier) { + rejections.merge("HOTSPOT_FILTERED", 1, Integer::sum); + continue; + } + if (imbalanceMultiplier > 0 && avgWaitMs > 0 && endpointWaitMs > avgWaitMs * imbalanceMultiplier) { + rejections.merge("IMBALANCE_FILTERED", 1, Integer::sum); + continue; + } + + survivors.add(ep); + } + + if (survivors.isEmpty()) { + PrefillEndpoint leastLoaded = feasible.stream() + .min(Comparator.comparingLong(PrefillEndpoint::realWaitTimeMs)) + .orElse(null); + if (leastLoaded != null) { + survivors.add(leastLoaded); + } + } + + return new FilterResult(survivors, rejections); + } + + private long computeScore(PrefillEndpoint ep) { + return ep.batcherWaitMs() + ep.realWaitTimeMs(); + } + + private EndpointFilterResult getAvailableEndpoints(RoleType roleType, String group, ResourceMeasureIndicatorEnum indicator) { + Map workerEndpointMap = engineWorkerStatus.selectModelWorkerStatus(roleType, group); + if (MapUtils.isEmpty(workerEndpointMap)) { + return new EndpointFilterResult(new ArrayList<>(), Map.of("NO_REGISTERED", 1)); + } + PrefillResourceMeasure measure = (PrefillResourceMeasure) resourceMeasureFactory.getMeasure(indicator); + if (measure == null) { + return new EndpointFilterResult(new ArrayList<>(), Map.of("NO_REGISTERED", 1)); + } + List result = new ArrayList<>(); + Map rejections = new java.util.HashMap<>(); + + for (WorkerEndpoint ep : workerEndpointMap.values()) { + if (!(ep instanceof PrefillEndpoint pe)) { + continue; + } + if (!pe.getStatus().isAlive()) { + rejections.merge("NOT_ALIVE", 1, Integer::sum); + continue; + } + if (!measure.isResourceAvailable(pe)) { + rejections.merge("RESOURCE_UNAVAILABLE", 1, Integer::sum); + continue; + } + result.add(pe); + } + return new EndpointFilterResult(result, rejections); + } + + private Map getCacheMatchResults(BalanceContext balanceContext, RoleType roleType, String group) { + List blockCacheKeys = balanceContext.getRequest().getBlockCacheKeys(); + return cacheAwareService.findMatchingEngines(blockCacheKeys, roleType, group); + } + + private long calculateCacheHit(PrefillEndpoint ep, Map cacheMatchResults) { + if (ep.getStatus().getCacheStatus() == null || cacheMatchResults == null) { + return 0L; + } + Integer prefixMatchLength = cacheMatchResults.get(ep.ipPort()); + if (prefixMatchLength == null) { + return 0L; + } + return ep.getStatus().getCacheStatus().getBlockSize() * prefixMatchLength; + } + + private void reportCacheHitMetrics(RoleType roleType, String ip, long hitCacheTokens, long seqLen) { + double hitRate = seqLen > 0 ? hitCacheTokens / (double) seqLen : 0.0; + engineHealthReporter.reportCacheHitMetrics(roleType, ip, hitCacheTokens, hitRate); + } + + private ServerStatus buildServerStatus(PrefillEndpoint ep, RoleType roleType, long requestId, long score, + FlexlbConfig config, BalanceContext balanceContext) { + // Non-batch path: reserve prefill inflight for load-aware scoring. + // Batch path uses FlexlbBatchScheduler.commitBatch() instead — skip here to avoid double-counting. + if (isNonBatchPath(config, balanceContext)) { + ep.commitBatch(requestId, score, Collections.emptyList()); + } + + ServerStatus result = new ServerStatus(); + result.setSuccess(true); + result.setRole(roleType); + result.setRequestId(requestId); + result.setPrefillTime(score); + result.setGroup(ep.getStatus().getGroup()); + result.setServerIp(ep.getIp()); + result.setHttpPort(ep.getHttpPort()); + result.setGrpcPort(CommonUtils.toGrpcPort(ep.getHttpPort())); + result.setDpRank(ep.getStatus().getDpRank()); + return result; + } + + /** + * Whether batch dispatching is globally disabled. + *

When batch is enabled, FlexlbBatchScheduler handles all inflight tracking; + * placeholders are only needed when batch is fully off ({@code flexlbBatchEnabled=false}). + */ + private static boolean isNonBatchPath(FlexlbConfig config, BalanceContext ctx) { + return !config.isFlexlbBatchEnabled(); + } +} diff --git a/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/strategy/PrefillTimePredictor.java b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/strategy/PrefillTimePredictor.java new file mode 100644 index 0000000000..b6c490a5e7 --- /dev/null +++ b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/strategy/PrefillTimePredictor.java @@ -0,0 +1,51 @@ +package org.flexlb.balance.strategy; + +import org.flexlb.balance.scheduler.BatchItem; + +import java.util.List; + +/** + * α₀ + α₁·Σcᵢ + α₂·Σcᵢ² + α₃·Σ(cᵢ·pᵢ) + α₄·Σpᵢ + α₅·bs + * + * where cᵢ = inputLen - hitCacheTokens (compute tokens), pᵢ = hitCacheTokens, bs = batch size. + */ +public class PrefillTimePredictor { + + private final double a0, a1, a2, a3, a4, a5; + + public PrefillTimePredictor(double a0, double a1, double a2, double a3, double a4, double a5) { + this.a0 = a0; + this.a1 = a1; + this.a2 = a2; + this.a3 = a3; + this.a4 = a4; + this.a5 = a5; + } + + /** Estimate prefill time for a single request from raw token counts. */ + public long estimateMs(long totalTokens, long hitTokens) { + long c = Math.max(0, totalTokens - hitTokens); + return (long) (a0 + a1 * c + a2 * c * c + a3 * c * hitTokens + a4 * hitTokens + a5); + } + + /** Estimate prefill time for a batch of {@link BatchItem}s. */ + public long predictBatchMs(List items) { + if (items.isEmpty()) { + return 0; + } + int bs = items.size(); + long sumC = 0; + double sumQuadratic = 0; + long sumP = 0; + + for (BatchItem item : items) { + long c = item.computeTokens(); + long p = item.hitCache(); + sumC += c; + sumQuadratic += a2 * c * c + a3 * c * p; + sumP += p; + } + + return (long) (a0 + a1 * sumC + sumQuadratic + a4 * sumP + a5 * bs); + } +} diff --git a/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/strategy/RandomStrategy.java b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/strategy/RandomStrategy.java index 0882a91946..3a709336a0 100644 --- a/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/strategy/RandomStrategy.java +++ b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/strategy/RandomStrategy.java @@ -2,6 +2,8 @@ import org.apache.commons.collections4.CollectionUtils; import org.apache.commons.collections4.MapUtils; +import org.flexlb.balance.endpoint.EndpointRegistry; +import org.flexlb.balance.endpoint.WorkerEndpoint; import org.flexlb.balance.resource.ResourceMeasure; import org.flexlb.balance.resource.ResourceMeasureFactory; import org.flexlb.config.ConfigService; @@ -10,8 +12,6 @@ import org.flexlb.dao.loadbalance.Request; import org.flexlb.dao.loadbalance.ServerStatus; import org.flexlb.dao.loadbalance.StrategyErrorType; -import org.flexlb.dao.master.TaskInfo; -import org.flexlb.dao.master.WorkerStatus; import org.flexlb.dao.route.RoleType; import org.flexlb.enums.LoadBalanceStrategyEnum; import org.flexlb.enums.ResourceMeasureIndicatorEnum; @@ -33,23 +33,21 @@ public class RandomStrategy implements LoadBalancer { private final EngineWorkerStatus engineWorkerStatus; private final ConfigService configService; private final ResourceMeasureFactory resourceMeasureFactory; + private final EndpointRegistry endpointRegistry; public RandomStrategy(EngineWorkerStatus engineWorkerStatus, ConfigService configService, - ResourceMeasureFactory resourceMeasureFactory) { + ResourceMeasureFactory resourceMeasureFactory, + EndpointRegistry endpointRegistry) { this.engineWorkerStatus = engineWorkerStatus; this.configService = configService; this.resourceMeasureFactory = resourceMeasureFactory; + this.endpointRegistry = endpointRegistry; LoadBalanceStrategyFactory.register(LoadBalanceStrategyEnum.RANDOM, this); } @Override public void rollBack(String ipPort, long requestId) { - Map workerStatusMap = engineWorkerStatus.selectModelWorkerStatus(RoleType.DECODE, null); - WorkerStatus workerStatus = workerStatusMap.get(ipPort); - if (workerStatus != null) { - workerStatus.removeLocalTask(requestId); - } } @Override @@ -57,26 +55,26 @@ public ServerStatus select(BalanceContext balanceContext, RoleType roleType, Str Request request = balanceContext.getRequest(); logger.debug("Selecting worker for , role: {}, group: {}", roleType, group); - Map workerStatusMap = engineWorkerStatus.selectModelWorkerStatus(roleType, group); + Map workerEndpointMap = engineWorkerStatus.selectModelWorkerStatus(roleType, group); - if (MapUtils.isEmpty(workerStatusMap)) { + if (MapUtils.isEmpty(workerEndpointMap)) { logger.warn("No worker status map found"); return ServerStatus.code(StrategyErrorType.NO_AVAILABLE_WORKER); } - List workerStatuses = new ArrayList<>(workerStatusMap.values()); - if (CollectionUtils.isEmpty(workerStatuses)) { + List endpoints = new ArrayList<>(workerEndpointMap.values()); + if (CollectionUtils.isEmpty(endpoints)) { logger.warn("No available workers"); return ServerStatus.code(StrategyErrorType.NO_AVAILABLE_WORKER); } // Random select with wrap-around to skip dead workers, no extra allocation - int size = workerStatuses.size(); + int size = endpoints.size(); int startIndex = ThreadLocalRandom.current().nextInt(size); - WorkerStatus selectedWorker = null; + WorkerEndpoint selectedWorker = null; for (int i = 0; i < size; i++) { - WorkerStatus ws = workerStatuses.get((startIndex + i) % size); - if (isWorkerAvailable(balanceContext, roleType, ws)) { - selectedWorker = ws; + WorkerEndpoint ep = endpoints.get((startIndex + i) % size); + if (isWorkerAvailable(balanceContext, roleType, ep)) { + selectedWorker = ep; break; } } @@ -85,12 +83,12 @@ public ServerStatus select(BalanceContext balanceContext, RoleType roleType, Str return ServerStatus.code(StrategyErrorType.NO_AVAILABLE_WORKER); } - logger.debug("Selected worker ip: {}, httpPort: {}", selectedWorker.getIp(), selectedWorker.getPort()); + logger.debug("Selected worker ip: {}, httpPort: {}", selectedWorker.getIp(), selectedWorker.getHttpPort()); return buildServerStatus(selectedWorker, roleType, balanceContext.getRequestId(), request); } - private boolean isWorkerAvailable(BalanceContext balanceContext, RoleType roleType, WorkerStatus workerStatus) { - if (workerStatus == null || !workerStatus.isAlive()) { + private boolean isWorkerAvailable(BalanceContext balanceContext, RoleType roleType, WorkerEndpoint ep) { + if (ep == null || !ep.getStatus().isAlive()) { return false; } @@ -99,25 +97,19 @@ private boolean isWorkerAvailable(BalanceContext balanceContext, RoleType roleTy : configService.loadBalanceConfig(); ResourceMeasureIndicatorEnum indicator = config.getResourceMeasureIndicator(roleType); ResourceMeasure resourceMeasure = resourceMeasureFactory.getMeasure(indicator); - return resourceMeasure == null || resourceMeasure.isResourceAvailable(workerStatus); + return resourceMeasure == null || resourceMeasure.isResourceAvailable(ep); } - private ServerStatus buildServerStatus(WorkerStatus worker, RoleType roleType, long requestId, Request request) { + private ServerStatus buildServerStatus(WorkerEndpoint ep, RoleType roleType, long requestId, Request request) { ServerStatus result = new ServerStatus(); try { - if (RoleType.DECODE == roleType) { - TaskInfo taskInfo = new TaskInfo(); - taskInfo.setRequestId(requestId); - taskInfo.setInputLength(request == null ? 0 : request.getSeqLen()); - taskInfo.setPrefixLength(0); - worker.putLocalTask(requestId, taskInfo); - } result.setSuccess(true); - result.setServerIp(worker.getIp()); - result.setHttpPort(worker.getPort()); - result.setGrpcPort(CommonUtils.toGrpcPort(worker.getPort())); + result.setServerIp(ep.getIp()); + result.setHttpPort(ep.getHttpPort()); + result.setGrpcPort(CommonUtils.toGrpcPort(ep.getHttpPort())); + result.setDpRank(ep.getStatus().getDpRank()); result.setRole(roleType); - result.setGroup(worker.getGroup()); + result.setGroup(ep.getStatus().getGroup()); result.setRequestId(requestId); } catch (Exception e) { Logger.error("buildServerStatus error", e); diff --git a/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/strategy/ShortestTTFTStrategy.java b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/strategy/ShortestTTFTStrategy.java deleted file mode 100644 index 514e5dcda2..0000000000 --- a/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/strategy/ShortestTTFTStrategy.java +++ /dev/null @@ -1,548 +0,0 @@ -package org.flexlb.balance.strategy; - -import org.apache.commons.collections4.CollectionUtils; -import org.apache.commons.collections4.MapUtils; -import org.flexlb.balance.resource.ResourceMeasure; -import org.flexlb.balance.resource.ResourceMeasureFactory; -import org.flexlb.cache.service.CacheAwareService; -import org.flexlb.config.ConfigService; -import org.flexlb.config.FlexlbConfig; -import org.flexlb.config.StrategyConfigs; -import org.flexlb.dao.BalanceContext; -import org.flexlb.dao.loadbalance.Request; -import org.flexlb.dao.loadbalance.ServerStatus; -import org.flexlb.dao.loadbalance.StrategyErrorType; -import org.flexlb.dao.master.TaskInfo; -import org.flexlb.dao.master.WorkerStatus; -import org.flexlb.dao.route.RoleType; -import org.flexlb.domain.worker.ScoredWorker; -import org.flexlb.enums.LoadBalanceStrategyEnum; -import org.flexlb.enums.ResourceMeasureIndicatorEnum; -import org.flexlb.service.monitor.EngineHealthReporter; -import org.flexlb.sync.status.EngineWorkerStatus; -import org.flexlb.util.CommonUtils; -import org.flexlb.util.Logger; -import org.springframework.stereotype.Component; - -import java.util.ArrayList; -import java.util.Comparator; -import java.util.List; -import java.util.Map; -import java.util.stream.Collectors; - -/** - * Load balancing strategy based on shortest Time-To-First-Token (TTFT) - * - *

This strategy selects the optimal worker by considering the following factors: - * 1. KV-Cache hit rate: Prioritize workers with higher cache hit rates - * 2. Queue time: Consider the current task queue status of workers - * 3. Scheduling fairness: Achieve load balancing among workers with similar performance - * - * @author saichen.sm - * @since 2025/3/10 - */ -@Component("shortestTTFTStrategy") -public class ShortestTTFTStrategy implements LoadBalancer { - - private final EngineWorkerStatus engineWorkerStatus; - private final EngineHealthReporter engineHealthReporter; - private final CacheAwareService cacheAwareService; - private final ResourceMeasureFactory resourceMeasureFactory; - private final ConfigService configService; - - private static final double TTFT_THRESHOLD_PERCENTAGE = 0.1; - private static final double STDDEV_THRESHOLD_FACTOR = 0.5; - - public ShortestTTFTStrategy(EngineWorkerStatus engineWorkerStatus, - EngineHealthReporter engineHealthReporter, - CacheAwareService cacheAwareService, - ResourceMeasureFactory resourceMeasureFactory, - ConfigService configService) { - this.engineWorkerStatus = engineWorkerStatus; - this.engineHealthReporter = engineHealthReporter; - this.cacheAwareService = cacheAwareService; - this.resourceMeasureFactory = resourceMeasureFactory; - this.configService = configService; - LoadBalanceStrategyFactory.register(LoadBalanceStrategyEnum.SHORTEST_TTFT, this); - } - - /** - * Select optimal worker to execute task - * - * @param balanceContext Load balancing context - * @param roleType Worker role type - * @param group Worker group - * @return Selected server status - */ - @Override - public ServerStatus select(BalanceContext balanceContext, RoleType roleType, String group) { - try { - return doSelect(balanceContext, roleType, group); - } catch (Exception e) { - Logger.warn("Failed to select worker", e); - return ServerStatus.code(StrategyErrorType.NO_AVAILABLE_WORKER); - } - } - - /** - * Release local cached tasks on the specified worker - * - * @param ipPort Worker IP address - * @param requestId Request ID - */ - @Override - public void rollBack(String ipPort, long requestId) { - - Map workerStatusMap = engineWorkerStatus.selectModelWorkerStatus(RoleType.PREFILL, null); - Logger.debug("Prefill rollBack - ipPort: {}, requestId: {}", ipPort, requestId); - - WorkerStatus workerStatus = workerStatusMap.get(ipPort); - if (workerStatus != null) { - workerStatus.removeLocalTask(requestId); - } - } - - /** - * Core logic for worker selection - * - * @param balanceContext Load balancing context - * @param roleType Worker role type - * @param group Worker group - * @return Selected server status - */ - private ServerStatus doSelect(BalanceContext balanceContext, RoleType roleType, String group) { - long requestId = balanceContext.getRequestId(); - long seqLen = balanceContext.getRequest().getSeqLen(); - - Logger.debug("Starting shortest TTFT selection for role: {}", roleType); - - // Get available worker list - FlexlbConfig config = balanceContext.getConfig(); - List availableWorkers = getAvailableWorkers(roleType, group, config.getResourceMeasureIndicator(roleType)); - if (CollectionUtils.isEmpty(availableWorkers)) { - Logger.warn("No available workers for role: {}", roleType.getCode()); - return ServerStatus.code(StrategyErrorType.NO_AVAILABLE_WORKER); - } - - // Calculate cache match results for each engine - Map cacheMatchResults = getCacheMatchResults(balanceContext, roleType, group); - reportCandidateRoutingCacheMatchMetrics(roleType, availableWorkers, cacheMatchResults, balanceContext.getRequest()); - - List scoredWorkers = scoreWorkers(availableWorkers, cacheMatchResults, seqLen); - - StrategyConfigs.CandidatePoolConfig candidatePoolConfig = configService.getStrategyConfigs() - .getShortestTtft() - .getCandidatePool(); - ScoredWorker bestWorker = selectBestWorker(scoredWorkers, candidatePoolConfig); - if (bestWorker == null) { - Logger.warn("Failed to find best worker for role: {}", roleType); - return ServerStatus.code(StrategyErrorType.NO_AVAILABLE_WORKER); - } - - return finalizeWorkerSelection(bestWorker, balanceContext, roleType, requestId, seqLen, cacheMatchResults); - } - - /** - * Get available worker list - * - * @param roleType Worker role type - * @param group Worker group - * @param indicator ResourceMeasureIndicatorEnum - * @return Available worker list - */ - private List getAvailableWorkers(RoleType roleType, String group, ResourceMeasureIndicatorEnum indicator) { - - Map workerStatusMap = engineWorkerStatus.selectModelWorkerStatus(roleType, group); - if (MapUtils.isEmpty(workerStatusMap)) { - return new ArrayList<>(); - } - - ResourceMeasure resourceMeasure = resourceMeasureFactory.getMeasure(indicator); - if (resourceMeasure == null) { - Logger.warn("No ResourceMeasure registered for indicator: {}", indicator); - return new ArrayList<>(); - } - - return new ArrayList<>(workerStatusMap.values()).stream() - .filter(WorkerStatus::isAlive) - .filter(resourceMeasure::isResourceAvailable) - .toList(); - } - - /** - * Get cache match results - * - * @param balanceContext Load balancing context - * @param roleType Worker role type - * @param group Worker group - * @return Cache match results: key: engineIpPort, value: prefixMatchLength - */ - private Map getCacheMatchResults(BalanceContext balanceContext, - RoleType roleType, - String group) { - List blockCacheKeys = balanceContext.getRequest().getBlockCacheKeys(); - return cacheAwareService.findMatchingEngines(blockCacheKeys, roleType, group); - } - - /** - * Calculate TTFT scores for all active workers - * - * @param workers Worker list - * @param cacheMatchResults Cache match results - * @param seqLen Sequence length - * @return List of scored workers - */ - private List scoreWorkers(List workers, Map cacheMatchResults, long seqLen) { - return workers.stream() - .filter(WorkerStatus::isAlive) - .map(workerStatus -> { - long hitCacheTokens = calculatePrefixMatchLength(workerStatus, cacheMatchResults); - long prefillTime = TaskInfo.estimatePrefillTimeMs(seqLen, hitCacheTokens); - long queueTime = workerStatus.getRunningQueueTime().get(); - long newTTFT = prefillTime + queueTime; - long lastSelectedTime = workerStatus.getLastSelectedTime().get(); - Logger.debug("Calculate TTFT for worker - ip: {}, port: {}, hitCacheTokens: {}, prefillTime: {}, queueTime: {}, newTTFT: {}", - workerStatus.getIp(), - workerStatus.getPort(), - hitCacheTokens, - prefillTime, - queueTime, - newTTFT); - return new ScoredWorker(workerStatus, newTTFT, hitCacheTokens, lastSelectedTime); - }) - .collect(Collectors.toList()); - } - - /** - * Finalize worker selection and update status - * - * @param selectedWorker Selected worker - * @param balanceContext Load balancing context - * @param roleType Worker role type - * @param requestId Request ID - * @param seqLen Sequence length - * @return Server status - */ - private ServerStatus finalizeWorkerSelection(ScoredWorker selectedWorker, - BalanceContext balanceContext, - RoleType roleType, - long requestId, - long seqLen, - Map cacheMatchResults) { - WorkerStatus workerStatus = selectedWorker.worker(); - - logWorkerSelection(selectedWorker, roleType); - reportCacheHitMetrics(roleType, workerStatus.getIp(), selectedWorker.hitCacheTokens(), seqLen); - reportSelectedRoutingCacheMatchMetrics(roleType, workerStatus, cacheMatchResults, balanceContext.getRequest()); - - TaskInfo task = createTaskInfo(requestId, balanceContext.getRequest().getSeqLen(), selectedWorker.hitCacheTokens()); - workerStatus.putLocalTask(requestId, task); - - return buildServerStatus(selectedWorker, roleType, requestId); - } - - /** - * Log worker selection - * - * @param selectedWorker Selected worker - * @param roleType Worker role type - */ - private void logWorkerSelection(ScoredWorker selectedWorker, RoleType roleType) { - WorkerStatus workerStatus = selectedWorker.worker(); - Logger.debug("Selected {} worker - ip: {}, port: {}, hitCacheTokens: {}, ttft: {}", - roleType, - workerStatus.getIp(), - workerStatus.getPort(), - selectedWorker.hitCacheTokens(), - selectedWorker.ttft()); - } - - /** - * Report cache hit metrics - * - * @param roleType Worker role type - * @param ip Worker IP address - * @param hitCacheTokens Number of cached tokens hit - * @param seqLen Sequence length - */ - private void reportCacheHitMetrics(RoleType roleType, String ip, long hitCacheTokens, long seqLen) { - double hitRate = seqLen > 0 ? hitCacheTokens / (double) seqLen : 0.0; - engineHealthReporter.reportCacheHitMetrics(roleType, ip, hitCacheTokens, hitRate); - } - - private void reportCandidateRoutingCacheMatchMetrics(RoleType roleType, - List availableWorkers, - Map cacheMatchResults, - Request request) { - if (MapUtils.isEmpty(cacheMatchResults) || request == null || request.getSeqLen() <= 0L) { - return; - } - - Map workerStatusByIpPort = availableWorkers.stream() - .collect(Collectors.toMap(WorkerStatus::getIpPort, workerStatus -> workerStatus, (left, right) -> left)); - for (Map.Entry entry : cacheMatchResults.entrySet()) { - String engineIpPort = entry.getKey(); - long hitTokens = calculateRoutingCacheMatchTokens( - entry.getValue(), - request, - workerStatusByIpPort.get(engineIpPort)); - engineHealthReporter.reportRoutingCandidateCacheMatchMetrics( - roleType, - engineIp(engineIpPort), - hitTokens, - request.getSeqLen()); - } - } - - private void reportSelectedRoutingCacheMatchMetrics(RoleType roleType, - WorkerStatus selectedWorker, - Map cacheMatchResults, - Request request) { - if (selectedWorker == null || cacheMatchResults == null || request == null || request.getSeqLen() <= 0L) { - return; - } - - long hitTokens = calculateRoutingCacheMatchTokens( - cacheMatchResults.get(selectedWorker.getIpPort()), - request, - selectedWorker); - engineHealthReporter.reportRoutingSelectedCacheMatchMetrics( - roleType, - selectedWorker.getIp(), - hitTokens, - request.getSeqLen()); - } - - /** - * Create task information - * - * @param requestId Request ID - * @param inputLength Input length - * @param prefixLength Prefix length - * @return Task information - */ - private TaskInfo createTaskInfo(long requestId, long inputLength, long prefixLength) { - TaskInfo task = new TaskInfo(); - task.setRequestId(requestId); - task.setInputLength(inputLength); - task.setPrefixLength(prefixLength); - return task; - } - - /** - * Select best worker considering TTFT and scheduling fairness - * - *

Algorithm: 1. Sort workers by TTFT 2. Select strategy-configured candidates 3. Among candidates with similar TTFT, prioritize recently unscheduled workers - * - * @param scoredWorkers List of scored workers - * @param candidatePoolConfig candidate pool config - * @return Best worker - */ - private ScoredWorker selectBestWorker(List scoredWorkers, - StrategyConfigs.CandidatePoolConfig candidatePoolConfig) { - if (scoredWorkers.isEmpty()) { - return null; - } - - List sortedWorkers = sortByTTFT(scoredWorkers); - List candidates = selectTopCandidates(sortedWorkers, candidatePoolConfig); - Logger.debug("Select best worker, sortedWorkers size: {}, candidates size: {}", sortedWorkers.size(), candidates.size()); - - if (candidates.isEmpty()) { - return null; - } - - if (candidates.size() == 1) { - Logger.debug("Select best worker with single candidate shortcut, sortedWorkers size: {}", sortedWorkers.size()); - return candidates.getFirst(); - } - - long minTTFT = candidates.getFirst().ttft(); - double threshold = calculateTTFTThreshold(candidates, minTTFT); - - List similarWorkers = filterSimilarWorkers(candidates, minTTFT, threshold); - - return selectWorkerByScheduleFairness(similarWorkers, candidates); - } - - /** - * Sort workers by TTFT - * - * @param workers Worker list - * @return Sorted worker list in ascending order - */ - private List sortByTTFT(List workers) { - // Two-level sorting - // 1. Primary sort: by TTFT (Time-To-First-Token) in ascending order - // 2. Secondary sort: when TTFT is equal, by lastSelectedTime in ascending order - return workers.stream() - .sorted(Comparator.comparingLong(ScoredWorker::ttft) - .thenComparingLong(ScoredWorker::lastSelectedTime)) - .toList(); - } - - /** - * Select top N candidate workers - * - * @param sortedWorkers Sorted worker list - * @return Candidate worker list - */ - private List selectTopCandidates(List sortedWorkers, - StrategyConfigs.CandidatePoolConfig candidatePoolConfig) { - int candidateCount = candidatePoolConfig.resolveCandidateCount(sortedWorkers.size()); - return sortedWorkers.stream().limit(candidateCount).toList(); - } - - /** - * Calculate TTFT similarity threshold - * - * @param candidates Candidate worker list - * @return TTFT threshold - */ - private double calculateTTFTThreshold(List candidates, long minTTFT) { - double avgTTFT = candidates.stream().mapToLong(ScoredWorker::ttft).average().orElse(0.0); - - double stdDev = Math.sqrt( - candidates.stream() - .mapToLong(ScoredWorker::ttft) - .mapToDouble(v -> Math.pow(v - avgTTFT, 2)) - .average() - .orElse(0.0)); - double percentageMinTTFT = minTTFT * TTFT_THRESHOLD_PERCENTAGE; - double factoredStdDev = stdDev * STDDEV_THRESHOLD_FACTOR; - Logger.debug("Calculate TTFT threshold, minTTFT: {}, avgTTFT: {}, stdDev: {}, percentageMinTTFT: {}, factoredStdDev: {}", - minTTFT, avgTTFT, stdDev, percentageMinTTFT, factoredStdDev); - return Math.max(percentageMinTTFT, factoredStdDev); - } - - /** - * Filter workers with similar TTFT - * - * @param candidates Candidate worker list - * @param minTTFT Minimum TTFT value - * @param threshold Threshold - * @return List of workers with similar TTFT - */ - private List filterSimilarWorkers(List candidates, long minTTFT, double threshold) { - List scoredWorkers = candidates.stream() - .filter(worker -> Math.abs(worker.ttft() - minTTFT) <= threshold) - .toList(); - Logger.debug("Filter similar workers, minTTFT: {}, threshold: {}, candidates size: {}", minTTFT, threshold, scoredWorkers.size()); - return scoredWorkers; - } - - /** - * Select worker based on scheduling fairness. - * Among workers with similar TTFT, prefer the least recently scheduled one. - * CAS on lastSelectedTime ensures concurrent requests are spread across different workers - * rather than all landing on the same one. - * - * @param similarWorkers workers with similar TTFT - * @param fallbackCandidates fallback candidate list - * @return selected worker - */ - private ScoredWorker selectWorkerByScheduleFairness(List similarWorkers, List fallbackCandidates) { - if (similarWorkers.isEmpty()) { - return fallbackCandidates.getFirst(); - } - - // Sort ascending by lastSelectedTime so the least recently used worker is tried first - List sorted = similarWorkers.stream() - .sorted(Comparator.comparingLong(ScoredWorker::lastSelectedTime)) - .toList(); - - long now = System.nanoTime() / 1000; - for (ScoredWorker candidate : sorted) { - long expected = candidate.lastSelectedTime(); - // CAS: claim this worker only if lastSelectedTime hasn't changed since we read it. - // A failed CAS means another concurrent request already claimed this worker. - if (candidate.worker().getLastSelectedTime().compareAndSet(expected, now)) { - return candidate; - } - // Another request claimed this worker; try the next candidate - } - - // All candidates were claimed concurrently; fall back to the first candidate - return fallbackCandidates.getFirst(); - } - - /** - * Build server status response - * - * @param selectedWorker Selected worker - * @param roleType Worker role type - * @param requestId Request ID - * @return Server status - */ - private ServerStatus buildServerStatus(ScoredWorker selectedWorker, RoleType roleType, long requestId) { - WorkerStatus workerStatus = selectedWorker.worker(); - ServerStatus result = new ServerStatus(); - try { - result.setSuccess(true); - result.setRole(roleType); - result.setRequestId(requestId); - result.setPrefillTime(selectedWorker.ttft()); - result.setGroup(workerStatus.getGroup()); - result.setServerIp(workerStatus.getIp()); - result.setHttpPort(workerStatus.getPort()); - result.setGrpcPort(CommonUtils.toGrpcPort(workerStatus.getPort())); - } catch (Exception e) { - Logger.error("Failed to build server status for requestId: {}", requestId, e); - result.setCode(StrategyErrorType.NO_AVAILABLE_WORKER.getErrorCode()); - result.setMessage(StrategyErrorType.NO_AVAILABLE_WORKER.getErrorMsg()); - result.setSuccess(false); - } - return result; - } - - /** - * Calculate prefix match length (number of cached tokens hit) - * - * @param workerStatus Worker status - * @param cacheMatchResults Cache match results - * @return Number of tokens hit - */ - private long calculatePrefixMatchLength(WorkerStatus workerStatus, Map cacheMatchResults) { - if (workerStatus.getCacheStatus() == null || cacheMatchResults == null) { - return 0L; - } - - Integer prefixMatchLength = cacheMatchResults.get(workerStatus.getIpPort()); - if (prefixMatchLength == null) { - return 0L; - } - - long blockSize = workerStatus.getCacheStatus().getBlockSize(); - return blockSize * prefixMatchLength; - } - - private long calculateRoutingCacheMatchTokens(Integer prefixMatchLength, Request request, WorkerStatus workerStatus) { - if (prefixMatchLength == null || prefixMatchLength <= 0 || request == null || request.getSeqLen() <= 0L) { - return 0L; - } - - // Page-RR routes one canonical key per virtual block, so the frontend sends - // cache_key_block_size as seq_size_per_block * cp_size in that mode. - long blockSize = request.getCacheKeyBlockSize(); - if (blockSize <= 0L && workerStatus != null && workerStatus.getCacheStatus() != null) { - blockSize = workerStatus.getCacheStatus().getBlockSize(); - } - if (blockSize <= 0L) { - return 0L; - } - - long hitTokens = blockSize * prefixMatchLength; - if (hitTokens < 0L) { - return request.getSeqLen(); - } - return Math.min(request.getSeqLen(), hitTokens); - } - - private String engineIp(String engineIpPort) { - if (engineIpPort == null) { - return ""; - } - int delimiter = engineIpPort.indexOf(':'); - return delimiter < 0 ? engineIpPort : engineIpPort.substring(0, delimiter); - } -} diff --git a/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/strategy/WeightedCacheLoadBalancer.java b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/strategy/WeightedCacheLoadBalancer.java deleted file mode 100644 index 18b0ebfd44..0000000000 --- a/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/balance/strategy/WeightedCacheLoadBalancer.java +++ /dev/null @@ -1,239 +0,0 @@ -package org.flexlb.balance.strategy; - -import org.apache.commons.collections4.CollectionUtils; -import org.apache.commons.collections4.MapUtils; -import org.flexlb.balance.resource.ResourceMeasure; -import org.flexlb.balance.resource.ResourceMeasureFactory; -import org.flexlb.config.ConfigService; -import org.flexlb.config.FlexlbConfig; -import org.flexlb.dao.BalanceContext; -import org.flexlb.dao.loadbalance.Request; -import org.flexlb.dao.loadbalance.ServerStatus; -import org.flexlb.dao.loadbalance.StrategyErrorType; -import org.flexlb.dao.master.CacheStatus; -import org.flexlb.dao.master.TaskInfo; -import org.flexlb.dao.master.WorkerStatus; -import org.flexlb.dao.route.RoleType; -import org.flexlb.enums.LoadBalanceStrategyEnum; -import org.flexlb.enums.ResourceMeasureIndicatorEnum; -import org.flexlb.sync.status.EngineWorkerStatus; -import org.flexlb.util.CommonUtils; -import org.flexlb.util.Logger; -import org.springframework.stereotype.Component; - -import java.util.ArrayList; -import java.util.Comparator; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.concurrent.ThreadLocalRandom; - -/** - * @author saichen.sm - * description: Weighted random load balancing strategy based on normalized cache usage - * Performs weighted random selection by normalizing cache usage across all workers - * date: 2025/3/21 - */ -@Component("weightedCacheStrategy") -public class WeightedCacheLoadBalancer implements LoadBalancer { - - private final EngineWorkerStatus engineWorkerStatus; - private final double decayFactor; - private final ResourceMeasureFactory resourceMeasureFactory; - - public WeightedCacheLoadBalancer(ConfigService configService, - EngineWorkerStatus engineWorkerStatus, - ResourceMeasureFactory resourceMeasureFactory) { - this.engineWorkerStatus = engineWorkerStatus; - FlexlbConfig config = configService.loadBalanceConfig(); - this.decayFactor = config.getWeightedCacheDecayFactor(); - this.resourceMeasureFactory = resourceMeasureFactory; - LoadBalanceStrategyFactory.register(LoadBalanceStrategyEnum.WEIGHTED_CACHE, this); - } - - private record WeightedWorker(WorkerStatus worker, long normalizedCacheUsed, double weight) { - } - - @Override - public ServerStatus select(BalanceContext balanceContext, RoleType roleType, String group) { - Request request = balanceContext.getRequest(); - long seqLen = request.getSeqLen(); - Map workerStatusMap = engineWorkerStatus.selectModelWorkerStatus(roleType, group); - if (MapUtils.isEmpty(workerStatusMap)) { - Logger.warn("select ROLE: {} failed, workerStatusMap is empty", roleType.getCode()); - return ServerStatus.code(StrategyErrorType.NO_AVAILABLE_WORKER); - } - FlexlbConfig config = balanceContext.getConfig(); - ResourceMeasureIndicatorEnum indicator = config.getResourceMeasureIndicator(roleType); - ResourceMeasure resourceMeasure = resourceMeasureFactory.getMeasure(indicator); - if (resourceMeasure == null) { - Logger.warn("No ResourceMeasure registered for indicator: {}, roleType: {}", indicator, roleType); - return ServerStatus.code(StrategyErrorType.NO_AVAILABLE_WORKER); - } - List workerStatusList = new ArrayList<>(workerStatusMap.values()).stream() - .filter(WorkerStatus::isAlive) - .filter(resourceMeasure::isResourceAvailable) - .toList(); - if (CollectionUtils.isEmpty(workerStatusList)) { - Logger.warn("select ROLE: {} failed, workerStatusList is empty", roleType.getCode()); - return ServerStatus.code(StrategyErrorType.NO_AVAILABLE_WORKER); - } - - // Implement weighted random selection algorithm - WorkerStatus selectedWorker = weightedRandomSelection(workerStatusList); - - if (selectedWorker != null) { - long prefixLength = calcPrefixMatchLength(selectedWorker.getCacheStatus(), balanceContext.getRequest().getBlockCacheKeys()); - // Update local task state - return buildServerStatus(selectedWorker, seqLen, prefixLength, roleType, balanceContext.getRequestId()); - } - - // Return failure if no suitable worker found - Logger.warn("Failed to select worker, no suitable worker available"); - return ServerStatus.code(StrategyErrorType.NO_AVAILABLE_WORKER); - } - - /** - * Release local cached tasks on the specified worker - * - * @param ipPort Worker IP address - * @param requestId Request ID - */ - @Override - public void rollBack(String ipPort, long requestId) { - - Map workerStatusMap = engineWorkerStatus.selectModelWorkerStatus(RoleType.DECODE, null); - Logger.debug("Decode rollBack - ip: {}, requestId: {}", - ipPort, requestId); - - WorkerStatus workerStatus = workerStatusMap.get(ipPort); - if (workerStatus != null) { - workerStatus.removeLocalTask(requestId); - } - } - - private long calcPrefixMatchLength(CacheStatus cacheStatus, List promptCacheKeys) { - - if (cacheStatus == null || promptCacheKeys == null) { - return 0; - } - long blockSize = cacheStatus.getBlockSize(); - Set cachePrefixHash = cacheStatus.getCachedKeys(); - if (cachePrefixHash == null) { - return 0; - } - - // Iterate from beginning to find first mismatch position - for (int index = 0; index < promptCacheKeys.size(); index++) { - long hash = promptCacheKeys.get(index); - if (!cachePrefixHash.contains(hash)) { - // Return matching prefix length (matched block count * block size) - return blockSize * index; - } - } - - // Return total length if all match - return blockSize * promptCacheKeys.size(); - } - - /** - * Weighted random selection algorithm: performs weighted random selection based on normalized cache usage - * - * @param candidateWorkers Candidate worker list - * @return Selected WorkerStatus, or null if no suitable worker found - */ - private WorkerStatus weightedRandomSelection(List candidateWorkers) { - int workerCount = candidateWorkers.size(); - if (workerCount == 0) { - return null; - } - - // 1. Calculate sum and average of cacheUsed - long totalCacheUsed = 0; - for (WorkerStatus worker : candidateWorkers) { - totalCacheUsed += worker.getUsedKvCacheTokens().get(); - } - double avgCacheUsed = (double) totalCacheUsed / workerCount; - - // 2. Normalize cacheUsed and calculate weights - List weightedWorkers = new ArrayList<>(); - boolean allSameUsage = true; - double totalWeight = 0; - Long firstCacheUsed = null; - - for (WorkerStatus worker : candidateWorkers) { - long cacheUsed = worker.getUsedKvCacheTokens().get(); - double normalizedValue = cacheUsed - avgCacheUsed; - - if (firstCacheUsed == null) { - firstCacheUsed = cacheUsed; - } else if (cacheUsed != firstCacheUsed) { - allSameUsage = false; - } - - double weight = Math.exp(-decayFactor * normalizedValue); - - weightedWorkers.add(new WeightedWorker(worker, (long) normalizedValue, weight)); - totalWeight += weight; - } - - // Check if total weight is valid - if (totalWeight <= 0) { - Logger.warn("Total weight is zero or negative: {}, using uniform random selection", totalWeight); - int randomIndex = ThreadLocalRandom.current().nextInt(workerCount); - return candidateWorkers.get(randomIndex); - } - - // If all workers have same cache usage, use uniform random - if (allSameUsage) { - int randomIndex = ThreadLocalRandom.current().nextInt(workerCount); - return candidateWorkers.get(randomIndex); - } - - // 3. Perform weighted random selection using roulette wheel algorithm - double randomValue = ThreadLocalRandom.current().nextDouble() * totalWeight; - double cumulativeWeight = 0; - - for (WeightedWorker weightedWorker : weightedWorkers) { - cumulativeWeight += weightedWorker.weight; - if (Double.compare(randomValue, cumulativeWeight) <= 0) { - return weightedWorker.worker; - } - } - - // Fallback: select worker with minimum cacheUsed - return weightedWorkers.stream() - .min(Comparator.comparingLong(w -> w.worker.getUsedKvCacheTokens().get())) - .map(w -> w.worker) - .orElse(null); - } - - private ServerStatus buildServerStatus(WorkerStatus optimalWorker, long seqLen, long prefixLength, RoleType roleType, long requestId) { - ServerStatus result = new ServerStatus(); - try { - TaskInfo taskInfo = new TaskInfo(); - taskInfo.setPrefillTime(0); - taskInfo.setWaitingTime(0); - taskInfo.setInputLength(seqLen); - taskInfo.setPrefixLength(prefixLength); - taskInfo.setRequestId(requestId); - - // Update local task state - optimalWorker.putLocalTask(requestId, taskInfo); - - result.setSuccess(true); - result.setRole(roleType); - result.setServerIp(optimalWorker.getIp()); - result.setHttpPort(optimalWorker.getPort()); - result.setGrpcPort(CommonUtils.toGrpcPort(optimalWorker.getPort())); - result.setGroup(optimalWorker.getGroup()); - result.setRequestId(requestId); - } catch (Exception e) { - Logger.error("buildServerStatus error", e); - result.setSuccess(false); - result.setCode(StrategyErrorType.NO_AVAILABLE_WORKER.getErrorCode()); - result.setMessage(StrategyErrorType.NO_AVAILABLE_WORKER.getErrorMsg()); - } - return result; - } -} diff --git a/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/consistency/LBStatusConsistencyService.java b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/consistency/LBStatusConsistencyService.java index 381cab011b..a891a3bda9 100644 --- a/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/consistency/LBStatusConsistencyService.java +++ b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/consistency/LBStatusConsistencyService.java @@ -9,8 +9,11 @@ import org.flexlb.domain.consistency.SyncLBStatusResp; import org.flexlb.util.JsonUtils; import org.flexlb.util.Logger; +import org.springframework.core.env.Environment; import org.springframework.stereotype.Component; +import javax.annotation.PreDestroy; + import java.net.InetAddress; import java.net.UnknownHostException; import java.util.concurrent.ScheduledExecutorService; @@ -30,12 +33,15 @@ public class LBStatusConsistencyService implements MasterElectService { ); private final ZookeeperMasterElectService zookeeperMasterElectService; + private final Environment environment; private LBConsistencyConfig lbConsistencyConfig; private String serverPort; private String roleId; - public LBStatusConsistencyService(ZookeeperMasterElectService zookeeperMasterElectService) { + public LBStatusConsistencyService(ZookeeperMasterElectService zookeeperMasterElectService, + Environment environment) { this.zookeeperMasterElectService = zookeeperMasterElectService; + this.environment = environment; this.init(); } @@ -47,7 +53,12 @@ public void init() { } catch (UnknownHostException e) { throw new RuntimeException(e); } - serverPort = System.getProperty("server.port", "7001"); + // Read from Spring Environment to respect --server.port= CLI args; + // fall back to JVM system property. + serverPort = environment.getProperty("server.port"); + if (serverPort == null) { + serverPort = System.getProperty("server.port", "7001"); + } log.info("hostIp:{}, serverPort:{}.", hostIp, serverPort); roleId = System.getenv("HIPPO_ROLE"); if (StringUtils.isBlank(roleId)) { @@ -161,4 +172,18 @@ public SyncLBStatusResp dumpLBStatus() { private void syncLBStatusFromMaster() { // TODO Get master status } + + @PreDestroy + public void shutdown() { + log.info("Shutting down LBStatusConsistencyService executor."); + SCHEDULED_EXECUTOR_SERVICE.shutdown(); + try { + if (!SCHEDULED_EXECUTOR_SERVICE.awaitTermination(5, TimeUnit.SECONDS)) { + SCHEDULED_EXECUTOR_SERVICE.shutdownNow(); + } + } catch (InterruptedException e) { + SCHEDULED_EXECUTOR_SERVICE.shutdownNow(); + Thread.currentThread().interrupt(); + } + } } diff --git a/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/consistency/ZookeeperMasterElectService.java b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/consistency/ZookeeperMasterElectService.java index 82512c5f4d..f69f196e1d 100644 --- a/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/consistency/ZookeeperMasterElectService.java +++ b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/consistency/ZookeeperMasterElectService.java @@ -23,6 +23,7 @@ import org.flexlb.util.JsonUtils; import org.flexlb.util.Logger; import org.slf4j.LoggerFactory; +import org.springframework.core.env.Environment; import org.springframework.scheduling.annotation.Scheduled; import org.springframework.stereotype.Component; import reactor.core.publisher.Mono; @@ -50,6 +51,7 @@ public class ZookeeperMasterElectService implements LeaderSelectorListener { private LBConsistencyConfig lbConsistencyConfig; private final GeneralHttpNettyService generalHttpNettyService; private final EngineHealthReporter engineHealthReporter; + private final Environment environment; @Setter private String roleId; @Setter @@ -69,12 +71,14 @@ public class ZookeeperMasterElectService implements LeaderSelectorListener { private final AtomicReference leaderCloseLatchRef = new AtomicReference<>(); public ZookeeperMasterElectService(GeneralHttpNettyService generalHttpNettyService, - EngineHealthReporter engineHealthReporter) { + EngineHealthReporter engineHealthReporter, + Environment environment) { Logger.warn("Initializing ZookeeperMasterElectService..."); this.generalHttpNettyService = generalHttpNettyService; this.engineHealthReporter = engineHealthReporter; + this.environment = environment; init(); } @@ -105,7 +109,13 @@ private void initializeIpAndPort() { } catch (UnknownHostException e) { throw new RuntimeException("Failed to retrieve local host address", e); } - port = Integer.parseInt(System.getProperty("server.port", "7001")); + // Read from Spring Environment to respect --server.port= CLI args; + // fall back to JVM system property. + String portStr = environment.getProperty("server.port"); + if (portStr == null) { + portStr = System.getProperty("server.port", "7001"); + } + port = Integer.parseInt(portStr); } private void initializeLBConsistencyConfig() { @@ -203,7 +213,8 @@ private void waitForLeadershipTransfer() { LOGGER.warn("ZKMasterElector roleId:{} currentHost:{} waiting for leadership transfer to complete.", roleId, localIp); int waitCount = 0; - while (true) { + final int MAX_WAIT_COUNT = 30; // 30 seconds max + while (waitCount < MAX_WAIT_COUNT) { try { if (!isStillMaster()) { LOGGER.warn("ZKMasterElector roleId:{} currentHost:{} leadership transferred to {}, waitCount: {}.", @@ -231,6 +242,8 @@ private void waitForLeadershipTransfer() { } } } + LOGGER.warn("ZKMasterElector roleId:{} currentHost:{} leadership transfer timeout after {} seconds, forcing exit.", + roleId, localIp, MAX_WAIT_COUNT); } private boolean isStillMaster() { @@ -358,7 +371,7 @@ private void activelyNotifyParticipants() { Collection participants = leaderSelector.getParticipants(); for (Participant participant : participants) { // Only notify non-master participants - if (!participant.isLeader() && localIp.equals(participant.getId())) { + if (!participant.isLeader() && !localIp.equals(participant.getId())) { notifyParticipant(participant.getId()); } } diff --git a/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/domain/worker/ScoredWorker.java b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/domain/worker/ScoredWorker.java deleted file mode 100644 index 8e03ceebe7..0000000000 --- a/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/domain/worker/ScoredWorker.java +++ /dev/null @@ -1,5 +0,0 @@ -package org.flexlb.domain.worker; - -import org.flexlb.dao.master.WorkerStatus; - -public record ScoredWorker(WorkerStatus worker, long ttft, long hitCacheTokens, long lastSelectedTime) {} diff --git a/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/service/RouteService.java b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/service/RouteService.java index 1855736b8d..7071ca3c31 100644 --- a/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/service/RouteService.java +++ b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/service/RouteService.java @@ -4,12 +4,17 @@ import java.util.concurrent.CompletableFuture; import org.flexlb.balance.scheduler.DefaultRouter; +import org.flexlb.balance.scheduler.FlexlbBatchScheduler; import org.flexlb.balance.scheduler.QueueManager; import org.flexlb.balance.scheduler.Router; import org.flexlb.config.ConfigService; import org.flexlb.config.FlexlbConfig; import org.flexlb.dao.BalanceContext; +import org.flexlb.dao.loadbalance.Request; import org.flexlb.dao.loadbalance.Response; +import org.flexlb.enums.ScheduleModeEnum; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Lazy; import org.springframework.stereotype.Component; import reactor.core.publisher.Mono; @@ -19,15 +24,18 @@ public class RouteService { private final ConfigService configService; private final Router router; private final QueueManager queueManager; + private final FlexlbBatchScheduler flexlbBatchScheduler; private final RecentCacheKeyTraceReporter recentCacheKeyTraceReporter; public RouteService(ConfigService configService, DefaultRouter defaultScheduler, QueueManager queueManager, + @Lazy @Autowired(required = false) FlexlbBatchScheduler flexlbBatchScheduler, RecentCacheKeyTraceReporter recentCacheKeyTraceReporter) { this.configService = configService; this.router = defaultScheduler; this.queueManager = queueManager; + this.flexlbBatchScheduler = flexlbBatchScheduler; this.recentCacheKeyTraceReporter = recentCacheKeyTraceReporter; } @@ -41,7 +49,11 @@ public Mono route(BalanceContext balanceContext) { balanceContext.setConfig(flexlbConfig); Mono resultMono; - if (flexlbConfig.isEnableQueueing()) { + if (shouldUseFlexlbBatch(balanceContext, flexlbConfig)) { + CompletableFuture future = flexlbBatchScheduler.submit(balanceContext); + balanceContext.setFuture(future); + resultMono = Mono.fromFuture(future); + } else if (flexlbConfig.isEnableQueueing()) { resultMono = queueManager.tryRouteAsync(balanceContext); // Use async queuing mechanism } else { resultMono = Mono.fromCallable(() -> router.route(balanceContext)); // Direct routing without queuing @@ -61,14 +73,47 @@ public Mono route(BalanceContext balanceContext) { */ public void cancel(BalanceContext balanceContext) { FlexlbConfig flexlbConfig = configService.loadBalanceConfig(); + balanceContext.cancel(); if (flexlbConfig.isEnableQueueing()) { - balanceContext.cancel(); CompletableFuture future = balanceContext.getFuture(); if (future != null) { future.completeExceptionally(new CancellationException("Request cancelled by client")); } } + if (flexlbBatchScheduler != null && balanceContext.getRequest() != null) { + flexlbBatchScheduler.cancel(balanceContext.getRequest().getRequestId()); + } balanceContext.setSuccess(false); balanceContext.setErrorMessage("request cancelled"); } + + public void cancelByRequestId(long requestId) { + if (flexlbBatchScheduler != null) { + flexlbBatchScheduler.cancel(requestId); + } + } + + boolean shouldUseFlexlbBatch(BalanceContext ctx, FlexlbConfig config) { + if (flexlbBatchScheduler == null || config == null) { + return false; + } + ScheduleModeEnum mode = ctx.getScheduleMode(); + if (mode == ScheduleModeEnum.BATCH) { + return true; + } + if (mode == ScheduleModeEnum.DIRECT) { + return false; + } + // AUTO: use batch when config enables it and request characteristics match + if (!config.isFlexlbBatchEnabled()) { + return false; + } + Request request = ctx.getRequest(); + return request != null + && request.getMaxNewTokens() > 1 + && request.getNumBeams() <= 1 + && !request.isForceDisableSpRun() + && ctx.getGenerateInputPbBytes() != null + && ctx.getGenerateInputPbBytes().length > 0; + } } diff --git a/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/service/grpc/EngineGrpcService.java b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/service/grpc/EngineGrpcService.java index aebd24229d..59fe3f3f4b 100644 --- a/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/service/grpc/EngineGrpcService.java +++ b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/service/grpc/EngineGrpcService.java @@ -86,7 +86,7 @@ public EngineRpcService.CacheStatusPB getCacheStatus( throw new RuntimeException("EngineGrpcService not initialized"); } // Only need cacheKeys for Prefill nodes in PD-separated mode and non PD-separated mode - boolean needCacheKeys = RoleType.PREFILL.matches(workerStatus.getRole()) || RoleType.PDFUSION.matches(workerStatus.getRole()); + boolean needCacheKeys = workerStatus.getRole() == RoleType.PREFILL || workerStatus.getRole() == RoleType.PDFUSION; EngineRpcService.CacheVersionPB request = EngineRpcService.CacheVersionPB.newBuilder() .setLatestCacheVersion((int) cacheVersion) .setNeedCacheKeys(needCacheKeys) @@ -111,7 +111,7 @@ public EngineRpcService.CacheStatusPB getCacheStatus( if (engineGrpcClient == null) { throw new RuntimeException("EngineGrpcService not initialized"); } - boolean needCacheKeys = RoleType.PREFILL.matches(workerStatus.getRole()) || RoleType.PDFUSION.matches(workerStatus.getRole()); + boolean needCacheKeys = workerStatus.getRole() == RoleType.PREFILL || workerStatus.getRole() == RoleType.PDFUSION; EngineRpcService.CacheVersionPB request = EngineRpcService.CacheVersionPB.newBuilder() .setLatestCacheVersion((int) cacheVersion) .setNeedCacheKeys(needCacheKeys) diff --git a/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/service/grpc/EngineStatusConverter.java b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/service/grpc/EngineStatusConverter.java index 342d6c4772..e258a210cb 100644 --- a/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/service/grpc/EngineStatusConverter.java +++ b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/service/grpc/EngineStatusConverter.java @@ -2,8 +2,10 @@ import org.flexlb.dao.master.CacheStatus; import org.flexlb.dao.master.TaskInfo; -import org.flexlb.domain.worker.WorkerStatusResponse; +import org.flexlb.dao.master.WorkerStatusResponse; import org.flexlb.engine.grpc.EngineRpcService; +import org.flexlb.engine.grpc.RoleTypeProtoConverter; +import org.flexlb.enums.TaskPhase; import java.util.HashMap; import java.util.List; @@ -21,8 +23,8 @@ public class EngineStatusConverter { public static WorkerStatusResponse convertToWorkerStatusResponse(EngineRpcService.WorkerStatusPB workerStatusPB) { WorkerStatusResponse response = new WorkerStatusResponse(); - // Set role directly as string - response.setRole(workerStatusPB.getRole()); + // Convert proto enum to RoleType + response.setRole(RoleTypeProtoConverter.fromProto(workerStatusPB.getRole())); response.setAvailableConcurrency(workerStatusPB.getAvailableConcurrency()); response.setRunningQueryLen(workerStatusPB.getRunningQueryLen()); response.setWaitingQueryLen(workerStatusPB.getWaitingQueryLen()); @@ -30,19 +32,14 @@ public static WorkerStatusResponse convertToWorkerStatusResponse(EngineRpcServic response.setIterateCount(workerStatusPB.getIterateCount()); response.setDpSize(workerStatusPB.getDpSize()); response.setTpSize(workerStatusPB.getTpSize()); + response.setDpRank(workerStatusPB.getDpRank()); response.setStatusVersion(workerStatusPB.getStatusVersion()); response.setLatestFinishedVersion(workerStatusPB.getLatestFinishedVersion()); response.setAlive(workerStatusPB.getAlive()); + response.setAvailableKvCacheTokens(workerStatusPB.getAvailableKvCache()); + response.setTotalKvCacheTokens(workerStatusPB.getTotalKvCache()); - List srcRunningTaskInfoList = workerStatusPB.getRunningTaskInfoList(); - List waitingTaskInfoList = srcRunningTaskInfoList.stream().filter(taskInfoPB -> taskInfoPB.getIsWaiting()).toList(); - List runningTaskInfoList = srcRunningTaskInfoList.stream().filter(taskInfoPB -> !taskInfoPB.getIsWaiting()).toList(); - - // Convert waiting task info - response.setWaitingTaskInfo(convertToTaskInfoList(waitingTaskInfoList)); - - // Convert running task info - response.setRunningTaskInfo(convertToTaskInfoList(runningTaskInfoList)); + response.setRunningTaskInfo(convertToTaskInfoList(workerStatusPB.getRunningTaskInfoList())); // Convert finished task list response.setFinishedTaskInfo(convertToTaskInfoList(workerStatusPB.getFinishedTaskListList())); @@ -85,10 +82,29 @@ private static Map convertToTaskInfoList(ListReuses existing dashboard metric keys so batch-path data appears + * on the same Grafana panels as the non-batch path: + * queue (routing.queue.length + routing.queue.wait.time.ms), + * dispatch reason (engine.balancing.master.select.detail), + * inflight (health.check.local.task.map.size + health.check.running.task.info.size). + */ +@Slf4j +@Component +public class BatchSchedulerReporter { + + private final FlexMonitor monitor; + + @Autowired + public BatchSchedulerReporter(FlexMonitor monitor) { + this.monitor = monitor; + } + + @PostConstruct + public void init() { + // Queue — same type as RoutingQueueReporter + monitor.register(ROUTING_QUEUE_LENGTH, FlexMetricType.GAUGE, FlexPriorityType.PRECISE); + monitor.register(ROUTING_QUEUE_WAIT_TIME_MS, FlexMetricType.GAUGE, FlexPriorityType.PRECISE); + + // Dispatch reason — same type as EngineHealthReporter + monitor.register(ENGINE_BALANCING_MASTER_SELECT_DETAIL, FlexMetricType.QPS, FlexPriorityType.PRECISE); + + // Inflight — same type as EngineHealthReporter + monitor.register(ENGINE_LOCAL_TASK_MAP_SIZE, FlexMetricType.GAUGE, FlexPriorityType.PRECISE); + monitor.register(ENGINE_RUNNING_TASK_INFO_SIZE, FlexMetricType.GAUGE, FlexPriorityType.PRECISE); + + log.info("BatchSchedulerReporter initialized (5 metrics reusing existing dashboard keys)"); + } + + // ==================== Queue metrics ==================== + + /** + * Report per-worker batcher queue depth via {@code routing.queue.length}. + */ + public void reportBatcherQueueDepth(String role, String engineIp, int depth) { + FlexMetricTags tags = FlexMetricTags.of( + "type", "batchQueue", + "role", role, + "engineIp", engineIp); + monitor.report(ROUTING_QUEUE_LENGTH, tags, depth); + } + + /** + * Report batch wait time (enqueue to dispatch) via {@code routing.queue.wait.time.ms}. + */ + public void reportBatchWaitTimeMs(String role, String engineIp, long waitMs) { + FlexMetricTags tags = FlexMetricTags.of( + "role", role, + "engineIp", engineIp); + monitor.report(ROUTING_QUEUE_WAIT_TIME_MS, tags, waitMs); + } + + // ==================== Dispatch reason metrics ==================== + + /** + * Report batch dispatch reason via {@code engine.balancing.master.select.detail}. + */ + public void reportDispatchReason(String role, String engineIp, String reason) { + FlexMetricTags tags = FlexMetricTags.of( + "role", role, + "engineIp", engineIp, + "reason", reason); + monitor.report(ENGINE_BALANCING_MASTER_SELECT_DETAIL, tags, 1.0); + } + + // ==================== Inflight metrics ==================== + + /** + * Report scheduler inflight map size via {@code health.check.local.task.map.size}. + * Uses role=prefill + engineIp=scheduler tags to match the Grafana panel filter. + */ + public void reportSchedulerInflightSize(int size) { + FlexMetricTags tags = FlexMetricTags.of( + "role", "prefill", + "engineIp", "scheduler"); + monitor.report(ENGINE_LOCAL_TASK_MAP_SIZE, tags, size); + } + + /** + * Report per-worker prefilled endpoint inflight batch count via {@code health.check.running.task.info.size}. + */ + public void reportPrefillInflightBatchCount(String role, String engineIp, int count) { + FlexMetricTags tags = FlexMetricTags.of( + "role", role, + "engineIp", engineIp); + monitor.report(ENGINE_RUNNING_TASK_INFO_SIZE, tags, count); + } +} diff --git a/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/service/monitor/EngineHealthReporter.java b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/service/monitor/EngineHealthReporter.java index 5a14dd730f..0506383df8 100644 --- a/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/service/monitor/EngineHealthReporter.java +++ b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/service/monitor/EngineHealthReporter.java @@ -5,6 +5,7 @@ import io.netty.util.concurrent.SingleThreadEventExecutor; import lombok.Data; import org.apache.commons.collections4.CollectionUtils; +import org.flexlb.balance.endpoint.WorkerEndpoint; import org.flexlb.cache.monitor.CacheMetricsReporter; import org.flexlb.constant.ZkMasterEvent; import org.flexlb.dao.BalanceContext; @@ -219,6 +220,7 @@ public void reportCacheStatusCheckerFail(String modelName, String engineIp, Bala public void reportStatusCheckerSuccess(String modelName, WorkerStatus workerStatus, + WorkerEndpoint ep, int runningTaskInfoSize, int finishedTaskListSize) { @@ -226,7 +228,7 @@ public void reportStatusCheckerSuccess(String modelName, "model", modelName, "code", "0", "engineIp", workerStatus.getIp(), - "role", workerStatus.getRole()); + "role", workerStatus.getRole().name()); Long availableConcurrency = workerStatus.getAvailableConcurrency(); if (availableConcurrency != null) { @@ -236,15 +238,14 @@ public void reportStatusCheckerSuccess(String modelName, if (lastUpdateTime > 0) { monitor.report(ENGINE_STATUS_CHECK_SUCCESS_PERIOD, metricTags, (double) System.nanoTime() / 1000 - lastUpdateTime); } - monitor.report(ENGINE_RUNNING_QUEUE_TIME, metricTags, workerStatus.getRunningQueueTime().get()); + monitor.report(ENGINE_RUNNING_QUEUE_TIME, metricTags, ep != null ? ep.getLoadMetric() : 0); // Report local task cache size - int localTaskMapSize = workerStatus.getLocalTaskMap() != null ? workerStatus.getLocalTaskMap().size() : 0; - monitor.report(ENGINE_LOCAL_TASK_MAP_SIZE, metricTags, localTaskMapSize); + monitor.report(ENGINE_LOCAL_TASK_MAP_SIZE, metricTags, ep != null ? ep.getLocalTaskCount() : 0); metricTags = FlexMetricTags.of( "engineIp", workerStatus.getIp(), - "role", workerStatus.getRole()); + "role", workerStatus.getRole().name()); monitor.report(ENGINE_FINISHED_TASK_LIST_SIZE, metricTags, finishedTaskListSize); monitor.report(ENGINE_RUNNING_TASK_INFO_SIZE, metricTags, runningTaskInfoSize); @@ -257,7 +258,7 @@ public void reportCacheStatusCheckerSuccess(String modelName, WorkerStatus worke "model", modelName, "code", "0", "engineIp", workerStatus.getIp(), - "role", workerStatus.getRole()); + "role", workerStatus.getRole().name()); monitor.report(CACHE_STATUS_CHECK_SUCCESS_PERIOD, metricTags, (double) System.nanoTime() / 1000 - cacheLastUpdateTime); } if (workerStatus.getCacheStatus() != null) { @@ -266,19 +267,19 @@ public void reportCacheStatusCheckerSuccess(String modelName, WorkerStatus worke FlexMetricTags metricTags = FlexMetricTags.of( "model", modelName, "engineIp", workerStatus.getIp(), - "role", workerStatus.getRole()); + "role", workerStatus.getRole().name()); monitor.report(CACHE_BLOCK_SIZE, metricTags, blockSize); monitor.report(CACHE_KEY_SIZE, metricTags, cacheKeySize); } - long usedKvCacheTokens = workerStatus.getUsedKvCacheTokens().get(); + long totalKvCacheTokens = workerStatus.getTotalKvCacheTokens().get(); long availableKvCacheTokens = workerStatus.getAvailableKvCacheTokens().get(); - long totalKvCacheTokens = usedKvCacheTokens + availableKvCacheTokens; + long usedKvCacheTokens = totalKvCacheTokens - availableKvCacheTokens; FlexMetricTags kvCacheMetricTags = FlexMetricTags.of( "model", modelName, "engineIp", workerStatus.getIp(), - "role", workerStatus.getRole()); + "role", workerStatus.getRole().name()); monitor.report(CACHE_USED_KV_CACHE_TOKENS, kvCacheMetricTags, usedKvCacheTokens); monitor.report(CACHE_AVAILABLE_KV_CACHE_TOKENS, kvCacheMetricTags, availableKvCacheTokens); diff --git a/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/sync/runner/EngineSyncRunner.java b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/sync/runner/EngineSyncRunner.java index d8b2c52825..5eefa63a5f 100644 --- a/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/sync/runner/EngineSyncRunner.java +++ b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/sync/runner/EngineSyncRunner.java @@ -1,10 +1,14 @@ package org.flexlb.sync.runner; +import org.flexlb.balance.endpoint.EndpointRegistry; +import org.flexlb.balance.endpoint.WorkerEndpoint; +import org.flexlb.balance.scheduler.FlexlbBatchScheduler; import org.flexlb.cache.service.CacheAwareService; import org.flexlb.dao.master.WorkerHost; import org.flexlb.dao.master.WorkerStatus; import org.flexlb.dao.route.RoleType; import org.flexlb.enums.BalanceStatusEnum; +import org.flexlb.util.CommonUtils; import org.flexlb.service.address.WorkerAddressService; import org.flexlb.service.grpc.EngineGrpcService; import org.flexlb.service.monitor.EngineHealthReporter; @@ -45,6 +49,10 @@ public class EngineSyncRunner implements Runnable { private final Long syncEngineStatusInterval; + private final FlexlbBatchScheduler batchScheduler; + + private final EndpointRegistry endpointRegistry; + public EngineSyncRunner(String modelName, Map workerStatusMap, WorkerAddressService workerAddressService, @@ -55,7 +63,9 @@ public EngineSyncRunner(String modelName, CacheAwareService localKvCacheAwareManager, long syncRequestTimeoutMs, LongAdder syncCount, - Long syncEngineStatusInterval) { + Long syncEngineStatusInterval, + FlexlbBatchScheduler batchScheduler, + EndpointRegistry endpointRegistry) { this.modelName = modelName; this.workerAddressService = workerAddressService; @@ -68,6 +78,8 @@ public EngineSyncRunner(String modelName, this.syncRequestTimeoutMs = syncRequestTimeoutMs; this.syncCount = syncCount; this.syncEngineStatusInterval = syncEngineStatusInterval; + this.batchScheduler = batchScheduler; + this.endpointRegistry = endpointRegistry; } @Override @@ -127,7 +139,7 @@ public void run() { GrpcWorkerStatusRunner grpcWorkerStatusRunner = new GrpcWorkerStatusRunner(modelName, workerIpPort, site, roleType, host.getGroup(), workerStatus, engineHealthReporter, engineGrpcService, - syncRequestTimeoutMs); + syncRequestTimeoutMs, batchScheduler, endpointRegistry); statusCheckExecutor.submit(grpcWorkerStatusRunner); } else { logger.info("Skip status check for worker: {}, previous request in progress", workerIpPort); @@ -157,25 +169,29 @@ public void run() { if (size >= 2) { double sumStepLatency = 0.0; - double sumRunningQueryTime = 0.0; - for (WorkerStatus workerStatus : workerStatusMap.values()) { + double sumRunningLoad = 0.0; + for (Map.Entry entry : workerStatusMap.entrySet()) { + WorkerStatus workerStatus = entry.getValue(); sumStepLatency += workerStatus.getStepLatencyMs(); - sumRunningQueryTime += workerStatus.getRunningQueueTime().get(); + WorkerEndpoint ep = endpointRegistry != null ? endpointRegistry.get(entry.getKey()) : null; + sumRunningLoad += ep != null ? ep.getLoadMetric() : 0; } double meanStepLatency = sumStepLatency / size; - double meanRunningQueryLen = sumRunningQueryTime / size; + double meanRunningLoad = sumRunningLoad / size; // Calculate variance (sample variance using Bessel correction) double sumStepLatencyOfSquaredDiffs = 0.0; - double sumRunningQueryLenOfSquaredDiffs = 0.0; - for (WorkerStatus workerStatus : workerStatusMap.values()) { + double sumRunningLoadOfSquaredDiffs = 0.0; + for (Map.Entry entry : workerStatusMap.entrySet()) { + WorkerStatus workerStatus = entry.getValue(); double diff = workerStatus.getStepLatencyMs() - meanStepLatency; - double diff2 = workerStatus.getRunningQueueTime().get() - meanRunningQueryLen; + WorkerEndpoint ep = endpointRegistry != null ? endpointRegistry.get(entry.getKey()) : null; + double diff2 = (ep != null ? ep.getLoadMetric() : 0) - meanRunningLoad; sumStepLatencyOfSquaredDiffs += diff * diff; - sumRunningQueryLenOfSquaredDiffs += diff2 * diff2; + sumRunningLoadOfSquaredDiffs += diff2 * diff2; } double variance = sumStepLatencyOfSquaredDiffs / (size - 1); // Sample variance - double variance2 = sumRunningQueryLenOfSquaredDiffs / (size - 1); + double variance2 = sumRunningLoadOfSquaredDiffs / (size - 1); engineHealthReporter.reportLatencyMetric(modelName, this.roleType.toString(), variance, variance2); logger.info("EngineSyncRunner finished for model: {}, role: {}", modelName, roleType); @@ -195,6 +211,32 @@ private WorkerStatus getOrCreateWorkerStatus(Map workerSta workerStatuses.put(workerIpPort, workerStatus); logger.info("Created new WorkerStatus for worker: {}", workerIpPort); } + if (endpointRegistry != null) { + ensureEndpoint(workerIpPort, workerStatus); + } return workerStatus; } + + private void ensureEndpoint(String ipPort, WorkerStatus workerStatus) { + String ip = workerStatus.getIp(); + int httpPort = workerStatus.getPort(); + int grpcPort = CommonUtils.toGrpcPort(httpPort); + workerStatus.setGrpcPort(grpcPort); + + if (roleType == RoleType.PREFILL) { + long dpSize = workerStatus.getDpSize(); + if (dpSize > 1) { + String message = String.format( + "Prefill DP group endpoint not yet supported: model=%s, ipPort=%s, dp_size=%d", + modelName, ipPort, dpSize); + logger.error(message); + throw new UnsupportedOperationException(message); + } + endpointRegistry.ensurePrefillEndpoint(ipPort, workerStatus); + } else if (roleType == RoleType.DECODE) { + endpointRegistry.ensureDecodeEndpoint(ipPort, workerStatus); + } else { + throw new IllegalArgumentException("Unsupported role type: " + roleType); + } + } } diff --git a/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/sync/runner/GrpcCacheStatusCheckRunner.java b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/sync/runner/GrpcCacheStatusCheckRunner.java index 7aba169615..8fb08e3f40 100644 --- a/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/sync/runner/GrpcCacheStatusCheckRunner.java +++ b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/sync/runner/GrpcCacheStatusCheckRunner.java @@ -131,14 +131,6 @@ private void handleCacheStatusResponse(CacheStatus newCacheStatus, long startTim engineHealthReporter.reportCacheStatusCheckRemoteInfo(modelName, ipPort, roleType.name(), startTime); - // Latest available KvCache tokens - long latestAvailableKvCacheTokens = newCacheStatus.getAvailableKvCache(); - // Latest used KvCache tokens - long latestUsedKvCacheTokens = newCacheStatus.getTotalKvCache() - latestAvailableKvCacheTokens; - - // Update KvCache tokens - workerStatus.updateKvCacheTokens(latestUsedKvCacheTokens, latestAvailableKvCacheTokens); - if (validateCacheStatusResponse(workerStatus, newCacheStatus)) { workerStatus.setCacheStatus(newCacheStatus); diff --git a/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/sync/runner/GrpcWorkerStatusRunner.java b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/sync/runner/GrpcWorkerStatusRunner.java index a6c815eb4e..b45c567134 100644 --- a/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/sync/runner/GrpcWorkerStatusRunner.java +++ b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/sync/runner/GrpcWorkerStatusRunner.java @@ -1,9 +1,12 @@ package org.flexlb.sync.runner; -import org.flexlb.dao.master.TaskInfo; +import org.flexlb.balance.endpoint.EndpointRegistry; +import org.flexlb.balance.endpoint.WorkerEndpoint; +import org.flexlb.balance.scheduler.FlexlbBatchScheduler; import org.flexlb.dao.master.WorkerStatus; import org.flexlb.dao.route.RoleType; -import org.flexlb.domain.worker.WorkerStatusResponse; +import org.flexlb.dao.master.WorkerStatusResponse; +import org.flexlb.dao.master.TaskInfo; import org.flexlb.engine.grpc.EngineRpcService; import org.flexlb.enums.BalanceStatusEnum; import org.flexlb.service.grpc.EngineGrpcService; @@ -16,7 +19,6 @@ import java.util.Map; import java.util.Optional; -import java.util.concurrent.atomic.AtomicLong; import static org.flexlb.constant.CommonConstants.DEADLINE_EXCEEDED_MESSAGE; @@ -32,17 +34,22 @@ public class GrpcWorkerStatusRunner implements Runnable { private final WorkerStatus workerStatus; private final EngineHealthReporter engineHealthReporter; private final EngineGrpcService engineGrpcService; + private final FlexlbBatchScheduler batchScheduler; private final String ip; private final int grpcPort; private final long createTimeUs = System.nanoTime() / 1000; private final String id = IdUtils.fastUuid(); private final long syncRequestTimeoutMs; + private static final int MAX_CONSECUTIVE_FAILURES = 3; + private final EndpointRegistry endpointRegistry; public GrpcWorkerStatusRunner(String modelName, String ipPort, String site, RoleType roleType, String group, WorkerStatus workerStatus, EngineHealthReporter engineHealthReporter, EngineGrpcService engineGrpcService, - long syncRequestTimeoutMs) { + long syncRequestTimeoutMs, + FlexlbBatchScheduler batchScheduler, + EndpointRegistry endpointRegistry) { this.ipPort = ipPort; String[] split = ipPort.split(":"); this.ip = split[0]; @@ -55,6 +62,8 @@ public GrpcWorkerStatusRunner(String modelName, String ipPort, String site, Role this.engineHealthReporter = engineHealthReporter; this.engineGrpcService = engineGrpcService; this.syncRequestTimeoutMs = syncRequestTimeoutMs; + this.batchScheduler = batchScheduler; + this.endpointRegistry = endpointRegistry; } @Override @@ -78,9 +87,14 @@ private WorkerStatusResponse launchGrpcStatusCheck(String ip, int grpcPort, long return EngineStatusConverter.convertToWorkerStatusResponse(workerStatusPB); } catch (Throwable throwable) { handleException(throwable); - WorkerStatusResponse errorResponse = new WorkerStatusResponse(); - errorResponse.setMessage("Worker status gRPC call failed: " + throwable.getMessage()); - return errorResponse; + long failures = workerStatus.getConsecutiveFailures().incrementAndGet(); + logger.error("gRPC status check failed, consecutiveFailures={}/{}, msg={}", + failures, MAX_CONSECUTIVE_FAILURES, throwable.getMessage()); + if (failures >= MAX_CONSECUTIVE_FAILURES) { + workerStatus.setAlive(false); + logger.error("worker {} marked dead after {} consecutive gRPC failures", ipPort, failures); + } + return null; } } @@ -92,14 +106,10 @@ private void handleStatusResponse(WorkerStatusResponse newWorkerStatus, long sta return; } - if (newWorkerStatus.getMessage() != null) { - workerStatus.setAlive(false); - logger.error("query engine worker status via gRPC, msg={}", newWorkerStatus.getMessage()); - return; - } + engineHealthReporter.reportStatusCheckRemoteInfo(modelName, ipPort, newWorkerStatus.getRole() != null ? newWorkerStatus.getRole().name() : "UNKNOWN", startTime); - // Only report success worker status check info - engineHealthReporter.reportStatusCheckRemoteInfo(modelName, ipPort, newWorkerStatus.getRole(), startTime); + // Reset consecutive failure counter on successful response + workerStatus.getConsecutiveFailures().set(0); Long responseVersion = newWorkerStatus.getStatusVersion(); if (responseVersion == 0L) { @@ -109,76 +119,55 @@ private void handleStatusResponse(WorkerStatusResponse newWorkerStatus, long sta workerStatus.setSite(site); workerStatus.setGroup(group); - workerStatus.setRole(newWorkerStatus.getRole()); - long currentVersion = workerStatus.getStatusVersion().get(); - if (currentVersion >= responseVersion) { - logger.info("query engine worker status via gRPC, version is not updated, currentVersion: {}, responseVersion: {}", - currentVersion, responseVersion); - // Update basic worker status even when version is not updated - workerStatus.setAlive(newWorkerStatus.isAlive()); - workerStatus.setDpSize(newWorkerStatus.getDpSize()); - workerStatus.setTpSize(newWorkerStatus.getTpSize()); - - // Update status timestamp and record actual sync interval - long nowUs = System.nanoTime() / 1000; - long prevUpdateTime = workerStatus.getStatusLastUpdateTime().get(); - if (prevUpdateTime > 0) { - workerStatus.getStatusUpdateIntervalUs().set(nowUs - prevUpdateTime); + // Debug: log received finished tasks details + Map finishedTaskInfo = newWorkerStatus.getFinishedTaskInfo(); + if (finishedTaskInfo != null && !finishedTaskInfo.isEmpty()) { + StringBuilder taskDetails = new StringBuilder(); + for (TaskInfo task : finishedTaskInfo.values()) { + taskDetails.append(" req_id=").append(task.getRequestId()) + .append(" batch_id=").append(task.getBatchId()) + .append(" error_code=").append(task.getErrorCode()) + .append("\n"); } - workerStatus.getStatusLastUpdateTime().set(nowUs); - - // Update task state - Map waitingTaskInfo = newWorkerStatus.getWaitingTaskInfo(); - Map runningTaskInfo = newWorkerStatus.getRunningTaskInfo(); - Map finishedTaskInfo = newWorkerStatus.getFinishedTaskInfo(); - workerStatus.setWaitingTaskList(waitingTaskInfo); - workerStatus.setRunningTaskList(runningTaskInfo); - workerStatus.updateTaskStates(waitingTaskInfo, runningTaskInfo, finishedTaskInfo); - workerStatus.updateRunningQueueTime(); - - // Report success even when version is not updated - engineHealthReporter.reportStatusCheckerSuccess(modelName, workerStatus, - Optional.ofNullable(runningTaskInfo).map(Map::size).orElse(0), - Optional.ofNullable(finishedTaskInfo).map(Map::size).orElse(0)); - - logWorkerStatusUpdate(startTime, workerStatus); - return; + logger.info("GetWorkerStatus received: latestFinishedVersion={}, finishedTasksCount={}\n{}", + newWorkerStatus.getLatestFinishedVersion(), + finishedTaskInfo.size(), + taskDetails); } - // Update worker status from gRPC response - workerStatus.setAvailableConcurrency(newWorkerStatus.getAvailableConcurrency()); - workerStatus.setStepLatencyMs(newWorkerStatus.getStepLatencyMs()); - workerStatus.setIterateCount(newWorkerStatus.getIterateCount()); - workerStatus.setDpSize(newWorkerStatus.getDpSize()); - workerStatus.setTpSize(newWorkerStatus.getTpSize()); - workerStatus.setAlive(newWorkerStatus.isAlive()); - workerStatus.getStatusVersion().set(responseVersion != null ? responseVersion : -1L); - workerStatus.getLatestFinishedTaskVersion().set(newWorkerStatus.getLatestFinishedVersion() != null ? newWorkerStatus.getLatestFinishedVersion() : -1L); - - Map waitingTaskInfo = newWorkerStatus.getWaitingTaskInfo(); - Map runningTaskInfo = newWorkerStatus.getRunningTaskInfo(); - Map finishedTaskInfo = newWorkerStatus.getFinishedTaskInfo(); - workerStatus.setWaitingTaskList(waitingTaskInfo); - workerStatus.setRunningTaskList(runningTaskInfo); + long currentVersion = workerStatus.getStatusVersion().get(); + WorkerEndpoint ep = endpointRegistry != null ? endpointRegistry.get(ipPort) : null; + if (currentVersion < responseVersion) { + // 1. WorkerStatusResponse directly updates WorkerStatus + workerStatus.updateFromResponse(newWorkerStatus); + + // 2. Notify EP (calibration) — passes both updated status and raw response + if (ep != null) { + ep.onWorkerStatusUpdate(workerStatus, newWorkerStatus); + } - // Update local task state (including checking lost, updating running, and cleaning completed) - workerStatus.updateTaskStates(waitingTaskInfo, runningTaskInfo, finishedTaskInfo); + // 3. Notify scheduler (cleanup finished requests) + if (batchScheduler != null) { + batchScheduler.onWorkerStatusUpdate(workerStatus, newWorkerStatus); + } + } else { + logger.info("query engine worker status via gRPC, version is not updated, " + + "currentVersion: {}, responseVersion: {}", + currentVersion, responseVersion); + } - // Correct running queue total wait time - workerStatus.updateRunningQueueTime(); + // 4. Update latestFinishedVersion if remote is ahead (always, regardless of status version) + Long latestFinishedVersion = newWorkerStatus.getLatestFinishedVersion(); + if (latestFinishedVersion != null + && latestFinishedVersion > workerStatus.getLatestFinishedTaskVersion().get()) { + workerStatus.getLatestFinishedTaskVersion().set(latestFinishedVersion); + } - engineHealthReporter.reportStatusCheckerSuccess(modelName, workerStatus, - Optional.ofNullable(runningTaskInfo).map(Map::size).orElse(0), - Optional.ofNullable(finishedTaskInfo).map(Map::size).orElse(0)); + engineHealthReporter.reportStatusCheckerSuccess(modelName, workerStatus, ep, + Optional.ofNullable(newWorkerStatus.getRunningTaskInfo()).map(Map::size).orElse(0), + Optional.ofNullable(newWorkerStatus.getFinishedTaskInfo()).map(Map::size).orElse(0)); - // Update status timestamp and record actual sync interval - long nowUs = System.nanoTime() / 1000; - long prevUpdateTime = workerStatus.getStatusLastUpdateTime().get(); - if (prevUpdateTime > 0) { - workerStatus.getStatusUpdateIntervalUs().set(nowUs - prevUpdateTime); - } - workerStatus.getStatusLastUpdateTime().set(nowUs); logWorkerStatusUpdate(startTime, workerStatus); } catch (Throwable e) { @@ -188,10 +177,26 @@ private void handleStatusResponse(WorkerStatusResponse newWorkerStatus, long sta } private void logWorkerStatusUpdate(long startTime, WorkerStatus workerStatus) { - logger.info("gRPC Worker Status - {}, role:{}, running_queue_tokens:{}, cost:{}", + logger.info("gRPC Worker Status - {}, role:{}, alive:{}, concurrency:{}, " + + "step_latency_ms:{}, iterate_count:{}, " + + "dp_rank:{}, dp_size:{}, tp_size:{}, " + + "avail_kv_tokens:{}, used_kv_tokens:{}, " + + "waiting_tasks:{}, running_tasks:{}, " + + "version:{}, sync_cost_us:{}", ipPort, workerStatus.getRole(), - workerStatus.getRunningQueueTime(), + workerStatus.isAlive(), + workerStatus.getAvailableConcurrency(), + workerStatus.getStepLatencyMs(), + workerStatus.getIterateCount(), + workerStatus.getDpRank(), + workerStatus.getDpSize(), + workerStatus.getTpSize(), + workerStatus.getAvailableKvCacheTokens(), + workerStatus.getTotalKvCacheTokens().get() - workerStatus.getAvailableKvCacheTokens().get(), + workerStatus.getRunningTaskList() != null ? workerStatus.getRunningTaskList().values().stream().filter(t -> t.getPhase() != org.flexlb.enums.TaskPhase.RUNNING).count() : 0, + workerStatus.getRunningTaskList() != null ? workerStatus.getRunningTaskList().size() : 0, + workerStatus.getStatusVersion(), System.nanoTime() / 1000 - startTime); } diff --git a/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/sync/schedule/ExpirationCleaner.java b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/sync/schedule/ExpirationCleaner.java index e596fe8297..1dc1ac5384 100644 --- a/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/sync/schedule/ExpirationCleaner.java +++ b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/sync/schedule/ExpirationCleaner.java @@ -1,44 +1,27 @@ package org.flexlb.sync.schedule; import org.apache.commons.collections4.MapUtils; -import org.flexlb.dao.master.TaskInfo; import org.flexlb.dao.master.WorkerStatus; import org.flexlb.dao.route.RoleType; -import org.flexlb.enums.FlexMetricType; -import org.flexlb.enums.FlexPriorityType; -import org.flexlb.enums.TaskStateEnum; -import org.flexlb.metric.FlexMetricTags; -import org.flexlb.metric.FlexMonitor; import org.flexlb.sync.status.EngineWorkerStatus; import org.flexlb.sync.status.ModelWorkerStatus; -import org.flexlb.util.Logger; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import org.springframework.scheduling.annotation.Scheduled; import org.springframework.stereotype.Component; -import javax.annotation.PostConstruct; import java.util.Iterator; import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.atomic.AtomicLong; @Component public class ExpirationCleaner { - private static final String TASK_REMOVED = "task.removed"; + private static final Logger logger = LoggerFactory.getLogger("syncLogger"); - private final long taskTimeoutUs; private final long workerTimeoutUs; - private final FlexMonitor monitor; - public ExpirationCleaner(FlexMonitor monitor) { - this.monitor = monitor; - this.taskTimeoutUs = Long.parseLong(System.getenv().getOrDefault("TASK_TIMEOUT_US", "3000000")); // Default 3s - this.workerTimeoutUs = Long.parseLong(System.getenv().getOrDefault("WORKER_TIMEOUT_US", "3000000")); // Default 3s - } - - @PostConstruct - public void init() { - this.monitor.register(TASK_REMOVED, FlexMetricType.QPS, FlexPriorityType.PRECISE); + public ExpirationCleaner() { + this.workerTimeoutUs = Long.parseLong(System.getenv().getOrDefault("WORKER_TIMEOUT_US", "3000000")); } @Scheduled(fixedRate = 3000) @@ -59,61 +42,12 @@ public void doClean(Map workerStatusMap, RoleType role) { Map.Entry item = it.next(); WorkerStatus workerStatus = item.getValue(); - // 1. Check if worker needs cleanup long expirationTime = workerStatus.getStatusLastUpdateTime().get() + workerTimeoutUs; long currentTime = System.nanoTime() / 1000; if (currentTime > expirationTime) { + logger.info("Removing expired worker: {}, role: {}", item.getKey(), role); it.remove(); - continue; - } - - // 2. Check if tasks within worker need cleanup: lost tasks and long-timeout tasks - ConcurrentHashMap localTaskMap = workerStatus.getLocalTaskMap(); - Iterator> taskIterator = localTaskMap.entrySet().iterator(); - while (taskIterator.hasNext()) { - Map.Entry entry = taskIterator.next(); - Long requestId = entry.getKey(); - TaskInfo task = entry.getValue(); - - boolean shouldRemove = false; - - // Check if task is lost - if (task.isLost()) { - Logger.warn("Cleaning lost task: {}, state: {}, role: {}, worker: {}", requestId, task.getTaskState(), role, workerStatus.getIp()); - reportTaskRemoved(workerStatus.getRole(), workerStatus.getIp(), "lost"); - task.updateTaskState(TaskStateEnum.CLEANED); - shouldRemove = true; - } - // Check if task is timed out - else if (task.isTimeout(currentTime, taskTimeoutUs)) { - Logger.warn("Removing timeout task: {}, state: {}, age: {}ms, role: {}, worker: {}", requestId, task.getTaskState(), - (currentTime - task.getLastActiveTimeUs()) / 1000, role, workerStatus.getIp()); - reportTaskRemoved(workerStatus.getRole(), workerStatus.getIp(), "timeout"); - task.updateTaskState(TaskStateEnum.CLEANED); - shouldRemove = true; - } - - if (shouldRemove) { - decrementQueueTime(workerStatus.getRunningQueueTime(), task, workerStatus.getRole()); - taskIterator.remove(); - } } } } - - private void reportTaskRemoved(String role, String ip, String type) { - FlexMetricTags tags = FlexMetricTags.of( - "role", role, - "ip", ip, - "type", type - ); - monitor.report(TASK_REMOVED, tags, 1); - } - - private static void decrementQueueTime(AtomicLong runningQueueTime, TaskInfo task, String role) { - if (RoleType.PREFILL.matches(role) || RoleType.PDFUSION.matches(role)) { - long delta = task.estimatePrefillTime(); - WorkerStatus.safeDecrementQueueTime(runningQueueTime, delta); - } - } -} \ No newline at end of file +} diff --git a/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/sync/status/EngineWorkerStatus.java b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/sync/status/EngineWorkerStatus.java index 980e2e68b4..cb52975b0c 100644 --- a/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/sync/status/EngineWorkerStatus.java +++ b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/sync/status/EngineWorkerStatus.java @@ -2,13 +2,15 @@ import lombok.Data; import lombok.extern.slf4j.Slf4j; +import org.flexlb.balance.endpoint.EndpointRegistry; +import org.flexlb.balance.endpoint.WorkerEndpoint; import org.flexlb.config.ModelMetaConfig; import org.flexlb.dao.master.WorkerStatus; import org.flexlb.dao.route.RoleType; import org.springframework.stereotype.Component; +import java.util.LinkedHashMap; import java.util.Map; -import java.util.stream.Collectors; @Slf4j @Data @@ -18,12 +20,19 @@ public class EngineWorkerStatus { public static final ModelWorkerStatus MODEL_ROLE_WORKER_STATUS = new ModelWorkerStatus(); public final ModelMetaConfig modelMetaConfig; + private final EndpointRegistry endpointRegistry; - public EngineWorkerStatus(ModelMetaConfig modelMetaConfig) { + public EngineWorkerStatus(ModelMetaConfig modelMetaConfig, EndpointRegistry endpointRegistry) { this.modelMetaConfig = modelMetaConfig; + this.endpointRegistry = endpointRegistry; } - public Map selectModelWorkerStatus(RoleType roleType, String group) { + /** + * Select workers for a given role and group, returning + * {@link WorkerEndpoint} instances so callers can access both + * engine status and endpoint-local methods (reserve / release / …). + */ + public Map selectModelWorkerStatus(RoleType roleType, String group) { Map roleStatusMap = MODEL_ROLE_WORKER_STATUS.getRoleStatusMap(roleType); @@ -31,17 +40,21 @@ public EngineWorkerStatus(ModelMetaConfig modelMetaConfig) { return Map.of(); } - if (group == null) { - return roleStatusMap; + Map result = new LinkedHashMap<>(); + for (Map.Entry entry : roleStatusMap.entrySet()) { + WorkerStatus ws = entry.getValue(); + if (ws == null) { + continue; + } + if (group != null && ws.getGroup() != null && !ws.getGroup().equals(group)) { + continue; + } + WorkerEndpoint ep = endpointRegistry.get(entry.getKey()); + if (ep != null) { + result.put(entry.getKey(), ep); + } } - - return roleStatusMap.entrySet() - .stream() - .filter(entry -> { - WorkerStatus workerStatus = entry.getValue(); - return workerStatus != null && workerStatus.getGroup() != null && workerStatus.getGroup().equals(group); - }) - .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + return result; } } diff --git a/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/sync/status/EngineWorkerStatusProvider.java b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/sync/status/EngineWorkerStatusProvider.java index 46ea27afe5..4985254a92 100644 --- a/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/sync/status/EngineWorkerStatusProvider.java +++ b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/sync/status/EngineWorkerStatusProvider.java @@ -1,7 +1,7 @@ package org.flexlb.sync.status; import lombok.extern.slf4j.Slf4j; -import org.flexlb.dao.master.WorkerStatus; +import org.flexlb.balance.endpoint.WorkerEndpoint; import org.flexlb.dao.master.WorkerStatusProvider; import org.flexlb.dao.route.RoleType; import org.springframework.beans.factory.annotation.Autowired; @@ -14,16 +14,16 @@ @Slf4j @Service public class EngineWorkerStatusProvider implements WorkerStatusProvider { - + @Autowired private EngineWorkerStatus engineWorkerStatus; - + @Override public List getWorkerIpPorts(RoleType roleType, String group) { - Map workerStatusMap + Map workerEndpointMap = engineWorkerStatus.selectModelWorkerStatus(roleType, group); - return new ArrayList<>(workerStatusMap.keySet()); + return new ArrayList<>(workerEndpointMap.keySet()); } } \ No newline at end of file diff --git a/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/sync/synchronizer/MasterEngineSynchronizer.java b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/sync/synchronizer/MasterEngineSynchronizer.java index dc8fb82c76..f9d9840c91 100644 --- a/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/sync/synchronizer/MasterEngineSynchronizer.java +++ b/rtp_llm/flexlb/flexlb-sync/src/main/java/org/flexlb/sync/synchronizer/MasterEngineSynchronizer.java @@ -3,6 +3,8 @@ import com.fasterxml.jackson.core.type.TypeReference; import io.micrometer.core.instrument.util.NamedThreadFactory; import org.apache.commons.lang3.StringUtils; +import org.flexlb.balance.endpoint.EndpointRegistry; +import org.flexlb.balance.scheduler.FlexlbBatchScheduler; import org.flexlb.cache.service.CacheAwareService; import org.flexlb.config.ModelMetaConfig; import org.flexlb.dao.route.Endpoint; @@ -35,6 +37,8 @@ public class MasterEngineSynchronizer extends AbstractEngineStatusSynchronizer { private final List modelNames = new ArrayList<>(); private final EngineGrpcService engineGrpcService; private final CacheAwareService localKvCacheAwareManager; + private final FlexlbBatchScheduler batchScheduler; + private final EndpointRegistry endpointRegistry; private final long syncRequestTimeoutMs; private final LongAdder syncCount = new LongAdder(); private final Long syncEngineStatusInterval; @@ -44,12 +48,17 @@ public MasterEngineSynchronizer(WorkerAddressService workerAddressService, EngineWorkerStatus engineWorkerStatus, EngineGrpcService engineGrpcService, ModelMetaConfig modelMetaConfig, - CacheAwareService localKvCacheAwareManager) { + CacheAwareService localKvCacheAwareManager, + @org.springframework.beans.factory.annotation.Autowired(required = false) + FlexlbBatchScheduler batchScheduler, + EndpointRegistry endpointRegistry) { super(workerAddressService, engineHealthReporter, engineWorkerStatus, modelMetaConfig); this.engineGrpcService = engineGrpcService; this.localKvCacheAwareManager = localKvCacheAwareManager; + this.batchScheduler = batchScheduler; + this.endpointRegistry = endpointRegistry; this.syncEngineStatusInterval = System.getenv("SYNC_STATUS_INTERVAL") != null ? Long.parseLong(System.getenv("SYNC_STATUS_INTERVAL")) @@ -98,7 +107,8 @@ public void syncEngineStatus() { modelName, modelWorkerStatus.getRoleStatusMap(roleType), workerAddressService, statusCheckExecutor, engineHealthReporter, engineGrpcService, roleType, localKvCacheAwareManager, - syncRequestTimeoutMs, syncCount, syncEngineStatusInterval + syncRequestTimeoutMs, syncCount, syncEngineStatusInterval, + batchScheduler, endpointRegistry )); } else { logger.error("roleEndpoints is null, by roleType : {}", roleType); diff --git a/rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/balance/endpoint/DecodeEndpointTest.java b/rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/balance/endpoint/DecodeEndpointTest.java new file mode 100644 index 0000000000..4409158c53 --- /dev/null +++ b/rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/balance/endpoint/DecodeEndpointTest.java @@ -0,0 +1,123 @@ +package org.flexlb.balance.endpoint; + +import org.flexlb.dao.master.TaskInfo; +import org.flexlb.dao.master.WorkerStatus; +import org.flexlb.enums.TaskPhase; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.*; + +class DecodeEndpointTest { + + private WorkerStatus status; + private DecodeEndpoint endpoint; + + @BeforeEach + void setUp() { + status = new WorkerStatus(); + status.setIp("10.0.0.1"); + status.setPort(8080); + status.setGrpcPort(8081); + endpoint = new DecodeEndpoint(status); + } + + @Test + void reserve_updatesSnapshotAndInflight() { + endpoint.calibrate(null, null, 10000); + endpoint.reserve(100L, 500); + assertEquals(1, endpoint.getInflightCount()); + assertEquals(9500, endpoint.realKvAvailable()); + } + + @Test + void release_decrementsInflight() { + endpoint.reserve(100L, 500); + endpoint.reserve(101L, 300); + endpoint.release(100L); + + assertEquals(1, endpoint.getInflightCount()); + } + + @Test + void release_unknownRequestId_noEffect() { + endpoint.reserve(100L, 500); + endpoint.release(999L); + assertEquals(1, endpoint.getInflightCount()); + } + + @Test + void release_neverGoesNegative() { + endpoint.reserve(100L, 100); + endpoint.release(100L); + endpoint.release(100L); + assertEquals(0, endpoint.getInflightCount()); + assertEquals(0, endpoint.realKvAvailable()); + } + + @Test + void calibrate_kvAllocatedReleasesFromInflight() { + endpoint.reserve(100L, 500); + + TaskInfo running = task(100L); + running.setPhase(TaskPhase.KV_ALLOCATED); + endpoint.calibrate(Map.of("100", running), null, 10000); + + assertEquals(0, endpoint.getInflightCount()); + assertEquals(10000, endpoint.realKvAvailable()); + } + + @Test + void calibrate_finishedFailureReleasesFromInflight() { + endpoint.reserve(100L, 500); + + TaskInfo failed = task(100L); + failed.setErrorCode(1); + failed.setErrorMessage("timeout"); + endpoint.calibrate(null, Map.of("100", failed), 10000); + + assertEquals(0, endpoint.getInflightCount()); + } + + @Test + void calibrate_finishedSuccessReleasesIfStillPresent() { + endpoint.reserve(100L, 500); + + TaskInfo success = task(100L); + success.setErrorCode(0); + endpoint.calibrate(null, Map.of("100", success), 10000); + + assertEquals(0, endpoint.getInflightCount()); + } + + @Test + void calibrate_updatesReportedKvAvailable() { + endpoint.reserve(100L, 500); + endpoint.calibrate(null, null, 10000); + + assertEquals(9500, endpoint.realKvAvailable()); + } + + @Test + void availableKvTokens_accountsForReservations() { + endpoint.calibrate(null, null, 10000); + + endpoint.reserve(100L, 3000); + endpoint.reserve(101L, 2000); + + assertEquals(5000, endpoint.realKvAvailable()); + } + + @Test + void ipPort_format() { + assertEquals("10.0.0.1:8080", endpoint.ipPort()); + } + + private TaskInfo task(long requestId) { + TaskInfo task = new TaskInfo(); + task.setRequestId(requestId); + return task; + } +} diff --git a/rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/balance/endpoint/PrefillEndpointTest.java b/rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/balance/endpoint/PrefillEndpointTest.java new file mode 100644 index 0000000000..2a7975ea59 --- /dev/null +++ b/rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/balance/endpoint/PrefillEndpointTest.java @@ -0,0 +1,337 @@ +package org.flexlb.balance.endpoint; + +import org.flexlb.balance.scheduler.BatchDecisionHandler; +import org.flexlb.balance.scheduler.BatchItem; +import org.flexlb.balance.scheduler.DispatchMeta; +import org.flexlb.config.FlexlbConfig; +import org.flexlb.dao.BalanceContext; +import org.flexlb.dao.loadbalance.DebugInfo; +import org.flexlb.dao.loadbalance.Request; +import org.flexlb.dao.loadbalance.ServerStatus; +import org.flexlb.dao.master.TaskInfo; +import org.flexlb.dao.master.WorkerStatus; +import org.flexlb.dao.master.WorkerStatusResponse; +import org.flexlb.dao.route.RoleType; +import org.flexlb.enums.TaskPhase; +import org.flexlb.service.monitor.BatchSchedulerReporter; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.mock; + +class PrefillEndpointTest { + + private PrefillEndpoint endpoint; + private FlexlbConfig config; + + @BeforeEach + void setUp() { + WorkerStatus status = new WorkerStatus(); + status.setIp("127.0.0.1"); + status.setPort(8080); + status.setGrpcPort(8090); + status.setRole(RoleType.PREFILL); + + config = new FlexlbConfig(); + config.setFlexlbBatchQueueMaxSize(100); + config.setFlexlbBatchFixedWaitMs(300); + config.setCostAlpha0(10); + config.setCostAlpha1(0.1); + config.setCostAlpha2(0); + config.setCostAlpha3(0); + config.setCostAlpha4(0); + config.setCostAlpha5(5); + + endpoint = new PrefillEndpoint(status, config, noopHandler(), mock(BatchSchedulerReporter.class)); + } + + // ---- batch commit / release ---- + + @Test + void commitBatchIncreasesInflightCount() { + assertEquals(0, endpoint.getInflightBatchCount()); + + BatchItem item = createBatchItem(1L, 500, 200); + endpoint.commitBatch(1L, 100, List.of(item)); + + assertEquals(1, endpoint.getInflightBatchCount()); + assertEquals(1, endpoint.getInflightRequestCount()); + } + + @Test + void releaseBatchDecreasesInflightCount() { + BatchItem item = createBatchItem(1L, 500, 200); + endpoint.commitBatch(1L, 100, List.of(item)); + endpoint.releaseBatch(1L); + + assertEquals(0, endpoint.getInflightBatchCount()); + } + + @Test + void releaseBatchNonExistentDoesNotThrow() { + endpoint.releaseBatch(999L); // should not throw + assertEquals(0, endpoint.getInflightBatchCount()); + } + + @Test + void commitMultipleBatches() { + BatchItem item1 = createBatchItem(1L, 500, 200); + BatchItem item2 = createBatchItem(2L, 300, 100); + BatchItem item3 = createBatchItem(3L, 400, 0); + + endpoint.commitBatch(1L, 100, List.of(item1, item2)); + endpoint.commitBatch(2L, 50, List.of(item3)); + + assertEquals(2, endpoint.getInflightBatchCount()); + assertEquals(3, endpoint.getInflightRequestCount()); + } + + // ---- repack batch ---- + + @Test + void repackBatchRemovesFailedRequests() { + BatchItem item1 = createBatchItem(1L, 500, 200); + BatchItem item2 = createBatchItem(2L, 300, 100); + endpoint.commitBatch(1L, 100, List.of(item1, item2)); + + BatchInflight result = endpoint.repackBatch(1L, Set.of(2L)); + assertNotNull(result); + assertEquals(1, result.requests().size()); + assertEquals(1L, result.requests().get(0).requestId()); + } + + @Test + void repackBatchAllFailedReturnsNull() { + BatchItem item1 = createBatchItem(1L, 500, 200); + endpoint.commitBatch(1L, 100, List.of(item1)); + + BatchInflight result = endpoint.repackBatch(1L, Set.of(1L)); + assertNull(result); + assertEquals(0, endpoint.getInflightBatchCount()); + } + + // ---- calibrate ---- + + @Test + void calibrateRemovesBatchOnSuccess() { + BatchItem item = createBatchItem(1L, 500, 200); + endpoint.commitBatch(1L, 100, List.of(item)); + + Map finished = new HashMap<>(); + TaskInfo successTask = new TaskInfo(); + successTask.setRequestId(1L); + successTask.setBatchId(1L); + successTask.setErrorCode(0); + finished.put("1", successTask); + + endpoint.calibrate(finished, Map.of()); + + assertEquals(0, endpoint.getInflightBatchCount()); + } + + @Test + void calibrateRepacksOnPartialFailure() { + BatchItem item1 = createBatchItem(1L, 500, 200); + BatchItem item2 = createBatchItem(2L, 300, 100); + endpoint.commitBatch(1L, 100, List.of(item1, item2)); + + Map finished = new HashMap<>(); + TaskInfo failedTask = new TaskInfo(); + failedTask.setRequestId(2L); + failedTask.setBatchId(1L); + failedTask.setErrorCode(500); + failedTask.setErrorMessage("engine error"); + finished.put("2", failedTask); + + endpoint.calibrate(finished, Map.of()); + + assertEquals(1, endpoint.getInflightBatchCount()); + BatchInflight remaining = endpoint.getInflightBatches().get(1L); + assertNotNull(remaining); + assertEquals(1, remaining.requests().size()); + assertEquals(1L, remaining.requests().get(0).requestId()); + } + + @Test + void calibrateHandlesTaskWithNoBatchId() { + BatchItem item = createBatchItem(1L, 500, 200); + endpoint.commitBatch(1L, 100, List.of(item)); + + Map finished = new HashMap<>(); + TaskInfo badTask = new TaskInfo(); + badTask.setRequestId(999L); // non-colliding: won't match batchId=1 + badTask.setBatchId(-1); + badTask.setErrorCode(0); + finished.put("1", badTask); + + // should not throw, just log a warning for missing non-batch inflight + endpoint.calibrate(finished, Map.of()); + assertEquals(1, endpoint.getInflightBatchCount()); + } + + @Test + void calibrateUpdatesProgressAnchorsForRunningBatches() { + BatchItem item = createBatchItem(1L, 500, 200); + endpoint.commitBatch(1L, 100, List.of(item)); + + Map running = new HashMap<>(); + TaskInfo runningTask = new TaskInfo(); + runningTask.setRequestId(1L); + runningTask.setBatchId(1L); + runningTask.setPhase(TaskPhase.RUNNING); + running.put("1", runningTask); + + endpoint.calibrate(Map.of(), running); + + BatchInflight batch = endpoint.getInflightBatches().get(1L); + assertNotNull(batch); + assertTrue(batch.progressBaseMs() > 0, "Running batch should have its progress anchor updated"); + } + + // ---- estimated waiting time ---- + + @Test + void realWaitTimeMsZeroWhenNoInflight() { + assertEquals(0, endpoint.realWaitTimeMs()); + } + + @Test + void realWaitTimeMsPositiveWithInflight() { + BatchItem item = createBatchItem(1L, 500, 200); + endpoint.commitBatch(1L, 5000, List.of(item)); // 5s prediction + + long waitMs = endpoint.realWaitTimeMs(); + assertTrue(waitMs > 0, "Should have non-zero wait time with inflight batch"); + assertTrue(waitMs <= 5000, "Wait time should not exceed prediction"); + } + + @Test + void realWaitTimeMsDecreasesOverTime() throws InterruptedException { + BatchItem item = createBatchItem(1L, 500, 200); + endpoint.commitBatch(1L, 5000, List.of(item)); + + long waitBefore = endpoint.realWaitTimeMs(); + + // Mark the batch as running so elapsed time counts + Map running = new HashMap<>(); + TaskInfo runningTask = new TaskInfo(); + runningTask.setRequestId(1L); + runningTask.setBatchId(1L); + runningTask.setPhase(TaskPhase.RUNNING); + running.put("1", runningTask); + endpoint.calibrate(Map.of(), running); + + Thread.sleep(50); + + long waitAfter = endpoint.realWaitTimeMs(); + assertTrue(waitAfter <= waitBefore, "Wait time should decrease after progress"); + } + + // ---- eviction ---- + + @Test + void evictExpiredBatchesCleansUpStaleEntries() throws InterruptedException { + BatchItem item = createBatchItem(1L, 500, 200); + endpoint.commitBatch(1L, 100, List.of(item)); + + assertEquals(1, endpoint.getInflightBatchCount()); + + // Wait a bit so the batch ages + Thread.sleep(10); + + int evicted = endpoint.evictExpiredBatches(1); // 1ms TTL — should evict + assertEquals(1, evicted); + assertEquals(0, endpoint.getInflightBatchCount()); + } + + @Test + void evictExpiredBatchesFreshEntriesSurvive() { + BatchItem item = createBatchItem(1L, 500, 200); + endpoint.commitBatch(1L, 100, List.of(item)); + + int evicted = endpoint.evictExpiredBatches(60_000); // 60s TTL — fresh entry survives + assertEquals(0, evicted); + assertEquals(1, endpoint.getInflightBatchCount()); + } + + // ---- realPendingCount ---- + + @Test + void realPendingCountIncludesBatcherQueue() { + // Initially, batcher queue is empty + assertEquals(0, endpoint.realPendingCount()); + + BatchItem item = createBatchItem(1L, 500, 200); + endpoint.getBatcher().offer(item); + + assertTrue(endpoint.realPendingCount() > 0, "Pending count should include batcher queue"); + } + + // ---- WorkerEndpoint inherited behavior ---- + + @Test + void onWorkerStatusUpdateUpdatesAliveStatus() { + WorkerStatusResponse response = new WorkerStatusResponse(); + response.setRole(RoleType.PREFILL); + WorkerStatus status = new WorkerStatus(); + status.setIp("127.0.0.1"); + status.setPort(8080); + status.setAlive(true); + + endpoint.onWorkerStatusUpdate(status, response); + + assertTrue(endpoint.getStatus().isAlive()); + } + + // ---- close ---- + + @Test + void closeShutsDownBatcher() { + assertNotNull(endpoint.getBatcher()); + endpoint.close(); + // After close, offering should fail (batcher is stopped) + BatchItem item = createBatchItem(1L, 500, 200); + endpoint.getBatcher().offer(item); + // Should not throw — batcher handles stopped state + } + + // ---- helpers ---- + + private BatchItem createBatchItem(long requestId, long seqLen, long hitCacheLen) { + Request request = new Request(); + request.setRequestId(requestId); + request.setSeqLen(seqLen); + + BalanceContext ctx = new BalanceContext(); + ctx.setRequest(request); + + ServerStatus prefill = new ServerStatus(); + prefill.setRole(RoleType.PREFILL); + prefill.setServerIp("127.0.0.1"); + prefill.setHttpPort(8080); + prefill.setGrpcPort(8090); + DebugInfo debugInfo = new DebugInfo(); + debugInfo.setHitCacheLen(hitCacheLen); + prefill.setDebugInfo(debugInfo); + + return new BatchItem(ctx, null, null, prefill, null, endpoint, null, 0, System.currentTimeMillis()); + } + + private static BatchDecisionHandler noopHandler() { + return new BatchDecisionHandler() { + @Override public void onExpired(BatchItem head) {} + @Override public void onUrgent(BatchItem head, DispatchMeta meta) {} + @Override public void onBatchReady(List items, DispatchMeta meta) {} + @Override public void onOfferFailure(BatchItem item, Throwable error) {} + }; + } +} diff --git a/rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/balance/endpoint/WorkerEndpointTest.java b/rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/balance/endpoint/WorkerEndpointTest.java new file mode 100644 index 0000000000..79111d2543 --- /dev/null +++ b/rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/balance/endpoint/WorkerEndpointTest.java @@ -0,0 +1,284 @@ +package org.flexlb.balance.endpoint; + +import org.flexlb.balance.scheduler.BatchDecisionHandler; +import org.flexlb.balance.scheduler.BatchItem; +import org.flexlb.config.FlexlbConfig; +import org.flexlb.dao.BalanceContext; +import org.flexlb.dao.loadbalance.Request; +import org.flexlb.dao.master.TaskInfo; +import org.flexlb.dao.master.WorkerStatus; +import org.flexlb.dao.master.WorkerStatusResponse; +import org.flexlb.dao.route.RoleType; +import org.flexlb.service.monitor.BatchSchedulerReporter; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; + +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.*; + +class WorkerEndpointTest { + + private WorkerStatus status; + private PrefillEndpoint endpoint; + + @BeforeEach + void setUp() { + status = new WorkerStatus(); + status.setIp("10.0.0.1"); + status.setPort(8080); + status.setGrpcPort(8081); + FlexlbConfig config = new FlexlbConfig(); + config.setCostAlpha0(0); + config.setCostAlpha1(1); + BatchDecisionHandler handler = Mockito.mock(BatchDecisionHandler.class); + endpoint = new PrefillEndpoint(status, config, handler, Mockito.mock(BatchSchedulerReporter.class)); + } + + @Test + void commitBatch_incrementsEstimate() { + endpoint.commitBatch(1L, 500, List.of(new BatchItem(ctx(100L, 1000), null, null, null, null, null, null, 0, 0))); + assertEquals(500, endpoint.realWaitTimeMs()); + + endpoint.commitBatch(2L, 300, List.of(new BatchItem(ctx(101L, 500), null, null, null, null, null, null, 0, 0))); + assertEquals(800, endpoint.realWaitTimeMs()); + } + + @Test + void releaseBatch_decrementsEstimate() { + endpoint.commitBatch(1L, 500, List.of(new BatchItem(ctx(100L, 1000), null, null, null, null, null, null, 0, 0))); + endpoint.commitBatch(2L, 300, List.of(new BatchItem(ctx(101L, 500), null, null, null, null, null, null, 0, 0))); + + endpoint.releaseBatch(1L); + assertEquals(300, endpoint.realWaitTimeMs()); + } + + @Test + void releaseBatch_unknownBatchId_noEffect() { + endpoint.commitBatch(1L, 500, List.of(new BatchItem(ctx(100L, 1000), null, null, null, null, null, null, 0, 0))); + endpoint.releaseBatch(999L); + assertEquals(500, endpoint.realWaitTimeMs()); + } + + @Test + void releaseBatch_neverGoesNegative() { + endpoint.commitBatch(1L, 100, List.of(new BatchItem(ctx(100L, 1000), null, null, null, null, null, null, 0, 0))); + endpoint.releaseBatch(1L); + endpoint.releaseBatch(1L); + assertEquals(0, endpoint.realWaitTimeMs()); + } + + @Test + void calibrate_noInflight_resetsToZero() { + endpoint.commitBatch(1L, 500, List.of(new BatchItem(ctx(100L, 1000), null, null, null, null, null, null, 0, 0))); + + TaskInfo finished = task(100L, 1000, 0, 1L); + finished.setErrorCode(0); + endpoint.calibrate(Map.of("100", finished), null); + + assertEquals(0, endpoint.realWaitTimeMs()); + assertTrue(endpoint.getInflightBatches().isEmpty()); + } + + @Test + void calibrate_finishedBatch_removedFromInflight() { + endpoint.commitBatch(5L, 9999, List.of( + new BatchItem(ctx(100L, 1000), null, null, null, null, null, null, 0, 0), new BatchItem(ctx(101L, 2000), null, null, null, null, null, null, 0, 0))); + + TaskInfo t1 = task(100L, 1000, 0, 5L); + t1.setErrorCode(0); + TaskInfo t2 = task(101L, 2000, 0, 5L); + t2.setErrorCode(0); + endpoint.calibrate(Map.of("100", t1, "101", t2), null); + + assertEquals(0, endpoint.realWaitTimeMs()); + assertFalse(endpoint.getInflightBatches().containsKey(5L)); + } + + @Test + void calibrate_partialBatchFailure_repacks() { + endpoint.commitBatch(5L, 9999, List.of( + new BatchItem(ctx(100L, 1000), null, null, null, null, null, null, 0, 0), new BatchItem(ctx(101L, 2000), null, null, null, null, null, null, 0, 0))); + + TaskInfo failed = task(100L, 1000, 0, 5L); + failed.setErrorCode(1); + failed.setErrorMessage("timeout"); + TaskInfo success = task(101L, 2000, 0, 5L); + success.setErrorCode(0); + endpoint.calibrate(Map.of("100", failed, "101", success), null); + + assertFalse(endpoint.getInflightBatches().containsKey(5L)); + assertEquals(0, endpoint.realWaitTimeMs()); + } + + @Test + void calibrate_inflightUnconfirmedBatchesSurvive() { + endpoint.commitBatch(5L, 1000, List.of(new BatchItem(ctx(100L, 500), null, null, null, null, null, null, 0, 0))); + endpoint.commitBatch(7L, 2000, List.of(new BatchItem(ctx(200L, 1000), null, null, null, null, null, null, 0, 0))); + + TaskInfo finished = task(100L, 500, 0, 5L); + finished.setErrorCode(0); + endpoint.calibrate(Map.of("100", finished), null); + + assertFalse(endpoint.getInflightBatches().containsKey(5L)); + assertTrue(endpoint.getInflightBatches().containsKey(7L)); + // realWaitTimeMs = predictMs - elapsedMs; allow small timing delta + assertTrue(Math.abs(endpoint.realWaitTimeMs() - 2000) < 50, + "Expected ~2000ms but got " + endpoint.realWaitTimeMs()); + } + + @Test + void repackBatch_removesFailedRequests() { + endpoint.commitBatch(5L, 9999, List.of( + new BatchItem(ctx(100L, 1000), null, null, null, null, null, null, 0, 0), + new BatchItem(ctx(101L, 2000), null, null, null, null, null, null, 0, 0), + new BatchItem(ctx(102L, 3000), null, null, null, null, null, null, 0, 0))); + BatchInflight repacked = endpoint.repackBatch(5L, java.util.Set.of(101L)); + + assertNotNull(repacked); + assertEquals(2, repacked.requests().size()); + assertTrue(repacked.requests().stream().anyMatch(r -> r.requestId() == 100L)); + assertTrue(repacked.requests().stream().anyMatch(r -> r.requestId() == 102L)); + assertFalse(repacked.requests().stream().anyMatch(r -> r.requestId() == 101L)); + } + + @Test + void repackBatch_allFailed_removesBatch() { + endpoint.commitBatch(5L, 500, List.of(new BatchItem(ctx(100L, 1000), null, null, null, null, null, null, 0, 0))); + + BatchInflight repacked = endpoint.repackBatch(5L, java.util.Set.of(100L)); + + assertNull(repacked); + assertFalse(endpoint.getInflightBatches().containsKey(5L)); + assertEquals(0, endpoint.realWaitTimeMs()); + } + + @Test + void ipPort_format() { + assertEquals("10.0.0.1:8080", endpoint.ipPort()); + } + + // ==================== getStatus() returns live reference ==================== + + @Test + void getStatus_returns_live_reference() { + status.setAlive(true); + status.setAvailableConcurrency(42L); + status.setDpRank(3); + + WorkerStatus liveStatus = endpoint.getStatus(); + assertSame(status, liveStatus); + assertTrue(liveStatus.isAlive()); + assertEquals(42L, (long) liveStatus.getAvailableConcurrency()); + assertEquals(3L, liveStatus.getDpRank()); + } + + // ==================== WorkerStatus.updateFromResponse ==================== + + @Test + void updateFromResponse_applies_all_engine_fields() { + WorkerStatusResponse resp = new WorkerStatusResponse(); + resp.setRole(RoleType.DECODE); + resp.setAlive(true); + resp.setAvailableConcurrency(8L); + resp.setStepLatencyMs(25.0); + resp.setIterateCount(100L); + resp.setDpSize(4); + resp.setTpSize(2); + resp.setDpRank(1); + resp.setAvailableKvCacheTokens(10000L); + resp.setStatusVersion(5L); + resp.setLatestFinishedVersion(3L); + + status.updateFromResponse(resp); + + assertEquals(RoleType.DECODE, status.getRole()); + assertTrue(status.isAlive()); + assertEquals(8L, (long) status.getAvailableConcurrency()); + assertEquals(25.0, status.getStepLatencyMs(), 0.001); + assertEquals(100L, status.getIterateCount()); + assertEquals(4L, status.getDpSize()); + assertEquals(2L, status.getTpSize()); + assertEquals(1L, status.getDpRank()); + assertEquals(10000L, status.getAvailableKvCacheTokens().get()); + assertEquals(5L, status.getStatusVersion().get()); + assertEquals(3L, status.getLatestFinishedTaskVersion().get()); + } + + @Test + void updateFromResponse_null_is_noop() { + status.setAlive(true); + status.setAvailableConcurrency(10L); + + status.updateFromResponse(null); + + assertTrue(status.isAlive()); + assertEquals(10L, (long) status.getAvailableConcurrency()); + } + + // ==================== onWorkerStatusUpdate ==================== + + @Test + void onWorkerStatusUpdate_replaces_status_reference() { + WorkerStatusResponse resp = new WorkerStatusResponse(); + WorkerStatus newStatus = new WorkerStatus(); + newStatus.setSite("site-a"); + newStatus.setGroup("group-b"); + newStatus.setAlive(true); + + assertNotSame(newStatus, endpoint.getStatus()); + + endpoint.onWorkerStatusUpdate(newStatus, resp); + + assertSame(newStatus, endpoint.getStatus()); + assertEquals("site-a", endpoint.getStatus().getSite()); + assertEquals("group-b", endpoint.getStatus().getGroup()); + } + + @Test + void onWorkerStatusUpdate_calibrates_prefill() { + WorkerStatusResponse resp = new WorkerStatusResponse(); + resp.setFinishedTaskInfo(Map.of("100", task(100L, 1000, 0, 1L))); + + // PrefillEndpoint calibrates even when runningTaskInfo is null + endpoint.onWorkerStatusUpdate(status, resp); + // No exception = calibrate handled null gracefully + } + + @Test + void onWorkerStatusUpdate_preserves_engine_state_from_ws() { + WorkerStatusResponse resp = new WorkerStatusResponse(); + WorkerStatus ws = new WorkerStatus(); + ws.setSite("site-x"); + ws.setGroup("group-x"); + ws.setDpRank(5); + ws.setAlive(true); + + endpoint.onWorkerStatusUpdate(ws, resp); + + assertEquals("site-x", endpoint.getStatus().getSite()); + assertEquals("group-x", endpoint.getStatus().getGroup()); + assertEquals(5L, endpoint.getStatus().getDpRank()); + assertTrue(endpoint.getStatus().isAlive()); + } + + private BalanceContext ctx(long requestId, long seqLen) { + Request req = new Request(); + req.setRequestId(requestId); + req.setSeqLen(seqLen); + BalanceContext ctx = new BalanceContext(); + ctx.setRequest(req); + return ctx; + } + + private TaskInfo task(long requestId, long inputLength, long prefixLength, long batchId) { + TaskInfo task = new TaskInfo(); + task.setRequestId(requestId); + task.setInputLength(inputLength); + task.setPrefixLength(prefixLength); + task.setBatchId(batchId); + return task; + } +} diff --git a/rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/balance/resource/DecodeResourceMeasureTest.java b/rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/balance/resource/DecodeResourceMeasureTest.java index b83f15d66e..4fab3953e8 100644 --- a/rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/balance/resource/DecodeResourceMeasureTest.java +++ b/rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/balance/resource/DecodeResourceMeasureTest.java @@ -1,5 +1,6 @@ package org.flexlb.balance.resource; +import org.flexlb.balance.endpoint.DecodeEndpoint; import org.flexlb.config.ConfigService; import org.flexlb.config.FlexlbConfig; import org.flexlb.dao.master.TaskInfo; @@ -37,45 +38,29 @@ void setUp() { void concurrency_limit_disabled_should_not_affect_decode_availability() { config.setDecodeConcurrencyLimit(0); DecodeResourceMeasure measure = new DecodeResourceMeasure(configService); - WorkerStatus worker = createAliveDecodeWorker(); - worker.setWaitingTaskList(taskMap(1L, 2L)); - worker.setRunningTaskList(taskMap(3L, 4L)); - worker.getLocalTaskMap().put(5L, taskInfo(5L)); + DecodeEndpoint endpoint = createAliveDecodeEndpoint(); + endpoint.getStatus().setRunningTaskList(taskMap(1L, 2L, 3L, 4L)); - assertTrue(measure.isResourceAvailable(worker)); - assertEquals(0.0, measure.calculateAverageWaterLevel(Map.of("worker", worker))); + assertTrue(measure.isResourceAvailable(endpoint)); + assertEquals(0.0, measure.calculateAverageWaterLevel(Map.of("worker", endpoint.getStatus()))); } @Test void worker_should_be_unavailable_when_decode_concurrency_limit_reached() { config.setDecodeConcurrencyLimit(2); DecodeResourceMeasure measure = new DecodeResourceMeasure(configService); - WorkerStatus worker = createAliveDecodeWorker(); - worker.setRunningTaskList(taskMap(1L)); - worker.getLocalTaskMap().put(2L, taskInfo(2L)); - - assertFalse(measure.isResourceAvailable(worker)); - } - - @Test - void concurrency_count_should_deduplicate_reported_and_local_request_ids() { - config.setDecodeConcurrencyLimit(2); - DecodeResourceMeasure measure = new DecodeResourceMeasure(configService); - WorkerStatus worker = createAliveDecodeWorker(); - worker.setRunningTaskList(taskMap(1L)); - worker.getLocalTaskMap().put(1L, taskInfo(1L)); + DecodeEndpoint endpoint = createAliveDecodeEndpoint(); + endpoint.getStatus().setRunningTaskList(taskMap(1L, 2L)); - assertTrue(measure.isResourceAvailable(worker)); + assertFalse(measure.isResourceAvailable(endpoint)); } @Test void concurrency_water_level_should_contribute_to_serviceability() { config.setDecodeConcurrencyLimit(4); DecodeResourceMeasure measure = new DecodeResourceMeasure(configService); - WorkerStatus worker = createAliveDecodeWorker(); - worker.setWaitingTaskList(taskMap(1L)); - worker.setRunningTaskList(taskMap(2L)); - worker.getLocalTaskMap().put(3L, taskInfo(3L)); + WorkerStatus worker = createAliveWorkerStatus(); + worker.setRunningTaskList(taskMap(1L, 2L, 3L)); assertEquals(75.0, measure.calculateAverageWaterLevel(Map.of("worker", worker))); } @@ -84,18 +69,23 @@ void concurrency_water_level_should_contribute_to_serviceability() { void water_level_should_use_higher_value_between_kv_cache_and_concurrency() { config.setDecodeConcurrencyLimit(4); DecodeResourceMeasure measure = new DecodeResourceMeasure(configService); - WorkerStatus worker = createAliveDecodeWorker(); - worker.getUsedKvCacheTokens().set(70); + WorkerStatus worker = createAliveWorkerStatus(); + worker.getTotalKvCacheTokens().set(100); worker.getAvailableKvCacheTokens().set(30); worker.setRunningTaskList(taskMap(1L)); assertEquals(75.0, measure.calculateAverageWaterLevel(Map.of("worker", worker))); } - private WorkerStatus createAliveDecodeWorker() { + private DecodeEndpoint createAliveDecodeEndpoint() { + WorkerStatus status = createAliveWorkerStatus(); + return new DecodeEndpoint(status); + } + + private WorkerStatus createAliveWorkerStatus() { WorkerStatus worker = new WorkerStatus(); worker.setAlive(true); - worker.getUsedKvCacheTokens().set(0); + worker.getTotalKvCacheTokens().set(100); worker.getAvailableKvCacheTokens().set(100); return worker; } diff --git a/rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/balance/resource/DynamicWorkerManagerTest.java b/rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/balance/resource/DynamicWorkerManagerTest.java index 9ccd083606..907202f300 100644 --- a/rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/balance/resource/DynamicWorkerManagerTest.java +++ b/rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/balance/resource/DynamicWorkerManagerTest.java @@ -115,7 +115,6 @@ private void setupWorkers(double waterLevel) { WorkerStatus ws = new WorkerStatus(); ws.setAlive(true); ws.setAvailableKvCacheTokens(new AtomicLong(100)); - ws.setUsedKvCacheTokens(new AtomicLong(0)); EngineWorkerStatus.MODEL_ROLE_WORKER_STATUS.getPrefillStatusMap().put("127.0.0.1:8080", ws); lenient().when(resourceMeasureFactory.getMeasure(any())).thenReturn(resourceMeasure); diff --git a/rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/balance/resource/PrefillResourceMeasureTest.java b/rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/balance/resource/PrefillResourceMeasureTest.java index b087f7c5fe..8aa77b84b1 100644 --- a/rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/balance/resource/PrefillResourceMeasureTest.java +++ b/rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/balance/resource/PrefillResourceMeasureTest.java @@ -5,6 +5,7 @@ import org.flexlb.dao.master.TaskInfo; import org.flexlb.dao.master.WorkerStatus; import org.flexlb.dao.route.RoleType; +import org.flexlb.enums.TaskPhase; import org.flexlb.enums.TaskStateEnum; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -16,10 +17,18 @@ import java.util.Map; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.Mockito.when; +/** + * Tests for {@link PrefillResourceMeasure}. + * + *

Since {@code isResourceAvailable(PrefillEndpoint)} depends on endpoint-level + * state ({@code realPendingCount()}), it is tested in integration tests + * ({@code FlexlbBatchSchedulerTest}). This unit test focuses on + * {@code calculateAverageWaterLevel} using {@code WorkerStatus.runningTaskList} + * with {@link TaskPhase} to distinguish running vs waiting tasks. + */ @ExtendWith(MockitoExtension.class) class PrefillResourceMeasureTest { @@ -37,50 +46,107 @@ void setUp() { } @Test - void local_pending_tasks_should_make_prefill_worker_unavailable() { + void pending_and_received_tasks_contribute_to_water_level() { + // Non-RUNNING tasks (PENDING, RECEIVED, KV_ALLOCATED) are counted as waiting. + // With maxQueueSize=10, 3 waiting tasks → water level = 30% PrefillResourceMeasure measure = new PrefillResourceMeasure(configService); WorkerStatus worker = createAlivePrefillWorker(); - worker.setWaitingTaskList(new HashMap<>()); - worker.getLocalTaskMap().put(1L, taskInfo(1L, TaskStateEnum.IN_TRANSIT)); - worker.getLocalTaskMap().put(2L, taskInfo(2L, TaskStateEnum.CONFIRMED)); + Map runningTaskList = new HashMap<>(); + runningTaskList.put("1", taskInfo(1L, TaskPhase.PENDING)); + runningTaskList.put("2", taskInfo(2L, TaskPhase.RECEIVED)); + runningTaskList.put("3", taskInfo(3L, TaskPhase.KV_ALLOCATED)); + worker.setRunningTaskList(runningTaskList); - assertFalse(measure.isResourceAvailable(worker)); + assertEquals(30.0, measure.calculateAverageWaterLevel(Map.of("worker", worker))); } @Test - void running_local_tasks_should_not_count_as_prefill_queue() { + void running_tasks_do_not_count_as_prefill_queue() { + // Only RUNNING tasks → water level = 0% PrefillResourceMeasure measure = new PrefillResourceMeasure(configService); WorkerStatus worker = createAlivePrefillWorker(); - worker.setWaitingTaskList(new HashMap<>()); - worker.getLocalTaskMap().put(1L, taskInfo(1L, TaskStateEnum.RUNNING)); - worker.getLocalTaskMap().put(2L, taskInfo(2L, TaskStateEnum.RUNNING)); + Map runningTaskList = new HashMap<>(); + runningTaskList.put("1", taskInfo(1L, TaskPhase.RUNNING)); + runningTaskList.put("2", taskInfo(2L, TaskPhase.RUNNING)); + worker.setRunningTaskList(runningTaskList); - assertTrue(measure.isResourceAvailable(worker)); + assertEquals(0.0, measure.calculateAverageWaterLevel(Map.of("worker", worker))); } @Test - void water_level_should_use_higher_value_between_engine_waiting_and_local_pending() { + void water_level_counts_all_non_running_tasks_from_engine_reported_list() { + // Engine reports a unified runningTaskList; + // tasks with phase != RUNNING are counted as waiting. + // PENDING + RECEIVED + KV_ALLOCATED = 3 waiting → 30% with maxQueueSize=10 PrefillResourceMeasure measure = new PrefillResourceMeasure(configService); WorkerStatus worker = createAlivePrefillWorker(); - worker.setWaitingTaskList(Map.of("1", taskInfo(1L, TaskStateEnum.CONFIRMED))); - worker.getLocalTaskMap().put(1L, taskInfo(1L, TaskStateEnum.IN_TRANSIT)); - worker.getLocalTaskMap().put(2L, taskInfo(2L, TaskStateEnum.CONFIRMED)); - worker.getLocalTaskMap().put(3L, taskInfo(3L, TaskStateEnum.CONFIRMED)); + Map runningTaskList = new HashMap<>(); + runningTaskList.put("1", taskInfo(1L, TaskPhase.PENDING)); + runningTaskList.put("2", taskInfo(2L, TaskPhase.RECEIVED)); + runningTaskList.put("3", taskInfo(3L, TaskPhase.KV_ALLOCATED)); + runningTaskList.put("4", taskInfo(4L, TaskPhase.RUNNING)); + worker.setRunningTaskList(runningTaskList); + // 3 waiting out of max 10 = 30% assertEquals(30.0, measure.calculateAverageWaterLevel(Map.of("worker", worker))); } + @Test + void water_level_capped_at_100_when_queue_full() { + PrefillResourceMeasure measure = new PrefillResourceMeasure(configService); + WorkerStatus worker = createAlivePrefillWorker(); + Map runningTaskList = new HashMap<>(); + for (int i = 1; i <= 12; i++) { + runningTaskList.put(String.valueOf(i), taskInfo(i, TaskPhase.PENDING)); + } + worker.setRunningTaskList(runningTaskList); + + // 12 waiting > maxQueueSize=10 → capped at 100% + assertEquals(100.0, measure.calculateAverageWaterLevel(Map.of("worker", worker))); + } + + @Test + void empty_task_list_gives_zero_water_level() { + PrefillResourceMeasure measure = new PrefillResourceMeasure(configService); + WorkerStatus worker = createAlivePrefillWorker(); + worker.setRunningTaskList(new HashMap<>()); + + assertEquals(0.0, measure.calculateAverageWaterLevel(Map.of("worker", worker))); + } + + @Test + void null_task_list_gives_zero_water_level() { + PrefillResourceMeasure measure = new PrefillResourceMeasure(configService); + WorkerStatus worker = createAlivePrefillWorker(); + worker.setRunningTaskList(null); + + assertEquals(0.0, measure.calculateAverageWaterLevel(Map.of("worker", worker))); + } + + @Test + void all_running_tasks_gives_zero_water_level() { + PrefillResourceMeasure measure = new PrefillResourceMeasure(configService); + WorkerStatus worker = createAlivePrefillWorker(); + Map runningTaskList = new HashMap<>(); + runningTaskList.put("1", taskInfo(1L, TaskPhase.RUNNING)); + runningTaskList.put("2", taskInfo(2L, TaskPhase.RUNNING)); + runningTaskList.put("3", taskInfo(3L, TaskPhase.RUNNING)); + worker.setRunningTaskList(runningTaskList); + + assertEquals(0.0, measure.calculateAverageWaterLevel(Map.of("worker", worker))); + } + private WorkerStatus createAlivePrefillWorker() { WorkerStatus worker = new WorkerStatus(); worker.setAlive(true); - worker.setRole(RoleType.PREFILL.getCode()); + worker.setRole(RoleType.PREFILL); return worker; } - private TaskInfo taskInfo(long requestId, TaskStateEnum taskState) { + private TaskInfo taskInfo(long requestId, TaskPhase phase) { TaskInfo taskInfo = new TaskInfo(); taskInfo.setRequestId(requestId); - taskInfo.updateTaskState(taskState); + taskInfo.setPhase(phase); return taskInfo; } } diff --git a/rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/balance/scheduler/DefaultBatchDispatcherTest.java b/rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/balance/scheduler/DefaultBatchDispatcherTest.java new file mode 100644 index 0000000000..3872ba8b6d --- /dev/null +++ b/rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/balance/scheduler/DefaultBatchDispatcherTest.java @@ -0,0 +1,296 @@ +package org.flexlb.balance.scheduler; + +import com.google.protobuf.ByteString; +import com.google.protobuf.Int64Value; +import org.flexlb.balance.endpoint.PrefillEndpoint; +import org.flexlb.config.ConfigService; +import org.flexlb.config.FlexlbConfig; +import org.flexlb.dao.BalanceContext; +import org.flexlb.dao.loadbalance.DebugInfo; +import org.flexlb.dao.loadbalance.Request; +import org.flexlb.dao.loadbalance.Response; +import org.flexlb.dao.loadbalance.ServerStatus; +import org.flexlb.dao.master.WorkerStatus; +import org.flexlb.dao.route.RoleType; +import org.flexlb.engine.grpc.EngineGrpcClient; +import org.flexlb.engine.grpc.EngineRpcService; +import org.flexlb.service.monitor.BatchSchedulerReporter; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +class DefaultBatchDispatcherTest { + + private ConfigService configService; + private EngineGrpcClient grpcClient; + private BatchSchedulerReporter reporter; + private FlexlbConfig config; + private DefaultBatchDispatcher dispatcher; + private TestCallback callback; + + @BeforeEach + void setUp() { + configService = mock(ConfigService.class); + grpcClient = mock(EngineGrpcClient.class); + reporter = mock(BatchSchedulerReporter.class); + config = new FlexlbConfig(); + config.setFlexlbBatchDispatchPoolSize(2); + config.setFlexlbBatchDispatchQueueSize(10); + config.setFlexlbBatchEnqueueDeadlineMs(5000); + when(configService.loadBalanceConfig()).thenReturn(config); + + dispatcher = new DefaultBatchDispatcher(grpcClient, configService); + callback = new TestCallback(); + } + + @Test + void dispatchSendsItemsToGrpcAndReceivesAck() throws Exception { + PrefillEndpoint prefillEp = createPrefillEndpoint(); + BatchItem item = createBatchItem(1L, 500, 200, prefillEp); + + EngineRpcService.EnqueueBatchResponsePB response = ackResponse(List.of(1L)); + when(grpcClient.batchEnqueue(anyString(), anyInt(), any(EngineRpcService.EnqueueBatchRequestPB.class), anyLong())) + .thenReturn(response); + + dispatcher.dispatch(List.of(item), prefillEp, 1L, 100, "test_reason", callback); + + assertTrue(callback.successLatch.await(5, TimeUnit.SECONDS), "onSuccess should be called"); + assertEquals(1, callback.successCount.get()); + assertEquals(0, callback.failureCount.get()); + } + + @Test + void dispatchHandlesGrpcError() throws Exception { + PrefillEndpoint prefillEp = createPrefillEndpoint(); + BatchItem item = createBatchItem(1L, 500, 200, prefillEp); + + when(grpcClient.batchEnqueue(anyString(), anyInt(), any(EngineRpcService.EnqueueBatchRequestPB.class), anyLong())) + .thenThrow(new RuntimeException("gRPC connection refused")); + + dispatcher.dispatch(List.of(item), prefillEp, 1L, 100, "test_reason", callback); + + assertTrue(callback.failureLatch.await(5, TimeUnit.SECONDS), "onFailure should be called"); + assertEquals(1, callback.failureCount.get()); + assertEquals(0, callback.successCount.get()); + } + + @Test + void dispatchHandlesNullGrpcResponse() throws Exception { + PrefillEndpoint prefillEp = createPrefillEndpoint(); + BatchItem item = createBatchItem(1L, 500, 200, prefillEp); + + when(grpcClient.batchEnqueue(anyString(), anyInt(), any(EngineRpcService.EnqueueBatchRequestPB.class), anyLong())) + .thenReturn(null); + + dispatcher.dispatch(List.of(item), prefillEp, 1L, 100, "test_reason", callback); + + assertTrue(callback.failureLatch.await(5, TimeUnit.SECONDS)); + assertEquals(1, callback.failureCount.get()); + } + + @Test + void dispatchFiltersCancelledItemsBeforeSend() throws Exception { + PrefillEndpoint prefillEp = createPrefillEndpoint(); + BatchItem active = createBatchItem(1L, 500, 200, prefillEp); + BatchItem cancelled = createBatchItem(2L, 300, 100, prefillEp); + cancelled.ctx().cancel(); // mark as cancelled + + AtomicReference captured = new AtomicReference<>(); + when(grpcClient.batchEnqueue(anyString(), anyInt(), any(EngineRpcService.EnqueueBatchRequestPB.class), anyLong())) + .thenAnswer(inv -> { + captured.set(inv.getArgument(2)); + return ackResponse(List.of(1L)); + }); + + dispatcher.dispatch(List.of(active, cancelled), prefillEp, 1L, 100, "test", callback); + + assertTrue(callback.successLatch.await(5, TimeUnit.SECONDS)); + EngineRpcService.EnqueueBatchRequestPB sent = captured.get(); + assertNotNull(sent); + // Only 1 request should be in the batch (the cancelled one filtered out) + long sentCount = sent.getDpSlotsList().stream() + .mapToLong(slot -> slot.getRequestsCount()) + .sum(); + assertEquals(1, sentCount); + } + + @Test + void dispatchHandlesRejectedExecutionAfterShutdown() { + dispatcher.shutdown(); + + PrefillEndpoint prefillEp = createPrefillEndpoint(); + BatchItem item = createBatchItem(1L, 500, 200, prefillEp); + + dispatcher.dispatch(List.of(item), prefillEp, 1L, 100, "test", callback); + + // Should fail synchronously when executor is shut down + assertEquals(1, callback.failureCount.get()); + } + + @Test + void dispatchHandlesResponseWithErrors() throws Exception { + PrefillEndpoint prefillEp = createPrefillEndpoint(); + BatchItem item = createBatchItem(1L, 500, 200, prefillEp); + + EngineRpcService.EnqueueBatchResponsePB response = + EngineRpcService.EnqueueBatchResponsePB.newBuilder() + .addErrors(EngineRpcService.EnqueueBatchErrorPB.newBuilder() + .setRequestId(1L) + .setErrorInfo(EngineRpcService.ErrorDetailsPB.newBuilder() + .setErrorCode(500) + .setErrorMessage("engine busy") + .build()) + .build()) + .build(); + when(grpcClient.batchEnqueue(anyString(), anyInt(), any(EngineRpcService.EnqueueBatchRequestPB.class), anyLong())) + .thenReturn(response); + + dispatcher.dispatch(List.of(item), prefillEp, 1L, 100, "test", callback); + + assertTrue(callback.failureLatch.await(5, TimeUnit.SECONDS)); + assertEquals(1, callback.failureCount.get()); + } + + @Test + void dispatchHandlesMissingAck() throws Exception { + PrefillEndpoint prefillEp = createPrefillEndpoint(); + BatchItem item = createBatchItem(1L, 500, 200, prefillEp); + + EngineRpcService.EnqueueBatchResponsePB response = + EngineRpcService.EnqueueBatchResponsePB.newBuilder().build(); // no success, no error + when(grpcClient.batchEnqueue(anyString(), anyInt(), any(EngineRpcService.EnqueueBatchRequestPB.class), anyLong())) + .thenReturn(response); + + dispatcher.dispatch(List.of(item), prefillEp, 1L, 100, "test", callback); + + assertTrue(callback.failureLatch.await(5, TimeUnit.SECONDS)); + assertEquals(1, callback.failureCount.get()); + } + + @Test + void shutdownDrainsExecutor() throws Exception { + PrefillEndpoint prefillEp = createPrefillEndpoint(); + + // Submit tasks so executor has work in flight + CountDownLatch started = new CountDownLatch(1); + when(grpcClient.batchEnqueue(anyString(), anyInt(), any(EngineRpcService.EnqueueBatchRequestPB.class), anyLong())) + .thenAnswer(inv -> { + started.countDown(); + return ackResponse(List.of(1L)); + }); + + BatchItem item = createBatchItem(1L, 500, 200, prefillEp); + dispatcher.dispatch(List.of(item), prefillEp, 1L, 100, "test", callback); + + // Wait for at least one task to start, then shutdown + assertTrue(started.await(5, TimeUnit.SECONDS)); + dispatcher.shutdown(); + + // Post-shutdown dispatch should be rejected immediately + int failuresBefore = callback.failureCount.get(); + BatchItem extra = createBatchItem(99L, 500, 200, prefillEp); + dispatcher.dispatch(List.of(extra), prefillEp, 99L, 100, "test", callback); + assertEquals(failuresBefore + 1, callback.failureCount.get(), "Post-shutdown dispatch should add exactly 1 failure"); + } + + // ---- helpers ---- + + private PrefillEndpoint createPrefillEndpoint() { + WorkerStatus status = new WorkerStatus(); + status.setIp("127.0.0.1"); + status.setPort(8080); + status.setGrpcPort(8090); + status.setRole(RoleType.PREFILL); + FlexlbConfig epConfig = new FlexlbConfig(); + epConfig.setFlexlbBatchQueueMaxSize(100); + epConfig.setFlexlbBatchFixedWaitMs(300); + return new PrefillEndpoint(status, epConfig, noopHandler(), reporter); + } + + private static BatchDecisionHandler noopHandler() { + return new BatchDecisionHandler() { + @Override public void onExpired(BatchItem head) {} + @Override public void onUrgent(BatchItem head, DispatchMeta meta) {} + @Override public void onBatchReady(List items, DispatchMeta meta) {} + @Override public void onOfferFailure(BatchItem item, Throwable error) {} + }; + } + + private BatchItem createBatchItem(long requestId, long seqLen, long hitCacheLen, PrefillEndpoint prefillEp) { + Request request = new Request(); + request.setRequestId(requestId); + request.setSeqLen(seqLen); + + BalanceContext ctx = new BalanceContext(); + ctx.setRequest(request); + + // Provide a valid GenerateInputPB bytes (minimum: requestId + empty config) + EngineRpcService.GenerateInputPB input = EngineRpcService.GenerateInputPB.newBuilder() + .setRequestId(requestId) + .setGroupId(Int64Value.of(1L)) + .setGroupSize(1) + .setGenerateConfig(EngineRpcService.GenerateConfigPB.newBuilder().build()) + .build(); + ctx.setGenerateInputPbBytes(input.toByteArray()); + + ServerStatus prefill = new ServerStatus(); + prefill.setRole(RoleType.PREFILL); + prefill.setServerIp("127.0.0.1"); + prefill.setHttpPort(8080); + prefill.setGrpcPort(8090); + prefill.setDpRank(0L); + DebugInfo debugInfo = new DebugInfo(); + debugInfo.setHitCacheLen(hitCacheLen); + prefill.setDebugInfo(debugInfo); + + return new BatchItem(ctx, new CompletableFuture<>(), null, prefill, null, prefillEp, null, 0, System.currentTimeMillis()); + } + + private EngineRpcService.EnqueueBatchResponsePB ackResponse(List successIds) { + EngineRpcService.EnqueueBatchResponsePB.Builder builder = + EngineRpcService.EnqueueBatchResponsePB.newBuilder(); + for (long id : successIds) { + builder.addSuccesses(EngineRpcService.EnqueueBatchSuccessPB.newBuilder() + .setRequestId(id) + .build()); + } + return builder.build(); + } + + // ---- Test callback ---- + + private static class TestCallback implements DispatchCallback { + final AtomicInteger successCount = new AtomicInteger(0); + final AtomicInteger failureCount = new AtomicInteger(0); + final CountDownLatch successLatch = new CountDownLatch(1); + final CountDownLatch failureLatch = new CountDownLatch(1); + + @Override + public void onSuccess(BatchItem item, long batchId) { + successCount.incrementAndGet(); + successLatch.countDown(); + } + + @Override + public void onFailure(BatchItem item, Throwable error) { + failureCount.incrementAndGet(); + failureLatch.countDown(); + } + } +} diff --git a/rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/balance/scheduler/DefaultRouterTest.java b/rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/balance/scheduler/DefaultRouterTest.java index c1cead6968..fb2c963ff7 100644 --- a/rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/balance/scheduler/DefaultRouterTest.java +++ b/rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/balance/scheduler/DefaultRouterTest.java @@ -21,7 +21,6 @@ import org.mockito.junit.jupiter.MockitoExtension; import java.lang.reflect.Field; -import java.util.List; import java.util.Map; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -79,21 +78,21 @@ void setUp() { // Mock config service when(configService.loadBalanceConfig()).thenReturn(loadBalanceConfig); - lenient().when(loadBalanceConfig.getLoadBalanceStrategy()).thenReturn(LoadBalanceStrategyEnum.SHORTEST_TTFT); + lenient().when(loadBalanceConfig.getLoadBalanceStrategy()).thenReturn(LoadBalanceStrategyEnum.COST_BASED_PREFILL); when(loadBalanceConfig.getStrategyForRoleType(any(RoleType.class))).thenAnswer(inv -> { RoleType roleType = inv.getArgument(0); if (roleType == RoleType.DECODE) { - return LoadBalanceStrategyEnum.WEIGHTED_CACHE; + return LoadBalanceStrategyEnum.COST_BASED_DECODE; } if (roleType == RoleType.PDFUSION) { return LoadBalanceStrategyEnum.RANDOM; } - return LoadBalanceStrategyEnum.SHORTEST_TTFT; + return LoadBalanceStrategyEnum.COST_BASED_PREFILL; }); - LoadBalanceStrategyFactory.register(LoadBalanceStrategyEnum.SHORTEST_TTFT, prefillLoadBalancer); - LoadBalanceStrategyFactory.register(LoadBalanceStrategyEnum.WEIGHTED_CACHE, decodeLoadBalancer); - LoadBalanceStrategyFactory.register(LoadBalanceStrategyEnum.SHORTEST_TTFT, vitLoadBalancer); + LoadBalanceStrategyFactory.register(LoadBalanceStrategyEnum.COST_BASED_PREFILL, prefillLoadBalancer); + LoadBalanceStrategyFactory.register(LoadBalanceStrategyEnum.COST_BASED_DECODE, decodeLoadBalancer); + LoadBalanceStrategyFactory.register(LoadBalanceStrategyEnum.COST_BASED_PREFILL, vitLoadBalancer); LoadBalanceStrategyFactory.register(LoadBalanceStrategyEnum.RANDOM, fusionLoadBalancer); // Create scheduler instance diff --git a/rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/balance/scheduler/FixedWindowBatcherAlgorithmTest.java b/rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/balance/scheduler/FixedWindowBatcherAlgorithmTest.java new file mode 100644 index 0000000000..df26bc876b --- /dev/null +++ b/rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/balance/scheduler/FixedWindowBatcherAlgorithmTest.java @@ -0,0 +1,108 @@ +package org.flexlb.balance.scheduler; + +import org.flexlb.config.FlexlbConfig; +import org.flexlb.service.monitor.BatchSchedulerReporter; +import org.junit.jupiter.api.Test; + +import java.util.Comparator; +import java.util.concurrent.PriorityBlockingQueue; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.mock; + +/** + * Tests for {@link FixedWindowBatcherAlgorithm#headWaitMs(BatcherContext)}. + * + *

Verifies that the algorithm computes head wait time using the + * fixed-window semantics ({@code fixedWaitMs - elapsedMs}) rather + * than leaking sortKey-as-deadline assumptions. + */ +class FixedWindowBatcherAlgorithmTest { + + @Test + void headWaitMsReturnsRemainingWindowTime() { + FixedWindowBatcherAlgorithm algo = new FixedWindowBatcherAlgorithm(); + + FlexlbConfig config = new FlexlbConfig(); + config.setFlexlbBatchFixedWaitMs(300L); + + long now = System.currentTimeMillis(); + BatchItem head = enqueuedItem(1L, now - 200); + PriorityBlockingQueue queue = queueWith(head); + + BatcherContext ctx = new BatcherContext("test", null, config, null, queue, mock(BatchSchedulerReporter.class)); + + // Head enqueued 200ms ago, window=300ms → ~100ms remaining (±5ms tolerance for timing) + long waitMs = algo.headWaitMs(ctx); + assertTrue(waitMs >= 95 && waitMs <= 100, + "Expected ~100ms remaining, got " + waitMs); + } + + @Test + void headWaitMsReturnsZeroWhenQueueEmpty() { + FixedWindowBatcherAlgorithm algo = new FixedWindowBatcherAlgorithm(); + + FlexlbConfig config = new FlexlbConfig(); + PriorityBlockingQueue queue = new PriorityBlockingQueue<>(11, Comparator.comparingLong(BatchItem::sortKey)); + + BatcherContext ctx = new BatcherContext("test", null, config, null, queue, mock(BatchSchedulerReporter.class)); + + assertEquals(0, algo.headWaitMs(ctx)); + } + + @Test + void headWaitMsReturnsZeroWhenElapsedExceedsWindow() { + FixedWindowBatcherAlgorithm algo = new FixedWindowBatcherAlgorithm(); + + FlexlbConfig config = new FlexlbConfig(); + config.setFlexlbBatchFixedWaitMs(300L); + + long now = System.currentTimeMillis(); + // Enqueued 500ms ago → elapsed > fixedWaitMs + BatchItem head = enqueuedItem(1L, now - 500); + PriorityBlockingQueue queue = queueWith(head); + + BatcherContext ctx = new BatcherContext("test", null, config, null, queue, mock(BatchSchedulerReporter.class)); + + assertEquals(0, algo.headWaitMs(ctx)); + } + + @Test + void headWaitMsDiffersFromSortKeyBasedDefault() { + // The default BatcherAlgorithm.headWaitMs() treats sortKey as deadline. + // FixedWindow must NOT follow that pattern — its sortKey is enqueuedAtMs (past). + FixedWindowBatcherAlgorithm algo = new FixedWindowBatcherAlgorithm(); + + FlexlbConfig config = new FlexlbConfig(); + config.setFlexlbBatchFixedWaitMs(300L); + + long now = System.currentTimeMillis(); + BatchItem head = enqueuedItem(1L, now - 200); + // Default headWaitMs would compute: sortKey - now = (now-200) - now = -200 → 0 + // FixedWindow headWaitMs computes: fixedWaitMs - elapsed = 300 - 200 = 100 + PriorityBlockingQueue queue = queueWith(head); + + BatcherContext ctx = new BatcherContext("test", null, config, null, queue, mock(BatchSchedulerReporter.class)); + + long fixedWaitMs = algo.headWaitMs(ctx); + assertTrue(fixedWaitMs >= 95 && fixedWaitMs <= 100, + "FixedWindow should return ~100ms remaining, got " + fixedWaitMs); + } + + // ---- helpers ---- + + private static BatchItem enqueuedItem(long requestId, long enqueuedAtMs) { + BatchItem item = new BatchItem(null, null, null, null, null, null, null, 0, enqueuedAtMs); + item.setSortKey(enqueuedAtMs); // FixedWindow: sortKey = enqueuedAtMs + return item; + } + + private static PriorityBlockingQueue queueWith(BatchItem... items) { + PriorityBlockingQueue queue = new PriorityBlockingQueue<>(11, Comparator.comparingLong(BatchItem::sortKey)); + for (BatchItem item : items) { + queue.add(item); + } + return queue; + } +} diff --git a/rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/balance/scheduler/FlexlbBatchSchedulerTest.java b/rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/balance/scheduler/FlexlbBatchSchedulerTest.java new file mode 100644 index 0000000000..5e1a35e72b --- /dev/null +++ b/rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/balance/scheduler/FlexlbBatchSchedulerTest.java @@ -0,0 +1,624 @@ +package org.flexlb.balance.scheduler; + +import org.flexlb.balance.endpoint.DecodeEndpoint; +import org.flexlb.balance.endpoint.EndpointRegistry; +import org.flexlb.balance.endpoint.PrefillEndpoint; +import org.flexlb.config.ConfigService; +import org.flexlb.config.FlexlbConfig; +import org.flexlb.dao.BalanceContext; +import org.flexlb.dao.loadbalance.Request; +import org.flexlb.dao.loadbalance.Response; +import org.flexlb.dao.loadbalance.ServerStatus; +import org.flexlb.dao.loadbalance.StrategyErrorType; +import org.flexlb.dao.master.WorkerStatus; +import org.flexlb.dao.route.RoleType; +import org.flexlb.engine.grpc.EngineGrpcClient; +import org.flexlb.engine.grpc.EngineRpcService; +import org.flexlb.service.monitor.BatchSchedulerReporter; +import org.flexlb.sync.status.EngineWorkerStatus; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.List; +import java.util.Map; +import java.util.concurrent.CancellationException; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +class FlexlbBatchSchedulerTest { + + private ConfigService configService; + private Router router; + private EngineGrpcClient grpcClient; + private EngineWorkerStatus engineWorkerStatus; + private BatchSchedulerReporter reporter; + private FlexlbBatchScheduler scheduler; + private FlexlbConfig config; + private final List sentBatches = new CopyOnWriteArrayList<>(); + private final List sentEndpoints = new CopyOnWriteArrayList<>(); + + @BeforeEach + void setUp() { + configService = mock(ConfigService.class); + router = mock(Router.class); + grpcClient = mock(EngineGrpcClient.class); + engineWorkerStatus = mock(EngineWorkerStatus.class); + reporter = mock(BatchSchedulerReporter.class); + + config = new FlexlbConfig(); + config.setScheduleWorkerSize(1); + config.setFlexlbBatchSizeMax(2); + config.setFlexlbBatchWindowMs(10_000); + config.setCostSloMs(50000L); + config.setCostSloRiskMarginMs(50L); + config.setFlexlbBatchFillThreshold(1.0); + when(configService.loadBalanceConfig()).thenReturn(config); + + when(router.route(any(BalanceContext.class))).thenAnswer(inv -> { + BalanceContext ctx = inv.getArgument(0); + return successRoute(ctx.getRequestId()); + }); + when(grpcClient.batchEnqueue(anyString(), anyInt(), any(EngineRpcService.EnqueueBatchRequestPB.class), anyLong())) + .thenAnswer(inv -> { + sentEndpoints.add(inv.getArgument(0) + ":" + inv.getArgument(1)); + EngineRpcService.EnqueueBatchRequestPB request = inv.getArgument(2); + sentBatches.add(request); + return ackFor(request); + }); + when(grpcClient.cancel(anyString(), anyInt(), anyLong(), anyLong())) + .thenReturn(EngineRpcService.EmptyPB.getDefaultInstance()); + + EndpointRegistry endpointRegistry = new EndpointRegistry(configService, null, reporter); + BatchDispatcher dispatcher = new DefaultBatchDispatcher(grpcClient, configService); + scheduler = new FlexlbBatchScheduler(configService, router, grpcClient, engineWorkerStatus, + endpointRegistry, dispatcher, reporter); + + // Create endpoint and batcher for the worker that successRoute() returns + String ipPort = "10.0.0.1:8080"; + WorkerStatus ws = new WorkerStatus(); + ws.setIp("10.0.0.1"); + ws.setPort(8080); + ws.setGrpcPort(9080); + PrefillEndpoint endpoint = new PrefillEndpoint(ws, config, scheduler, reporter); + ServerStatus prefill = new ServerStatus(); + prefill.setServerIp("10.0.0.1"); + prefill.setHttpPort(8080); + prefill.setGrpcPort(9080); + prefill.setRole(RoleType.PREFILL); + endpointRegistry.putPrefill(ipPort, endpoint); + } + + @AfterEach + void tearDown() { + scheduler.shutdown(); + } + + @Test + void submit_flushes_grouped_requests_with_force_batch_payload() throws Exception { + CompletableFuture first = scheduler.submit(context(1)); + assertFalse(first.isDone()); + + CompletableFuture second = scheduler.submit(context(2)); + + Response firstResponse = first.get(2, TimeUnit.SECONDS); + Response secondResponse = second.get(2, TimeUnit.SECONDS); + assertTrue(firstResponse.isSuccess()); + assertTrue(secondResponse.isSuccess()); + assertTrue(firstResponse.isEnqueuedByMaster()); + assertTrue(secondResponse.isEnqueuedByMaster()); + + assertEquals(1, sentBatches.size()); + EngineRpcService.EnqueueBatchRequestPB batch = sentBatches.getFirst(); + List inputs = batchInputs(batch); + assertEquals(1, batch.getDpSlotsCount()); + assertEquals(0, batch.getDpSlots(0).getDpRank()); + assertEquals(2, batch.getDpSlots(0).getRequestsCount()); + assertEquals(2, inputs.size()); + assertEquals(2, inputs.get(0).getGroupSize()); + assertEquals(batch.getBatchId(), inputs.get(0).getGroupId().getValue()); + assertEquals(batch.getBatchId(), inputs.get(1).getGroupId().getValue()); + assertEquals(77, inputs.get(0).getGenerateConfig().getGroupTimeout().getValue()); + assertEquals(2, inputs.get(0).getGenerateConfig().getRoleAddrsCount()); + assertEquals(EngineRpcService.RoleTypePB.ROLE_TYPE_PREFILL, + inputs.get(0).getGenerateConfig().getRoleAddrs(0).getRole()); + assertEquals(EngineRpcService.RoleTypePB.ROLE_TYPE_DECODE, + inputs.get(0).getGenerateConfig().getRoleAddrs(1).getRole()); + } + + @Test + void submit_groups_batch_payload_by_dp_rank() throws Exception { + when(router.route(any(BalanceContext.class))).thenAnswer(inv -> { + BalanceContext ctx = inv.getArgument(0); + long requestId = ctx.getRequestId(); + return successRouteWithPrefillDp(requestId, requestId == 71L ? 0 : 1); + }); + + CompletableFuture first = scheduler.submit(context(71)); + CompletableFuture second = scheduler.submit(context(72)); + + assertTrue(first.get(2, TimeUnit.SECONDS).isSuccess()); + assertTrue(second.get(2, TimeUnit.SECONDS).isSuccess()); + + assertEquals(1, sentBatches.size()); + EngineRpcService.EnqueueBatchRequestPB batch = sentBatches.getFirst(); + assertEquals(2, batch.getDpSlotsCount()); + assertEquals(0, batch.getDpSlots(0).getDpRank()); + assertEquals(1, batch.getDpSlots(1).getDpRank()); + assertEquals(1, batch.getDpSlots(0).getRequestsCount()); + assertEquals(1, batch.getDpSlots(1).getRequestsCount()); + } + + @Test + void batch_enqueue_error_list_fails_only_rejected_request() throws Exception { + // Use request IDs to match, not input positions + when(grpcClient.batchEnqueue(anyString(), anyInt(), any(EngineRpcService.EnqueueBatchRequestPB.class), anyLong())) + .thenAnswer(inv -> { + sentEndpoints.add(inv.getArgument(0) + ":" + inv.getArgument(1)); + EngineRpcService.EnqueueBatchRequestPB request = inv.getArgument(2); + sentBatches.add(request); + + EngineRpcService.EnqueueBatchResponsePB.Builder response = + EngineRpcService.EnqueueBatchResponsePB.newBuilder().setBatchId(request.getBatchId()); + + for (EngineRpcService.GenerateInputPB input : batchInputs(request)) { + long reqId = input.getRequestId(); + if (reqId == 81) { + response.addSuccesses(EngineRpcService.EnqueueBatchSuccessPB.newBuilder() + .setRequestId(reqId).build()); + } else { + response.addErrors(EngineRpcService.EnqueueBatchErrorPB.newBuilder() + .setRequestId(reqId) + .setErrorInfo(EngineRpcService.ErrorDetailsPB.newBuilder() + .setErrorCode(13) + .setErrorMessage("decode alloc failed") + .build()) + .build()); + } + } + return response.build(); + }); + + CompletableFuture first = scheduler.submit(context(81)); + CompletableFuture second = scheduler.submit(context(82)); + + assertTrue(first.get(2, TimeUnit.SECONDS).isSuccess()); + assertFalse(second.get(2, TimeUnit.SECONDS).isSuccess()); + } + + @Test + void batch_enqueue_missing_success_fails_missing_request() throws Exception { + // Only return success for request 83, missing ack for 84 + when(grpcClient.batchEnqueue(anyString(), anyInt(), any(EngineRpcService.EnqueueBatchRequestPB.class), anyLong())) + .thenAnswer(inv -> { + sentEndpoints.add(inv.getArgument(0) + ":" + inv.getArgument(1)); + EngineRpcService.EnqueueBatchRequestPB request = inv.getArgument(2); + sentBatches.add(request); + + EngineRpcService.EnqueueBatchResponsePB.Builder response = + EngineRpcService.EnqueueBatchResponsePB.newBuilder().setBatchId(request.getBatchId()); + + for (EngineRpcService.GenerateInputPB input : batchInputs(request)) { + if (input.getRequestId() == 83) { + response.addSuccesses(EngineRpcService.EnqueueBatchSuccessPB.newBuilder() + .setRequestId(83).build()); + } + } + return response.build(); + }); + + CompletableFuture first = scheduler.submit(context(83)); + CompletableFuture second = scheduler.submit(context(84)); + + assertTrue(first.get(2, TimeUnit.SECONDS).isSuccess()); + Response secondResp = second.get(2, TimeUnit.SECONDS); + assertFalse(secondResp.isSuccess()); + assertTrue(secondResp.getErrorMessage().contains("EnqueueBatch missing ack for request 84")); + } + + @Test + void dispatch_falls_back_to_selected_prefill_when_dp0_status_not_synced() throws Exception { + config.setFlexlbBatchSizeMax(1); + when(router.route(any(BalanceContext.class))).thenAnswer(inv -> successRouteWithPrefillDp(92, 1)); + + WorkerStatus unsyncedDp0 = new WorkerStatus(); + unsyncedDp0.setIp("10.0.0.9"); + unsyncedDp0.setPort(8090); + unsyncedDp0.setGrpcPort(9090); + unsyncedDp0.setDpRank(0); + PrefillEndpoint unsyncedEp = new PrefillEndpoint(unsyncedDp0, config, scheduler, reporter); + when(engineWorkerStatus.selectModelWorkerStatus(RoleType.PREFILL, "g1")) + .thenReturn(Map.of("10.0.0.9:8090", unsyncedEp)); + + Response response = scheduler.submit(context(92)).get(2, TimeUnit.SECONDS); + + assertTrue(response.isSuccess()); + assertEquals("10.0.0.1:9080", sentEndpoints.getFirst()); + assertEquals(1, sentBatches.getFirst().getDpSlots(0).getDpRank()); + } + + @Test + void cancel_removes_request_before_batch_enqueue() throws Exception { + CompletableFuture future = scheduler.submit(context(11)); + + scheduler.cancel(11L); + + assertTrue(future.isDone()); + Response response = future.get(1, TimeUnit.SECONDS); + assertFalse(response.isSuccess()); + assertEquals(StrategyErrorType.REQUEST_CANCELLED.getErrorCode(), response.getCode()); + verify(grpcClient, never()).batchEnqueue(anyString(), anyInt(), any(), anyLong()); + } + + @Test + void cancel_inflight_before_ack_completes_cancelled_and_sends_engine_cancel() throws Exception { + config.setFlexlbBatchSizeMax(1); + CountDownLatch batchStarted = new CountDownLatch(1); + CountDownLatch cancelSeen = new CountDownLatch(1); + + when(grpcClient.batchEnqueue(anyString(), anyInt(), any(EngineRpcService.EnqueueBatchRequestPB.class), anyLong())) + .thenAnswer(inv -> { + EngineRpcService.EnqueueBatchRequestPB request = inv.getArgument(2); + sentBatches.add(request); + batchStarted.countDown(); + assertTrue(cancelSeen.await(2, TimeUnit.SECONDS)); + return ackFor(request); + }); + when(grpcClient.cancel(anyString(), anyInt(), anyLong(), anyLong())) + .thenAnswer(inv -> { + cancelSeen.countDown(); + return EngineRpcService.EmptyPB.getDefaultInstance(); + }); + + CompletableFuture future = scheduler.submit(context(12)); + + assertTrue(batchStarted.await(2, TimeUnit.SECONDS)); + scheduler.cancel(12L); + + Response response = future.get(2, TimeUnit.SECONDS); + assertFalse(response.isSuccess()); + assertEquals(StrategyErrorType.REQUEST_CANCELLED.getErrorCode(), response.getCode()); + verify(grpcClient, atLeastOnce()).cancel(anyString(), anyInt(), anyLong(), anyLong()); + } + + @Test + void route_failure_completes_without_batch_enqueue() throws Exception { + Response failure = Response.error(StrategyErrorType.NO_PREFILL_WORKER); + when(router.route(any(BalanceContext.class))).thenReturn(failure); + + Response response = scheduler.submit(context(21)).get(1, TimeUnit.SECONDS); + + assertFalse(response.isSuccess()); + assertEquals(StrategyErrorType.NO_PREFILL_WORKER.getErrorCode(), response.getCode()); + verify(grpcClient, never()).batchEnqueue(anyString(), anyInt(), any(), anyLong()); + } + + @Test + void submit_rejects_when_global_inflight_limit_reached() throws Exception { + config.setFlexlbBatchSizeMax(1); + config.setFlexlbBatchMaxInflight(1); + + CountDownLatch batchBlocked = new CountDownLatch(1); + CountDownLatch releaseBlock = new CountDownLatch(1); + when(grpcClient.batchEnqueue(anyString(), anyInt(), any(EngineRpcService.EnqueueBatchRequestPB.class), anyLong())) + .thenAnswer(inv -> { + batchBlocked.countDown(); + assertTrue(releaseBlock.await(5, TimeUnit.SECONDS)); + EngineRpcService.EnqueueBatchRequestPB request = inv.getArgument(2); + return ackFor(request); + }); + + scheduler.submit(context(41)); + assertTrue(batchBlocked.await(2, TimeUnit.SECONDS)); + + Response rejected = scheduler.submit(context(42)).get(1, TimeUnit.SECONDS); + assertFalse(rejected.isSuccess()); + assertEquals(StrategyErrorType.QUEUE_FULL.getErrorCode(), rejected.getCode()); + + releaseBlock.countDown(); + } + + @Test + void batcher_rejects_when_queue_full() throws Exception { + config.setFlexlbBatchQueueMaxSize(1); + config.setFlexlbBatchFillThreshold(1.0); + + CompletableFuture first = scheduler.submit(context(51)); + assertFalse(first.isDone()); + + // Second submit should fail because queue is full (maxSize=1) + CompletableFuture second = scheduler.submit(context(52)); + Response response = second.get(1, TimeUnit.SECONDS); + assertFalse(response.isSuccess()); + } + + @Test + void processQueue_park_converges_to_urgent_dispatch() throws Exception { + // budget = sloMs(300) - predMs(128) = 172ms, margin = 100ms + // fillThreshold=2.0 → fillRatio can never reach it (max 1.0) + // batchSizeMax=1000 → single request can't trigger size condition + // So request parks, budget shrinks each 1ms iteration, after ~72ms budget < margin → urgent dispatch + config.setCostSloMs(300L); + config.setCostSloRiskMarginMs(100L); + config.setFlexlbBatchFillThreshold(2.0); + config.setFlexlbBatchSizeMax(1000); + + CompletableFuture future = scheduler.submit(context(901)); + + assertTrue(future.get(2, TimeUnit.SECONDS).isSuccess()); + assertEquals(1, sentBatches.size()); + assertEquals(1, batchInputs(sentBatches.getFirst()).size()); + } + + @Test + void processQueue_fillRatio_triggers_dispatch() throws Exception { + // budget = sloMs(500) - predMs(128) = 372ms, margin = 50ms + // fillRatio = 128/322 ≈ 0.40 >= threshold(0.3) → dispatches immediately via fillRatio + // batchSizeMax=1000 ensures size condition is NOT the trigger + config.setCostSloMs(500L); + config.setCostSloRiskMarginMs(50L); + config.setFlexlbBatchMaxCapacity(500); + config.setFlexlbBatchFillThreshold(0.3); + config.setFlexlbBatchSizeMax(1000); + + CompletableFuture future = scheduler.submit(context(1001)); + + assertTrue(future.get(1, TimeUnit.SECONDS).isSuccess()); + assertEquals(1, sentBatches.size()); + assertEquals(1, batchInputs(sentBatches.getFirst()).size()); + } + + @Test + void processQueue_bsIter_exhaustion_uses_conservative_bound() throws Exception { + // With slo_budget batcher (default), two 100-token requests each have + // budget ≈ 350ms (slo=500, margin=50, pred≈100). Both fit within the + // incremental budget and are dispatched together in a single batch. + // flexlbBatchSearchIter is NOT used by slo_budget; flexlbBatchScanAhead + // (default 64) determines how many candidates are scanned per iteration. + config.setCostSloMs(500L); + config.setCostSloRiskMarginMs(50L); + config.setFlexlbBatchMaxCapacity(100000); + config.setFlexlbBatchFillThreshold(0.5); + config.setFlexlbBatchSizeMax(100); + + CompletableFuture f1 = scheduler.submit(contextWithSeqLen(1401, 100)); + CompletableFuture f2 = scheduler.submit(contextWithSeqLen(1402, 100)); + + assertTrue(f1.get(2, TimeUnit.SECONDS).isSuccess()); + assertTrue(f2.get(2, TimeUnit.SECONDS).isSuccess()); + + // Both requests fit within the incremental budget → 1 combined batch + assertEquals(1, sentBatches.size(), + "slo_budget dispatches both requests together when they fit within budget"); + assertEquals(2, batchInputs(sentBatches.get(0)).size()); + } + + @Test + void resolveSloMs_uses_buckets_when_configured() { + FlexlbConfig cfg = new FlexlbConfig(); + cfg.setCostSloMs(500L); + cfg.setCostSloBuckets("4096:2000,32768:10000,131072:30000,524288:60000"); + + assertEquals(2000L, cfg.resolveSloMs(100)); + assertEquals(2000L, cfg.resolveSloMs(4096)); + assertEquals(10000L, cfg.resolveSloMs(4097)); + assertEquals(10000L, cfg.resolveSloMs(32768)); + assertEquals(30000L, cfg.resolveSloMs(32769)); + assertEquals(30000L, cfg.resolveSloMs(131072)); + assertEquals(60000L, cfg.resolveSloMs(131073)); + assertEquals(60000L, cfg.resolveSloMs(1000000)); + } + + @Test + void resolveSloMs_falls_back_to_costSloMs_when_no_buckets() { + FlexlbConfig cfg = new FlexlbConfig(); + cfg.setCostSloMs(500L); + cfg.setCostSloBuckets(""); + + assertEquals(500L, cfg.resolveSloMs(100)); + assertEquals(500L, cfg.resolveSloMs(100000)); + } + + @Test + void resolveSloMs_handles_unsorted_bucket_input() { + FlexlbConfig cfg = new FlexlbConfig(); + cfg.setCostSloBuckets("131072:30000,4096:2000,32768:10000"); + + assertEquals(2000L, cfg.resolveSloMs(1000)); + assertEquals(10000L, cfg.resolveSloMs(5000)); + assertEquals(30000L, cfg.resolveSloMs(50000)); + } + + @Test + void dynamic_slo_prevents_drop_for_requests_exceeding_fixed_slo() throws Exception { + // With default costSloMs=500 and alpha1=1.0, a 600-token request has + // predMs=600 > sloMs=500 → budget=0 → immediate drop. + // With buckets "1000:5000,...", sloMs=5000 → budget=4400 → enough to batch. + config.setCostSloBuckets("1000:5000,100000:50000"); + config.setCostSloRiskMarginMs(50L); + config.setFlexlbBatchSizeMax(2); + config.setFlexlbBatchFillThreshold(1.0); + + CompletableFuture f1 = scheduler.submit(contextWithSeqLen(601, 600)); + CompletableFuture f2 = scheduler.submit(contextWithSeqLen(602, 600)); + + assertTrue(f1.get(3, TimeUnit.SECONDS).isSuccess()); + assertTrue(f2.get(3, TimeUnit.SECONDS).isSuccess()); + + assertEquals(1, sentBatches.size()); + assertEquals(2, batchInputs(sentBatches.getFirst()).size()); + } + + @Test + void mismatched_generate_input_request_id_fails_before_batch_enqueue() throws Exception { + config.setFlexlbBatchSizeMax(1); + + CompletableFuture future = scheduler.submit(context(31, 999)); + + Response response = future.get(2, TimeUnit.SECONDS); + assertFalse(response.isSuccess()); + verify(grpcClient, never()).batchEnqueue(anyString(), anyInt(), any(), anyLong()); + } + + // ==================== cancel / onRequestsFinished → Decode endpoint release ==================== + + @Test + void cancel_releases_decode_endpoint_resource() { + // Register a DecodeEndpoint at the address the router returns for DECODE + WorkerStatus decodeStatus = new WorkerStatus(); + decodeStatus.setIp("10.0.0.2"); + decodeStatus.setPort(8081); + decodeStatus.setGrpcPort(9081); + DecodeEndpoint decodeEp = scheduler.endpointRegistry.ensureDecodeEndpoint( + "10.0.0.2:8081", decodeStatus); + + // Simulate strategy having reserved resources on the decode endpoint + decodeEp.calibrate(null, null, 10000); + decodeEp.reserve(17L, 500); + assertEquals(1, decodeEp.getInflightCount()); + + // Submit a request — router returns decode at 10.0.0.2:8081 + CompletableFuture future = scheduler.submit(context(17)); + + // Cancel before batch is dispatched + scheduler.cancel(17L); + + // Decode endpoint resource should be released + assertEquals(0, decodeEp.getInflightCount(), + "Cancel should propagate to DecodeEndpoint.release()"); + assertTrue(future.isDone()); + assertFalse(future.getNow(null).isSuccess()); + } + + + @Test + void cancel_with_decode_endpoint_not_registered_is_noop() throws Exception { + // No DecodeEndpoint registered at 10.0.0.2:8081 — cancel should not throw + CompletableFuture future = scheduler.submit(context(20)); + + scheduler.cancel(20L); + + assertTrue(future.isDone()); + Response response = future.get(1, TimeUnit.SECONDS); + assertFalse(response.isSuccess()); + // No exception thrown — passes + } + + private static EngineRpcService.EnqueueBatchResponsePB ackFor(EngineRpcService.EnqueueBatchRequestPB request) { + EngineRpcService.EnqueueBatchResponsePB.Builder response = + EngineRpcService.EnqueueBatchResponsePB.newBuilder().setBatchId(request.getBatchId()); + for (EngineRpcService.GenerateInputPB input : batchInputs(request)) { + response.addSuccesses(EngineRpcService.EnqueueBatchSuccessPB.newBuilder() + .setRequestId(input.getRequestId()) + .build()); + } + return response.build(); + } + + private static List batchInputs( + EngineRpcService.EnqueueBatchRequestPB request) { + return request.getDpSlotsList().stream() + .flatMap(slot -> slot.getRequestsList().stream()) + .map(EngineRpcService.EnqueueBatchExternalInputPB::getInput) + .toList(); + } + + private static BalanceContext context(long requestId) { + return context(requestId, requestId); + } + + private static BalanceContext context(long requestId, long generateInputRequestId) { + Request request = new Request(); + request.setRequestId(requestId); + request.setSeqLen(128); + request.setMaxNewTokens(8); + request.setNumBeams(1); + request.setModel("test-model"); + + BalanceContext ctx = new BalanceContext(); + ctx.setRequest(request); + ctx.setConfig(new FlexlbConfig()); + ctx.setGenerateInputPbBytes(generateInputBytes(generateInputRequestId)); + return ctx; + } + + private static BalanceContext contextWithSeqLen(long requestId, long seqLen) { + Request request = new Request(); + request.setRequestId(requestId); + request.setSeqLen(seqLen); + request.setMaxNewTokens(8); + request.setNumBeams(1); + request.setModel("test-model"); + + BalanceContext ctx = new BalanceContext(); + ctx.setRequest(request); + ctx.setConfig(new FlexlbConfig()); + ctx.setGenerateInputPbBytes(generateInputBytes(requestId)); + return ctx; + } + + private static byte[] generateInputBytes(long requestId) { + EngineRpcService.GenerateInputPB input = EngineRpcService.GenerateInputPB.newBuilder() + .setRequestId(requestId) + .addTokenIds(101) + .addTokenIds(102) + .setGenerateConfig(EngineRpcService.GenerateConfigPB.newBuilder() + .setMaxNewTokens(8) + .setGroupTimeout(com.google.protobuf.Int32Value.of(77)) + .build()) + .build(); + return input.toByteArray(); + } + + private static Response successRoute(long requestId) { + return successRouteWithPrefillDp(requestId, 0); + } + + private static Response successRouteWithPrefillDp(long requestId, long dpRank) { + Response response = new Response(); + response.setSuccess(true); + response.setServerStatus(List.of( + server(RoleType.PREFILL, "10.0.0.1", 8080, 9080, requestId, dpRank), + server(RoleType.DECODE, "10.0.0.2", 8081, 9081, requestId) + )); + return response; + } + + private static ServerStatus server(RoleType role, String ip, int httpPort, int grpcPort, long requestId) { + return server(role, ip, httpPort, grpcPort, requestId, 0); + } + + private static ServerStatus server(RoleType role, + String ip, + int httpPort, + int grpcPort, + long requestId, + long dpRank) { + ServerStatus status = new ServerStatus(); + status.setSuccess(true); + status.setRole(role); + status.setServerIp(ip); + status.setHttpPort(httpPort); + status.setGrpcPort(grpcPort); + status.setDpRank(dpRank); + status.setGroup("g1"); + status.setRequestId(requestId); + return status; + } +} diff --git a/rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/balance/scheduler/InflightEvictorTest.java b/rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/balance/scheduler/InflightEvictorTest.java new file mode 100644 index 0000000000..0e83b51bbe --- /dev/null +++ b/rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/balance/scheduler/InflightEvictorTest.java @@ -0,0 +1,121 @@ +package org.flexlb.balance.scheduler; + +import org.junit.jupiter.api.Test; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +class InflightEvictorTest { + + private static final class TestEntry implements InflightEvictor.TtlTracked { + private final long createdAtMs; + TestEntry(long createdAtMs) { this.createdAtMs = createdAtMs; } + @Override public long createdAtMs() { return createdAtMs; } + } + + @Test + void evictExpiredRemovesOldEntries() { + Map map = new ConcurrentHashMap<>(); + long now = System.currentTimeMillis(); + map.put(1L, new TestEntry(now - 100_000)); // 100s old + map.put(2L, new TestEntry(now - 50_000)); // 50s old + map.put(3L, new TestEntry(now)); // just now + + InflightEvictor evictor = new InflightEvictor<>(map, null); + int evicted = evictor.evictExpired(60_000); // TTL = 60s + + assertEquals(1, evicted); + assertEquals(2, map.size()); + assertTrue(map.containsKey(2L)); + assertTrue(map.containsKey(3L)); + } + + @Test + void evictExpiredEmptyMapReturnsZero() { + Map map = new ConcurrentHashMap<>(); + InflightEvictor evictor = new InflightEvictor<>(map, null); + assertEquals(0, evictor.evictExpired(60_000)); + } + + @Test + void evictExpiredAllFreshReturnsZero() { + Map map = new ConcurrentHashMap<>(); + long now = System.currentTimeMillis(); + map.put(1L, new TestEntry(now)); + map.put(2L, new TestEntry(now - 10_000)); + + InflightEvictor evictor = new InflightEvictor<>(map, null); + assertEquals(0, evictor.evictExpired(60_000)); + assertEquals(2, map.size()); + } + + @Test + void evictExpiredAllExpiredReturnsAll() { + Map map = new ConcurrentHashMap<>(); + long now = System.currentTimeMillis(); + map.put(1L, new TestEntry(now - 200_000)); + map.put(2L, new TestEntry(now - 150_000)); + + InflightEvictor evictor = new InflightEvictor<>(map, null); + assertEquals(2, evictor.evictExpired(60_000)); + assertEquals(0, map.size()); + } + + @Test + void evictExpiredCallsOnEvictCallback() { + Map map = new ConcurrentHashMap<>(); + long now = System.currentTimeMillis(); + map.put(1L, new TestEntry(now - 100_000)); + map.put(2L, new TestEntry(now - 100_000)); + + AtomicInteger callbackCount = new AtomicInteger(0); + InflightEvictor evictor = new InflightEvictor<>(map, entry -> callbackCount.incrementAndGet()); + + evictor.evictExpired(60_000); + assertEquals(2, callbackCount.get()); + } + + @Test + void evictExpiredPartialExpiryCallsCallbackOnlyForEvicted() { + Map map = new ConcurrentHashMap<>(); + long now = System.currentTimeMillis(); + map.put(1L, new TestEntry(now - 100_000)); // expired + map.put(2L, new TestEntry(now)); // fresh + + AtomicInteger callbackCount = new AtomicInteger(0); + InflightEvictor evictor = new InflightEvictor<>(map, entry -> callbackCount.incrementAndGet()); + + int evicted = evictor.evictExpired(60_000); + assertEquals(1, evicted); + assertEquals(1, callbackCount.get()); + } + + @Test + void evictExpiredNullOnEvictDoesNotThrow() { + Map map = new ConcurrentHashMap<>(); + long now = System.currentTimeMillis(); + map.put(1L, new TestEntry(now - 100_000)); + + InflightEvictor evictor = new InflightEvictor<>(map, null); + assertEquals(1, evictor.evictExpired(60_000)); // should not throw NPE + } + + @Test + void evictExpiredLargeMap() { + Map map = new ConcurrentHashMap<>(); + long now = System.currentTimeMillis(); + for (long i = 0; i < 1000; i++) { + map.put(i, new TestEntry(i % 2 == 0 ? now - 100_000 : now)); + } + + InflightEvictor evictor = new InflightEvictor<>(map, null); + int evicted = evictor.evictExpired(60_000); + + assertEquals(500, evicted); + assertEquals(500, map.size()); + } +} diff --git a/rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/balance/strategy/WeightedCacheLoadBalancerTest.java b/rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/balance/strategy/CostBasedDecodeStrategyTest.java similarity index 50% rename from rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/balance/strategy/WeightedCacheLoadBalancerTest.java rename to rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/balance/strategy/CostBasedDecodeStrategyTest.java index 84af162692..280cadd87c 100644 --- a/rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/balance/strategy/WeightedCacheLoadBalancerTest.java +++ b/rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/balance/strategy/CostBasedDecodeStrategyTest.java @@ -1,6 +1,8 @@ package org.flexlb.balance.strategy; import lombok.extern.slf4j.Slf4j; +import org.flexlb.balance.endpoint.DecodeEndpoint; +import org.flexlb.balance.endpoint.EndpointRegistry; import org.flexlb.balance.resource.DecodeResourceMeasure; import org.flexlb.balance.resource.ResourceMeasureFactory; import org.flexlb.config.ConfigService; @@ -10,6 +12,7 @@ import org.flexlb.dao.loadbalance.ServerStatus; import org.flexlb.dao.master.WorkerStatus; import org.flexlb.dao.route.RoleType; +import org.flexlb.service.monitor.BatchSchedulerReporter; import org.flexlb.sync.status.EngineWorkerStatus; import org.flexlb.sync.status.ModelWorkerStatus; import org.junit.jupiter.api.Assertions; @@ -20,8 +23,10 @@ import java.util.HashMap; import java.util.Map; +import static org.mockito.ArgumentMatchers.any; + @Slf4j -class WeightedCacheLoadBalancerTest { +class CostBasedDecodeStrategyTest { private ConfigService configService; @@ -46,13 +51,27 @@ WorkerStatus createWorkerStatus(String ip) { return workerStatus; } + /** Create an EndpointRegistry with DecodeEndpoints registered for each WorkerStatus entry. */ + private EndpointRegistry createDecodeRegistry(Map workerMap) { + EndpointRegistry registry = new EndpointRegistry(configService, null, Mockito.mock(BatchSchedulerReporter.class)); + for (Map.Entry entry : workerMap.entrySet()) { + WorkerStatus ws = entry.getValue(); + ws.setGrpcPort(9090); + DecodeEndpoint ep = registry.ensureDecodeEndpoint(entry.getKey(), ws); + // Initialize reported KV cache from status + ep.calibrate(null, null, ws.getAvailableKvCacheTokens().get()); + } + return registry; + } + @Test void should_handle_empty_worker_map_when_no_workers_available() { - EngineWorkerStatus engineWorkerStatus = new EngineWorkerStatus(new ModelMetaConfig()); + EndpointRegistry emptyRegistry = new EndpointRegistry(configService, null, Mockito.mock(BatchSchedulerReporter.class)); + EngineWorkerStatus engineWorkerStatus = new EngineWorkerStatus(new ModelMetaConfig(), emptyRegistry); ResourceMeasureFactory resourceMeasureFactory = Mockito.mock(ResourceMeasureFactory.class); DecodeResourceMeasure decodeResourceMeasure = new DecodeResourceMeasure(configService); Mockito.when(resourceMeasureFactory.getMeasure(Mockito.any())).thenReturn(decodeResourceMeasure); - WeightedCacheLoadBalancer weightedCacheLoadBalancer = new WeightedCacheLoadBalancer(configService, engineWorkerStatus, resourceMeasureFactory); + CostBasedDecodeStrategy costBasedDecodeStrategy = new CostBasedDecodeStrategy(configService, engineWorkerStatus, resourceMeasureFactory, emptyRegistry); Request req = new Request(); req.setSeqLen(1000); @@ -60,8 +79,9 @@ void should_handle_empty_worker_map_when_no_workers_available() { BalanceContext balanceContext = new BalanceContext(); balanceContext.setRequest(req); + balanceContext.setConfig(configService.loadBalanceConfig()); - ServerStatus status = weightedCacheLoadBalancer.select(balanceContext, RoleType.DECODE, null); + ServerStatus status = costBasedDecodeStrategy.select(balanceContext, RoleType.DECODE, null); Assertions.assertFalse(status.isSuccess()); Assertions.assertNotNull(status.getMessage()); @@ -69,20 +89,25 @@ void should_handle_empty_worker_map_when_no_workers_available() { @Test void should_use_uniform_distribution_when_all_cache_usages_are_equal() { - EngineWorkerStatus engineWorkerStatus = new EngineWorkerStatus(new ModelMetaConfig()); Map decodeMap = EngineWorkerStatus.MODEL_ROLE_WORKER_STATUS.getDecodeStatusMap(); WorkerStatus worker1 = createWorkerStatus("127.0.0.1"); - worker1.getUsedKvCacheTokens().set(1000); + worker1.getTotalKvCacheTokens().set(10000); + worker1.getAvailableKvCacheTokens().set(9000); WorkerStatus worker2 = createWorkerStatus("127.0.0.2"); - worker2.getUsedKvCacheTokens().set(1000); + worker2.getTotalKvCacheTokens().set(10000); + worker2.getAvailableKvCacheTokens().set(9000); WorkerStatus worker3 = createWorkerStatus("127.0.0.3"); - worker3.getUsedKvCacheTokens().set(1000); + worker3.getTotalKvCacheTokens().set(10000); + worker3.getAvailableKvCacheTokens().set(9000); decodeMap.put("127.0.0.1:8080", worker1); decodeMap.put("127.0.0.2:8080", worker2); decodeMap.put("127.0.0.3:8080", worker3); + EndpointRegistry registry = createDecodeRegistry(decodeMap); + EngineWorkerStatus engineWorkerStatus = new EngineWorkerStatus(new ModelMetaConfig(), registry); + Request req = new Request(); req.setSeqLen(1000); req.setRequestId(1000L); @@ -90,14 +115,14 @@ void should_use_uniform_distribution_when_all_cache_usages_are_equal() { ResourceMeasureFactory resourceMeasureFactory = Mockito.mock(ResourceMeasureFactory.class); DecodeResourceMeasure decodeResourceMeasure = Mockito.mock(DecodeResourceMeasure.class); Mockito.when(resourceMeasureFactory.getMeasure(Mockito.any())).thenReturn(decodeResourceMeasure); - Mockito.when(decodeResourceMeasure.isResourceAvailable(Mockito.any())).thenReturn(true); - WeightedCacheLoadBalancer weightedCacheLoadBalancer = new WeightedCacheLoadBalancer(configService, engineWorkerStatus, resourceMeasureFactory); + Mockito.when(decodeResourceMeasure.isResourceAvailable(any())).thenReturn(true); + CostBasedDecodeStrategy costBasedDecodeStrategy = new CostBasedDecodeStrategy(configService, engineWorkerStatus, resourceMeasureFactory, registry); BalanceContext balanceContext = new BalanceContext(); balanceContext.setRequest(req); balanceContext.setConfig(configService.loadBalanceConfig()); - ServerStatus status = weightedCacheLoadBalancer.select(balanceContext, RoleType.DECODE, null); + ServerStatus status = costBasedDecodeStrategy.select(balanceContext, RoleType.DECODE, null); Assertions.assertTrue(status.isSuccess()); Assertions.assertNotNull(status.getServerIp()); @@ -105,25 +130,27 @@ void should_use_uniform_distribution_when_all_cache_usages_are_equal() { @Test void should_prioritize_workers_with_lower_cache_usage_when_normalized_values_negative() { - EngineWorkerStatus engineWorkerStatus = new EngineWorkerStatus(new ModelMetaConfig()); Map decodeMap = EngineWorkerStatus.MODEL_ROLE_WORKER_STATUS.getDecodeStatusMap(); - // Worker1: cacheUsed = 500 (well below average) WorkerStatus worker1 = createWorkerStatus("127.0.0.1"); - worker1.getUsedKvCacheTokens().set(500); + worker1.getTotalKvCacheTokens().set(10000); + worker1.getAvailableKvCacheTokens().set(9500); - // Worker2: cacheUsed = 1500 (above average) WorkerStatus worker2 = createWorkerStatus("127.0.0.2"); - worker2.getUsedKvCacheTokens().set(1500); + worker2.getTotalKvCacheTokens().set(10000); + worker2.getAvailableKvCacheTokens().set(8500); - // Worker3: cacheUsed = 1000 (average) WorkerStatus worker3 = createWorkerStatus("127.0.0.3"); - worker3.getUsedKvCacheTokens().set(1000); + worker3.getTotalKvCacheTokens().set(10000); + worker3.getAvailableKvCacheTokens().set(9000); decodeMap.put("127.0.0.1:8080", worker1); decodeMap.put("127.0.0.2:8080", worker2); decodeMap.put("127.0.0.3:8080", worker3); + EndpointRegistry registry = createDecodeRegistry(decodeMap); + EngineWorkerStatus engineWorkerStatus = new EngineWorkerStatus(new ModelMetaConfig(), registry); + Request req = new Request(); req.setSeqLen(1000); req.setRequestId(1000L); @@ -131,14 +158,14 @@ void should_prioritize_workers_with_lower_cache_usage_when_normalized_values_neg ResourceMeasureFactory resourceMeasureFactory = Mockito.mock(ResourceMeasureFactory.class); DecodeResourceMeasure decodeResourceMeasure = Mockito.mock(DecodeResourceMeasure.class); Mockito.when(resourceMeasureFactory.getMeasure(Mockito.any())).thenReturn(decodeResourceMeasure); - Mockito.when(decodeResourceMeasure.isResourceAvailable(Mockito.any())).thenReturn(true); - WeightedCacheLoadBalancer weightedCacheLoadBalancer = new WeightedCacheLoadBalancer(configService, engineWorkerStatus, resourceMeasureFactory); + Mockito.when(decodeResourceMeasure.isResourceAvailable(any())).thenReturn(true); + CostBasedDecodeStrategy costBasedDecodeStrategy = new CostBasedDecodeStrategy(configService, engineWorkerStatus, resourceMeasureFactory, registry); BalanceContext balanceContext = new BalanceContext(); balanceContext.setRequest(req); balanceContext.setConfig(configService.loadBalanceConfig()); - ServerStatus status = weightedCacheLoadBalancer.select(balanceContext, RoleType.DECODE, null); + ServerStatus status = costBasedDecodeStrategy.select(balanceContext, RoleType.DECODE, null); Assertions.assertTrue(status.isSuccess()); Assertions.assertNotNull(status.getServerIp()); @@ -146,16 +173,16 @@ void should_prioritize_workers_with_lower_cache_usage_when_normalized_values_neg @Test void should_handle_group_selection_when_group_parameter_provided() { - EngineWorkerStatus engineWorkerStatus = new EngineWorkerStatus(new ModelMetaConfig()); ModelWorkerStatus modelStatus = EngineWorkerStatus.MODEL_ROLE_WORKER_STATUS; - // Create workers for specific group WorkerStatus worker1 = createWorkerStatus("127.0.0.1"); worker1.setGroup("group-a"); - worker1.getUsedKvCacheTokens().set(1000); modelStatus.getDecodeStatusMap().put("127.0.0.1:8080", worker1); + EndpointRegistry registry = createDecodeRegistry(modelStatus.getDecodeStatusMap()); + EngineWorkerStatus engineWorkerStatus = new EngineWorkerStatus(new ModelMetaConfig(), registry); + Request req = new Request(); req.setSeqLen(1000); req.setRequestId(1000L); @@ -163,14 +190,14 @@ void should_handle_group_selection_when_group_parameter_provided() { ResourceMeasureFactory resourceMeasureFactory = Mockito.mock(ResourceMeasureFactory.class); DecodeResourceMeasure decodeResourceMeasure = Mockito.mock(DecodeResourceMeasure.class); Mockito.when(resourceMeasureFactory.getMeasure(Mockito.any())).thenReturn(decodeResourceMeasure); - Mockito.when(decodeResourceMeasure.isResourceAvailable(Mockito.any())).thenReturn(true); - WeightedCacheLoadBalancer weightedCacheLoadBalancer = new WeightedCacheLoadBalancer(configService, engineWorkerStatus, resourceMeasureFactory); + Mockito.when(decodeResourceMeasure.isResourceAvailable(any())).thenReturn(true); + CostBasedDecodeStrategy costBasedDecodeStrategy = new CostBasedDecodeStrategy(configService, engineWorkerStatus, resourceMeasureFactory, registry); BalanceContext balanceContext = new BalanceContext(); balanceContext.setRequest(req); balanceContext.setConfig(configService.loadBalanceConfig()); - ServerStatus status = weightedCacheLoadBalancer.select(balanceContext, RoleType.DECODE, "group-a"); + ServerStatus status = costBasedDecodeStrategy.select(balanceContext, RoleType.DECODE, "group-a"); Assertions.assertTrue(status.isSuccess()); Assertions.assertEquals("127.0.0.1", status.getServerIp()); @@ -178,46 +205,46 @@ void should_handle_group_selection_when_group_parameter_provided() { @Test void should_use_exponential_decay_for_balanced_weight_distribution_when_cache_usage_differs() { - EngineWorkerStatus engineWorkerStatus = new EngineWorkerStatus(new ModelMetaConfig()); Map decodeMap = EngineWorkerStatus.MODEL_ROLE_WORKER_STATUS.getDecodeStatusMap(); - // Create two workers to test exponential decay weight distribution - // Normalized values are -500 and +500 WorkerStatus worker1 = createWorkerStatus("127.0.0.1"); - worker1.getUsedKvCacheTokens().set(500); // Below average 1000, normalizedValue = -500 + worker1.getTotalKvCacheTokens().set(10000); + worker1.getAvailableKvCacheTokens().set(9500); WorkerStatus worker2 = createWorkerStatus("127.0.0.2"); - worker2.getUsedKvCacheTokens().set(1500); // Above average 1000, normalizedValue = +500 + worker2.getTotalKvCacheTokens().set(10000); + worker2.getAvailableKvCacheTokens().set(8500); decodeMap.put("127.0.0.1:8080", worker1); decodeMap.put("127.0.0.2:8080", worker2); + EndpointRegistry registry = createDecodeRegistry(decodeMap); + EngineWorkerStatus engineWorkerStatus = new EngineWorkerStatus(new ModelMetaConfig(), registry); + Request req = new Request(); req.setSeqLen(1000); ResourceMeasureFactory resourceMeasureFactory = Mockito.mock(ResourceMeasureFactory.class); DecodeResourceMeasure decodeResourceMeasure = Mockito.mock(DecodeResourceMeasure.class); Mockito.when(resourceMeasureFactory.getMeasure(Mockito.any())).thenReturn(decodeResourceMeasure); - Mockito.when(decodeResourceMeasure.isResourceAvailable(Mockito.any())).thenReturn(true); - WeightedCacheLoadBalancer weightedCacheLoadBalancer = new WeightedCacheLoadBalancer(configService, engineWorkerStatus, resourceMeasureFactory); + Mockito.when(decodeResourceMeasure.isResourceAvailable(any())).thenReturn(true); + CostBasedDecodeStrategy costBasedDecodeStrategy = new CostBasedDecodeStrategy(configService, engineWorkerStatus, resourceMeasureFactory, registry); BalanceContext balanceContext = new BalanceContext(); balanceContext.setRequest(req); balanceContext.setConfig(configService.loadBalanceConfig()); - // Run multiple iterations to verify weight distribution int totalRuns = 10000; Map selectionCount = new HashMap<>(); for (int i = 0; i < totalRuns; i++) { balanceContext.getRequest().setRequestId(1000L + i); - ServerStatus status = weightedCacheLoadBalancer.select(balanceContext, RoleType.DECODE, null); + ServerStatus status = costBasedDecodeStrategy.select(balanceContext, RoleType.DECODE, null); if (status.isSuccess()) { String selectedIp = status.getServerIp(); selectionCount.put(selectedIp, selectionCount.getOrDefault(selectedIp, 0) + 1); - // Rollback to reset local tasks and cache usage - weightedCacheLoadBalancer.rollBack(selectedIp + ":8080", 1000L + i); + costBasedDecodeStrategy.rollBack(selectedIp + ":8080", 1000L + i); } } @@ -226,19 +253,87 @@ void should_use_exponential_decay_for_balanced_weight_distribution_when_cache_us log.info("Exponential decay weight distribution verification: worker1={} ({}%), worker2={} ({}%)", worker1Count, worker1Count * 100.0 / totalRuns, worker2Count, worker2Count * 100.0 / totalRuns); - // Verify worker1 (lower cache usage) is selected more frequently Assertions.assertTrue(worker1Count > worker2Count, "Worker with lower cache usage should be selected more frequently"); - // Verify weight ratio is more balanced (improvement from exponential decay algorithm) double ratio = (double) worker1Count / worker2Count; Assertions.assertTrue(ratio >= 1.5 && ratio <= 3.0, "Weight ratio should be between 1.5-3.0, actual ratio: %.2f".formatted(ratio)); + } - double worker1Ratio = (double) worker1Count / totalRuns; - double worker2Ratio = (double) worker2Count / totalRuns; + @Test + void should_skip_worker_with_insufficient_kv_cache_capacity() { + Map decodeMap = EngineWorkerStatus.MODEL_ROLE_WORKER_STATUS.getDecodeStatusMap(); - log.info("Exponential decay weight distribution verification: worker1={} ({}%), worker2={} ({}%), weight ratio: {}", - worker1Count, worker1Ratio * 100, worker2Count, worker2Ratio * 100, "%.2f".formatted(ratio)); + WorkerStatus worker1 = createWorkerStatus("127.0.0.1"); + worker1.getTotalKvCacheTokens().set(1000); + worker1.getAvailableKvCacheTokens().set(100); + + WorkerStatus worker2 = createWorkerStatus("127.0.0.2"); + worker2.getTotalKvCacheTokens().set(1000); + worker2.getAvailableKvCacheTokens().set(800); + + decodeMap.put("127.0.0.1:8080", worker1); + decodeMap.put("127.0.0.2:8080", worker2); + + EndpointRegistry registry = createDecodeRegistry(decodeMap); + EngineWorkerStatus engineWorkerStatus = new EngineWorkerStatus(new ModelMetaConfig(), registry); + + Request req = new Request(); + req.setSeqLen(500); + req.setRequestId(2000L); + + ResourceMeasureFactory resourceMeasureFactory = Mockito.mock(ResourceMeasureFactory.class); + DecodeResourceMeasure decodeResourceMeasure = Mockito.mock(DecodeResourceMeasure.class); + Mockito.when(resourceMeasureFactory.getMeasure(Mockito.any())).thenReturn(decodeResourceMeasure); + Mockito.when(decodeResourceMeasure.isResourceAvailable(any())).thenReturn(true); + CostBasedDecodeStrategy costBasedDecodeStrategy = new CostBasedDecodeStrategy(configService, engineWorkerStatus, resourceMeasureFactory, registry); + + BalanceContext balanceContext = new BalanceContext(); + balanceContext.setRequest(req); + balanceContext.setConfig(configService.loadBalanceConfig()); + + ServerStatus status = costBasedDecodeStrategy.select(balanceContext, RoleType.DECODE, null); + + Assertions.assertTrue(status.isSuccess()); + Assertions.assertEquals("127.0.0.2", status.getServerIp()); + } + + @Test + void should_fallback_to_least_used_when_all_workers_filtered() { + Map decodeMap = EngineWorkerStatus.MODEL_ROLE_WORKER_STATUS.getDecodeStatusMap(); + + WorkerStatus worker1 = createWorkerStatus("127.0.0.1"); + worker1.getTotalKvCacheTokens().set(1000); + worker1.getAvailableKvCacheTokens().set(50); + + WorkerStatus worker2 = createWorkerStatus("127.0.0.2"); + worker2.getTotalKvCacheTokens().set(1000); + worker2.getAvailableKvCacheTokens().set(100); + + decodeMap.put("127.0.0.1:8080", worker1); + decodeMap.put("127.0.0.2:8080", worker2); + + EndpointRegistry registry = createDecodeRegistry(decodeMap); + EngineWorkerStatus engineWorkerStatus = new EngineWorkerStatus(new ModelMetaConfig(), registry); + + Request req = new Request(); + req.setSeqLen(200); + req.setRequestId(3000L); + + ResourceMeasureFactory resourceMeasureFactory = Mockito.mock(ResourceMeasureFactory.class); + DecodeResourceMeasure decodeResourceMeasure = Mockito.mock(DecodeResourceMeasure.class); + Mockito.when(resourceMeasureFactory.getMeasure(Mockito.any())).thenReturn(decodeResourceMeasure); + Mockito.when(decodeResourceMeasure.isResourceAvailable(any())).thenReturn(true); + CostBasedDecodeStrategy costBasedDecodeStrategy = new CostBasedDecodeStrategy(configService, engineWorkerStatus, resourceMeasureFactory, registry); + + BalanceContext balanceContext = new BalanceContext(); + balanceContext.setRequest(req); + balanceContext.setConfig(configService.loadBalanceConfig()); + + ServerStatus status = costBasedDecodeStrategy.select(balanceContext, RoleType.DECODE, null); + + Assertions.assertTrue(status.isSuccess()); + Assertions.assertEquals("127.0.0.2", status.getServerIp()); } } diff --git a/rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/balance/strategy/CostBasedPrefillStrategyTest.java b/rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/balance/strategy/CostBasedPrefillStrategyTest.java new file mode 100644 index 0000000000..4174e6883f --- /dev/null +++ b/rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/balance/strategy/CostBasedPrefillStrategyTest.java @@ -0,0 +1,298 @@ +package org.flexlb.balance.strategy; + +import org.flexlb.balance.endpoint.EndpointRegistry; +import org.flexlb.balance.endpoint.PrefillEndpoint; +import org.flexlb.balance.endpoint.WorkerEndpoint; +import org.flexlb.balance.resource.PrefillResourceMeasure; +import org.flexlb.balance.resource.ResourceMeasureFactory; +import org.flexlb.balance.scheduler.BatchItem; +import org.flexlb.balance.scheduler.FlexlbBatchScheduler; +import org.flexlb.cache.service.CacheAwareService; +import org.flexlb.config.ConfigService; +import org.flexlb.config.FlexlbConfig; +import org.flexlb.config.ModelMetaConfig; +import org.flexlb.dao.BalanceContext; +import org.flexlb.dao.loadbalance.Request; +import org.flexlb.dao.loadbalance.ServerStatus; +import org.flexlb.dao.master.CacheStatus; +import org.flexlb.dao.master.WorkerStatus; +import org.flexlb.dao.route.RoleType; +import org.flexlb.service.monitor.BatchSchedulerReporter; +import org.flexlb.service.monitor.EngineHealthReporter; +import org.flexlb.sync.status.EngineWorkerStatus; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.*; + +class CostBasedPrefillStrategyTest { + + private EngineWorkerStatus engineWorkerStatus; + private CacheAwareService cacheAwareService; + private ResourceMeasureFactory resourceMeasureFactory; + private EngineHealthReporter engineHealthReporter; + private FlexlbBatchScheduler batchScheduler; + private EndpointRegistry endpointRegistry; + private CostBasedPrefillStrategy strategy; + + @BeforeEach + void setUp() { + EngineWorkerStatus.MODEL_ROLE_WORKER_STATUS.getPrefillStatusMap().clear(); + ConfigService configService = Mockito.mock(ConfigService.class); + Mockito.when(configService.loadBalanceConfig()).thenReturn(new FlexlbConfig()); + cacheAwareService = Mockito.mock(CacheAwareService.class); + resourceMeasureFactory = Mockito.mock(ResourceMeasureFactory.class); + engineHealthReporter = Mockito.mock(EngineHealthReporter.class); + batchScheduler = Mockito.mock(FlexlbBatchScheduler.class); + + // Create registry first to break circular dependency + endpointRegistry = new EndpointRegistry(configService, batchScheduler, Mockito.mock(BatchSchedulerReporter.class)); + engineWorkerStatus = new EngineWorkerStatus(new ModelMetaConfig(), endpointRegistry); + + PrefillResourceMeasure prefillResourceMeasure = Mockito.mock(PrefillResourceMeasure.class); + Mockito.when(resourceMeasureFactory.getMeasure(any())).thenReturn(prefillResourceMeasure); + Mockito.when(prefillResourceMeasure.isResourceAvailable(any())).thenReturn(true); + Mockito.when(cacheAwareService.findMatchingEngines(anyList(), any(), any())).thenReturn(new HashMap<>()); + + strategy = new CostBasedPrefillStrategy( + engineWorkerStatus, cacheAwareService, resourceMeasureFactory, + engineHealthReporter, endpointRegistry); + } + + /** Helper: register PrefillEndpoints for all entries in the given worker map. */ + private void registerPrefillEndpoints(Map workerMap) { + for (Map.Entry entry : workerMap.entrySet()) { + WorkerStatus ws = entry.getValue(); + ws.setGrpcPort(9090); + endpointRegistry.ensurePrefillEndpoint(entry.getKey(), ws); + } + } + + @Test + void selectsWorkerWithLowestCostScore() { + Map prefillMap = EngineWorkerStatus.MODEL_ROLE_WORKER_STATUS.getPrefillStatusMap(); + prefillMap.put("10.0.0.1:8080", createWorker("10.0.0.1", 100)); + prefillMap.put("10.0.0.2:8080", createWorker("10.0.0.2", 50)); + + ServerStatus result = strategy.select(buildContext(1000, 1L), RoleType.PREFILL, null); + + assertTrue(result.isSuccess()); + assertEquals("10.0.0.2", result.getServerIp()); + } + + @Test + void batcherQueueReducesWaitCost() { + Map prefillMap = EngineWorkerStatus.MODEL_ROLE_WORKER_STATUS.getPrefillStatusMap(); + WorkerStatus w1 = createWorker("10.0.0.1", 0); + WorkerStatus w2 = createWorker("10.0.0.2", 0); + prefillMap.put("10.0.0.1:8080", w1); + prefillMap.put("10.0.0.2:8080", w2); + + ServerStatus result = strategy.select(buildContext(500, 2L), RoleType.PREFILL, null); + + assertTrue(result.isSuccess()); + } + + @Test + void deltaPrefillCostFavorsCacheHitWorker() { + Map prefillMap = EngineWorkerStatus.MODEL_ROLE_WORKER_STATUS.getPrefillStatusMap(); + prefillMap.put("10.0.0.1:8080", createWorker("10.0.0.1", 0)); + prefillMap.put("10.0.0.2:8080", createWorker("10.0.0.2", 0)); + + Map cacheResults = new HashMap<>(); + cacheResults.put("10.0.0.2:8080", 3); // 3 blocks * 256 = 768 tokens + Mockito.when(cacheAwareService.findMatchingEngines(anyList(), any(), any())).thenReturn(cacheResults); + + ServerStatus result = strategy.select(buildContext(1000, 3L), RoleType.PREFILL, null); + + assertTrue(result.isSuccess()); + assertEquals("10.0.0.2", result.getServerIp()); + } + + @Test + void sloRiskFilterExcludesOverloadedWorker() { + Map prefillMap = EngineWorkerStatus.MODEL_ROLE_WORKER_STATUS.getPrefillStatusMap(); + prefillMap.put("10.0.0.1:8080", createWorker("10.0.0.1", 2000)); + prefillMap.put("10.0.0.2:8080", createWorker("10.0.0.2", 10)); + + ServerStatus result = strategy.select(buildContext(500, 4L), RoleType.PREFILL, null); + + assertTrue(result.isSuccess()); + assertEquals("10.0.0.2", result.getServerIp()); + } + + @Test + void allFilteredFallsBackToLeastLoaded() { + Map prefillMap = EngineWorkerStatus.MODEL_ROLE_WORKER_STATUS.getPrefillStatusMap(); + prefillMap.put("10.0.0.1:8080", createWorker("10.0.0.1", 5000)); + prefillMap.put("10.0.0.2:8080", createWorker("10.0.0.2", 3000)); + + ServerStatus result = strategy.select(buildContext(500, 5L), RoleType.PREFILL, null); + + assertTrue(result.isSuccess()); + assertEquals("10.0.0.2", result.getServerIp()); + } + + @Test + void hotspotFilterExcludesBatcherOverloadedWorker() { + Map prefillMap = EngineWorkerStatus.MODEL_ROLE_WORKER_STATUS.getPrefillStatusMap(); + prefillMap.put("10.0.0.1:8080", createWorker("10.0.0.1", 0)); + prefillMap.put("10.0.0.2:8080", createWorker("10.0.0.2", 0)); + prefillMap.put("10.0.0.3:8080", createWorker("10.0.0.3", 0)); + + ServerStatus result = strategy.select(buildContext(500, 6L), RoleType.PREFILL, null); + + assertTrue(result.isSuccess()); + assertNotEquals("10.0.0.1", result.getServerIp()); + } + + @Test + void imbalanceFilterExcludesOverloadedEngineQueue() { + Map prefillMap = EngineWorkerStatus.MODEL_ROLE_WORKER_STATUS.getPrefillStatusMap(); + prefillMap.clear(); + prefillMap.put("10.0.0.1:8080", createWorker("10.0.0.1", 1000)); + for (int i = 2; i <= 10; i++) { + String ip = "10.0.0." + i; + prefillMap.put(ip + ":8080", createWorker(ip, 10)); + } + + FlexlbConfig config = new FlexlbConfig(); + config.setCostSloMs(50000L); + + ServerStatus result = strategy.select(buildContext(500, 7L, config), RoleType.PREFILL, null); + + assertTrue(result.isSuccess()); + assertNotEquals("10.0.0.1", result.getServerIp()); + } + + @Test + void noAvailableWorkersReturnsError() { + EngineWorkerStatus.MODEL_ROLE_WORKER_STATUS.getPrefillStatusMap().clear(); + + ServerStatus result = strategy.select(buildContext(500, 8L), RoleType.PREFILL, null); + + assertFalse(result.isSuccess()); + } + + @Test + void rollBackDoesNotThrow() { + Map prefillMap = EngineWorkerStatus.MODEL_ROLE_WORKER_STATUS.getPrefillStatusMap(); + WorkerStatus w = createWorker("10.0.0.1", 0); + prefillMap.put("10.0.0.1:8080", w); + + ServerStatus result = strategy.select(buildContext(500, 9L), RoleType.PREFILL, null); + assertTrue(result.isSuccess()); + + assertDoesNotThrow(() -> strategy.rollBack("10.0.0.1:8080", 9L)); + } + + @Test + void endpointWaitMsFavorsEndpointWithLowerEstimate() { + WorkerStatus w1 = createWorker("10.0.0.1", 0); + WorkerStatus w2 = createWorker("10.0.0.2", 0); + w1.setGrpcPort(8081); + w2.setGrpcPort(8081); + + Map prefillMap = EngineWorkerStatus.MODEL_ROLE_WORKER_STATUS.getPrefillStatusMap(); + prefillMap.put("10.0.0.1:8080", w1); + prefillMap.put("10.0.0.2:8080", w2); + + PrefillEndpoint ep1 = endpointRegistry.ensurePrefillEndpoint("10.0.0.1:8080", w1); + endpointRegistry.ensurePrefillEndpoint("10.0.0.2:8080", w2); + ep1.commitBatch(1L, 4000, List.of(batchItem(1L, 1000, 0))); + + ServerStatus result = strategy.select(buildContext(500, 10L), RoleType.PREFILL, null); + + assertTrue(result.isSuccess()); + assertEquals("10.0.0.2", result.getServerIp()); + } + + @Test + void predictorUsesPolynomialFormula() { + PrefillTimePredictor predictor = new PrefillTimePredictor(10, 0.5, 0.001, 0.0005, 0.2, 5); + + // Single request: n=1000, p=200 → c=800, bs=1 + // = 10 + 0.5*800 + (0.001*640000 + 0.0005*160000) + 0.2*200 + 5*1 + // = 10 + 400 + (640 + 80) + 40 + 5 = 1175 + long single = predictor.predictBatchMs(List.of(batchItem(0, 1000, 200))); + assertEquals(1175, single); + + // Batch of 2: req1=(1000,200) req2=(500,100) + // c1=800, p1=200, c2=400, p2=100 + // Σc=1200, Σ(640+80, 160+20)=900, Σp=300, bs=2 + // = 10 + 0.5*1200 + 900 + 0.2*300 + 5*2 = 1580 + long batch = predictor.predictBatchMs(List.of( + batchItem(0, 1000, 200), + batchItem(1, 500, 100))); + assertEquals(1580, batch); + + assertEquals(0, predictor.predictBatchMs(List.of())); + } + + private WorkerStatus createWorker(String ip, long estimatedWaitMs) { + WorkerStatus w = new WorkerStatus(); + w.setIp(ip); + w.setPort(8080); + w.setAlive(true); + w.setRole(RoleType.PREFILL); + CacheStatus cacheStatus = new CacheStatus(); + cacheStatus.setAvailableKvCache(10000); + cacheStatus.setBlockSize(256); + w.setCacheStatus(cacheStatus); + w.setRunningTaskList(new HashMap<>()); + + String ipPort = ip + ":8080"; + w.setGrpcPort(8081); + PrefillEndpoint ep = endpointRegistry.ensurePrefillEndpoint(ipPort, w); + if (estimatedWaitMs > 0) { + ep.commitBatch(900000L + ip.hashCode(), estimatedWaitMs, + List.of(batchItem(900000L + ip.hashCode(), estimatedWaitMs, 0))); + } + return w; + } + + private BatchItem batchItem(long requestId, long seqLen, long hitCache) { + Request req = new Request(); + req.setRequestId(requestId); + req.setSeqLen(seqLen); + BalanceContext ctx = new BalanceContext(); + ctx.setRequest(req); + // For prediction, hitCache comes from prefill.debugInfo. Use null prefill → 0, + // but the caller's hitCache parameter is what matters for prediction — we set it + // via the constructor as a convenience; the predictor will call item.hitCache() + // which reads prefill.debugInfo.hitCacheLen, so we must build a real ServerStatus. + if (hitCache > 0) { + org.flexlb.dao.loadbalance.DebugInfo di = new org.flexlb.dao.loadbalance.DebugInfo(); + di.setHitCacheLen(hitCache); + org.flexlb.dao.loadbalance.ServerStatus ss = new org.flexlb.dao.loadbalance.ServerStatus(); + ss.setDebugInfo(di); + return new BatchItem(ctx, null, null, ss, null, null, null, 0, 0); + } + return new BatchItem(ctx, null, null, null, null, null, null, 0, 0); + } + + private BalanceContext buildContext(long seqLen, long requestId) { + FlexlbConfig config = new FlexlbConfig(); + config.setCostSloMs(50000L); + config.setCostSloRiskMarginMs(50L); + return buildContext(seqLen, requestId, config); + } + + private BalanceContext buildContext(long seqLen, long requestId, FlexlbConfig config) { + Request req = new Request(); + req.setSeqLen(seqLen); + req.setRequestId(requestId); + req.setBlockCacheKeys(new ArrayList<>(List.of(1L, 2L))); + BalanceContext ctx = new BalanceContext(); + ctx.setRequest(req); + ctx.setConfig(config); + return ctx; + } +} diff --git a/rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/balance/strategy/PrefillTimePredictorTest.java b/rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/balance/strategy/PrefillTimePredictorTest.java new file mode 100644 index 0000000000..ac2aa925d2 --- /dev/null +++ b/rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/balance/strategy/PrefillTimePredictorTest.java @@ -0,0 +1,195 @@ +package org.flexlb.balance.strategy; + +import org.flexlb.balance.scheduler.BatchItem; +import org.flexlb.dao.BalanceContext; +import org.flexlb.dao.loadbalance.DebugInfo; +import org.flexlb.dao.loadbalance.Request; +import org.flexlb.dao.loadbalance.ServerStatus; +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +class PrefillTimePredictorTest { + + // ---- estimateMs (single request) ---- + + @Test + void estimateMsZeroCoefficientsReturnsZero() { + PrefillTimePredictor predictor = new PrefillTimePredictor(0, 0, 0, 0, 0, 0); + assertEquals(0, predictor.estimateMs(1000, 0)); + assertEquals(0, predictor.estimateMs(1000, 500)); + } + + @Test + void estimateMsConstantTermOnly() { + // α₀ = 50, others 0 → always 50 + PrefillTimePredictor predictor = new PrefillTimePredictor(50, 0, 0, 0, 0, 0); + assertEquals(50, predictor.estimateMs(100, 0)); + assertEquals(50, predictor.estimateMs(0, 0)); + } + + @Test + void estimateMsLinearInComputeTokens() { + // α₁ = 2, others 0 → time = 2 * c + // c = totalTokens - hitTokens + PrefillTimePredictor predictor = new PrefillTimePredictor(0, 2, 0, 0, 0, 0); + assertEquals(2000, predictor.estimateMs(1500, 500)); // c=1000 → 2*1000 + assertEquals(600, predictor.estimateMs(300, 0)); // c=300 → 2*300 + } + + @Test + void estimateMsQuadraticInComputeTokens() { + // α₂ = 0.1, others 0 → time = 0.1 * c² + PrefillTimePredictor predictor = new PrefillTimePredictor(0, 0, 0.1, 0, 0, 0); + assertEquals(1000, predictor.estimateMs(100, 0)); // c=100 → 0.1*10000 + } + + @Test + void estimateMsInteractionTerm() { + // α₃ = 0.5, others 0 → time = 0.5 * c * hitTokens + PrefillTimePredictor predictor = new PrefillTimePredictor(0, 0, 0, 0.5, 0, 0); + assertEquals(40000, predictor.estimateMs(600, 400)); // c=200, hit=400 → 0.5*200*400 + } + + @Test + void estimateMsLinearInHitTokens() { + // α₄ = 3, others 0 → time = 3 * hitTokens + PrefillTimePredictor predictor = new PrefillTimePredictor(0, 0, 0, 0, 3, 0); + assertEquals(300, predictor.estimateMs(500, 100)); // hitTokens never negative + } + + @Test + void estimateMsBatchSizeTerm() { + // α₅ = 10, others 0 → time = 10 (bs=1 in single mode) + PrefillTimePredictor predictor = new PrefillTimePredictor(0, 0, 0, 0, 0, 10); + assertEquals(10, predictor.estimateMs(1000, 0)); + } + + @Test + void estimateMsFullFormula() { + // α₀=10, α₁=0.1, α₂=0.01, α₃=0.001, α₄=0.5, α₅=5 + // total=500, hit=200 → c=300 + // result = 10 + 0.1*300 + 0.01*90000 + 0.001*300*200 + 0.5*200 + 5 + // = 10 + 30 + 900 + 60 + 100 + 5 = 1105 + PrefillTimePredictor predictor = new PrefillTimePredictor(10, 0.1, 0.01, 0.001, 0.5, 5); + assertEquals(1105, predictor.estimateMs(500, 200)); + } + + @Test + void estimateMsHitTokensCannotExceedTotal() { + // c = max(0, total - hit), so if hit > total, c = 0 + PrefillTimePredictor predictor = new PrefillTimePredictor(0, 2, 0, 0, 0, 0); + assertEquals(0, predictor.estimateMs(100, 500)); // c = max(0, 100-500) = 0 + } + + @Test + void estimateMsLargeValuesNoOverflow() { + PrefillTimePredictor predictor = new PrefillTimePredictor(100, 1, 0.001, 0.0001, 0.5, 10); + long result = predictor.estimateMs(100_000, 50_000); + assertTrue(result >= 0, "Should not overflow or produce negative values"); + } + + // ---- predictBatchMs ---- + + @Test + void predictBatchMsEmptyListReturnsZero() { + PrefillTimePredictor predictor = new PrefillTimePredictor(10, 1, 0, 0, 0, 5); + assertEquals(0, predictor.predictBatchMs(List.of())); + } + + @Test + void predictBatchMsSingleItemMatchesEstimateMs() { + // For a single item, predictBatchMs should be close to estimateMs + // (α₅ contributes per-item vs per-call, so α₅ * bs differs: estimateMs uses α₅*1, + // predictBatchMs uses α₅*bs=α₅*1 — same for single item) + PrefillTimePredictor predictor = new PrefillTimePredictor(10, 0.1, 0.01, 0.001, 0.5, 5); + long single = predictor.estimateMs(500, 200); + + BatchItem item = batchItem(500, 200); + long batch = predictor.predictBatchMs(List.of(item)); + + assertEquals(single, batch); + } + + @Test + void predictBatchMsMultipleItems() { + // α₀=10, α₁=0.1, α₂=0.01, α₃=0.001, α₄=0.5, α₅=5 + // item1: seq=500, hit=200 → c=300 + // item2: seq=300, hit=100 → c=200 + // Σc=500, Σc²*p need to be recomputed + // Let me compute: for item1: a2*c²=0.01*90000=900, a3*c*p=0.001*300*200=60 + // for item2: a2*c²=0.01*40000=400, a3*c*p=0.001*200*100=20 + // Σquadratic = 900+60+400+20 = 1380 + // Σp = 200+100 = 300 + // result = 10 + 0.1*500 + 1380 + 0.5*300 + 5*2 + // = 10 + 50 + 1380 + 150 + 10 = 1600 + PrefillTimePredictor predictor = new PrefillTimePredictor(10, 0.1, 0.01, 0.001, 0.5, 5); + + BatchItem item1 = batchItem(500, 200); + BatchItem item2 = batchItem(300, 100); + long result = predictor.predictBatchMs(List.of(item1, item2)); + + assertEquals(1600, result); + } + + @Test + void predictBatchMsBatchSizeAffectsResult() { + // α₅=10, others 0 → time depends only on batch size + PrefillTimePredictor predictor = new PrefillTimePredictor(0, 0, 0, 0, 0, 10); + + BatchItem item = batchItem(100, 0); + assertEquals(10, predictor.predictBatchMs(List.of(item))); + assertEquals(20, predictor.predictBatchMs(List.of(item, item))); + assertEquals(30, predictor.predictBatchMs(List.of(item, item, item))); + } + + @Test + void predictBatchMsZeroCacheHits() { + PrefillTimePredictor predictor = new PrefillTimePredictor(0, 1, 0, 0, 0, 0); + // c = seqLen - 0 = seqLen, so time = seqLen + BatchItem item = batchItem(500, 0); + assertEquals(500, predictor.predictBatchMs(List.of(item))); + } + + @Test + void predictBatchMsAllCached() { + // α₁ = 1 (linear in compute tokens) + PrefillTimePredictor predictor = new PrefillTimePredictor(0, 1, 0, 0, 0, 0); + // seq=500, hit=500 → c=0 + BatchItem item = batchItem(500, 500); + assertEquals(0, predictor.predictBatchMs(List.of(item))); + } + + @Test + void predictBatchMsLargeBatch() { + PrefillTimePredictor predictor = new PrefillTimePredictor(100, 0.5, 0, 0, 0.1, 3); + List items = new ArrayList<>(); + for (int i = 0; i < 100; i++) { + items.add(batchItem(1000, 200)); + } + long result = predictor.predictBatchMs(items); + assertTrue(result > 0, "Large batch should produce positive prediction"); + } + + // ---- helpers ---- + + private static BatchItem batchItem(long seqLen, long hitCacheLen) { + Request request = new Request(); + request.setRequestId(1L); + request.setSeqLen(seqLen); + + BalanceContext ctx = new BalanceContext(); + ctx.setRequest(request); + + ServerStatus prefill = new ServerStatus(); + DebugInfo debugInfo = new DebugInfo(); + debugInfo.setHitCacheLen(hitCacheLen); + prefill.setDebugInfo(debugInfo); + + return new BatchItem(ctx, null, null, prefill, null, null, null, 0, System.currentTimeMillis()); + } +} diff --git a/rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/balance/strategy/RandomStrategyTest.java b/rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/balance/strategy/RandomStrategyTest.java index 334c46d44d..11a1ed57a8 100644 --- a/rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/balance/strategy/RandomStrategyTest.java +++ b/rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/balance/strategy/RandomStrategyTest.java @@ -1,6 +1,10 @@ package org.flexlb.balance.strategy; import lombok.extern.slf4j.Slf4j; +import org.flexlb.balance.endpoint.EndpointRegistry; +import org.flexlb.balance.endpoint.PrefillEndpoint; +import org.flexlb.balance.endpoint.DecodeEndpoint; +import org.flexlb.balance.endpoint.WorkerEndpoint; import org.flexlb.balance.resource.ResourceMeasure; import org.flexlb.balance.resource.ResourceMeasureFactory; import org.flexlb.config.ConfigService; @@ -13,6 +17,7 @@ import org.flexlb.dao.master.WorkerStatus; import org.flexlb.dao.route.RoleType; import org.flexlb.enums.LoadBalanceStrategyEnum; +import org.flexlb.service.monitor.BatchSchedulerReporter; import org.flexlb.sync.status.EngineWorkerStatus; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; @@ -38,19 +43,22 @@ class RandomStrategyTest { private RandomStrategy randomStrategy; private ResourceMeasure resourceMeasure; + private EndpointRegistry endpointRegistry; @BeforeEach void setUp() { ConfigService configService = Mockito.mock(ConfigService.class); ResourceMeasureFactory resourceMeasureFactory = Mockito.mock(ResourceMeasureFactory.class); + endpointRegistry = new EndpointRegistry(configService, null, Mockito.mock(BatchSchedulerReporter.class)); resourceMeasure = Mockito.mock(ResourceMeasure.class); Mockito.when(configService.loadBalanceConfig()).thenReturn(new FlexlbConfig()); Mockito.when(resourceMeasureFactory.getMeasure(Mockito.any())).thenReturn(resourceMeasure); - Mockito.when(resourceMeasure.isResourceAvailable(Mockito.any())).thenReturn(true); + Mockito.when(resourceMeasure.isResourceAvailable(Mockito.any(WorkerEndpoint.class))).thenReturn(true); randomStrategy = new RandomStrategy( - new EngineWorkerStatus(new ModelMetaConfig()), + new EngineWorkerStatus(new ModelMetaConfig(), endpointRegistry), configService, - resourceMeasureFactory); + resourceMeasureFactory, + endpointRegistry); } @AfterEach @@ -61,19 +69,37 @@ void tearDown() { EngineWorkerStatus.MODEL_ROLE_WORKER_STATUS.getVitStatusMap().clear(); } + /** Register a mock PrefillEndpoint for the given ipPort and WorkerStatus. */ + private void registerPrefill(String ipPort, WorkerStatus ws) { + PrefillEndpoint ep = Mockito.mock(PrefillEndpoint.class); + Mockito.when(ep.getIp()).thenReturn(ws.getIp()); + Mockito.when(ep.getHttpPort()).thenReturn(ws.getPort()); + Mockito.when(ep.getGrpcPort()).thenReturn(ws.getGrpcPort()); + Mockito.when(ep.getStatus()).thenReturn(ws); + Mockito.when(ep.ipPort()).thenReturn(ipPort); + endpointRegistry.putPrefill(ipPort, ep); + } + + /** Register a mock DecodeEndpoint for the given ipPort and WorkerStatus. */ + private void registerDecode(String ipPort, WorkerStatus ws) { + DecodeEndpoint ep = Mockito.mock(DecodeEndpoint.class); + Mockito.when(ep.getIp()).thenReturn(ws.getIp()); + Mockito.when(ep.getHttpPort()).thenReturn(ws.getPort()); + Mockito.when(ep.getGrpcPort()).thenReturn(ws.getGrpcPort()); + Mockito.when(ep.getStatus()).thenReturn(ws); + Mockito.when(ep.ipPort()).thenReturn(ipPort); + endpointRegistry.putDecode(ipPort, ep); + } + @Test void should_return_error_when_no_workers_available() { - // Given: No workers registered for the model Request req = new Request(); - BalanceContext balanceContext = new BalanceContext(); balanceContext.setRequest(req); - // When: Select a worker ServerStatus result = randomStrategy.select(balanceContext, RoleType.PREFILL, null); - // Then: Should return error status assertFalse(result.isSuccess()); assertEquals(StrategyErrorType.NO_AVAILABLE_WORKER.getErrorCode(), result.getCode()); assertEquals(StrategyErrorType.NO_AVAILABLE_WORKER.getErrorMsg(), result.getMessage()); @@ -81,19 +107,15 @@ void should_return_error_when_no_workers_available() { @Test void should_return_error_when_worker_map_is_empty() { - // Given: Model exists but no workers EngineWorkerStatus.MODEL_ROLE_WORKER_STATUS.getPrefillStatusMap().clear(); Request req = new Request(); - BalanceContext balanceContext = new BalanceContext(); balanceContext.setRequest(req); - // When: Select a worker ServerStatus result = randomStrategy.select(balanceContext, RoleType.PREFILL, null); - // Then: Should return error status assertFalse(result.isSuccess()); assertEquals(StrategyErrorType.NO_AVAILABLE_WORKER.getErrorCode(), result.getCode()); assertEquals(StrategyErrorType.NO_AVAILABLE_WORKER.getErrorMsg(), result.getMessage()); @@ -101,32 +123,26 @@ void should_return_error_when_worker_map_is_empty() { @Test void should_return_success_when_workers_available() { - // Given: Model with available workers Map prefillStatusMap = EngineWorkerStatus.MODEL_ROLE_WORKER_STATUS.getPrefillStatusMap(); - // Add a worker WorkerStatus workerStatus = createWorkerStatus("127.0.0.1"); prefillStatusMap.put("127.0.0.1:8080", workerStatus); + registerPrefill("127.0.0.1:8080", workerStatus); Request req = new Request(); - BalanceContext balanceContext = new BalanceContext(); balanceContext.setRequest(req); - // When: Select a worker ServerStatus result = randomStrategy.select(balanceContext, RoleType.PREFILL, null); - // Then: Should return success status with batchId assertTrue(result.isSuccess()); } @Test void should_select_randomly_from_available_workers() { - // Given: Model with multiple available workers Map prefillStatusMap = EngineWorkerStatus.MODEL_ROLE_WORKER_STATUS.getPrefillStatusMap(); - // Add multiple workers WorkerStatus worker1 = createWorkerStatus("127.0.0.1"); WorkerStatus worker2 = createWorkerStatus("127.0.0.2"); WorkerStatus worker3 = createWorkerStatus("127.0.0.3"); @@ -134,19 +150,19 @@ void should_select_randomly_from_available_workers() { prefillStatusMap.put("127.0.0.1:8080", worker1); prefillStatusMap.put("127.0.0.2:8080", worker2); prefillStatusMap.put("127.0.0.3:8080", worker3); + registerPrefill("127.0.0.1:8080", worker1); + registerPrefill("127.0.0.2:8080", worker2); + registerPrefill("127.0.0.3:8080", worker3); Request req = new Request(); - BalanceContext balanceContext = new BalanceContext(); balanceContext.setRequest(req); - // When: Select a worker multiple times ServerStatus result1 = randomStrategy.select(balanceContext, RoleType.PREFILL, null); ServerStatus result2 = randomStrategy.select(balanceContext, RoleType.PREFILL, null); ServerStatus result3 = randomStrategy.select(balanceContext, RoleType.PREFILL, null); - // Then: All should be successful (random selection is working) assertTrue(result1.isSuccess()); assertTrue(result2.isSuccess()); assertTrue(result3.isSuccess()); @@ -154,71 +170,57 @@ void should_select_randomly_from_available_workers() { @Test void should_work_with_different_role_types() { - // Given: Model with workers for different roles - - // Add workers for different roles WorkerStatus prefillWorker = createWorkerStatus("127.0.0.1"); EngineWorkerStatus.MODEL_ROLE_WORKER_STATUS.getPrefillStatusMap().put("127.0.0.1:8080", prefillWorker); + registerPrefill("127.0.0.1:8080", prefillWorker); WorkerStatus decodeWorker = createWorkerStatus("127.0.0.2"); EngineWorkerStatus.MODEL_ROLE_WORKER_STATUS.getDecodeStatusMap().put("127.0.0.2:8080", decodeWorker); + registerDecode("127.0.0.2:8080", decodeWorker); Request req = new Request(); - BalanceContext balanceContext = new BalanceContext(); balanceContext.setRequest(req); - // When: Select workers for different roles ServerStatus prefillResult = randomStrategy.select(balanceContext, RoleType.PREFILL, null); ServerStatus decodeResult = randomStrategy.select(balanceContext, RoleType.DECODE, null); - // Then: Both should be successful assertTrue(prefillResult.isSuccess()); assertTrue(decodeResult.isSuccess()); } @Test void should_work_with_group_parameter() { - // Given: Model with workers in specific groups - - // Add worker with specific group WorkerStatus worker = createWorkerStatus("127.0.0.1"); worker.setGroup("group-a"); EngineWorkerStatus.MODEL_ROLE_WORKER_STATUS.getPrefillStatusMap().put("127.0.0.1:8080", worker); + registerPrefill("127.0.0.1:8080", worker); Request req = new Request(); - BalanceContext balanceContext = new BalanceContext(); balanceContext.setRequest(req); - // When: Select worker with group parameter ServerStatus result = randomStrategy.select(balanceContext, RoleType.PREFILL, "group-a"); - // Then: Should be successful assertTrue(result.isSuccess()); } @Test void should_return_error_when_no_workers_in_specified_group() { - // Given: Model with workers but none in the specified group - - // Add worker with different group WorkerStatus worker = createWorkerStatus("127.0.0.1"); worker.setGroup("group-a"); EngineWorkerStatus.MODEL_ROLE_WORKER_STATUS.getPrefillStatusMap().put("127.0.0.1:8080", worker); + registerPrefill("127.0.0.1:8080", worker); Request req = new Request(); - BalanceContext balanceContext = new BalanceContext(); balanceContext.setRequest(req); - // When: Select worker with different group parameter ServerStatus result = randomStrategy.select(balanceContext, RoleType.PREFILL, "group-b"); - // Then: Should return error status assertFalse(result.isSuccess()); assertEquals(StrategyErrorType.NO_AVAILABLE_WORKER.getErrorCode(), result.getCode()); assertEquals(StrategyErrorType.NO_AVAILABLE_WORKER.getErrorMsg(), result.getMessage()); @@ -226,9 +228,6 @@ void should_return_error_when_no_workers_in_specified_group() { @Test void should_register_strategy_in_factory() { - // Given: RandomStrategy is instantiated - // When: Check if it's registered in the factory - // Then: Should be able to get it from the factory RandomStrategy strategyFromFactory = (RandomStrategy) LoadBalanceStrategyFactory.getLoadBalancer(LoadBalanceStrategyEnum.RANDOM); assertNotNull(strategyFromFactory); assertSame(randomStrategy, strategyFromFactory); @@ -236,10 +235,8 @@ void should_register_strategy_in_factory() { @Test void should_distribute_requests_uniformly_across_workers() { - // Given: Model with multiple available workers Map prefillStatusMap = EngineWorkerStatus.MODEL_ROLE_WORKER_STATUS.getPrefillStatusMap(); - // Add multiple workers WorkerStatus worker1 = createWorkerStatus("127.0.0.1"); WorkerStatus worker2 = createWorkerStatus("127.0.0.2"); WorkerStatus worker3 = createWorkerStatus("127.0.0.3"); @@ -247,14 +244,15 @@ void should_distribute_requests_uniformly_across_workers() { prefillStatusMap.put("127.0.0.1:8080", worker1); prefillStatusMap.put("127.0.0.2:8080", worker2); prefillStatusMap.put("127.0.0.3:8080", worker3); + registerPrefill("127.0.0.1:8080", worker1); + registerPrefill("127.0.0.2:8080", worker2); + registerPrefill("127.0.0.3:8080", worker3); Request req = new Request(); - BalanceContext balanceContext = new BalanceContext(); balanceContext.setRequest(req); - // When: Select workers many times int totalRuns = 10000; Map selectionCount = new HashMap<>(); @@ -268,7 +266,6 @@ void should_distribute_requests_uniformly_across_workers() { } } - // Then: Each worker should be selected approximately 33% of the time (within 10% tolerance) int expectedCountPerWorker = totalRuns / 3; double tolerance = 0.10; @@ -285,25 +282,22 @@ void should_distribute_requests_uniformly_across_workers() { @Test void should_skip_dead_workers() { - // Given: Model with mix of alive and dead workers Map prefillStatusMap = EngineWorkerStatus.MODEL_ROLE_WORKER_STATUS.getPrefillStatusMap(); - // Add dead worker WorkerStatus deadWorker = createWorkerStatus("127.0.0.1"); deadWorker.setAlive(false); prefillStatusMap.put("127.0.0.1:8080", deadWorker); + registerPrefill("127.0.0.1:8080", deadWorker); - // Add alive worker WorkerStatus aliveWorker = createWorkerStatus("127.0.0.2"); prefillStatusMap.put("127.0.0.2:8080", aliveWorker); + registerPrefill("127.0.0.2:8080", aliveWorker); Request req = new Request(); - BalanceContext balanceContext = new BalanceContext(); balanceContext.setRequest(req); - // When: Select worker multiple times int totalRuns = 100; Map selectionCount = new HashMap<>(); @@ -317,23 +311,25 @@ void should_skip_dead_workers() { } } - // Then: Only alive workers should be selected assertFalse(selectionCount.containsKey("127.0.0.1")); assertEquals(totalRuns, selectionCount.getOrDefault("127.0.0.2", 0)); } @Test void should_skip_workers_rejected_by_resource_measure() { - // Given: Model with one resource-unavailable worker and one available worker Map decodeStatusMap = EngineWorkerStatus.MODEL_ROLE_WORKER_STATUS.getDecodeStatusMap(); WorkerStatus unavailableWorker = createWorkerStatus("127.0.0.1"); WorkerStatus availableWorker = createWorkerStatus("127.0.0.2"); decodeStatusMap.put("127.0.0.1:8080", unavailableWorker); decodeStatusMap.put("127.0.0.2:8080", availableWorker); + registerDecode("127.0.0.1:8080", unavailableWorker); + registerDecode("127.0.0.2:8080", availableWorker); - Mockito.when(resourceMeasure.isResourceAvailable(unavailableWorker)).thenReturn(false); - Mockito.when(resourceMeasure.isResourceAvailable(availableWorker)).thenReturn(true); + Mockito.when(resourceMeasure.isResourceAvailable( + Mockito.argThat(ep -> ep != null && "127.0.0.1".equals(ep.getIp())))).thenReturn(false); + Mockito.when(resourceMeasure.isResourceAvailable( + Mockito.argThat(ep -> ep != null && "127.0.0.2".equals(ep.getIp())))).thenReturn(true); Request req = new Request(); req.setSeqLen(1000); @@ -342,35 +338,29 @@ void should_skip_workers_rejected_by_resource_measure() { BalanceContext balanceContext = new BalanceContext(); balanceContext.setRequest(req); - // When: Select a decode worker ServerStatus result = randomStrategy.select(balanceContext, RoleType.DECODE, null); - // Then: RandomStrategy should honor serviceability filtering assertTrue(result.isSuccess()); assertEquals("127.0.0.2", result.getServerIp()); - assertTrue(availableWorker.getLocalTaskMap().containsKey(12345L)); } @Test void should_properly_set_server_status_fields() { - // Given: Model with a worker Map prefillStatusMap = EngineWorkerStatus.MODEL_ROLE_WORKER_STATUS.getPrefillStatusMap(); WorkerStatus worker = createWorkerStatus("127.0.0.1"); worker.setGroup("group-x"); prefillStatusMap.put("127.0.0.1:8080", worker); + registerPrefill("127.0.0.1:8080", worker); Request req = new Request(); - req.setSeqLen(1000); BalanceContext balanceContext = new BalanceContext(); balanceContext.setRequest(req); - // When: Select a worker ServerStatus result = randomStrategy.select(balanceContext, RoleType.PREFILL, null); - // Then: All server status fields should be properly set assertTrue(result.isSuccess()); assertEquals("127.0.0.1", result.getServerIp()); assertEquals(8080, result.getHttpPort()); @@ -380,32 +370,26 @@ void should_properly_set_server_status_fields() { @Test void should_handle_null_request_id() { - // Given: Model with a worker Map prefillStatusMap = EngineWorkerStatus.MODEL_ROLE_WORKER_STATUS.getPrefillStatusMap(); prefillStatusMap.clear(); WorkerStatus worker = createWorkerStatus("127.0.0.1"); prefillStatusMap.put("127.0.0.1:8080", worker); + registerPrefill("127.0.0.1:8080", worker); Request req = new Request(); - BalanceContext balanceContext = new BalanceContext(); balanceContext.setRequest(req); - // When: Select a worker with null requestId ServerStatus result = randomStrategy.select(balanceContext, RoleType.PREFILL, null); - // Then: Should still return success (RandomStrategy doesn't require requestId) assertTrue(result.isSuccess()); assertEquals("127.0.0.1", result.getServerIp()); } @Test void should_handle_rollback_without_error() { - // Given: Rollback is called - // When: Rollback is called (RandomStrategy has empty implementation) - // Then: Should not throw any exception randomStrategy.rollBack("127.0.0.1:8080", 0); } diff --git a/rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/balance/strategy/ShortestTTFTStrategyTest.java b/rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/balance/strategy/ShortestTTFTStrategyTest.java deleted file mode 100644 index f566e4cd95..0000000000 --- a/rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/balance/strategy/ShortestTTFTStrategyTest.java +++ /dev/null @@ -1,294 +0,0 @@ -package org.flexlb.balance.strategy; - -import org.flexlb.balance.resource.ResourceMeasureFactory; -import org.flexlb.cache.service.CacheAwareService; -import org.flexlb.config.ConfigService; -import org.flexlb.config.FlexlbConfig; -import org.flexlb.config.ModelMetaConfig; -import org.flexlb.config.StrategyConfigs; -import org.flexlb.dao.BalanceContext; -import org.flexlb.dao.loadbalance.Request; -import org.flexlb.dao.loadbalance.ServerStatus; -import org.flexlb.dao.master.CacheStatus; -import org.flexlb.dao.master.TaskInfo; -import org.flexlb.dao.master.WorkerStatus; -import org.flexlb.dao.route.RoleType; -import org.flexlb.service.monitor.EngineHealthReporter; -import org.flexlb.sync.status.EngineWorkerStatus; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.mockito.Mockito; - -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; - -/** - * @author zjw - * description: - * date: 2025/3/11 - */ -class ShortestTTFTStrategyTest { - - @BeforeEach - void setUp() { - EngineWorkerStatus.MODEL_ROLE_WORKER_STATUS.getPrefillStatusMap().clear(); - } - - @AfterEach - void cleanUp() { - EngineWorkerStatus.MODEL_ROLE_WORKER_STATUS.getPrefillStatusMap().clear(); - } - - @Test - void test() { - - EngineWorkerStatus engineWorkerStatus = new EngineWorkerStatus(new ModelMetaConfig()); - Map prefillStatusMap = EngineWorkerStatus.MODEL_ROLE_WORKER_STATUS.getPrefillStatusMap(); - Map waitingTaskList = new HashMap<>(); - Map runningTaskList = new HashMap<>(); - Map finishedTaskList = new HashMap<>(); - ConcurrentHashMap localTaskList = new ConcurrentHashMap<>(); - WorkerStatus workerStatus = createWorkerStatus("127.0.0.1", 200, waitingTaskList, runningTaskList, finishedTaskList, localTaskList); - - Map waitingTaskList1 = new HashMap<>(); - Map runningTaskList1 = new HashMap<>(); - Map finishedTaskList1 = new HashMap<>(); - ConcurrentHashMap localTaskList1 = new ConcurrentHashMap<>(); - WorkerStatus workerStatus1 = createWorkerStatus("127.0.0.2", 100, waitingTaskList1, runningTaskList1, finishedTaskList1, localTaskList1); - - prefillStatusMap.put("127.0.0.1:8080", workerStatus); - prefillStatusMap.put("127.0.0.2:8080", workerStatus1); - Request req = new Request(); - req.setSeqLen(1000); - req.setRequestId(12345L); - List blockCacheKeys = new ArrayList<>(); - blockCacheKeys.add(1L); - blockCacheKeys.add(2L); - req.setBlockCacheKeys(blockCacheKeys); - - EngineHealthReporter engineHealthReporter = Mockito.mock(EngineHealthReporter.class); - CacheAwareService cacheAwareService = Mockito.mock(CacheAwareService.class); - ResourceMeasureFactory resourceMeasureFactory = Mockito.mock(ResourceMeasureFactory.class); - org.flexlb.balance.resource.ResourceMeasure resourceMeasure = Mockito.mock(org.flexlb.balance.resource.ResourceMeasure.class); - ConfigService configService = Mockito.mock(ConfigService.class); - Mockito.when(configService.loadBalanceConfig()).thenReturn(new FlexlbConfig()); - Mockito.when(configService.getStrategyConfigs()).thenReturn(new StrategyConfigs()); - Mockito.when(resourceMeasureFactory.getMeasure(Mockito.any())).thenReturn(resourceMeasure); - Mockito.when(resourceMeasure.isResourceAvailable(Mockito.any())).thenReturn(true); - Mockito.when(cacheAwareService.findMatchingEngines(Mockito.anyList(), Mockito.any(), Mockito.any())).thenReturn(new HashMap<>()); - - ShortestTTFTStrategy staticCacheLoadBalancer = - new ShortestTTFTStrategy(engineWorkerStatus, engineHealthReporter, cacheAwareService, resourceMeasureFactory, configService); - - BalanceContext balanceContext = new BalanceContext(); - balanceContext.setConfig(new FlexlbConfig()); - balanceContext.setRequest(req); - ServerStatus result = staticCacheLoadBalancer.select(balanceContext, RoleType.PREFILL, null); - if (!result.isSuccess()) { - System.out.println("Result not successful - code: " + result.getCode() + ", message: " + result.getMessage()); - } - Assertions.assertTrue(result.isSuccess(), "Result should be successful but got: " + result.getMessage()); - Assertions.assertEquals("127.0.0.2", result.getServerIp()); - } - - @Test - void should_short_circuit_to_lowest_ttft_when_fixed_candidate_pool_size_is_one() { - EngineWorkerStatus engineWorkerStatus = new EngineWorkerStatus(new ModelMetaConfig()); - Map prefillStatusMap = EngineWorkerStatus.MODEL_ROLE_WORKER_STATUS.getPrefillStatusMap(); - - WorkerStatus lowestTtftWorker = createWorkerStatus("127.0.0.1", 0); - lowestTtftWorker.getLastSelectedTime().set(1000L); - WorkerStatus olderWorkerWithHigherTtft = createWorkerStatus("127.0.0.2", 5); - olderWorkerWithHigherTtft.getLastSelectedTime().set(1L); - - prefillStatusMap.put(lowestTtftWorker.getIpPort(), lowestTtftWorker); - prefillStatusMap.put(olderWorkerWithHigherTtft.getIpPort(), olderWorkerWithHigherTtft); - - ShortestTTFTStrategy strategy = createStrategy(engineWorkerStatus, fixedCandidatePoolConfigService(1)); - - ServerStatus result = strategy.select(createBalanceContext(100L), RoleType.PREFILL, null); - - Assertions.assertTrue(result.isSuccess(), "Result should be successful but got: " + result.getMessage()); - Assertions.assertEquals("127.0.0.1", result.getServerIp()); - } - - @Test - void should_apply_fairness_when_fixed_candidate_pool_size_is_larger_than_one() { - EngineWorkerStatus engineWorkerStatus = new EngineWorkerStatus(new ModelMetaConfig()); - Map prefillStatusMap = EngineWorkerStatus.MODEL_ROLE_WORKER_STATUS.getPrefillStatusMap(); - - WorkerStatus lowestTtftWorker = createWorkerStatus("127.0.0.1", 0); - lowestTtftWorker.getLastSelectedTime().set(1000L); - WorkerStatus olderWorkerWithSimilarTtft = createWorkerStatus("127.0.0.2", 0); - olderWorkerWithSimilarTtft.getLastSelectedTime().set(1L); - - prefillStatusMap.put(lowestTtftWorker.getIpPort(), lowestTtftWorker); - prefillStatusMap.put(olderWorkerWithSimilarTtft.getIpPort(), olderWorkerWithSimilarTtft); - - ShortestTTFTStrategy strategy = createStrategy(engineWorkerStatus, fixedCandidatePoolConfigService(2)); - - ServerStatus result = strategy.select(createBalanceContext(100L), RoleType.PREFILL, null); - - Assertions.assertTrue(result.isSuccess(), "Result should be successful but got: " + result.getMessage()); - Assertions.assertEquals("127.0.0.2", result.getServerIp()); - } - - @Test - void should_apply_fairness_when_ratio_candidate_pool_has_multiple_workers() { - EngineWorkerStatus engineWorkerStatus = new EngineWorkerStatus(new ModelMetaConfig()); - Map prefillStatusMap = EngineWorkerStatus.MODEL_ROLE_WORKER_STATUS.getPrefillStatusMap(); - - WorkerStatus recentlySelectedWorker = createWorkerStatus("127.0.0.1", 0); - recentlySelectedWorker.getLastSelectedTime().set(1000L); - WorkerStatus olderWorker = createWorkerStatus("127.0.0.2", 0); - olderWorker.getLastSelectedTime().set(1L); - - prefillStatusMap.put(recentlySelectedWorker.getIpPort(), recentlySelectedWorker); - prefillStatusMap.put(olderWorker.getIpPort(), olderWorker); - - ShortestTTFTStrategy strategy = createStrategy(engineWorkerStatus, ratioCandidatePoolConfigService(1.0)); - - ServerStatus result = strategy.select(createBalanceContext(100L), RoleType.PREFILL, null); - - Assertions.assertTrue(result.isSuccess(), "Result should be successful but got: " + result.getMessage()); - Assertions.assertEquals("127.0.0.2", result.getServerIp()); - } - - @Test - void should_report_candidate_and_selected_routing_cache_match_tokens() { - EngineWorkerStatus engineWorkerStatus = new EngineWorkerStatus(new ModelMetaConfig()); - Map prefillStatusMap = EngineWorkerStatus.MODEL_ROLE_WORKER_STATUS.getPrefillStatusMap(); - - WorkerStatus longestMatchWorker = createWorkerStatus("127.0.0.1", 1000); - WorkerStatus selectedWorker = createWorkerStatus("127.0.0.2", 0); - prefillStatusMap.put(longestMatchWorker.getIpPort(), longestMatchWorker); - prefillStatusMap.put(selectedWorker.getIpPort(), selectedWorker); - - EngineHealthReporter engineHealthReporter = Mockito.mock(EngineHealthReporter.class); - CacheAwareService cacheAwareService = Mockito.mock(CacheAwareService.class); - ResourceMeasureFactory resourceMeasureFactory = Mockito.mock(ResourceMeasureFactory.class); - org.flexlb.balance.resource.ResourceMeasure resourceMeasure = Mockito.mock(org.flexlb.balance.resource.ResourceMeasure.class); - Mockito.when(resourceMeasureFactory.getMeasure(Mockito.any())).thenReturn(resourceMeasure); - Mockito.when(resourceMeasure.isResourceAvailable(Mockito.any())).thenReturn(true); - Mockito.when(cacheAwareService.findMatchingEngines(Mockito.anyList(), Mockito.any(), Mockito.any())) - .thenReturn(Map.of( - longestMatchWorker.getIpPort(), 4, - selectedWorker.getIpPort(), 1)); - - ShortestTTFTStrategy strategy = new ShortestTTFTStrategy( - engineWorkerStatus, - engineHealthReporter, - cacheAwareService, - resourceMeasureFactory, - fixedCandidatePoolConfigService(1)); - - BalanceContext balanceContext = createBalanceContext(4096L); - // Page-RR sends virtual block tokens here: seq_size_per_block * cp_size. - balanceContext.getRequest().setCacheKeyBlockSize(1024L); - - ServerStatus result = strategy.select(balanceContext, RoleType.PREFILL, null); - - Assertions.assertTrue(result.isSuccess(), "Result should be successful but got: " + result.getMessage()); - Assertions.assertEquals("127.0.0.2", result.getServerIp()); - Mockito.verify(engineHealthReporter).reportRoutingCandidateCacheMatchMetrics( - RoleType.PREFILL, "127.0.0.1", 4096L, 4096L); - Mockito.verify(engineHealthReporter).reportRoutingCandidateCacheMatchMetrics( - RoleType.PREFILL, "127.0.0.2", 1024L, 4096L); - Mockito.verify(engineHealthReporter).reportRoutingSelectedCacheMatchMetrics( - RoleType.PREFILL, "127.0.0.2", 1024L, 4096L); - } - - private ShortestTTFTStrategy createStrategy(EngineWorkerStatus engineWorkerStatus, ConfigService configService) { - EngineHealthReporter engineHealthReporter = Mockito.mock(EngineHealthReporter.class); - CacheAwareService cacheAwareService = Mockito.mock(CacheAwareService.class); - ResourceMeasureFactory resourceMeasureFactory = Mockito.mock(ResourceMeasureFactory.class); - org.flexlb.balance.resource.ResourceMeasure resourceMeasure = Mockito.mock(org.flexlb.balance.resource.ResourceMeasure.class); - - Mockito.when(resourceMeasureFactory.getMeasure(Mockito.any())).thenReturn(resourceMeasure); - Mockito.when(resourceMeasure.isResourceAvailable(Mockito.any())).thenReturn(true); - Mockito.when(cacheAwareService.findMatchingEngines(Mockito.anyList(), Mockito.any(), Mockito.any())).thenReturn(new HashMap<>()); - - return new ShortestTTFTStrategy( - engineWorkerStatus, - engineHealthReporter, - cacheAwareService, - resourceMeasureFactory, - configService); - } - - private ConfigService fixedCandidatePoolConfigService(int size) { - StrategyConfigs strategyConfigs = new StrategyConfigs(); - StrategyConfigs.CandidatePoolConfig candidatePool = strategyConfigs.getShortestTtft().getCandidatePool(); - candidatePool.setMode(StrategyConfigs.CandidatePoolMode.FIXED); - candidatePool.setSize(size); - strategyConfigs.normalize(); - - ConfigService configService = Mockito.mock(ConfigService.class); - Mockito.when(configService.getStrategyConfigs()).thenReturn(strategyConfigs); - return configService; - } - - private ConfigService ratioCandidatePoolConfigService(double ratio) { - StrategyConfigs strategyConfigs = new StrategyConfigs(); - StrategyConfigs.CandidatePoolConfig candidatePool = strategyConfigs.getShortestTtft().getCandidatePool(); - candidatePool.setMode(StrategyConfigs.CandidatePoolMode.RATIO); - candidatePool.setRatio(ratio); - strategyConfigs.normalize(); - - ConfigService configService = Mockito.mock(ConfigService.class); - Mockito.when(configService.getStrategyConfigs()).thenReturn(strategyConfigs); - return configService; - } - - private BalanceContext createBalanceContext(long seqLen) { - Request req = new Request(); - req.setSeqLen(seqLen); - req.setRequestId(12345L); - req.setBlockCacheKeys(List.of(1L, 2L)); - - BalanceContext balanceContext = new BalanceContext(); - balanceContext.setConfig(new FlexlbConfig()); - balanceContext.setRequest(req); - return balanceContext; - } - - WorkerStatus createWorkerStatus(String ip, long runningQueueTime) { - return createWorkerStatus( - ip, - runningQueueTime, - new HashMap<>(), - new HashMap<>(), - new HashMap<>(), - new ConcurrentHashMap<>()); - } - - WorkerStatus createWorkerStatus(String ip, - long runningQueueTime, - Map waitingTaskInfo, - Map finishedTaskList, - Map runningTaslList, - ConcurrentHashMap localTaskList) { - WorkerStatus workerStatus = new WorkerStatus(); - - workerStatus.setIp(ip); - workerStatus.setPort(8080); - workerStatus.setSite("na61"); - workerStatus.setAlive(true); - workerStatus.setRole(RoleType.PREFILL.getCode()); - CacheStatus cacheStatus = new CacheStatus(); - cacheStatus.setAvailableKvCache(10000); - cacheStatus.setBlockSize(256); - workerStatus.setCacheStatus(cacheStatus); - workerStatus.getRunningQueueTime().getAndSet(runningQueueTime); - workerStatus.setWaitingTaskList(waitingTaskInfo); - workerStatus.updateTaskStates(waitingTaskInfo, runningTaslList, finishedTaskList); - workerStatus.setRunningTaskList(runningTaslList); - return workerStatus; - } - -} diff --git a/rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/service/RouteServiceTest.java b/rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/service/RouteServiceTest.java index 12435de2cc..b52ec1e2db 100644 --- a/rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/service/RouteServiceTest.java +++ b/rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/service/RouteServiceTest.java @@ -34,6 +34,9 @@ class RouteServiceTest { @Mock private QueueManager queueManager; + @Mock + private org.flexlb.balance.scheduler.FlexlbBatchScheduler flexlbBatchScheduler; + @Mock private RecentCacheKeyTraceReporter recentCacheKeyTraceReporter; @@ -45,7 +48,8 @@ class RouteServiceTest { @BeforeEach void setUp() { when(configService.loadBalanceConfig()).thenReturn(flexlbConfig); - routeService = new RouteService(configService, defaultRouter, queueManager, recentCacheKeyTraceReporter); + routeService = new RouteService(configService, defaultRouter, queueManager, + flexlbBatchScheduler, recentCacheKeyTraceReporter); } @Test diff --git a/rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/sync/runner/EngineSyncRunnerTest.java b/rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/sync/runner/EngineSyncRunnerTest.java index ea2a20cb26..ff8cd166a1 100644 --- a/rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/sync/runner/EngineSyncRunnerTest.java +++ b/rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/sync/runner/EngineSyncRunnerTest.java @@ -71,7 +71,9 @@ void setUp() { localKvCacheAwareManager, syncRequestTimeoutMs, syncCount, - syncEngineStatusInterval + syncEngineStatusInterval, + null, + null ); } @@ -98,7 +100,9 @@ void should_handle_null_worker_status_gracefully() { localKvCacheAwareManager, syncRequestTimeoutMs, syncCount, - syncEngineStatusInterval + syncEngineStatusInterval, + null, + null ); // Execute diff --git a/rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/sync/runner/GrpcWorkerStatusCheckRunnerTest.java b/rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/sync/runner/GrpcWorkerStatusCheckRunnerTest.java index f122c0f9a1..97f3fd34a6 100644 --- a/rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/sync/runner/GrpcWorkerStatusCheckRunnerTest.java +++ b/rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/sync/runner/GrpcWorkerStatusCheckRunnerTest.java @@ -9,13 +9,24 @@ import org.mockito.Mockito; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.assertNull; import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +/** + * Tests for {@link GrpcWorkerStatusRunner}. + * + *

Key API changes since original implementation: + *

    + *
  • Proto field {@code is_waiting} replaced by {@code TaskPhase phase}
  • + *
  • {@code WorkerStatus.runningTaskList} replaces old {@code waitingTaskList + localTaskMap}
  • + *
  • Constructor requires {@code FlexlbBatchScheduler + EndpointRegistry} (nullable)
  • + *
  • Task list refresh only occurs when status version advances (not on equal version)
  • + *
+ */ class GrpcWorkerStatusCheckRunnerTest { private final EngineGrpcService engineGrpcService = Mockito.mock(EngineGrpcService.class); @@ -35,7 +46,7 @@ void should_callGrpcServiceAndVerifyInteraction_when_runnerExecutes() { workerStatus.setPort(8080); EngineRpcService.WorkerStatusPB workerStatusPB = EngineRpcService.WorkerStatusPB.newBuilder() - .setRole("test-role") + .setRole(EngineRpcService.RoleTypePB.ROLE_TYPE_PREFILL) .setAvailableConcurrency(10) .setRunningQueryLen(5) .setWaitingQueryLen(3) @@ -47,21 +58,25 @@ void should_callGrpcServiceAndVerifyInteraction_when_runnerExecutes() { .setAlive(true) .build(); - when(engineGrpcService.getWorkerStatus(anyString(), anyInt(), anyLong(), anyLong(), org.mockito.ArgumentMatchers.any(RoleType.class))).thenReturn(workerStatusPB); + when(engineGrpcService.getWorkerStatus(anyString(), anyInt(), anyLong(), anyLong(), + org.mockito.ArgumentMatchers.any(RoleType.class))).thenReturn(workerStatusPB); - // Act + // Act — pass null for FlexlbBatchScheduler and EndpointRegistry (not needed in unit test) GrpcWorkerStatusRunner runner = new GrpcWorkerStatusRunner( modelName, ipPort, site, RoleType.PREFILL, - group, workerStatus, engineHealthReporter, engineGrpcService, 20); + group, workerStatus, engineHealthReporter, engineGrpcService, 20L, null, null); runner.run(); - // Assert + // Assert — gRPC port is derived from HTTP port 8080 → 8081 verify(engineGrpcService).getWorkerStatus("127.0.0.1", 8081, -1L, 20L, RoleType.PREFILL); } @Test - void should_refreshTaskLists_when_statusVersionIsNotUpdated() { + void should_not_update_task_list_when_status_version_is_unchanged() { + // When the gRPC response version equals the local version, the status update + // is skipped — including the runningTaskList refresh. This avoids unnecessary + // state churn when the engine hasn't changed. String modelName = "test-model"; String ipPort = "127.0.0.1:8080"; String site = "test-site"; @@ -72,27 +87,30 @@ void should_refreshTaskLists_when_statusVersionIsNotUpdated() { workerStatus.setPort(8080); workerStatus.getStatusVersion().set(100L); - EngineRpcService.TaskInfoPB waitingTask = EngineRpcService.TaskInfoPB.newBuilder() + // Use TaskPhasePB instead of the removed is_waiting field + EngineRpcService.TaskInfoPB taskInfo = EngineRpcService.TaskInfoPB.newBuilder() .setRequestId(123L) .setInputLength(100) - .setIsWaiting(true) + .setPhase(EngineRpcService.TaskPhase.TASK_PHASE_RECEIVED) .build(); EngineRpcService.WorkerStatusPB workerStatusPB = EngineRpcService.WorkerStatusPB.newBuilder() - .setRole(RoleType.PREFILL.getCode()) + .setRole(EngineRpcService.RoleTypePB.ROLE_TYPE_PREFILL) .setStatusVersion(100L) .setAlive(true) - .addRunningTaskInfo(waitingTask) + .addRunningTaskInfo(taskInfo) .build(); - when(engineGrpcService.getWorkerStatus(anyString(), anyInt(), anyLong(), anyLong(), org.mockito.ArgumentMatchers.any(RoleType.class))).thenReturn(workerStatusPB); + when(engineGrpcService.getWorkerStatus(anyString(), anyInt(), anyLong(), anyLong(), + org.mockito.ArgumentMatchers.any(RoleType.class))).thenReturn(workerStatusPB); GrpcWorkerStatusRunner runner = new GrpcWorkerStatusRunner( modelName, ipPort, site, RoleType.PREFILL, - group, workerStatus, engineHealthReporter, engineGrpcService, 20); + group, workerStatus, engineHealthReporter, engineGrpcService, 20L, null, null); runner.run(); - assertEquals(1, workerStatus.getWaitingTaskList().size()); - assertTrue(workerStatus.getWaitingTaskList().containsKey("123")); + // Version not advanced → runningTaskList should NOT be populated from response + assertNull(workerStatus.getRunningTaskList(), + "runningTaskList should not be updated when status version is unchanged"); } } diff --git a/rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/sync/status/EngineWorkerStatusTest.java b/rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/sync/status/EngineWorkerStatusTest.java index a4baec0a27..0861dab7ef 100644 --- a/rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/sync/status/EngineWorkerStatusTest.java +++ b/rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/sync/status/EngineWorkerStatusTest.java @@ -1,10 +1,15 @@ package org.flexlb.sync.status; +import org.flexlb.balance.endpoint.EndpointRegistry; +import org.flexlb.config.ConfigService; +import org.flexlb.config.FlexlbConfig; import org.flexlb.config.ModelMetaConfig; import org.flexlb.dao.master.WorkerStatus; import org.flexlb.dao.route.RoleType; +import org.flexlb.service.monitor.BatchSchedulerReporter; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.mockito.Mockito; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; @@ -15,12 +20,17 @@ class EngineWorkerStatusTest { private EngineWorkerStatus engineWorkerStatus; + private EndpointRegistry registry; + private ConfigService configService; private WorkerStatus workerStatus1; private WorkerStatus workerStatus2; @BeforeEach void setUp() { - engineWorkerStatus = new EngineWorkerStatus(new ModelMetaConfig()); + configService = Mockito.mock(ConfigService.class); + Mockito.when(configService.loadBalanceConfig()).thenReturn(new FlexlbConfig()); + registry = new EndpointRegistry(configService, null, Mockito.mock(BatchSchedulerReporter.class)); + engineWorkerStatus = new EngineWorkerStatus(new ModelMetaConfig(), registry); workerStatus1 = new WorkerStatus(); workerStatus1.setGroup("group1"); workerStatus2 = new WorkerStatus(); @@ -33,7 +43,7 @@ void should_create_engine_worker_status_with_config_when_constructing_with_model ModelMetaConfig config = new ModelMetaConfig(); // When - EngineWorkerStatus status = new EngineWorkerStatus(config); + EngineWorkerStatus status = new EngineWorkerStatus(config, registry); // Then assertNotNull(status); @@ -50,6 +60,16 @@ void should_return_filtered_worker_status_when_selecting_model_worker_status_wit EngineWorkerStatus.MODEL_ROLE_WORKER_STATUS.getDecodeStatusMap().put(ipPort1, workerStatus1); // group1 EngineWorkerStatus.MODEL_ROLE_WORKER_STATUS.getDecodeStatusMap().put(ipPort2, workerStatus2); // group2 + // Register corresponding DecodeEndpoints + workerStatus1.setIp("127.0.0.1"); + workerStatus1.setPort(8080); + workerStatus1.setGrpcPort(9090); + workerStatus2.setIp("127.0.0.1"); + workerStatus2.setPort(8081); + workerStatus2.setGrpcPort(9091); + registry.ensureDecodeEndpoint(ipPort1, workerStatus1); + registry.ensureDecodeEndpoint(ipPort2, workerStatus2); + // When var result = engineWorkerStatus.selectModelWorkerStatus(RoleType.DECODE, "group1"); @@ -72,6 +92,16 @@ void should_return_all_worker_status_when_selecting_model_worker_status_without_ EngineWorkerStatus.MODEL_ROLE_WORKER_STATUS.getPrefillStatusMap().put(ipPort1, workerStatus1); EngineWorkerStatus.MODEL_ROLE_WORKER_STATUS.getPrefillStatusMap().put(ipPort2, workerStatus2); + // Register corresponding PrefillEndpoints + workerStatus1.setIp("127.0.0.1"); + workerStatus1.setPort(8080); + workerStatus1.setGrpcPort(9090); + workerStatus2.setIp("127.0.0.1"); + workerStatus2.setPort(8081); + workerStatus2.setGrpcPort(9091); + registry.ensurePrefillEndpoint(ipPort1, workerStatus1); + registry.ensurePrefillEndpoint(ipPort2, workerStatus2); + // When var result = engineWorkerStatus.selectModelWorkerStatus(RoleType.PREFILL, null); diff --git a/rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/sync/worker/WorkerStatusResponseTest.java b/rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/sync/worker/WorkerStatusResponseTest.java index 388f0ed77e..a021183171 100644 --- a/rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/sync/worker/WorkerStatusResponseTest.java +++ b/rtp_llm/flexlb/flexlb-sync/src/test/java/org/flexlb/sync/worker/WorkerStatusResponseTest.java @@ -1,6 +1,7 @@ package org.flexlb.sync.worker; -import org.flexlb.domain.worker.WorkerStatusResponse; +import org.flexlb.dao.master.WorkerStatusResponse; +import org.flexlb.dao.route.RoleType; import org.flexlb.util.JsonUtils; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; @@ -9,10 +10,10 @@ class WorkerStatusResponseTest { @Test void testConfigLoader() throws Exception { - String TEST_JSON = "{\"role\":\"RoleType.PREFILL\",\"available_concurrency\":1637,\"running_task_info\":{},\"finished_task_info\":{},\"step_latency_ms\":36.636,\"iterate_count\":1,\"dp_size\":1,\"tp_size\":1,\"alive\":true,\"version\":1,\"status_version\":1752025357566,\"cache_status\":{\"available_kv_cache\":82944,\"total_kv_cache\":82944,\"block_size\":256,\"version\":-1},\"waiting_query_len\":0,\"running_query_len\":0}"; + String TEST_JSON = "{\"role\":\"PREFILL\",\"available_concurrency\":1637,\"running_task_info\":{},\"finished_task_info\":{},\"step_latency_ms\":36.636,\"iterate_count\":1,\"dp_size\":1,\"tp_size\":1,\"alive\":true,\"version\":1,\"status_version\":1752025357566,\"cache_status\":{\"available_kv_cache\":82944,\"total_kv_cache\":82944,\"block_size\":256,\"version\":-1},\"waiting_query_len\":0,\"running_query_len\":0}"; WorkerStatusResponse workerStatusResponse = JsonUtils.toObject(TEST_JSON, new com.fasterxml.jackson.core.type.TypeReference() { }); - Assertions.assertEquals("RoleType.PREFILL", workerStatusResponse.getRole()); + Assertions.assertEquals(RoleType.PREFILL, workerStatusResponse.getRole()); Assertions.assertTrue(workerStatusResponse.isAlive()); Assertions.assertEquals(1637, workerStatusResponse.getAvailableConcurrency()); } diff --git a/rtp_llm/frontend/frontend_server.py b/rtp_llm/frontend/frontend_server.py index dbca0e2191..ebbc7d948e 100644 --- a/rtp_llm/frontend/frontend_server.py +++ b/rtp_llm/frontend/frontend_server.py @@ -260,6 +260,9 @@ async def inference(self, req: Union[str, Dict[Any, Any]], raw_request: RawReque sequence, ) request_headers = extract_request_headers(raw_request.headers) + logging.info("request_arrival: trace_id=%s request_id=%s model=%s prompt_len=%d", + req.get("trace_id", "-"), req.get(request_id_field_name), + self.py_env_configs.model_args.model_type, len(req.get("prompt", ""))) except Exception as e: return self._handle_exception(req, e) @@ -439,6 +442,8 @@ async def _collect_complete_response_and_record_access_log( else complete_response ) self._access_logger.log_success_access(req, complete_response) + logging.info("request_completion: trace_id=%s request_id=%s status=success", + req.get("trace_id", "-"), req.get(request_id_field_name, "unknown")) return complete_response diff --git a/rtp_llm/frontend/frontend_worker.py b/rtp_llm/frontend/frontend_worker.py index 8ecf973b8e..d1b24bdd5b 100644 --- a/rtp_llm/frontend/frontend_worker.py +++ b/rtp_llm/frontend/frontend_worker.py @@ -112,6 +112,7 @@ def __init__( ) self.backend_rpc_server_visitor = self.pipeline.backend_rpc_server_visitor self.generate_env_config = py_env_configs.generate_env_config + self.server_config = py_env_configs.server_config logging.info("frontend worker start done.") @@ -169,9 +170,9 @@ def _inference(self, request: Request, **kwargs: Any): num_return_sequences = request.generate_configs[0].num_return_sequences generators: List[AsyncGenerator[Dict[str, Any], None]] = [] # TODO temp fix sp with batch infer, will change request_id to str later - batch_group_size = len(request.input_texts) - # Use request.request_id as batch_group_id for all streams in the same batch - batch_group_id = request.request_id + group_size = len(request.input_texts) + # Use request.request_id as group_id for all streams in the same batch + group_id = request.request_id for i, (text, urls, generate_config) in enumerate( zip(request.input_texts, request.input_urls, request.generate_configs) ): @@ -181,8 +182,8 @@ def _inference(self, request: Request, **kwargs: Any): text, urls, generate_config=generate_config, - batch_group_size=batch_group_size, - batch_group_id=batch_group_id, + group_size=group_size, + group_id=group_id, **kwargs, ) ) @@ -201,7 +202,9 @@ def _inference(self, request: Request, **kwargs: Any): ) def _format_response( - self, gen_responses: GenerateResponse, generate_config: GenerateConfig + self, + gen_responses: GenerateResponse, + generate_config: GenerateConfig, ) -> Dict[str, Any]: generate_texts = gen_responses.generate_texts finished = gen_responses.generate_outputs.generate_outputs[0].finished @@ -249,16 +252,18 @@ def _format_response( return response def _format_response_new( - self, gen_responses: GenerateResponse, generate_config: GenerateConfig + self, + gen_responses: GenerateResponse, + generate_config: GenerateConfig, ) -> Dict[str, Any]: generate_texts = gen_responses.generate_texts if generate_config.num_return_sequences > 0: aux_info = [] if generate_config.aux_info: - aux_info = [ - asdict(seq.aux_info) - for seq in gen_responses.generate_outputs.generate_outputs - ] + aux_info = [] + for seq in gen_responses.generate_outputs.generate_outputs: + info = asdict(seq.aux_info) + aux_info.append(info) sequences_pipeline_response = MultiSequencesPipelineResponse( response=generate_texts, finished=all( @@ -279,8 +284,8 @@ async def _yield_generate( text: str, urls: List[str], generate_config: GenerateConfig, - batch_group_size: int = 1, - batch_group_id: int = -1, + group_size: int = 1, + group_id: int = -1, **kwargs: Any, ) -> AsyncGenerator[Dict[str, Any], None]: stream = self.pipeline.pipeline_async( @@ -289,8 +294,8 @@ async def _yield_generate( urls=urls, generate_config=generate_config, generate_env_config=self.generate_env_config, - batch_group_size=batch_group_size, - batch_group_id=batch_group_id, + group_size=group_size, + group_id=group_id, **kwargs, ) async for generate_response in stream: diff --git a/rtp_llm/model_loader/w4a8_int4_per_channel_quant_weight.py b/rtp_llm/model_loader/w4a8_int4_per_channel_quant_weight.py index ac88f8ce77..95ff9b9577 100644 --- a/rtp_llm/model_loader/w4a8_int4_per_channel_quant_weight.py +++ b/rtp_llm/model_loader/w4a8_int4_per_channel_quant_weight.py @@ -4,9 +4,10 @@ import torch from rtp_llm.config.quant_config import ( - W4a8Int4PerChannelQuantConfig, QuantizationConfig, + W4a8Int4PerChannelQuantConfig, ) +from rtp_llm.model_loader.dynamic_fp8_quant_weight import quantize_weight_to_fp8 from rtp_llm.model_loader.load_config import LoadConfig from rtp_llm.model_loader.tensor_source import TensorSource from rtp_llm.model_loader.w8a8_weight import create_w8a8_fp8_weight @@ -16,22 +17,29 @@ QuantWeight, WeightModule, ) -from rtp_llm.model_loader.dynamic_fp8_quant_weight import quantize_weight_to_fp8 from rtp_llm.utils.model_weight import W, WeightStyle def quantize_weight_to_int4b(input: torch.Tensor, group_size: int, eps: float = 1e-12): - from rtp_kernel.w4a8_group_gemm import unified_encode_int4b, reorder_tensor, pack_scale_fp8 + try: + from rtp_kernel.w4a8_group_gemm import ( + pack_scale_fp8, + reorder_tensor, + unified_encode_int4b, + ) + except ImportError as e: + logging.warning(f"rtp_kernel.w4a8_group_gemm not available: {e}") + raise N, K = input.shape - assert (K % group_size == 0), f"invalid params {K} or {group_size}" + assert K % group_size == 0, f"invalid params {K} or {group_size}" n_groups = K // group_size input_g = input.view(N, n_groups, group_size) amax = input_g.abs().amax(dim=2, keepdim=True) finfo = torch.finfo(torch.float8_e4m3fn) - scale = (amax / 7.).clamp(min=eps, max=finfo.max / 8.) + scale = (amax / 7.0).clamp(min=eps, max=finfo.max / 8.0) scale_f = scale.to(torch.float8_e4m3fn).to(input.dtype) output_int8 = torch.round(input_g / scale).clamp_(min=-8, max=7).to(torch.int8) @@ -168,11 +176,19 @@ def _load_raw_tensor( N = kernel_tensor.shape[1] K = kernel_tensor.shape[2] - quant_kernel = torch.empty((E, N, K // 2), device=kernel_tensor.device, dtype=torch.int8) - scale = torch.empty((E, K // self.group_size, N, 8), device=kernel_tensor.device, dtype=torch.float8_e4m3fn) + quant_kernel = torch.empty( + (E, N, K // 2), device=kernel_tensor.device, dtype=torch.int8 + ) + scale = torch.empty( + (E, K // self.group_size, N, 8), + device=kernel_tensor.device, + dtype=torch.float8_e4m3fn, + ) for i in range(E): - quant_kernel[i, :, :], scale[i] = quantize_weight_to_int4b(kernel_tensor[i, :, :], self.group_size) + quant_kernel[i, :, :], scale[i] = quantize_weight_to_int4b( + kernel_tensor[i, :, :], self.group_size + ) else: quant_kernel, scale = quantize_weight_to_fp8(kernel.get(self.kernel.name)) diff --git a/rtp_llm/models/base_model.py b/rtp_llm/models/base_model.py index ad3abde19d..7080cb9a24 100644 --- a/rtp_llm/models/base_model.py +++ b/rtp_llm/models/base_model.py @@ -143,7 +143,8 @@ def load_default_generate_config(self, generate_env_config: Optional[Any] = None def _get_device_str(self) -> str: """Get device string from parallelism_config.""" - return f"cuda:{self.parallelism_config.local_rank}" + local_device_offset = int(os.environ.get("RTP_LLM_LOCAL_DEVICE_OFFSET", "0")) + return f"cuda:{self.parallelism_config.local_rank + local_device_offset}" @timer_wrapper(description="load model") def load(self, skip_python_model: bool = False): diff --git a/rtp_llm/models_py/bindings/core/BUILD b/rtp_llm/models_py/bindings/core/BUILD index b47c04d28c..9dfbae71bc 100644 --- a/rtp_llm/models_py/bindings/core/BUILD +++ b/rtp_llm/models_py/bindings/core/BUILD @@ -135,6 +135,7 @@ cc_library( ":types_hdr", ":common_defines", "//rtp_llm/cpp/distribute:cpu_tp_broadcaster", + "//rtp_llm/cpp/distribute:rpc_cpu_tp_broadcaster", "//rtp_llm/cpp/utils:signal_utils", "//rtp_llm/cpp/utils:kv_cache_utils", "//rtp_llm/cpp/utils:debug_utils", @@ -214,6 +215,7 @@ cc_library( ":types_hdr", ":common_defines", "//rtp_llm/cpp/distribute:cpu_tp_broadcaster", + "//rtp_llm/cpp/distribute:rpc_cpu_tp_broadcaster", "//rtp_llm/cpp/utils:signal_utils", "//rtp_llm/cpp/utils:kv_cache_utils", "//rtp_llm/cpp/utils:debug_utils", diff --git a/rtp_llm/models_py/bindings/core/ExecOps.cc b/rtp_llm/models_py/bindings/core/ExecOps.cc index 304ec5b3da..69a4961f94 100644 --- a/rtp_llm/models_py/bindings/core/ExecOps.cc +++ b/rtp_llm/models_py/bindings/core/ExecOps.cc @@ -2,6 +2,7 @@ #include "rtp_llm/models_py/bindings/core/CommonDefines.h" #include "rtp_llm/cpp/disaggregate/cache_store/CacheStore.h" #include "rtp_llm/cpp/distribute/CpuTpBroadcaster.h" +#include "rtp_llm/cpp/distribute/RpcCpuTpBroadcaster.h" #include "rtp_llm/cpp/utils/Logger.h" #include "rtp_llm/cpp/cache/CacheGroupType.h" #include "rtp_llm/cpp/cache/KVCacheResource.h" @@ -629,6 +630,19 @@ void execBroadcastCpu(const BroadcastParams& params) { } return; } + auto& rpc_bcast = RpcCpuTpBroadcaster::instance(); + if (rpc_bcast.isInitialized()) { + for (auto& t : params.buffers) { + RTP_LLM_CHECK_WITH_INFO( + t.is_cpu(), "execBroadcastCpu requires CPU tensors (got device=%s)", t.device().str().c_str()); + auto contig = t.contiguous(); + rpc_bcast.broadcast(contig.data_ptr(), contig.nbytes(), params.root); + if (!contig.is_same(t)) { + t.copy_(contig); + } + } + return; + } // Fallback to NCCL via Python callback, typically for cross-node TP. // Preserve immediate-read semantics with the original sync sequence. execBroadcast(params); @@ -820,8 +834,9 @@ void registerExecCtxOps(pybind11::module& m) { []() { py::gil_scoped_release release; CpuTpBroadcaster::instance().reset(); + RpcCpuTpBroadcaster::instance().reset(); }, - "Tear down the UDS-backed intra-node TP broadcaster and clear its singleton state."); + "Tear down CPU TP broadcasters and clear singleton state."); } } // namespace rtp_llm diff --git a/rtp_llm/models_py/modules/dsv4/fp8/attention.py b/rtp_llm/models_py/modules/dsv4/fp8/attention.py index 1ad936e6a2..6a3c7f2c4f 100644 --- a/rtp_llm/models_py/modules/dsv4/fp8/attention.py +++ b/rtp_llm/models_py/modules/dsv4/fp8/attention.py @@ -2443,11 +2443,9 @@ def forward( ) -> torch.Tensor: """Prefill entry point. - ``x``: flat ``[T, dim]`` (single-request, B==1 — enforced by - the FIFO scheduler's ``max_context_batch_size=1`` setting and - ``DeepSeekV4Model.forward``). ``positions``: ``[T]`` int64 of - absolute token positions; ``positions[0]`` is the prefill - start position. We don't read it eagerly — under broadcast + ``x``: flat ``[T, dim]`` where T is the total token count + across all requests in the batch (ragged via ``cu_seqlens``). + ``positions``: ``[T]`` int64 of absolute token positions. We don't read it eagerly — under broadcast meta the sp_int is already on ``self._prefill_meta_shared`` (synced once in ``forward.py`` for all layers); standalone path syncs once inside ``_build_shared_prefill_meta``. diff --git a/rtp_llm/models_py/modules/dsv4/test/test_dsv4_kernel_jit_warmup.py b/rtp_llm/models_py/modules/dsv4/test/test_dsv4_kernel_jit_warmup.py index 39a613d22e..464195961b 100644 --- a/rtp_llm/models_py/modules/dsv4/test/test_dsv4_kernel_jit_warmup.py +++ b/rtp_llm/models_py/modules/dsv4/test/test_dsv4_kernel_jit_warmup.py @@ -34,9 +34,10 @@ def _stub_package(name: str, path: str) -> None: os.path.join(_REPO, "rtp_llm", "models_py", "modules", "dsv4"), ) +import rtp_llm.models_py.modules.dsv4.dsv4_kernel_jit_warmup as warmup_module from rtp_llm.models_py.modules.dsv4.dsv4_kernel_jit_warmup import ( - _collect_dsv4_branch_kernel_configs, _collect_dsv4_batched_fp8_einsum_shapes, + _collect_dsv4_branch_kernel_configs, _collect_dsv4_dense_gemm_shapes, _collect_dsv4_fp8_mqa_logits_shapes, _collect_dsv4_mhc_head_fused_shapes, @@ -55,7 +56,6 @@ def _stub_package(name: str, path: str) -> None: _warmup_fused_kv_compress_norm_rope_insert, resolve_dense_gemm_warmup_max_m, ) -import rtp_llm.models_py.modules.dsv4.dsv4_kernel_jit_warmup as warmup_module def _module_type(name, attrs): @@ -149,7 +149,7 @@ def test_dense_warmup_decode_fallback_accounts_for_speculative_width(self): resolve_dense_gemm_warmup_max_m( max_seq_len=1048576, max_batch_size=1024, - role_type_name="RoleType.DECODE", + role_type_name="DECODE", is_speculative=True, gen_num_per_cycle=4, ), @@ -758,9 +758,7 @@ def with_patch(name, value): old_enabled = mhc_tilelang.tk_mhc_head_fused_enabled mhc_tilelang.tk_mhc_head_fused_enabled = lambda: True self.addCleanup( - lambda: setattr( - mhc_tilelang, "tk_mhc_head_fused_enabled", old_enabled - ) + lambda: setattr(mhc_tilelang, "tk_mhc_head_fused_enabled", old_enabled) ) warmup_module._MHC_HEAD_FUSED_JIT_WARMED_KEYS.clear() @@ -937,9 +935,11 @@ def test_mhc_pre_big_fuse_warmup_initializes_tilelang_env_first(self): ) def test_jit_kernel_specialization_contracts(self): - from rtp_llm.models_py.modules.dsv4.fp8 import _compressor_vllm_triton - from rtp_llm.models_py.modules.dsv4.fp8 import _swa_dequant_triton - from rtp_llm.models_py.modules.dsv4.fp8 import _swa_kv_insert_triton + from rtp_llm.models_py.modules.dsv4.fp8 import ( + _compressor_vllm_triton, + _swa_dequant_triton, + _swa_kv_insert_triton, + ) compress_src = inspect.getsource( _compressor_vllm_triton._fused_kv_compress_norm_rope_insert_sparse_attn.fn @@ -1025,7 +1025,7 @@ def launch(): calls.append(None) if len(calls) == 1: raise RuntimeError( - 'Catastrophic error: cannot open source file ' + "Catastrophic error: cannot open source file " '"/tmp/tmpxft_000011ba_00000000-7_tvm_kernels.cpp1.ii"' ) @@ -1045,7 +1045,7 @@ def launch(): calls.append(None) if len(calls) == 1: cause = RuntimeError( - 'Catastrophic error: cannot open source file ' + "Catastrophic error: cannot open source file " '"/tmp/tmpxft_000011ba_00000000-7_tvm_kernels.cpp1.ii"' ) raise RuntimeError("TileLang mhc_pre failed: shape=(1, 2)") from cause diff --git a/rtp_llm/models_py/modules/factory/fused_moe/impl/cuda/executors/cutlass_moe.py b/rtp_llm/models_py/modules/factory/fused_moe/impl/cuda/executors/cutlass_moe.py index a85de4e38e..404e0af693 100644 --- a/rtp_llm/models_py/modules/factory/fused_moe/impl/cuda/executors/cutlass_moe.py +++ b/rtp_llm/models_py/modules/factory/fused_moe/impl/cuda/executors/cutlass_moe.py @@ -1,10 +1,14 @@ +import logging from math import prod from typing import Any, Callable, Dict, Optional import torch -from rtp_kernel.fp8_group_gemm import ( - get_cutlass_batched_moe_mm_data, -) + +try: + from rtp_kernel.fp8_group_gemm import get_cutlass_batched_moe_mm_data +except ImportError as e: + logging.warning(f"rtp_kernel.fp8_group_gemm not available: {e}") + get_cutlass_batched_moe_mm_data = None from rtp_llm.models_py.kernels.cuda.fp8_kernel import ( cutlass_moe_mm_fp8_scaled, @@ -28,8 +32,8 @@ ) from rtp_llm.models_py.triton_kernels.moe.ep_kernels import ( cutlass_moe_pre_reorder, - post_reorder_triton_kernel, get_cutlass_moe_mm_without_permute_info, + post_reorder_triton_kernel, ) from rtp_llm.utils.model_weight import W diff --git a/rtp_llm/models_py/modules/factory/fused_moe/impl/cuda/executors/cutlass_w4a8_moe.py b/rtp_llm/models_py/modules/factory/fused_moe/impl/cuda/executors/cutlass_w4a8_moe.py index 68c5d25b6c..8a8015976a 100644 --- a/rtp_llm/models_py/modules/factory/fused_moe/impl/cuda/executors/cutlass_w4a8_moe.py +++ b/rtp_llm/models_py/modules/factory/fused_moe/impl/cuda/executors/cutlass_w4a8_moe.py @@ -1,10 +1,16 @@ +import logging from typing import Any, Dict, Optional import torch -from rtp_kernel.fp8_group_gemm import ( - get_cutlass_batched_moe_mm_data, -) -from rtp_kernel.w4a8_group_gemm import w4a8_group_gemm_ptpc, compute_reorder_stride + +try: + from rtp_kernel.fp8_group_gemm import get_cutlass_batched_moe_mm_data + from rtp_kernel.w4a8_group_gemm import compute_reorder_stride, w4a8_group_gemm_ptpc +except ImportError as e: + logging.warning(f"rtp_kernel MoE kernels not available: {e}") + get_cutlass_batched_moe_mm_data = None + w4a8_group_gemm_ptpc = None + compute_reorder_stride = None from rtp_llm.models_py.modules.factory.fused_moe.defs.config_adapter import ( MoEConfigAdapter, @@ -24,8 +30,8 @@ ) from rtp_llm.models_py.triton_kernels.moe.ep_kernels import ( cutlass_moe_pre_reorder, - post_reorder_triton_kernel, get_cutlass_moe_mm_without_permute_info, + post_reorder_triton_kernel, ) from rtp_llm.utils.model_weight import W diff --git a/rtp_llm/ops/fused_rope_kvcache_op.py b/rtp_llm/ops/fused_rope_kvcache_op.py index acc3723e0f..42fd2df977 100644 --- a/rtp_llm/ops/fused_rope_kvcache_op.py +++ b/rtp_llm/ops/fused_rope_kvcache_op.py @@ -1,3 +1,4 @@ +import logging from dataclasses import dataclass from typing import Optional @@ -8,11 +9,18 @@ check_rope_cache, get_rope_cache_once, ) -from rtp_kernel.fused_rope_kvcache import ( - convert_offset_to_block_array, - decode_fused_rope_kvcache, - prefill_fused_rope_kvcache, -) + +try: + from rtp_kernel.fused_rope_kvcache import ( + convert_offset_to_block_array, + decode_fused_rope_kvcache, + prefill_fused_rope_kvcache, + ) +except ImportError as e: + logging.warning(f"rtp_kernel.fused_rope_kvcache not available: {e}") + convert_offset_to_block_array = None + decode_fused_rope_kvcache = None + prefill_fused_rope_kvcache = None @dataclass diff --git a/rtp_llm/ops/libth_transformer_config.pyi b/rtp_llm/ops/libth_transformer_config.pyi index 212026a7ee..6ba50f8fe6 100644 --- a/rtp_llm/ops/libth_transformer_config.pyi +++ b/rtp_llm/ops/libth_transformer_config.pyi @@ -1318,6 +1318,9 @@ class NormType: class PDSepConfig: + batch_dispatch_timeout_ms: int + batch_load_timeout_ms: int + batch_prepare_timeout_ms: int cache_store_connect_port: int cache_store_listen_port: int cache_store_rdma_connect_port: int @@ -1779,6 +1782,7 @@ class RuntimeConfig: use_gather_batch_scheduler: bool warm_up: bool warm_up_with_loss: bool + all_worker_grpc_addrs: list[str] worker_addrs: list[str] worker_grpc_addrs: list[str] diff --git a/rtp_llm/pipeline/pipeline.py b/rtp_llm/pipeline/pipeline.py index cc616c7fc4..901cb2e048 100644 --- a/rtp_llm/pipeline/pipeline.py +++ b/rtp_llm/pipeline/pipeline.py @@ -473,8 +473,8 @@ async def generate_stream( generate_config=generate_config, tokenizer=self.tokenizer, token_type_ids=token_type_ids, - batch_group_size=kwargs.get("batch_group_size", 1), - batch_group_id=kwargs.get("batch_group_id", -1), + group_size=kwargs.get("group_size", 1), + group_id=kwargs.get("group_id", -1), headers=request_headers, ) diff --git a/rtp_llm/server/backend_rpc_server_visitor.py b/rtp_llm/server/backend_rpc_server_visitor.py index d68d3c0984..41ab35759e 100644 --- a/rtp_llm/server/backend_rpc_server_visitor.py +++ b/rtp_llm/server/backend_rpc_server_visitor.py @@ -1,7 +1,7 @@ +import asyncio import logging import os import time -import asyncio from typing import TYPE_CHECKING, AsyncGenerator, List, Optional import torch @@ -9,7 +9,7 @@ from rtp_llm.config.exceptions import ExceptionType, FtRuntimeException from rtp_llm.config.generate_config import RoleAddr, RoleType from rtp_llm.config.model_config import ModelConfig as PyModelConfig -from rtp_llm.cpp.model_rpc.model_rpc_client import ModelRpcClient +from rtp_llm.cpp.model_rpc.model_rpc_client import ModelRpcClient, trans_input from rtp_llm.metrics import kmonitor from rtp_llm.metrics.kmonitor_metric_reporter import AccMetrics, GaugeMetrics from rtp_llm.ops import SpeculativeExecutionConfig, VitSeparation, get_block_cache_keys @@ -216,6 +216,7 @@ async def get_master_route_addrs( full_block_cache_keys = get_block_cache_keys(token_ids, self.seq_size_per_block) block_cache_keys = self._route_cache_keys(full_block_cache_keys) self._report_recent_cache_key_metrics(block_cache_keys) + input_pb = trans_input(input) try: route_result = await self.master_client.get_backend_role_addrs( @@ -223,6 +224,7 @@ async def get_master_route_addrs( cache_key_block_size=self._cache_key_block_size(), input=input, request_id=input.request_id, + input_pb=input_pb, ) except BaseException as e: exception_json = format_exception(e) @@ -235,6 +237,7 @@ async def get_master_route_addrs( if route_result.is_ok: input.generate_config.role_addrs = route_result.role_addrs + input.enqueued_by_master = route_result.enqueued_by_master route_logger.debug( "master route success, request_id=%s, addrs=%s", input.request_id, @@ -360,6 +363,12 @@ async def route_ips(self, input: GenerateInput): allow_domain_fallback = master_route_result is None or ( master_route_result.connection_failed ) + if self.master_config and self.master_config.disable_domain_fallback: + allow_domain_fallback = False + route_logger.warning( + "master_config.disable_domain_fallback is enabled, " + "skipping domain fallback routing" + ) if ( not input.generate_config.role_addrs or need_domain_routing ) and allow_domain_fallback: @@ -544,15 +553,18 @@ async def stream_with_aux_info(): return stream_with_aux_info() def is_backend_service_ready(self, refresh: bool = False) -> bool: - roles: List[RoleAddr] = self.host_service.get_backend_role_addrs( - self.backend_role_list, refresh - ) - if not roles: - return False - for role in self.backend_role_list: - if role not in [r.role for r in roles]: - logging.warning(f"role {role} not in available roles {roles}") - return False + # COMMENTED OUT: Direct connection to prefill/decode bypasses FlexLB + # roles: List[RoleAddr] = self.host_service.get_backend_role_addrs( + # self.backend_role_list, refresh + # ) + # if not roles: + # return False + # for role in self.backend_role_list: + # if role not in [r.role for r in roles]: + # logging.warning(f"role {role} not in available roles {roles}") + # return False + # return True + # Always return True to force routing through FlexLB return True diff --git a/rtp_llm/server/host_service.py b/rtp_llm/server/host_service.py index 118838ce2a..6fc9a4c9c2 100644 --- a/rtp_llm/server/host_service.py +++ b/rtp_llm/server/host_service.py @@ -151,9 +151,15 @@ def get_hosts(self, refresh: bool = False) -> List[Host]: class EndPoint(BaseModel): + """Service discovery endpoint. + + address MUST carry the HTTP port (e.g. "10.0.0.1:7001"). + Health check probes HTTP on this port. gRPC is derived as HTTP + FLEXLB_GRPC_PORT_OFFSET. + """ + type: str - address: str - protocol: str + address: str # "ip:HTTP_PORT" + protocol: str # protocol of the port in address (typically "http") path: str diff --git a/rtp_llm/server/master_client.py b/rtp_llm/server/master_client.py index 7cb31e8c58..bd5bc9e598 100644 --- a/rtp_llm/server/master_client.py +++ b/rtp_llm/server/master_client.py @@ -1,29 +1,34 @@ -"""FlexLB schedule client: request role addrs from master/slave and parse response.""" +"""FlexLB schedule client: request role addrs from master/slave via gRPC.""" -from __future__ import annotations - -import json import logging import time from dataclasses import dataclass from typing import Any, Dict, List, Optional +import grpc +import grpc.aio + from rtp_llm.config.exceptions import ExceptionType, FtRuntimeException from rtp_llm.config.generate_config import RoleAddr, RoleType from rtp_llm.config.py_config_modules import MasterConfig +from rtp_llm.cpp.model_rpc.proto.model_rpc_service_pb2 import ( + FlexlbScheduleRequestPB, + GenerateInputPB, +) +from rtp_llm.cpp.model_rpc.proto.model_rpc_service_pb2_grpc import FlexlbServiceStub +from rtp_llm.metrics import kmonitor +from rtp_llm.metrics.kmonitor_metric_reporter import AccMetrics from rtp_llm.server.host_service import HostService -from rtp_llm.server.request_headers import normalize_request_headers -from rtp_llm.server.worker_status import ScheduleMeta from rtp_llm.utils.base_model_datatypes import GenerateInput route_logger = logging.getLogger("route_logger") -SCHEDULE_PATH = "/rtp_llm/schedule" -DEFAULT_REQUEST_TIMEOUT_SEC = 0.5 SUCCESS_CODE = 200 DEFAULT_REQUEST_PRIORITY = 100 -CONNECTOR_LIMIT_PER_HOST = 30 -CONNECTOR_KEEPALIVE_TIMEOUT_SEC = 30 +# gRPC = HTTP + 2 for FlexLB's own servers (consistent with FlexlbGrpcServer.FLEXLB_GRPC_PORT_OFFSET). +# This is NOT the same as the backend engine offset (HTTP+1)—see CommonConstants.GRPC_PORT_OFFSET. +FLEXLB_GRPC_PORT_OFFSET = 2 +BEARER_PREFIX = "Bearer " @dataclass @@ -40,7 +45,8 @@ class FlexlbResponse: connection_failed: bool = False error_code: Optional[int] = None error_message: Optional[str] = None - result: Optional[Dict[str, Any]] = None # internal: raw JSON from scheduler + result: Optional[Dict[str, Any]] = None + enqueued_by_master: bool = False @property def is_ok(self) -> bool: @@ -55,10 +61,13 @@ def ok_with_result(cls, result: Dict[str, Any]) -> "FlexlbResponse": error_code=None, error_message=None, result=result, + enqueued_by_master=False, ) @classmethod - def ok(cls, role_addrs: List[RoleAddr]) -> "FlexlbResponse": + def ok( + cls, role_addrs: List[RoleAddr], enqueued_by_master: bool = False + ) -> "FlexlbResponse": """Business success: parsed role addrs.""" return cls( role_addrs=role_addrs, @@ -66,6 +75,7 @@ def ok(cls, role_addrs: List[RoleAddr]) -> "FlexlbResponse": error_code=None, error_message=None, result=None, + enqueued_by_master=enqueued_by_master, ) @classmethod @@ -81,71 +91,67 @@ def error_response( error_code=error_code, error_message=error_message, result=None, + enqueued_by_master=False, ) @classmethod def connection_failed_response(cls) -> "FlexlbResponse": - """No HTTP response (connection/timeout). Triggers slave retry and domain fallback.""" + """No response (connection/timeout). Triggers slave retry and domain fallback.""" return cls( role_addrs=None, connection_failed=True, error_code=None, error_message=None, result=None, + enqueued_by_master=False, ) class MasterClient: - """Client for FlexLB schedule API (master and optional slave).""" + """Client for FlexLB schedule gRPC API (master and optional slave).""" def __init__(self, host_service=None, server_config=None, master_config=None): - # Always use a MasterConfig instance; when the caller passes None we fall - # back to dataclass defaults aligned with the args registration - # (master_default_timeout_ms=3600000). self.master_config = ( master_config if master_config is not None else MasterConfig() ) self.host_service: Optional[HostService] = host_service - self.max_connect_pool_size = self.master_config.master_max_connect_pool_size - self._session: Optional[Any] = None + self._channels: Dict[str, grpc.aio.Channel] = {} self.latest_queue_length: int = 0 - self.session_timeout_s = self._get_session_timeout_s() - - def _get_session_timeout_s(self) -> Optional[float]: - # Session-level timeout is a safety net for the connection pool lifetime, - # not for individual requests. Per-request timeout in _send_schedule_request - # always takes precedence (aiohttp per-request timeout overrides session timeout). - # master_session_timeout_s semantics: - # > 0 -> use that value - # == 0 -> None (aiohttp treats it as no total timeout; unlimited) - # < 0 -> auto (3600s in queue mode, 0.5s otherwise) - ts = self.master_config.master_session_timeout_s - if ts == 0: - return None - if ts > 0: - return float(ts) - if self.host_service and self.host_service.master_vip.domain: - return 3600.0 - return DEFAULT_REQUEST_TIMEOUT_SEC - - async def _get_session(self): - import aiohttp - from aiohttp import ClientTimeout - - if self._session is None or self._session.closed: - timeout = ClientTimeout(total=self.session_timeout_s) - connector = aiohttp.TCPConnector( - limit=self.max_connect_pool_size, - limit_per_host=CONNECTOR_LIMIT_PER_HOST, - keepalive_timeout=CONNECTOR_KEEPALIVE_TIMEOUT_SEC, - enable_cleanup_closed=True, + + def _get_grpc_target(self, addr: str) -> str: + """Resolve gRPC target from service discovery address (ip:HTTP_PORT). + + gRPC port is always derived as HTTP port + FLEXLB_GRPC_PORT_OFFSET. + """ + ip = addr.split(":")[0] + try: + http_port = int(addr.split(":")[1]) + return f"{ip}:{http_port + FLEXLB_GRPC_PORT_OFFSET}" + except (IndexError, ValueError): + return f"{ip}:{7001 + FLEXLB_GRPC_PORT_OFFSET}" + + def _get_channel(self, target: str) -> grpc.aio.Channel: + if target not in self._channels: + self._channels[target] = grpc.aio.insecure_channel( + target, + options=[ + ("grpc.max_receive_message_length", 16 * 1024 * 1024), + ("grpc.max_send_message_length", 16 * 1024 * 1024), + ("grpc.keepalive_time_ms", 30000), + ("grpc.keepalive_timeout_ms", 10000), + ], ) - self._session = aiohttp.ClientSession(timeout=timeout, connector=connector) - return self._session + return self._channels[target] + + async def _close_channel(self, target: str) -> None: + channel = self._channels.pop(target, None) + if channel is not None: + await channel.close() async def close(self) -> None: - if self._session and not self._session.closed: - await self._session.close() + for channel in self._channels.values(): + await channel.close() + self._channels.clear() def get_latest_queue_length(self) -> int: return self.latest_queue_length @@ -153,81 +159,40 @@ def get_latest_queue_length(self) -> int: async def _send_schedule_request( self, addr: str, - payload: Dict[str, Any], - generate_timeout_ms: int, + request_pb: "FlexlbScheduleRequestPB", + timeout_s: Optional[float], request_id: int, - request_headers: Optional[Dict[str, str]] = None, - ) -> FlexlbResponse: - """ - Send one schedule request to the given host (master or slave). - Returns FlexlbResponse: ok_with_result on HTTP success, error_response on - non-200 body, connection_failed_response when no response received. - """ - url = f"http://{addr}{SCHEDULE_PATH}" - headers = {"Content-Type": "application/json"} - headers.update(normalize_request_headers(request_headers)) - # generate_timeout_ms <= 0 -> ClientTimeout(total=None); aiohttp treats as unlimited. - timeout_total = ( - generate_timeout_ms / 1000.0 if generate_timeout_ms > 0 else None - ) + ): + """Send gRPC schedule request. Returns proto response on success, None on transport failure.""" + target = self._get_grpc_target(addr) start = time.time() - import aiohttp - from aiohttp import ClientTimeout - try: - session = await self._get_session() - request_timeout = ClientTimeout(total=timeout_total) - async with session.post( - url, - data=json.dumps(payload), - headers=headers, - timeout=request_timeout, - ) as response: - if response.status != SUCCESS_CODE: - error_code = int(ExceptionType.MASTER_NO_AVAILABLE_WORKER) - error_message = None - try: - raw = await response.json() - if isinstance(raw, dict): - raw_code = raw.get("code") - if raw_code is not None: - try: - error_code = int(raw_code) - except (TypeError, ValueError): - pass - error_message = raw.get("error_message") - except (json.JSONDecodeError, aiohttp.ClientError): - pass - route_logger.error( - "FlexLB schedule failed, request_id=%s, error_code=%s, error_message=%s", - request_id, - error_code, - error_message or "", - ) - return FlexlbResponse.error_response(error_code, error_message) - - result = await response.json() - return FlexlbResponse.ok_with_result(result) - - except (aiohttp.ClientError, TimeoutError, ConnectionError, OSError) as e: + channel = self._get_channel(target) + stub = FlexlbServiceStub(channel) + response = await stub.Schedule(request_pb, timeout=timeout_s) + return response + except grpc.aio.AioRpcError as e: elapsed = time.time() - start route_logger.error( - "Schedule request failed, addr=%s, request_id=%s, error=%s, elapsed=%.3fs", + "gRPC schedule failed, addr=%s, request_id=%s, status=%s, detail=%s, elapsed=%.3fs", addr, request_id, - e, + e.code(), + e.details(), elapsed, ) - return FlexlbResponse.connection_failed_response() + await self._close_channel(target) + return None except Exception as e: elapsed = time.time() - start route_logger.exception( - "Unexpected error in schedule request, addr=%s, request_id=%s, elapsed=%.3fs", + "Unexpected gRPC error, addr=%s, request_id=%s, elapsed=%.3fs", addr, request_id, elapsed, ) - return FlexlbResponse.connection_failed_response() + await self._close_channel(target) + return None async def get_backend_role_addrs( self, @@ -235,6 +200,7 @@ async def get_backend_role_addrs( cache_key_block_size: int, input: GenerateInput, request_id: int, + input_pb: Optional["GenerateInputPB"] = None, ) -> FlexlbResponse: """ Resolve backend role addrs from FlexLB scheduler (master, then slave on connection failure). @@ -254,79 +220,90 @@ async def get_backend_role_addrs( input.generate_config, "ttft_timeout_ms", None ) or getattr(input.generate_config, "timeout_ms", None) if ttft_timeout_ms is None or ttft_timeout_ms <= 0: - # per-request not provided -> use args default (master_default_timeout_ms). - # When that default is <= 0, _send_schedule_request builds ClientTimeout(total=None). ttft_timeout_ms = self.master_config.master_default_timeout_ms - request_priority = getattr( - input.generate_config, - "traffic_reject_priority", - DEFAULT_REQUEST_PRIORITY, + timeout_s = ttft_timeout_ms / 1000.0 if ttft_timeout_ms > 0 else None + + gc = input.generate_config + api_key = self._extract_api_key(input) + request_pb = FlexlbScheduleRequestPB( + request_id=request_id, + block_cache_keys=block_cache_keys, + seq_len=input.prompt_length, + generate_timeout=ttft_timeout_ms, + request_time_ms=int(time.time() * 1000), + max_new_tokens=gc.max_new_tokens, + num_beams=gc.num_beams, + force_disable_sp_run=gc.force_disable_sp_run, + model="engine_service", + api_key=api_key, ) - start = time.time() + if input_pb is not None: + request_pb.generate_input.CopyFrom(input_pb) - payload: Dict[str, Any] = { - "model": "engine_service", - "block_cache_keys": block_cache_keys, - "cache_key_block_size": cache_key_block_size, - "seq_len": input.prompt_length, - "debug": False, - "request_priority": request_priority, - "generate_timeout": ttft_timeout_ms, - "request_id": request_id, - "request_time_ms": int(start * 1000), - } - - request_headers = getattr(input, "headers", None) - resp = await self._send_schedule_request( - master_addr, payload, ttft_timeout_ms, request_id, request_headers + response = await self._send_schedule_request( + master_addr, request_pb, timeout_s, request_id ) - if resp.connection_failed and slave_addr: + if response is None and slave_addr: route_logger.info( "Master connection failed, retrying slave, slave=%s, request_id=%s", slave_addr, request_id, ) - resp = await self._send_schedule_request( - slave_addr, payload, ttft_timeout_ms, request_id, request_headers + response = await self._send_schedule_request( + slave_addr, request_pb, timeout_s, request_id ) - if resp.result is None: - return FlexlbResponse( - role_addrs=None, - connection_failed=resp.connection_failed, - error_code=resp.error_code, - error_message=resp.error_message, - result=None, - ) + if response is None: + return FlexlbResponse.connection_failed_response() - if resp.result.get("code", SUCCESS_CODE) != SUCCESS_CODE: - raw_code = resp.result.get("code", SUCCESS_CODE) - try: - code = int(raw_code) - except (TypeError, ValueError): - code = int(ExceptionType.MASTER_NO_AVAILABLE_WORKER) + self.latest_queue_length = response.queue_length + + if response.code != SUCCESS_CODE: try: - exception_type = ExceptionType(code) + exception_type = ExceptionType(response.code) except ValueError: exception_type = ExceptionType.MASTER_NO_AVAILABLE_WORKER - message = resp.result.get("error_message") or "master schedule error" + message = response.error_message or "master schedule error" route_logger.error( - "Master schedule error, request_id=%s, error_code=%s, error_message=%s", + "Master schedule error, request_id=%s, error_code=%s, " + "error_message=%s", request_id, - code, + response.code, message, ) - raise FtRuntimeException(exception_type=exception_type, message=message) + kmonitor.report( + AccMetrics.MASTER_ROUTE_ERROR_QPS_METRIC, + 1, + {"error_code": str(response.code)}, + ) + raise FtRuntimeException( + exception_type=exception_type, + message=message, + ) - schedule_meta = ScheduleMeta.model_validate(resp.result) role_addrs = [ RoleAddr( - role=RoleType(s.role), # type: ignore[arg-type] + role=RoleType(s.role), ip=s.server_ip, http_port=s.http_port, grpc_port=s.grpc_port, ) - for s in schedule_meta.server_status + for s in response.server_status ] - return FlexlbResponse.ok(role_addrs) + return FlexlbResponse.ok( + role_addrs, enqueued_by_master=response.enqueued_by_master + ) + + @staticmethod + def _extract_api_key(input: GenerateInput) -> str: + headers = getattr(input, "headers", None) + if not headers: + return "" + api_key = headers.get("x-api-key") or headers.get("api-key") + if api_key: + return api_key + auth = headers.get("authorization", "") + if auth.startswith(BEARER_PREFIX): + return auth[len(BEARER_PREFIX) :].strip() + return "" diff --git a/rtp_llm/server/server_args/master_group_args.py b/rtp_llm/server/server_args/master_group_args.py index ee8a6d1c18..8bec926ab2 100644 --- a/rtp_llm/server/server_args/master_group_args.py +++ b/rtp_llm/server/server_args/master_group_args.py @@ -29,6 +29,15 @@ def init_master_group_args(parser, master_config): help="Master max connect pool size", ) + master_group.add_argument( + "--master_connector_limit_per_host", + env_name="MASTER_CONNECTOR_LIMIT_PER_HOST", + bind_to=(master_config, "master_connector_limit_per_host"), + type=int, + default=0, + help="Max HTTP connections per master host (0 = use default 30)", + ) + master_group.add_argument( "--master_session_timeout_s", env_name="MASTER_SESSION_TIMEOUT_S", @@ -39,3 +48,13 @@ def init_master_group_args(parser, master_config): "<0: auto (3600 when queue mode, 0.5 otherwise); " "==0: 不设超时(链路不超时); >0: 使用该值", ) + + master_group.add_argument( + "--master_disable_domain_fallback", + env_name="MASTER_DISABLE_DOMAIN_FALLBACK", + bind_to=(master_config, "disable_domain_fallback"), + type=bool, + default=False, + help="When True, disable domain fallback routing when master is unavailable or not configured. " + "Requests will fail with ROUTE_ERROR instead of falling back to VipServer domain routing.", + ) diff --git a/rtp_llm/server/server_args/pd_separation_group_args.py b/rtp_llm/server/server_args/pd_separation_group_args.py index f665e60e30..894ec6635c 100644 --- a/rtp_llm/server/server_args/pd_separation_group_args.py +++ b/rtp_llm/server/server_args/pd_separation_group_args.py @@ -53,6 +53,60 @@ def init_pd_separation_group_args(parser, pd_separation_config): "per-request generate_config.timeout_ms 优先级更高", ) + pd_separation_group.add_argument( + "--batch_dispatch_timeout_ms", + env_name="BATCH_DISPATCH_TIMEOUT_MS", + bind_to=(pd_separation_config, "batch_dispatch_timeout_ms"), + type=int, + default=60000, + help="EnqueueBatch 跨 DP 分发超时(毫秒),防止远端 DP 卡死阻塞整个 batch", + ) + + pd_separation_group.add_argument( + "--batch_prepare_timeout_ms", + env_name="BATCH_PREPARE_TIMEOUT_MS", + bind_to=(pd_separation_config, "batch_prepare_timeout_ms"), + type=int, + default=10000, + help="EnqueueGroup 内部 prepareAllocateResource 超时(毫秒)", + ) + + pd_separation_group.add_argument( + "--batch_load_timeout_ms", + env_name="BATCH_LOAD_TIMEOUT_MS", + bind_to=(pd_separation_config, "batch_load_timeout_ms"), + type=int, + default=10000, + help="EnqueueGroup 内部 remoteLoadCacheStart 超时(毫秒)", + ) + + pd_separation_group.add_argument( + "--prefill_enqueue_pool_size", + env_name="PREFILL_ENQUEUE_POOL_SIZE", + bind_to=(pd_separation_config, "prefill_enqueue_pool_size"), + type=int, + default=0, + help="Prefill L1 enqueue 线程池大小,0 表示使用公式默认值", + ) + + pd_separation_group.add_argument( + "--prefill_worker_lambda_pool_size", + env_name="PREFILL_WORKER_LAMBDA_POOL_SIZE", + bind_to=(pd_separation_config, "prefill_worker_lambda_pool_size"), + type=int, + default=0, + help="Prefill worker lambda 线程池大小,0 表示使用公式默认值", + ) + + pd_separation_group.add_argument( + "--prefill_slot_pool_size", + env_name="PREFILL_SLOT_POOL_SIZE", + bind_to=(pd_separation_config, "prefill_slot_pool_size"), + type=int, + default=0, + help="Prefill slot 线程池大小,0 表示使用公式默认值", + ) + pd_separation_group.add_argument( "--decode_retry_times", env_name="DECODE_RETRY_TIMES", diff --git a/rtp_llm/server/test/backend_rpc_server_visitor_test.py b/rtp_llm/server/test/backend_rpc_server_visitor_test.py index 1d218e32d8..82e8252ca3 100644 --- a/rtp_llm/server/test/backend_rpc_server_visitor_test.py +++ b/rtp_llm/server/test/backend_rpc_server_visitor_test.py @@ -31,6 +31,22 @@ def __init__(self, is_streaming=False): self.generate_config = _FakeGenerateConfig(is_streaming=is_streaming) +class _FakeRouteTokenIds: + shape = (3,) + + def tolist(self): + return [1, 2, 3] + + +class _FakeRouteInput: + request_id = 456 + token_ids = _FakeRouteTokenIds() + + def __init__(self): + self.generate_config = _FakeGenerateConfig() + self.enqueued_by_master = False + + class _FakeHostService: service_available = False @@ -38,6 +54,29 @@ def get_master_addr(self): return "master:1234" +class _FakeInputPB: + def SerializeToString(self): + return b"serialized-input" + + +class _FakeMasterClient: + def __init__(self): + self.calls = [] + + async def get_backend_role_addrs( + self, block_cache_keys, input, request_id, input_pb_bytes=None + ): + self.calls.append( + { + "block_cache_keys": block_cache_keys, + "input": input, + "request_id": request_id, + "input_pb_bytes": input_pb_bytes, + } + ) + return FlexlbResponse.ok(["prefill-role"], enqueued_by_master=True) + + class BackendRPCServerVisitorRouteCacheKeysTest(unittest.TestCase): def test_route_cache_keys_passthrough_when_page_rr_disabled(self): self.assertEqual( @@ -68,6 +107,35 @@ def test_cache_key_block_size_tracks_routed_key_granularity(self): class BackendRPCServerVisitorRouteIpsTest(unittest.IsolatedAsyncioTestCase): + async def test_get_master_route_addrs_passes_pb_and_marks_master_enqueue(self): + visitor = BackendRPCServerVisitor.__new__(BackendRPCServerVisitor) + visitor.seq_size_per_block = 16 + visitor.master_client = _FakeMasterClient() + visitor._route_cache_keys = lambda keys: keys + visitor._report_recent_cache_key_metrics = lambda keys: None + + input = _FakeRouteInput() + + with patch( + "rtp_llm.server.backend_rpc_server_visitor.get_block_cache_keys", + return_value=[11, 22], + ), patch( + "rtp_llm.server.backend_rpc_server_visitor.trans_input", + return_value=_FakeInputPB(), + ), patch( + "rtp_llm.server.backend_rpc_server_visitor.kmonitor" + ): + result = await visitor.get_master_route_addrs(input) + + self.assertIsNone(result) + self.assertEqual(input.generate_config.role_addrs, ["prefill-role"]) + self.assertTrue(input.enqueued_by_master) + self.assertEqual(visitor.master_client.calls[0]["block_cache_keys"], [11, 22]) + self.assertEqual(visitor.master_client.calls[0]["request_id"], 456) + self.assertEqual( + visitor.master_client.calls[0]["input_pb_bytes"], b"serialized-input" + ) + async def test_route_ips_preserves_master_route_error_code_on_route_error(self): visitor = BackendRPCServerVisitor.__new__(BackendRPCServerVisitor) visitor.master_config = None diff --git a/rtp_llm/server/test/master_client_test.py b/rtp_llm/server/test/master_client_test.py new file mode 100644 index 0000000000..eba506d12c --- /dev/null +++ b/rtp_llm/server/test/master_client_test.py @@ -0,0 +1,110 @@ +import base64 +import unittest + +from rtp_llm.server.master_client import FlexlbResponse, MasterClient + + +class _FakeMasterConfig: + master_max_connect_pool_size = 4 + master_session_timeout_s = 1 + master_default_timeout_ms = 3600000 + + +class _FakeHostService: + def get_master_addr(self): + return "master:1234" + + def get_slave_addr(self): + return None + + +class _FakeGenerateConfig: + max_new_tokens = 17 + num_beams = 2 + force_disable_sp_run = True + ttft_timeout_ms = 3000 + timeout_ms = -1 + traffic_reject_priority = 12 + + +class _FakeInput: + prompt_length = 5 + headers = {"x-request-id": "req-1"} + + def __init__(self): + self.generate_config = _FakeGenerateConfig() + + +class _CaptureMasterClient(MasterClient): + def __init__(self): + super().__init__( + host_service=_FakeHostService(), + master_config=_FakeMasterConfig(), + ) + self.calls = [] + + async def _send_schedule_request( + self, addr, payload, generate_timeout_ms, request_id, request_headers=None + ): + self.calls.append( + { + "addr": addr, + "payload": payload, + "generate_timeout_ms": generate_timeout_ms, + "request_id": request_id, + "request_headers": request_headers, + } + ) + return FlexlbResponse.ok_with_result( + { + "code": 200, + "server_status": [ + { + "role": "PREFILL", + "server_ip": "10.0.0.7", + "http_port": 8080, + "grpc_port": 9000, + } + ], + "enqueued_by_master": True, + } + ) + + +class MasterClientBatchPayloadTest(unittest.IsolatedAsyncioTestCase): + async def test_schedule_payload_contains_batch_fields_and_pb(self): + client = _CaptureMasterClient() + + response = await client.get_backend_role_addrs( + block_cache_keys=[1, 2, 3], + input=_FakeInput(), + request_id=99, + input_pb_bytes=b"serialized-input", + ) + + self.assertTrue(response.is_ok) + self.assertTrue(response.enqueued_by_master) + self.assertEqual(response.role_addrs[0].ip, "10.0.0.7") + + call = client.calls[0] + payload = call["payload"] + self.assertEqual(call["addr"], "master:1234") + self.assertEqual(call["generate_timeout_ms"], 3000) + self.assertEqual(call["request_id"], 99) + self.assertEqual(call["request_headers"], {"x-request-id": "req-1"}) + self.assertEqual(payload["block_cache_keys"], [1, 2, 3]) + self.assertEqual(payload["seq_len"], 5) + self.assertEqual(payload["request_priority"], 12) + self.assertEqual(payload["generate_timeout"], 3000) + self.assertEqual(payload["request_id"], 99) + self.assertEqual(payload["max_new_tokens"], 17) + self.assertEqual(payload["num_beams"], 2) + self.assertTrue(payload["force_disable_sp_run"]) + self.assertEqual( + payload["generate_input_pb_b64"], + base64.b64encode(b"serialized-input").decode("ascii"), + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/rtp_llm/server/test/schedule_meta_test.py b/rtp_llm/server/test/schedule_meta_test.py index dd543e2c17..0c07972230 100644 --- a/rtp_llm/server/test/schedule_meta_test.py +++ b/rtp_llm/server/test/schedule_meta_test.py @@ -122,6 +122,16 @@ def test_schedule_meta_with_empty_server_status(self): schedule_meta = ScheduleMeta.model_validate(test_data) self.assertEqual(len(schedule_meta.server_status), 0) + def test_schedule_meta_enqueued_by_master_defaults_false(self): + schedule_meta = ScheduleMeta.model_validate(self.valid_schedule_meta) + self.assertFalse(schedule_meta.enqueued_by_master) + + def test_schedule_meta_accepts_enqueued_by_master(self): + test_data = dict(self.valid_schedule_meta) + test_data["enqueued_by_master"] = True + schedule_meta = ScheduleMeta.model_validate(test_data) + self.assertTrue(schedule_meta.enqueued_by_master) + if __name__ == "__main__": unittest.main() diff --git a/rtp_llm/server/worker_status.py b/rtp_llm/server/worker_status.py index 5ab4aa9037..2063ec2d80 100644 --- a/rtp_llm/server/worker_status.py +++ b/rtp_llm/server/worker_status.py @@ -73,10 +73,17 @@ class ServerStatus(BaseModel): @model_validator(mode="before") def validate_role(cls, values: Dict[str, Any]): role = values.get("role") - if isinstance(role, str): - values["role"] = getattr(RoleType, role) + if isinstance(role, int): + values["role"] = RoleType(role) + elif isinstance(role, RoleType): + pass # already correct + elif isinstance(role, str): + try: + values["role"] = getattr(RoleType, role.upper()) + except AttributeError: + raise ValueError(f"Invalid role: {role}") from None else: - raise ValueError(f"Invalid role: {role}, expected str") + raise ValueError(f"Invalid role: {role}, expected int, str, or RoleType") return values @@ -87,3 +94,4 @@ class ScheduleMeta(BaseModel): error_message: Optional[str] = None success: Optional[bool] = True real_master_host: Optional[str] = None + enqueued_by_master: bool = False diff --git a/rtp_llm/test/utils/maga_server_manager.py b/rtp_llm/test/utils/maga_server_manager.py index b14a153672..53d2e45384 100644 --- a/rtp_llm/test/utils/maga_server_manager.py +++ b/rtp_llm/test/utils/maga_server_manager.py @@ -36,6 +36,7 @@ def __init__( role_name: str = "main", process_file_name: str = "process.log", smoke_args_str: str = "", + health_check_path: str = "/health", ): self._username = os.getenv("USER") self._env_args = env_args @@ -47,6 +48,7 @@ def __init__( self._process_file_name = process_file_name self._port = port self._smoke_args_str = smoke_args_str + self._health_check_path = health_check_path self._exit_code: Optional[int] = None self._state_lock = threading.Lock() self._stop_requested = False @@ -105,7 +107,9 @@ def wait_sever_done(self, timeout: int = 1600): from rtp_llm.utils.util import wait_sever_done # Health check uses START_PORT (self._port); when VIT_SEPARATION==1 we return True above - result = wait_sever_done(self._server_process, int(self._port), timeout) + result = wait_sever_done( + self._server_process, int(self._port), timeout, self._health_check_path + ) if not result: rc = self._server_process.poll() if self._server_process else None self._exit_code = rc diff --git a/rtp_llm/utils/base_model_datatypes.py b/rtp_llm/utils/base_model_datatypes.py index cf7601527a..daf1df66bc 100644 --- a/rtp_llm/utils/base_model_datatypes.py +++ b/rtp_llm/utils/base_model_datatypes.py @@ -6,6 +6,7 @@ from rtp_llm.config.generate_config import GenerateConfig, RoleAddr from rtp_llm.utils.multimodal_util import MultimodalInput + class EmbeddingOutput: text_embedding: torch.Tensor extra_input: Optional[torch.Tensor] @@ -43,8 +44,9 @@ class GenerateInput: tokenizer: Any = None # TODO: remove this prefix_length: int = 0 token_type_ids: List[int] = field(default_factory=list) - batch_group_size: int = 1 - batch_group_id: int = -1 # Batch group ID for force batch grouping, -1 means not set + group_size: int = 1 + group_id: int = -1 # Batch group ID for force batch grouping, -1 means not set + enqueued_by_master: bool = False headers: Dict[str, str] = field(default_factory=dict, repr=False) request_info: RequestInfo = field(default_factory=RequestInfo, repr=False) diff --git a/rtp_llm/utils/util.py b/rtp_llm/utils/util.py index 3a505f2b4a..5b8634d5ff 100644 --- a/rtp_llm/utils/util.py +++ b/rtp_llm/utils/util.py @@ -10,6 +10,9 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Set, Union +import psutil +import requests + class AtomicCounter: def __init__(self, initial: int = 0): @@ -319,22 +322,27 @@ async def _handle_response(response) -> Dict[str, Any]: } -def wait_sever_done(server_process, port: int, timeout: int = 1600): - import psutil - import requests - +def wait_sever_done( + server_process, port: int, timeout: int = 1600, health_check_path: str = "/health" +): host = "localhost" retry_interval = 1 # 重试间隔(秒) start_time = time.time() port = str(port) + health_check_path = health_check_path or "/health" + if not health_check_path.startswith("/"): + health_check_path = "/" + health_check_path - logging.info(f"等待pid[{server_process.pid}]启动中...\n端口 {port}") + logging.info( + f"等待pid[{server_process.pid}]启动中...\n端口 {port}, health path {health_check_path}" + ) while True: try: # 使用 HTTP health check 检查服务是否准备就绪 response = requests.get( - f"http://{host}:{port}/health", timeout=retry_interval + f"http://{host}:{port}{health_check_path}", + timeout=retry_interval, ) logging.info( f"response status_code = {response.status_code}, text = {response.text}, len = {len(response.text)}" @@ -380,8 +388,6 @@ def wait_sever_done(server_process, port: int, timeout: int = 1600): def stop_server( server_process, ): - import psutil - if server_process is not None and server_process.pid is not None: try: # 如果只kill start_server,会残留 backend/frontend 占用显存。