Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions rtp_llm/cpp/disaggregate/cache_store/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ cc_library(
deps = [
":cache_store_interface",
"//rtp_llm/cpp/disaggregate/cache_store/proto:cache_store_service_cc_proto",
"//rtp_llm/cpp/utils:device_pin",
"//rtp_llm/cpp/utils:time_util",
"//:rtp_compute_ops",
":arpc_dep",
Expand Down
2 changes: 2 additions & 0 deletions rtp_llm/cpp/disaggregate/cache_store/InitParams.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ struct MessagerInitParams {
uint32_t server_port = 0;
uint32_t io_thread_count = 2;
uint32_t worker_thread_count = 4;
int device_id = -1;

uint32_t rdma_server_port = 0;
uint32_t rdma_io_thread_count = 1;
Expand Down Expand Up @@ -43,6 +44,7 @@ class CacheStoreInitParams {
uint32_t messager_worker_thread_count = 32;

kmonitor::MetricsReporterPtr metrics_reporter;
int device_id{-1};

// for test
std::shared_ptr<MemoryUtil> memory_util;
Expand Down
8 changes: 7 additions & 1 deletion rtp_llm/cpp/disaggregate/cache_store/NormalCacheStore.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "rtp_llm/cpp/disaggregate/cache_store/NormalCacheStore.h"
#include "rtp_llm/cpp/disaggregate/cache_store/Interface.h"
#include "rtp_llm/cpp/utils/DevicePin.h"
#include "rtp_llm/cpp/utils/Logger.h"

#include "autil/LockFreeThreadPool.h"
Expand Down Expand Up @@ -30,7 +31,8 @@ std::shared_ptr<NormalCacheStore> NormalCacheStore::createNormalCacheStore(const
}

bool NormalCacheStore::init(const CacheStoreInitParams& params) {
params_ = params;
params_ = params;
device_id_ = params.device_id;

if (params_.memory_util != nullptr) {
memory_util_ = params.memory_util;
Expand All @@ -53,6 +55,7 @@ bool NormalCacheStore::init(const CacheStoreInitParams& params) {
messager_init_params.rdma_worker_thread_count = params.rdma_worker_thread_count;
messager_init_params.io_thread_count = params.messager_io_thread_count;
messager_init_params.worker_thread_count = params.messager_worker_thread_count;
messager_init_params.device_id = params.device_id;

if (!messager_->init(messager_init_params)) {
RTP_LLM_LOG_ERROR("normal cache store init failed : init messager failed");
Expand All @@ -67,6 +70,7 @@ bool NormalCacheStore::init(const CacheStoreInitParams& params) {
}

auto check_task_readiness = [this]() {
setCurrentThreadDeviceIfNeeded(this->device_id_);
while (!thread_pool_close_) {
std::this_thread::sleep_for(std::chrono::milliseconds(1));
std::unique_lock<std::shared_mutex> lock(store_tasks_mutex_);
Expand Down Expand Up @@ -114,6 +118,7 @@ void NormalCacheStore::store(const std::shared_ptr<RequestBlockBuffer>& request_
metrics_reporter_, request_block_buffer->getBlocksCount(), request_block_buffer->getBlocksSize());
// task 只在threadpool中运行, threadpool退出前会清理所有running task, 用this是安全的
auto task = [this, request_block_buffer, callback, collector]() {
setCurrentThreadDeviceIfNeeded(this->device_id_);
this->runStoreTask(request_block_buffer, callback, collector);
};

Expand Down Expand Up @@ -192,6 +197,7 @@ void NormalCacheStore::load(const std::shared_ptr<RequestBlockBuffer>& request_b
collector,
partition_count,
partition_id]() {
setCurrentThreadDeviceIfNeeded(this->device_id_);
this->runLoadTask(
request_block_buffer, callback, ip, port, rdma_port, timeout_ms, collector, partition_count, partition_id);
};
Expand Down
14 changes: 9 additions & 5 deletions rtp_llm/cpp/disaggregate/cache_store/NormalCacheStore.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,11 @@ class NormalCacheStore: public CacheStore {
storeBuffers(const std::vector<std::shared_ptr<RequestBlockBuffer>>& request_block_buffers,
int64_t timeout_ms) override;

std::shared_ptr<RemoteStoreTask> submitRemoteStoreTask(const std::shared_ptr<RemoteStoreRequest>& request,
const std::shared_ptr<CacheStoreRemoteStoreMetricsCollector>& collector,
RemoteStoreTask::CheckCancelFunc check_cancel_func) override;
void releaseRemoteStoreTask(const std::shared_ptr<RemoteStoreTask>& task) override;
std::shared_ptr<RemoteStoreTask>
submitRemoteStoreTask(const std::shared_ptr<RemoteStoreRequest>& request,
const std::shared_ptr<CacheStoreRemoteStoreMetricsCollector>& collector,
RemoteStoreTask::CheckCancelFunc check_cancel_func) override;
void releaseRemoteStoreTask(const std::shared_ptr<RemoteStoreTask>& task) override;

bool regUserBuffers(const std::vector<std::shared_ptr<BlockBuffer>>& buffers) override;
std::shared_ptr<BlockBuffer> findUserBuffer(const std::string& buffer_key) override;
Expand Down Expand Up @@ -80,6 +81,7 @@ class NormalCacheStore: public CacheStore {

private:
bool thread_pool_close_{false};
int device_id_{-1};
CacheStoreInitParams params_;
std::shared_ptr<MemoryUtil> memory_util_;
std::shared_ptr<RequestBlockBufferStore> request_block_buffer_store_;
Expand All @@ -89,7 +91,9 @@ class NormalCacheStore: public CacheStore {
std::shared_mutex remote_store_tasks_mutex_;
std::unordered_map<std::string, std::list<std::shared_ptr<RemoteStoreTaskImpl>>> remote_store_tasks_;
std::shared_mutex store_tasks_mutex_;
std::unordered_map<std::shared_ptr<RequestBlockBuffer>, std::pair<CacheStoreStoreDoneCallback, std::function<void()>>> store_tasks_;
std::unordered_map<std::shared_ptr<RequestBlockBuffer>,
std::pair<CacheStoreStoreDoneCallback, std::function<void()>>>
store_tasks_;
};

} // namespace rtp_llm
11 changes: 8 additions & 3 deletions rtp_llm/cpp/disaggregate/cache_store/TcpBlockReadClosure.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "rtp_llm/cpp/disaggregate/cache_store/TcpBlockReadClosure.h"
#include "rtp_llm/models_py/bindings/core/ExecOps.h"
#include "rtp_llm/cpp/utils/DevicePin.h"
#include "rtp_llm/cpp/utils/Logger.h"
#include <torch/torch.h>

Expand All @@ -12,13 +13,15 @@ TcpBlockReadClosure::TcpBlockReadClosure(const std::vector<std::shared_ptr<Block
TransferConnection::ReadDoneCallback callback,
BlockReadRequest* request,
BlockReadResponse* response,
arpc::ANetRPCController* controller):
arpc::ANetRPCController* controller,
int device_id):
local_blocks_(local_blocks),
remote_blocks_(remote_blocks),
callback_(callback),
request_(request),
response_(response),
controller_(controller) {}
controller_(controller),
device_id_(device_id) {}

TcpBlockReadClosure::~TcpBlockReadClosure() {
delete request_;
Expand All @@ -27,6 +30,8 @@ TcpBlockReadClosure::~TcpBlockReadClosure() {
}

void TcpBlockReadClosure::Run() {
setCurrentThreadDeviceIfNeeded(device_id_);

if (controller_->Failed()) {
RTP_LLM_LOG_WARNING("tcp transfer connection read failed, error is %s", controller_->ErrorText().c_str());
end(false, CacheStoreUtil::fromArpcErrorCode(controller_->GetErrorCode()));
Expand Down Expand Up @@ -76,4 +81,4 @@ void TcpBlockReadClosure::end(bool success, CacheStoreErrorCode error_code) {
delete this;
}

} // namespace rtp_llm
} // namespace rtp_llm
6 changes: 4 additions & 2 deletions rtp_llm/cpp/disaggregate/cache_store/TcpBlockReadClosure.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ class TcpBlockReadClosure: public RPCClosure {
TransferConnection::ReadDoneCallback callback,
BlockReadRequest* request,
BlockReadResponse* response,
arpc::ANetRPCController* controller);
arpc::ANetRPCController* controller,
int device_id = -1);
~TcpBlockReadClosure();

public:
Expand All @@ -30,5 +31,6 @@ class TcpBlockReadClosure: public RPCClosure {
BlockReadRequest* request_;
BlockReadResponse* response_;
arpc::ANetRPCController* controller_;
int device_id_{-1};
};
} // namespace rtp_llm
} // namespace rtp_llm
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "rtp_llm/cpp/disaggregate/cache_store/MemoryUtil.h"
#include <torch/torch.h>
#include "rtp_llm/cpp/disaggregate/cache_store/CacheStoreUtil.h"
#include "rtp_llm/cpp/utils/DevicePin.h"
#include "rtp_llm/cpp/utils/Logger.h"

namespace rtp_llm {
Expand All @@ -20,6 +21,7 @@ TcpCacheStoreLoadServiceClosure::~TcpCacheStoreLoadServiceClosure() {
}

void TcpCacheStoreLoadServiceClosure::Run() {
setCurrentThreadDeviceIfNeeded(device_id_);
collector_->markRequestCallEnd(currentTimeUs() - response_->response_send_start_time_us());

if (controller_->Failed()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,16 @@ class TcpCacheStoreLoadServiceClosure: public RPCClosure {
CacheLoadRequest* request,
CacheLoadResponse* response,
CacheStoreLoadDoneCallback callback,
const std::shared_ptr<CacheStoreClientLoadMetricsCollector>& collector):
const std::shared_ptr<CacheStoreClientLoadMetricsCollector>& collector,
int device_id = -1):
memory_util_(memory_util),
request_block_buffer_(request_block_buffer),
controller_(controller),
request_(request),
response_(response),
callback_(callback),
collector_(collector) {}
collector_(collector),
device_id_(device_id) {}

~TcpCacheStoreLoadServiceClosure();

Expand All @@ -43,6 +45,7 @@ class TcpCacheStoreLoadServiceClosure: public RPCClosure {
CacheLoadResponse* response_{nullptr};
CacheStoreLoadDoneCallback callback_{nullptr};
std::shared_ptr<CacheStoreClientLoadMetricsCollector> collector_;
int device_id_{-1};
};

} // namespace rtp_llm
} // namespace rtp_llm
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "rtp_llm/cpp/disaggregate/cache_store/TcpCacheStoreServiceImpl.h"
#include "rtp_llm/models_py/bindings/core/ExecOps.h"
#include "rtp_llm/cpp/utils/DevicePin.h"
#include "rtp_llm/cpp/utils/Logger.h"
#include <torch/torch.h>
#include "rtp_llm/cpp/disaggregate/cache_store/TcpCacheStoreServiceImplContext.h"
Expand All @@ -13,10 +14,12 @@ TcpCacheStoreServiceImpl::TcpCacheStoreServiceImpl(
const kmonitor::MetricsReporterPtr& metrics_reporter,
const std::shared_ptr<TimerManager>& timer_manager,
const std::shared_ptr<LockedBlockBufferManager>& locked_block_buffer_manager,
const std::shared_ptr<TcpClient>& tcp_client):
const std::shared_ptr<TcpClient>& tcp_client,
int device_id):
CacheStoreServiceImpl(
memory_util, request_block_buffer_store, metrics_reporter, timer_manager, locked_block_buffer_manager),
tcp_client_(tcp_client) {}
tcp_client_(tcp_client),
device_id_(device_id) {}

void TcpCacheStoreServiceImpl::loadImpl(::google::protobuf::RpcController* controller,
const ::CacheLoadRequest* request,
Expand Down Expand Up @@ -82,7 +85,7 @@ void TcpCacheStoreServiceImpl::transferImpl(::google::protobuf::RpcController*
::google::protobuf::Closure* done,
const std::vector<std::shared_ptr<BlockBuffer>>& local_blocks,
const std::vector<std::shared_ptr<BlockBufferInfo>>& remote_blocks) {
auto connection = tcp_client_->getTransferConnection(request->client_ip(), request->client_port());
auto connection = tcp_client_->getTransferConnection(request->client_ip(), request->client_port(), device_id_);
auto context = std::make_shared<CacheTransferServiceImplContext>(
request, response, done, local_blocks, remote_blocks, locked_block_buffer_manager_, memory_util_, connection);
context->run();
Expand All @@ -92,6 +95,8 @@ void TcpCacheStoreServiceImpl::blockReadImpl(::google::protobuf::RpcController*
const ::BlockReadRequest* request,
BlockReadResponse* response,
::google::protobuf::Closure* done) {
setCurrentThreadDeviceIfNeeded(device_id_);

for (int i = 0; i < request->blocks_size(); i++) {
auto& block_info = request->blocks(i);
auto block_buffer = request_block_buffer_store_->findUserBuffer(block_info.key());
Expand All @@ -112,6 +117,7 @@ void TcpCacheStoreServiceImpl::blockReadImpl(::google::protobuf::RpcController*
resp_block_info->set_addr(block_info.addr());
resp_block_info->set_len(block_info.len());

// ROCm PyTorch exposes ordinary GPU tensors through the CUDA device type.
auto src_tensor = torch::from_blob((void*)block_info.addr(),
{(int64_t)block_info.len()},
torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA));
Expand All @@ -131,4 +137,4 @@ void TcpCacheStoreServiceImpl::blockReadImpl(::google::protobuf::RpcController*
done->Run();
}

} // namespace rtp_llm
} // namespace rtp_llm
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ class TcpCacheStoreServiceImpl: public CacheStoreServiceImpl {
const kmonitor::MetricsReporterPtr& metrics_reporter,
const std::shared_ptr<TimerManager>& timer_manager,
const std::shared_ptr<LockedBlockBufferManager>& locked_block_buffer_manager,
const std::shared_ptr<TcpClient>& tcp_client);
const std::shared_ptr<TcpClient>& tcp_client,
int device_id = -1);
virtual ~TcpCacheStoreServiceImpl() = default;

protected:
Expand Down Expand Up @@ -41,6 +42,7 @@ class TcpCacheStoreServiceImpl: public CacheStoreServiceImpl {

private:
std::shared_ptr<TcpClient> tcp_client_;
int device_id_{-1};
};

} // namespace rtp_llm
} // namespace rtp_llm
7 changes: 4 additions & 3 deletions rtp_llm/cpp/disaggregate/cache_store/TcpClient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,13 @@ std::shared_ptr<arpc::RPCChannelBase> TcpClient::openChannel(const std::string&
dynamic_cast<arpc::RPCChannelBase*>(rpc_channel_manager_->OpenChannel(spec, false, 1000ul)));
}

std::shared_ptr<TransferConnection> TcpClient::getTransferConnection(const std::string& ip, uint32_t port) {
std::shared_ptr<TransferConnection>
TcpClient::getTransferConnection(const std::string& ip, uint32_t port, int device_id) {
auto channel = getChannel(ip, port);
if (channel == nullptr) {
return nullptr;
}
return std::make_shared<TcpTransferConnection>(channel);
return std::make_shared<TcpTransferConnection>(channel, device_id);
}

} // namespace rtp_llm
} // namespace rtp_llm
4 changes: 2 additions & 2 deletions rtp_llm/cpp/disaggregate/cache_store/TcpClient.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class TcpClient {
public:
bool init(int io_thread_count);
std::shared_ptr<arpc::RPCChannelBase> getChannel(const std::string& ip, uint32_t port);
std::shared_ptr<TransferConnection> getTransferConnection(const std::string& ip, uint32_t port);
std::shared_ptr<TransferConnection> getTransferConnection(const std::string& ip, uint32_t port, int device_id = -1);

private:
void stop();
Expand All @@ -29,4 +29,4 @@ class TcpClient {
std::shared_ptr<arpc::ANetRPCChannelManager> rpc_channel_manager_;
};

} // namespace rtp_llm
} // namespace rtp_llm
8 changes: 5 additions & 3 deletions rtp_llm/cpp/disaggregate/cache_store/TcpMessager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ bool TcpMessager::init(MessagerInitParams params) {
metrics_reporter_,
timer_manager_,
locked_block_buffer_manager_,
tcp_client_);
tcp_client_,
init_params_.device_id);
if (!tcp_server_->registerService(service_.get())) {
RTP_LLM_LOG_WARNING("messager init failed, tcp server register service failed");
return false;
Expand Down Expand Up @@ -70,7 +71,8 @@ void TcpMessager::load(const std::shared_ptr<LoadRequest>&
load_request,
load_response,
request->callback,
collector);
collector,
init_params_.device_id);

collector->markRequestCallBegin();
KvCacheStoreService_Stub stub((::google::protobuf::RpcChannel*)(channel.get()),
Expand All @@ -88,4 +90,4 @@ bool TcpMessager::generateBlockInfo(BlockBufferInfo* block_in
return true;
}

} // namespace rtp_llm
} // namespace rtp_llm
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

namespace rtp_llm {

TcpTransferConnection::TcpTransferConnection(const std::shared_ptr<arpc::RPCChannelBase>& channel): channel_(channel) {}
TcpTransferConnection::TcpTransferConnection(const std::shared_ptr<arpc::RPCChannelBase>& channel, int device_id):
channel_(channel), device_id_(device_id) {}

void TcpTransferConnection::read(const std::vector<std::shared_ptr<BlockBuffer>>& local_blocks,
const std::vector<std::shared_ptr<BlockBufferInfo>>& remote_blocks,
Expand All @@ -18,10 +19,11 @@ void TcpTransferConnection::read(const std::vector<std::shared_ptr<BlockBuffer>>
auto response = new BlockReadResponse;
arpc::ANetRPCController* controller = new arpc::ANetRPCController();
controller->SetExpireTime(timeout_ms);
auto closure = new TcpBlockReadClosure(local_blocks, remote_blocks, callback, request, response, controller);
auto closure =
new TcpBlockReadClosure(local_blocks, remote_blocks, callback, request, response, controller, device_id_);

KvCacheStoreService_Stub stub((::google::protobuf::RpcChannel*)(channel_.get()),
::google::protobuf::Service::STUB_DOESNT_OWN_CHANNEL);
stub.blockRead(controller, request, response, closure);
}
} // namespace rtp_llm
} // namespace rtp_llm
Loading
Loading