diff --git a/rtp_llm/cpp/disaggregate/cache_store/BUILD b/rtp_llm/cpp/disaggregate/cache_store/BUILD index 556ed18bf4..f446de59ee 100644 --- a/rtp_llm/cpp/disaggregate/cache_store/BUILD +++ b/rtp_llm/cpp/disaggregate/cache_store/BUILD @@ -70,6 +70,7 @@ cc_library( deps = [ ":cache_store_interface", "//rtp_llm/cpp/disaggregate/cache_store/proto:cache_store_service_cc_proto", + "//rtp_llm/cpp/utils:device_pin", "//rtp_llm/cpp/utils:time_util", "//:rtp_compute_ops", ":arpc_dep", diff --git a/rtp_llm/cpp/disaggregate/cache_store/InitParams.h b/rtp_llm/cpp/disaggregate/cache_store/InitParams.h index 9422f2ecb0..fde8915921 100644 --- a/rtp_llm/cpp/disaggregate/cache_store/InitParams.h +++ b/rtp_llm/cpp/disaggregate/cache_store/InitParams.h @@ -10,6 +10,7 @@ struct MessagerInitParams { uint32_t server_port = 0; uint32_t io_thread_count = 2; uint32_t worker_thread_count = 4; + int device_id = -1; uint32_t rdma_server_port = 0; uint32_t rdma_io_thread_count = 1; @@ -43,6 +44,7 @@ class CacheStoreInitParams { uint32_t messager_worker_thread_count = 32; kmonitor::MetricsReporterPtr metrics_reporter; + int device_id{-1}; // for test std::shared_ptr memory_util; diff --git a/rtp_llm/cpp/disaggregate/cache_store/NormalCacheStore.cpp b/rtp_llm/cpp/disaggregate/cache_store/NormalCacheStore.cpp index 2ad4693774..a0f13a054c 100644 --- a/rtp_llm/cpp/disaggregate/cache_store/NormalCacheStore.cpp +++ b/rtp_llm/cpp/disaggregate/cache_store/NormalCacheStore.cpp @@ -1,5 +1,6 @@ #include "rtp_llm/cpp/disaggregate/cache_store/NormalCacheStore.h" #include "rtp_llm/cpp/disaggregate/cache_store/Interface.h" +#include "rtp_llm/cpp/utils/DevicePin.h" #include "rtp_llm/cpp/utils/Logger.h" #include "autil/LockFreeThreadPool.h" @@ -30,7 +31,8 @@ std::shared_ptr NormalCacheStore::createNormalCacheStore(const } bool NormalCacheStore::init(const CacheStoreInitParams& params) { - params_ = params; + params_ = params; + device_id_ = params.device_id; if (params_.memory_util != nullptr) { memory_util_ = params.memory_util; @@ -53,6 +55,7 @@ bool NormalCacheStore::init(const CacheStoreInitParams& params) { messager_init_params.rdma_worker_thread_count = params.rdma_worker_thread_count; messager_init_params.io_thread_count = params.messager_io_thread_count; messager_init_params.worker_thread_count = params.messager_worker_thread_count; + messager_init_params.device_id = params.device_id; if (!messager_->init(messager_init_params)) { RTP_LLM_LOG_ERROR("normal cache store init failed : init messager failed"); @@ -67,6 +70,7 @@ bool NormalCacheStore::init(const CacheStoreInitParams& params) { } auto check_task_readiness = [this]() { + setCurrentThreadDeviceIfNeeded(this->device_id_); while (!thread_pool_close_) { std::this_thread::sleep_for(std::chrono::milliseconds(1)); std::unique_lock lock(store_tasks_mutex_); @@ -114,6 +118,7 @@ void NormalCacheStore::store(const std::shared_ptr& request_ metrics_reporter_, request_block_buffer->getBlocksCount(), request_block_buffer->getBlocksSize()); // task 只在threadpool中运行, threadpool退出前会清理所有running task, 用this是安全的 auto task = [this, request_block_buffer, callback, collector]() { + setCurrentThreadDeviceIfNeeded(this->device_id_); this->runStoreTask(request_block_buffer, callback, collector); }; @@ -192,6 +197,7 @@ void NormalCacheStore::load(const std::shared_ptr& request_b collector, partition_count, partition_id]() { + setCurrentThreadDeviceIfNeeded(this->device_id_); this->runLoadTask( request_block_buffer, callback, ip, port, rdma_port, timeout_ms, collector, partition_count, partition_id); }; diff --git a/rtp_llm/cpp/disaggregate/cache_store/NormalCacheStore.h b/rtp_llm/cpp/disaggregate/cache_store/NormalCacheStore.h index 8c9bd8e88c..22830d029f 100644 --- a/rtp_llm/cpp/disaggregate/cache_store/NormalCacheStore.h +++ b/rtp_llm/cpp/disaggregate/cache_store/NormalCacheStore.h @@ -46,10 +46,11 @@ class NormalCacheStore: public CacheStore { storeBuffers(const std::vector>& request_block_buffers, int64_t timeout_ms) override; - std::shared_ptr submitRemoteStoreTask(const std::shared_ptr& request, - const std::shared_ptr& collector, - RemoteStoreTask::CheckCancelFunc check_cancel_func) override; - void releaseRemoteStoreTask(const std::shared_ptr& task) override; + std::shared_ptr + submitRemoteStoreTask(const std::shared_ptr& request, + const std::shared_ptr& collector, + RemoteStoreTask::CheckCancelFunc check_cancel_func) override; + void releaseRemoteStoreTask(const std::shared_ptr& task) override; bool regUserBuffers(const std::vector>& buffers) override; std::shared_ptr findUserBuffer(const std::string& buffer_key) override; @@ -80,6 +81,7 @@ class NormalCacheStore: public CacheStore { private: bool thread_pool_close_{false}; + int device_id_{-1}; CacheStoreInitParams params_; std::shared_ptr memory_util_; std::shared_ptr request_block_buffer_store_; @@ -89,7 +91,9 @@ class NormalCacheStore: public CacheStore { std::shared_mutex remote_store_tasks_mutex_; std::unordered_map>> remote_store_tasks_; std::shared_mutex store_tasks_mutex_; - std::unordered_map, std::pair>> store_tasks_; + std::unordered_map, + std::pair>> + store_tasks_; }; } // namespace rtp_llm diff --git a/rtp_llm/cpp/disaggregate/cache_store/TcpBlockReadClosure.cpp b/rtp_llm/cpp/disaggregate/cache_store/TcpBlockReadClosure.cpp index 9a843673de..de40d15d7a 100644 --- a/rtp_llm/cpp/disaggregate/cache_store/TcpBlockReadClosure.cpp +++ b/rtp_llm/cpp/disaggregate/cache_store/TcpBlockReadClosure.cpp @@ -1,5 +1,6 @@ #include "rtp_llm/cpp/disaggregate/cache_store/TcpBlockReadClosure.h" #include "rtp_llm/models_py/bindings/core/ExecOps.h" +#include "rtp_llm/cpp/utils/DevicePin.h" #include "rtp_llm/cpp/utils/Logger.h" #include @@ -12,13 +13,15 @@ TcpBlockReadClosure::TcpBlockReadClosure(const std::vectorFailed()) { RTP_LLM_LOG_WARNING("tcp transfer connection read failed, error is %s", controller_->ErrorText().c_str()); end(false, CacheStoreUtil::fromArpcErrorCode(controller_->GetErrorCode())); @@ -76,4 +81,4 @@ void TcpBlockReadClosure::end(bool success, CacheStoreErrorCode error_code) { delete this; } -} // namespace rtp_llm \ No newline at end of file +} // namespace rtp_llm diff --git a/rtp_llm/cpp/disaggregate/cache_store/TcpBlockReadClosure.h b/rtp_llm/cpp/disaggregate/cache_store/TcpBlockReadClosure.h index 40379aa837..82d76cc646 100644 --- a/rtp_llm/cpp/disaggregate/cache_store/TcpBlockReadClosure.h +++ b/rtp_llm/cpp/disaggregate/cache_store/TcpBlockReadClosure.h @@ -13,7 +13,8 @@ class TcpBlockReadClosure: public RPCClosure { TransferConnection::ReadDoneCallback callback, BlockReadRequest* request, BlockReadResponse* response, - arpc::ANetRPCController* controller); + arpc::ANetRPCController* controller, + int device_id = -1); ~TcpBlockReadClosure(); public: @@ -30,5 +31,6 @@ class TcpBlockReadClosure: public RPCClosure { BlockReadRequest* request_; BlockReadResponse* response_; arpc::ANetRPCController* controller_; + int device_id_{-1}; }; -} // namespace rtp_llm \ No newline at end of file +} // namespace rtp_llm diff --git a/rtp_llm/cpp/disaggregate/cache_store/TcpCacheStoreLoadServiceClosure.cpp b/rtp_llm/cpp/disaggregate/cache_store/TcpCacheStoreLoadServiceClosure.cpp index cdae01474b..2ea56fa4ea 100644 --- a/rtp_llm/cpp/disaggregate/cache_store/TcpCacheStoreLoadServiceClosure.cpp +++ b/rtp_llm/cpp/disaggregate/cache_store/TcpCacheStoreLoadServiceClosure.cpp @@ -3,6 +3,7 @@ #include "rtp_llm/cpp/disaggregate/cache_store/MemoryUtil.h" #include #include "rtp_llm/cpp/disaggregate/cache_store/CacheStoreUtil.h" +#include "rtp_llm/cpp/utils/DevicePin.h" #include "rtp_llm/cpp/utils/Logger.h" namespace rtp_llm { @@ -20,6 +21,7 @@ TcpCacheStoreLoadServiceClosure::~TcpCacheStoreLoadServiceClosure() { } void TcpCacheStoreLoadServiceClosure::Run() { + setCurrentThreadDeviceIfNeeded(device_id_); collector_->markRequestCallEnd(currentTimeUs() - response_->response_send_start_time_us()); if (controller_->Failed()) { diff --git a/rtp_llm/cpp/disaggregate/cache_store/TcpCacheStoreLoadServiceClosure.h b/rtp_llm/cpp/disaggregate/cache_store/TcpCacheStoreLoadServiceClosure.h index 6a1ffcbe23..600e817ac6 100644 --- a/rtp_llm/cpp/disaggregate/cache_store/TcpCacheStoreLoadServiceClosure.h +++ b/rtp_llm/cpp/disaggregate/cache_store/TcpCacheStoreLoadServiceClosure.h @@ -18,14 +18,16 @@ class TcpCacheStoreLoadServiceClosure: public RPCClosure { CacheLoadRequest* request, CacheLoadResponse* response, CacheStoreLoadDoneCallback callback, - const std::shared_ptr& collector): + const std::shared_ptr& collector, + int device_id = -1): memory_util_(memory_util), request_block_buffer_(request_block_buffer), controller_(controller), request_(request), response_(response), callback_(callback), - collector_(collector) {} + collector_(collector), + device_id_(device_id) {} ~TcpCacheStoreLoadServiceClosure(); @@ -43,6 +45,7 @@ class TcpCacheStoreLoadServiceClosure: public RPCClosure { CacheLoadResponse* response_{nullptr}; CacheStoreLoadDoneCallback callback_{nullptr}; std::shared_ptr collector_; + int device_id_{-1}; }; -} // namespace rtp_llm \ No newline at end of file +} // namespace rtp_llm diff --git a/rtp_llm/cpp/disaggregate/cache_store/TcpCacheStoreServiceImpl.cpp b/rtp_llm/cpp/disaggregate/cache_store/TcpCacheStoreServiceImpl.cpp index dae91a4690..5b711b92c2 100644 --- a/rtp_llm/cpp/disaggregate/cache_store/TcpCacheStoreServiceImpl.cpp +++ b/rtp_llm/cpp/disaggregate/cache_store/TcpCacheStoreServiceImpl.cpp @@ -1,5 +1,6 @@ #include "rtp_llm/cpp/disaggregate/cache_store/TcpCacheStoreServiceImpl.h" #include "rtp_llm/models_py/bindings/core/ExecOps.h" +#include "rtp_llm/cpp/utils/DevicePin.h" #include "rtp_llm/cpp/utils/Logger.h" #include #include "rtp_llm/cpp/disaggregate/cache_store/TcpCacheStoreServiceImplContext.h" @@ -13,10 +14,12 @@ TcpCacheStoreServiceImpl::TcpCacheStoreServiceImpl( const kmonitor::MetricsReporterPtr& metrics_reporter, const std::shared_ptr& timer_manager, const std::shared_ptr& locked_block_buffer_manager, - const std::shared_ptr& tcp_client): + const std::shared_ptr& tcp_client, + int device_id): CacheStoreServiceImpl( memory_util, request_block_buffer_store, metrics_reporter, timer_manager, locked_block_buffer_manager), - tcp_client_(tcp_client) {} + tcp_client_(tcp_client), + device_id_(device_id) {} void TcpCacheStoreServiceImpl::loadImpl(::google::protobuf::RpcController* controller, const ::CacheLoadRequest* request, @@ -82,7 +85,7 @@ void TcpCacheStoreServiceImpl::transferImpl(::google::protobuf::RpcController* ::google::protobuf::Closure* done, const std::vector>& local_blocks, const std::vector>& remote_blocks) { - auto connection = tcp_client_->getTransferConnection(request->client_ip(), request->client_port()); + auto connection = tcp_client_->getTransferConnection(request->client_ip(), request->client_port(), device_id_); auto context = std::make_shared( request, response, done, local_blocks, remote_blocks, locked_block_buffer_manager_, memory_util_, connection); context->run(); @@ -92,6 +95,8 @@ void TcpCacheStoreServiceImpl::blockReadImpl(::google::protobuf::RpcController* const ::BlockReadRequest* request, BlockReadResponse* response, ::google::protobuf::Closure* done) { + setCurrentThreadDeviceIfNeeded(device_id_); + for (int i = 0; i < request->blocks_size(); i++) { auto& block_info = request->blocks(i); auto block_buffer = request_block_buffer_store_->findUserBuffer(block_info.key()); @@ -112,6 +117,7 @@ void TcpCacheStoreServiceImpl::blockReadImpl(::google::protobuf::RpcController* resp_block_info->set_addr(block_info.addr()); resp_block_info->set_len(block_info.len()); + // ROCm PyTorch exposes ordinary GPU tensors through the CUDA device type. auto src_tensor = torch::from_blob((void*)block_info.addr(), {(int64_t)block_info.len()}, torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA)); @@ -131,4 +137,4 @@ void TcpCacheStoreServiceImpl::blockReadImpl(::google::protobuf::RpcController* done->Run(); } -} // namespace rtp_llm \ No newline at end of file +} // namespace rtp_llm diff --git a/rtp_llm/cpp/disaggregate/cache_store/TcpCacheStoreServiceImpl.h b/rtp_llm/cpp/disaggregate/cache_store/TcpCacheStoreServiceImpl.h index a32cdf47ac..384412b57e 100644 --- a/rtp_llm/cpp/disaggregate/cache_store/TcpCacheStoreServiceImpl.h +++ b/rtp_llm/cpp/disaggregate/cache_store/TcpCacheStoreServiceImpl.h @@ -13,7 +13,8 @@ class TcpCacheStoreServiceImpl: public CacheStoreServiceImpl { const kmonitor::MetricsReporterPtr& metrics_reporter, const std::shared_ptr& timer_manager, const std::shared_ptr& locked_block_buffer_manager, - const std::shared_ptr& tcp_client); + const std::shared_ptr& tcp_client, + int device_id = -1); virtual ~TcpCacheStoreServiceImpl() = default; protected: @@ -41,6 +42,7 @@ class TcpCacheStoreServiceImpl: public CacheStoreServiceImpl { private: std::shared_ptr tcp_client_; + int device_id_{-1}; }; -} // namespace rtp_llm \ No newline at end of file +} // namespace rtp_llm diff --git a/rtp_llm/cpp/disaggregate/cache_store/TcpClient.cpp b/rtp_llm/cpp/disaggregate/cache_store/TcpClient.cpp index ffca4c2c22..a6e9f54cc2 100644 --- a/rtp_llm/cpp/disaggregate/cache_store/TcpClient.cpp +++ b/rtp_llm/cpp/disaggregate/cache_store/TcpClient.cpp @@ -77,12 +77,13 @@ std::shared_ptr TcpClient::openChannel(const std::string& dynamic_cast(rpc_channel_manager_->OpenChannel(spec, false, 1000ul))); } -std::shared_ptr TcpClient::getTransferConnection(const std::string& ip, uint32_t port) { +std::shared_ptr +TcpClient::getTransferConnection(const std::string& ip, uint32_t port, int device_id) { auto channel = getChannel(ip, port); if (channel == nullptr) { return nullptr; } - return std::make_shared(channel); + return std::make_shared(channel, device_id); } -} // namespace rtp_llm \ No newline at end of file +} // namespace rtp_llm diff --git a/rtp_llm/cpp/disaggregate/cache_store/TcpClient.h b/rtp_llm/cpp/disaggregate/cache_store/TcpClient.h index 4625e19310..56951ba020 100644 --- a/rtp_llm/cpp/disaggregate/cache_store/TcpClient.h +++ b/rtp_llm/cpp/disaggregate/cache_store/TcpClient.h @@ -14,7 +14,7 @@ class TcpClient { public: bool init(int io_thread_count); std::shared_ptr getChannel(const std::string& ip, uint32_t port); - std::shared_ptr getTransferConnection(const std::string& ip, uint32_t port); + std::shared_ptr getTransferConnection(const std::string& ip, uint32_t port, int device_id = -1); private: void stop(); @@ -29,4 +29,4 @@ class TcpClient { std::shared_ptr rpc_channel_manager_; }; -} // namespace rtp_llm \ No newline at end of file +} // namespace rtp_llm diff --git a/rtp_llm/cpp/disaggregate/cache_store/TcpMessager.cpp b/rtp_llm/cpp/disaggregate/cache_store/TcpMessager.cpp index 217e056217..65511f7d53 100644 --- a/rtp_llm/cpp/disaggregate/cache_store/TcpMessager.cpp +++ b/rtp_llm/cpp/disaggregate/cache_store/TcpMessager.cpp @@ -27,7 +27,8 @@ bool TcpMessager::init(MessagerInitParams params) { metrics_reporter_, timer_manager_, locked_block_buffer_manager_, - tcp_client_); + tcp_client_, + init_params_.device_id); if (!tcp_server_->registerService(service_.get())) { RTP_LLM_LOG_WARNING("messager init failed, tcp server register service failed"); return false; @@ -70,7 +71,8 @@ void TcpMessager::load(const std::shared_ptr& load_request, load_response, request->callback, - collector); + collector, + init_params_.device_id); collector->markRequestCallBegin(); KvCacheStoreService_Stub stub((::google::protobuf::RpcChannel*)(channel.get()), @@ -88,4 +90,4 @@ bool TcpMessager::generateBlockInfo(BlockBufferInfo* block_in return true; } -} // namespace rtp_llm \ No newline at end of file +} // namespace rtp_llm diff --git a/rtp_llm/cpp/disaggregate/cache_store/TcpTransferConnection.cpp b/rtp_llm/cpp/disaggregate/cache_store/TcpTransferConnection.cpp index 7279e5d5d9..2f1a83d7ce 100644 --- a/rtp_llm/cpp/disaggregate/cache_store/TcpTransferConnection.cpp +++ b/rtp_llm/cpp/disaggregate/cache_store/TcpTransferConnection.cpp @@ -3,7 +3,8 @@ namespace rtp_llm { -TcpTransferConnection::TcpTransferConnection(const std::shared_ptr& channel): channel_(channel) {} +TcpTransferConnection::TcpTransferConnection(const std::shared_ptr& channel, int device_id): + channel_(channel), device_id_(device_id) {} void TcpTransferConnection::read(const std::vector>& local_blocks, const std::vector>& remote_blocks, @@ -18,10 +19,11 @@ void TcpTransferConnection::read(const std::vector> auto response = new BlockReadResponse; arpc::ANetRPCController* controller = new arpc::ANetRPCController(); controller->SetExpireTime(timeout_ms); - auto closure = new TcpBlockReadClosure(local_blocks, remote_blocks, callback, request, response, controller); + auto closure = + new TcpBlockReadClosure(local_blocks, remote_blocks, callback, request, response, controller, device_id_); KvCacheStoreService_Stub stub((::google::protobuf::RpcChannel*)(channel_.get()), ::google::protobuf::Service::STUB_DOESNT_OWN_CHANNEL); stub.blockRead(controller, request, response, closure); } -} // namespace rtp_llm \ No newline at end of file +} // namespace rtp_llm diff --git a/rtp_llm/cpp/disaggregate/cache_store/TcpTransferConnection.h b/rtp_llm/cpp/disaggregate/cache_store/TcpTransferConnection.h index 36b32d7134..1ab475c376 100644 --- a/rtp_llm/cpp/disaggregate/cache_store/TcpTransferConnection.h +++ b/rtp_llm/cpp/disaggregate/cache_store/TcpTransferConnection.h @@ -9,7 +9,7 @@ namespace rtp_llm { class TcpTransferConnection: public TransferConnection { public: - TcpTransferConnection(const std::shared_ptr& channel); + TcpTransferConnection(const std::shared_ptr& channel, int device_id = -1); public: void read(const std::vector>& local_blocks, @@ -19,6 +19,7 @@ class TcpTransferConnection: public TransferConnection { private: std::shared_ptr channel_; + int device_id_{-1}; }; -} // namespace rtp_llm \ No newline at end of file +} // namespace rtp_llm diff --git a/rtp_llm/cpp/model_rpc/RemoteRpcServer.cc b/rtp_llm/cpp/model_rpc/RemoteRpcServer.cc index 1e4f41167d..631e66ef78 100644 --- a/rtp_llm/cpp/model_rpc/RemoteRpcServer.cc +++ b/rtp_llm/cpp/model_rpc/RemoteRpcServer.cc @@ -79,6 +79,7 @@ void RemoteRpcServer::initCacheStore(const EngineInitParams& init params.messager_io_thread_count = init_params.cache_store_config.messager_io_thread_count; params.messager_worker_thread_count = init_params.cache_store_config.messager_worker_thread_count; params.metrics_reporter = metrics_reporter_; + params.device_id = static_cast(init_params.parallelism_config.local_rank); RTP_LLM_LOG_INFO("cache store listen port is [%ld], rdma listen port is [%ld] rdma_mode is [%d]", params.listen_port, params.rdma_listen_port, diff --git a/rtp_llm/cpp/models/PyWrappedModel.h b/rtp_llm/cpp/models/PyWrappedModel.h index 06c62a0c7e..a6fb30101b 100644 --- a/rtp_llm/cpp/models/PyWrappedModel.h +++ b/rtp_llm/cpp/models/PyWrappedModel.h @@ -272,7 +272,7 @@ inline PyWrappedModel::PyWrappedModel(const GptModelInitParams& params, throw std::runtime_error("PyWrappedModel constructor: Python model initialization failed."); } - cache_store_async_writer_ = std::make_unique(); + cache_store_async_writer_ = std::make_unique(params.parallelism_config.local_rank); if (device_props_.enable_prefill_cp) { context_parallel_processor_ = diff --git a/rtp_llm/cpp/utils/BUILD b/rtp_llm/cpp/utils/BUILD index bb3e70b3de..d6f8550c7f 100644 --- a/rtp_llm/cpp/utils/BUILD +++ b/rtp_llm/cpp/utils/BUILD @@ -54,6 +54,28 @@ cc_library( visibility = ["//visibility:public"], ) +cc_library( + name = "device_pin", + hdrs = [ + "DevicePin.h", + ], + deps = [ + ":core_utils", + ] + torch_deps() + select({ + "//:using_cuda": [ + "@local_config_cuda//cuda:cuda_headers", + "@local_config_cuda//cuda:cudart", + ], + "@//:using_rocm": [ + "@local_config_rocm//rocm:rocm_headers", + "@local_config_rocm//rocm:rocm", + ], + "//conditions:default": [], + }), + copts = copts(), + visibility = ["//visibility:public"], +) + cc_library( name = "error_code", hdrs = [ @@ -209,4 +231,4 @@ cc_library( ], copts = copts(), visibility = ["//visibility:public"], -) \ No newline at end of file +) diff --git a/rtp_llm/cpp/utils/DevicePin.h b/rtp_llm/cpp/utils/DevicePin.h new file mode 100644 index 0000000000..21ebfc1b96 --- /dev/null +++ b/rtp_llm/cpp/utils/DevicePin.h @@ -0,0 +1,59 @@ +#pragma once + +#include + +#include "rtp_llm/cpp/utils/Logger.h" + +#if USING_CUDA +#include +#include +#elif USING_ROCM +#include +#include +#endif + +namespace rtp_llm { + +namespace detail { + +template +inline void setCurrentThreadDeviceIfNeededImpl(int device_id, int& current_device, SetDevice&& set_device) { + if (device_id < 0) { + return; + } + + if (current_device == device_id) { + return; + } + + std::forward(set_device)(device_id); + // Cache only successful backend calls so invalid devices still retry later. + current_device = device_id; +} + +} // namespace detail + +inline void setCurrentThreadDeviceIfNeeded(int device_id) { + if (device_id < 0) { + return; + } + +#if USING_CUDA || USING_ROCM + thread_local int current_device = -1; + + // Thread-pool workers may serve different cache stores over time; a new + // device_id intentionally retargets the current thread instead of no-oping. + // ROCm PyTorch exposes ordinary GPU tensors as CUDA; only pinning uses HIP APIs. +#if USING_CUDA + detail::setCurrentThreadDeviceIfNeededImpl( + device_id, current_device, [](int device) { at::cuda::set_device(device); }); +#elif USING_ROCM + detail::setCurrentThreadDeviceIfNeededImpl( + device_id, current_device, [](int device) { at::hip::set_device(device); }); +#endif +#else + // CPU-only builds intentionally no-op; production cache-store builds pin to a GPU backend. +#endif +} + +} // namespace rtp_llm diff --git a/rtp_llm/cpp/utils/test/BUILD b/rtp_llm/cpp/utils/test/BUILD index 0f29b4adcc..47406f313d 100644 --- a/rtp_llm/cpp/utils/test/BUILD +++ b/rtp_llm/cpp/utils/test/BUILD @@ -71,4 +71,13 @@ cc_test( env = { "TEST_USING_DEVICE": "CUDA", }, -) \ No newline at end of file +) + +cc_test( + name = "device_pin_test", + srcs = [ + "DevicePinTest.cc", + ], + copts = copts(), + deps = google_test_deps + ["//rtp_llm/cpp/utils:device_pin"], +) diff --git a/rtp_llm/cpp/utils/test/DevicePinTest.cc b/rtp_llm/cpp/utils/test/DevicePinTest.cc new file mode 100644 index 0000000000..2cb4b665fa --- /dev/null +++ b/rtp_llm/cpp/utils/test/DevicePinTest.cc @@ -0,0 +1,61 @@ +#include "gtest/gtest.h" + +#include "rtp_llm/cpp/utils/DevicePin.h" + +#include +#include + +namespace rtp_llm { + +TEST(DevicePinTest, NegativeDeviceIdSkipsSetterAndKeepsCache) { + int cached_device = 3; + int set_calls = 0; + + detail::setCurrentThreadDeviceIfNeededImpl(-1, cached_device, [&set_calls](int) { ++set_calls; }); + + ASSERT_EQ(0, set_calls); + ASSERT_EQ(3, cached_device); +} + +TEST(DevicePinTest, CachedDeviceIdSkipsRepeatedSetDevice) { + int cached_device = -1; + std::vector set_devices; + + auto set_device = [&set_devices](int device) { set_devices.push_back(device); }; + + detail::setCurrentThreadDeviceIfNeededImpl(0, cached_device, set_device); + detail::setCurrentThreadDeviceIfNeededImpl(0, cached_device, set_device); + + ASSERT_EQ(1u, set_devices.size()); + ASSERT_EQ(0, set_devices[0]); + ASSERT_EQ(0, cached_device); +} + +TEST(DevicePinTest, DifferentDeviceIdRetargetsThread) { + int cached_device = 0; + std::vector set_devices; + + detail::setCurrentThreadDeviceIfNeededImpl( + 1, cached_device, [&set_devices](int device) { set_devices.push_back(device); }); + + ASSERT_EQ(1u, set_devices.size()); + ASSERT_EQ(1, set_devices[0]); + ASSERT_EQ(1, cached_device); +} + +TEST(DevicePinTest, SetterExceptionPropagatesAndDoesNotUpdateCache) { + int cached_device = 0; + + ASSERT_THROW(detail::setCurrentThreadDeviceIfNeededImpl( + 1, cached_device, [](int) { throw std::runtime_error("set device failed"); }), + std::runtime_error); + ASSERT_EQ(0, cached_device); + + int set_calls = 0; + detail::setCurrentThreadDeviceIfNeededImpl(1, cached_device, [&set_calls](int) { ++set_calls; }); + + ASSERT_EQ(1, set_calls); + ASSERT_EQ(1, cached_device); +} + +} // namespace rtp_llm diff --git a/rtp_llm/models_py/bindings/core/BUILD b/rtp_llm/models_py/bindings/core/BUILD index 65ce847bd2..7410107247 100644 --- a/rtp_llm/models_py/bindings/core/BUILD +++ b/rtp_llm/models_py/bindings/core/BUILD @@ -335,6 +335,7 @@ cc_library( srcs = ["CacheStoreAsyncWriter.cc"], deps = [ "//rtp_llm/cpp/utils:core_utils", + "//rtp_llm/cpp/utils:device_pin", "@havenask//aios/autil:lock_free", ], visibility = ["//visibility:public"], diff --git a/rtp_llm/models_py/bindings/core/CacheStoreAsyncWriter.cc b/rtp_llm/models_py/bindings/core/CacheStoreAsyncWriter.cc index 1cc3dd7417..64f32a6a30 100644 --- a/rtp_llm/models_py/bindings/core/CacheStoreAsyncWriter.cc +++ b/rtp_llm/models_py/bindings/core/CacheStoreAsyncWriter.cc @@ -1,10 +1,17 @@ #include "rtp_llm/models_py/bindings/core/CacheStoreAsyncWriter.h" #include "autil/LockFreeThreadPool.h" #include "rtp_llm/cpp/utils/AssertUtils.h" +#include "rtp_llm/cpp/utils/DevicePin.h" namespace rtp_llm { -CacheStoreAsyncWriter::CacheStoreAsyncWriter() { +CacheStoreAsyncWriter::PendingTaskGuard::PendingTaskGuard(CacheStoreAsyncWriter& writer): writer_(writer) {} + +CacheStoreAsyncWriter::PendingTaskGuard::~PendingTaskGuard() { + writer_.completePendingTask(); +} + +CacheStoreAsyncWriter::CacheStoreAsyncWriter(int device_id): device_id_(device_id) { constexpr size_t kThreadCount = 3; constexpr size_t kQueueSize = 10000; auto pool = std::make_shared(kThreadCount, kQueueSize, nullptr, "CacheStoreAsync"); @@ -12,6 +19,20 @@ CacheStoreAsyncWriter::CacheStoreAsyncWriter() { thread_pool_ = std::move(pool); } +void CacheStoreAsyncWriter::completePendingTask() { + if (pending_count_.fetch_sub(1, std::memory_order_acq_rel) == 1) { + std::lock_guard lock(wait_mutex_); + wait_cv_.notify_all(); + } +} + +void CacheStoreAsyncWriter::storeCurrentException() { + std::lock_guard ex_lock(exception_mutex_); + if (!stored_exception_) { + stored_exception_ = std::current_exception(); + } +} + CacheStoreAsyncWriter::~CacheStoreAsyncWriter() { if (state_ == State::RUNNING) { RTP_LLM_LOG_WARNING("CacheStoreAsyncWriter destroyed while RUNNING — " @@ -43,26 +64,19 @@ void CacheStoreAsyncWriter::submit(std::function task) { pending_count_.fetch_add(1, std::memory_order_acq_rel); auto wrapped = [this, task = std::move(task)]() { + PendingTaskGuard pending_task_guard(*this); try { + setCurrentThreadDeviceIfNeeded(device_id_); task(); } catch (...) { - { - std::lock_guard ex_lock(exception_mutex_); - if (!stored_exception_) { - stored_exception_ = std::current_exception(); - } - } + storeCurrentException(); RTP_LLM_LOG_ERROR("CacheStoreAsyncWriter: background task threw an exception"); } - if (pending_count_.fetch_sub(1, std::memory_order_acq_rel) == 1) { - std::lock_guard lock(wait_mutex_); - wait_cv_.notify_all(); - } }; auto rc = thread_pool_->pushTask(std::move(wrapped)); if (rc != autil::ThreadPoolBase::ERROR_NONE) { - pending_count_.fetch_sub(1, std::memory_order_acq_rel); + completePendingTask(); RTP_LLM_CHECK_WITH_INFO(false, "CacheStoreAsyncWriter: pushTask failed (rc=%d). " "Queue full or thread pool in bad state.", diff --git a/rtp_llm/models_py/bindings/core/CacheStoreAsyncWriter.h b/rtp_llm/models_py/bindings/core/CacheStoreAsyncWriter.h index c6f1cfdcec..0f2197ebc0 100644 --- a/rtp_llm/models_py/bindings/core/CacheStoreAsyncWriter.h +++ b/rtp_llm/models_py/bindings/core/CacheStoreAsyncWriter.h @@ -17,7 +17,7 @@ namespace rtp_llm { // Lifecycle: init() -> submit()* -> waitAllDone() -> init() -> ... class CacheStoreAsyncWriter { public: - CacheStoreAsyncWriter(); + explicit CacheStoreAsyncWriter(int device_id = -1); ~CacheStoreAsyncWriter(); void init(); @@ -25,6 +25,21 @@ class CacheStoreAsyncWriter { void waitAllDone(); private: + class PendingTaskGuard { + public: + explicit PendingTaskGuard(CacheStoreAsyncWriter& writer); + ~PendingTaskGuard(); + + PendingTaskGuard(const PendingTaskGuard&) = delete; + PendingTaskGuard& operator=(const PendingTaskGuard&) = delete; + + private: + CacheStoreAsyncWriter& writer_; + }; + + void completePendingTask(); + void storeCurrentException(); + enum class State { IDLE, RUNNING @@ -38,6 +53,7 @@ class CacheStoreAsyncWriter { std::mutex exception_mutex_; std::exception_ptr stored_exception_; State state_{State::IDLE}; + int device_id_{-1}; }; } // namespace rtp_llm diff --git a/rtp_llm/models_py/bindings/core/test/CacheStoreAsyncWriterTest.cpp b/rtp_llm/models_py/bindings/core/test/CacheStoreAsyncWriterTest.cpp index 2def1ec2ec..84eb78039b 100644 --- a/rtp_llm/models_py/bindings/core/test/CacheStoreAsyncWriterTest.cpp +++ b/rtp_llm/models_py/bindings/core/test/CacheStoreAsyncWriterTest.cpp @@ -1,21 +1,88 @@ #include "gtest/gtest.h" -#define private public #include "rtp_llm/models_py/bindings/core/CacheStoreAsyncWriter.h" #include +#include +#include #include #include +#if USING_CUDA +#include +#elif USING_ROCM +#include +#endif + namespace rtp_llm { +#if USING_CUDA || USING_ROCM +namespace { + +int gpuDeviceCountForTest() { +#if USING_CUDA + int device_count = 0; + return cudaGetDeviceCount(&device_count) == cudaSuccess ? device_count : 0; +#elif USING_ROCM + int device_count = 0; + return hipGetDeviceCount(&device_count) == hipSuccess ? device_count : 0; +#else + return 0; +#endif +} + +int currentDeviceForTest() { +#if USING_CUDA + int device = -1; + if (cudaGetDevice(&device) != cudaSuccess) { + return -1; + } + return device; +#elif USING_ROCM + int device = -1; + if (hipGetDevice(&device) != hipSuccess) { + return -1; + } + return device; +#else + return -1; +#endif +} + +bool setDeviceForTest(int device) { +#if USING_CUDA + return cudaSetDevice(device) == cudaSuccess; +#elif USING_ROCM + return hipSetDevice(device) == hipSuccess; +#else + return false; +#endif +} + +class ScopedDeviceResetForTest { +public: + ScopedDeviceResetForTest(): original_device_(currentDeviceForTest()) {} + ~ScopedDeviceResetForTest() { + if (original_device_ >= 0) { + setDeviceForTest(original_device_); + } + } + + ScopedDeviceResetForTest(const ScopedDeviceResetForTest&) = delete; + ScopedDeviceResetForTest& operator=(const ScopedDeviceResetForTest&) = delete; + +private: + int original_device_; +}; + +} // namespace +#endif + class CacheStoreAsyncWriterTest: public ::testing::Test {}; TEST_F(CacheStoreAsyncWriterTest, InitAndWaitBasic) { CacheStoreAsyncWriter writer; - ASSERT_TRUE(writer.state_ == CacheStoreAsyncWriter::State::IDLE); writer.init(); - ASSERT_FALSE(writer.state_ == CacheStoreAsyncWriter::State::IDLE); std::atomic counter{0}; writer.submit([&counter]() { counter.fetch_add(1); }); @@ -23,7 +90,6 @@ TEST_F(CacheStoreAsyncWriterTest, InitAndWaitBasic) { writer.submit([&counter]() { counter.fetch_add(1); }); writer.waitAllDone(); - ASSERT_TRUE(writer.state_ == CacheStoreAsyncWriter::State::IDLE); ASSERT_EQ(3, counter.load()); } @@ -86,6 +152,37 @@ TEST_F(CacheStoreAsyncWriterTest, AsyncExecution) { ASSERT_TRUE(different_thread.load()); } +TEST_F(CacheStoreAsyncWriterTest, AsyncExecutionWithDeviceId) { +#if USING_CUDA || USING_ROCM + if (gpuDeviceCountForTest() < 2) { + GTEST_SKIP() << "Need at least two GPU devices to prove non-default device pinning"; + } + + ScopedDeviceResetForTest reset_device; + constexpr int kMainThreadDevice = 0; + constexpr int kWriterDevice = 1; + ASSERT_TRUE(setDeviceForTest(kMainThreadDevice)); + ASSERT_EQ(kMainThreadDevice, currentDeviceForTest()); + + CacheStoreAsyncWriter writer(kWriterDevice); + writer.init(); + + std::atomic counter{0}; + std::atomic observed_device{-1}; + writer.submit([&counter, &observed_device]() { + observed_device.store(currentDeviceForTest(), std::memory_order_release); + counter.fetch_add(1); + }); + writer.waitAllDone(); + + ASSERT_EQ(1, counter.load()); + ASSERT_EQ(kWriterDevice, observed_device.load(std::memory_order_acquire)); + ASSERT_EQ(kMainThreadDevice, currentDeviceForTest()); +#else + GTEST_SKIP() << "GPU device pinning is unavailable in CPU-only builds"; +#endif +} + TEST_F(CacheStoreAsyncWriterTest, ExceptionPropagation) { CacheStoreAsyncWriter writer; writer.init(); @@ -95,7 +192,6 @@ TEST_F(CacheStoreAsyncWriterTest, ExceptionPropagation) { ASSERT_THROW(writer.waitAllDone(), std::runtime_error); // After exception, writer should be back in IDLE and re-initializable. - ASSERT_TRUE(writer.state_ == CacheStoreAsyncWriter::State::IDLE); writer.init(); std::atomic counter{0}; writer.submit([&counter]() { counter.fetch_add(1); }); @@ -123,7 +219,6 @@ TEST_F(CacheStoreAsyncWriterTest, WaitWithoutSubmit) { CacheStoreAsyncWriter writer; writer.init(); writer.waitAllDone(); - ASSERT_TRUE(writer.state_ == CacheStoreAsyncWriter::State::IDLE); } TEST_F(CacheStoreAsyncWriterTest, ManyCycles) {