Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions benchmarks/benchmark_single_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,19 +133,20 @@ def benchmark_flexkv(model_config: ModelConfig,
all_tokens = 0
start_time = time.time()
batch_get_ids = []
return_masks = []
cached_tokens = 0
for i in range(batch_size):
all_tokens += len(batch_sequence_tensor[i])
task_id, _ = kvmanager.get_match(batch_sequence_tensor[i],
task_id, return_mask = kvmanager.get_match(batch_sequence_tensor[i],
token_mask=None)
batch_get_ids.append(task_id)
cached_tokens += return_mask.sum().item()
get_match_time = time.time() - start_time
kvmanager.launch(batch_get_ids, batch_slot_mapping)
get_result = kvmanager.wait(batch_get_ids)
batch_id_list =kvmanager.launch(batch_get_ids, batch_slot_mapping, as_batch=True, layerwise_transfer=True)
get_result = kvmanager.wait(batch_id_list)
elapsed_time_get = time.time() - start_time
cached_tokens = 0
for _, response in get_result.items():
if response.status == KVResponseStatus.SUCCESS:
cached_tokens += response.return_mask.sum().item()
assert response.status == KVResponseStatus.SUCCESS
transfer_data_size_GB = cached_tokens * model_config.token_size_in_bytes / 1024 / 1024 / 1024
transfer_bandwidth_get = transfer_data_size_GB / elapsed_time_get
print(f"get {cached_tokens} tokens, data_size: {transfer_data_size_GB:.3f} GB, "
Expand Down
38 changes: 33 additions & 5 deletions csrc/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "transfer.cuh"
#include "transfer_ssd.h"
#include "radix_tree.h"
#include "layerwise.h"

namespace py = pybind11;

Expand All @@ -32,7 +33,7 @@ void transfer_kv_blocks_binding(
int64_t cpu_kv_stride_in_bytes, int64_t cpu_layer_stride_in_bytes,
int64_t cpu_block_stride_in_bytes, int64_t chunk_size_in_bytes,
int start_layer_id, int num_layers, int transfer_sms = -1, bool is_host_to_device = true,
bool use_ce_transfer = false, bool is_mla = false, int gpu_block_type = 0) {
bool use_ce_transfer = false, bool is_mla = false, int gpu_block_type = 0, bool sync = true) {
int num_blocks = gpu_block_id_tensor.numel();

int64_t *gpu_block_ids =
Expand Down Expand Up @@ -74,21 +75,21 @@ void transfer_kv_blocks_binding(
num_blocks, start_layer_id, num_layers, gpu_block_ids, handler, 0,
cpu_block_ids, cpu_ptr, cpu_kv_stride_in_bytes, cpu_layer_stride_in_bytes,
cpu_block_stride_in_bytes, 0, chunk_size_in_bytes, stream, transfer_sms,
is_host_to_device, use_ce_transfer, is_mla);
is_host_to_device, use_ce_transfer, is_mla, sync);
break;
case flexkv::BackendType::TRTLLM:
flexkv::transfer_kv_blocks<flexkv::BackendType::TRTLLM>(
num_blocks, start_layer_id, num_layers, gpu_block_ids, handler, 0,
cpu_block_ids, cpu_ptr, cpu_kv_stride_in_bytes, cpu_layer_stride_in_bytes,
cpu_block_stride_in_bytes, 0, chunk_size_in_bytes, stream, transfer_sms,
is_host_to_device, use_ce_transfer, is_mla);
is_host_to_device, use_ce_transfer, is_mla, sync);
break;
case flexkv::BackendType::SGLANG:
flexkv::transfer_kv_blocks<flexkv::BackendType::SGLANG>(
num_blocks, start_layer_id, num_layers, gpu_block_ids, handler, 0,
cpu_block_ids, cpu_ptr, cpu_kv_stride_in_bytes, cpu_layer_stride_in_bytes,
cpu_block_stride_in_bytes, 0, chunk_size_in_bytes, stream, transfer_sms,
is_host_to_device, use_ce_transfer, is_mla);
is_host_to_device, use_ce_transfer, is_mla, sync);
break;
}

Expand Down Expand Up @@ -351,7 +352,7 @@ PYBIND11_MODULE(c_ext, m) {
py::arg("chunk_size_in_bytes"), py::arg("start_layer_id"),
py::arg("num_layers"), py::arg("transfer_sms") = -1,
py::arg("is_host_to_device") = true, py::arg("use_ce_transfer") = false,
py::arg("is_mla") = false, py::arg("gpu_block_type") = 0);
py::arg("is_mla") = false, py::arg("gpu_block_type") = 0, py::arg("sync") = true);
m.def("transfer_kv_blocks_ssd", &transfer_kv_blocks_ssd_binding,
"Transfer KV blocks between SSD and CPU memory",
py::arg("ioctx"), py::arg("cpu_layer_id_list"),
Expand All @@ -362,6 +363,33 @@ PYBIND11_MODULE(c_ext, m) {
py::arg("block_stride_in_bytes"), py::arg("is_read"),
py::arg("num_blocks_per_file"), py::arg("round_robin") = 1,
py::arg("num_threads_per_device") = 16, py::arg("is_mla") = false);
py::class_<flexkv::LayerwiseTransferGroup>(m, "LayerwiseTransferGroup")
.def(py::init<int, const std::vector<std::vector<torch::Tensor>> &,
torch::Tensor &, std::map<int, std::vector<std::string>> &,
int, int, torch::Tensor &, torch::Tensor &, torch::Tensor &,
torch::Tensor &, int, int, torch::Tensor &, int>(),
py::arg("num_gpus"), py::arg("gpu_blocks"), py::arg("cpu_blocks"),
py::arg("ssd_files"), py::arg("dp_group_id"), py::arg("num_layers"),
py::arg("gpu_kv_strides_tensor"),
py::arg("gpu_block_strides_tensor"),
py::arg("gpu_layer_strides_tensor"),
py::arg("gpu_chunk_sizes_tensor"), py::arg("iouring_entries"),
py::arg("iouring_flags"), py::arg("layer_eventfds_tensor"),
py::arg("tp_size"))
.def("layerwise_transfer",
&flexkv::LayerwiseTransferGroup::layerwise_transfer,
py::arg("ssd_block_ids"), py::arg("cpu_block_ids_d2h"),
py::arg("ssd_layer_stride_in_bytes"),
py::arg("ssd_kv_stride_in_bytes"), py::arg("num_blocks_per_file"),
py::arg("round_robin"), py::arg("num_threads_per_device"),
py::arg("gpu_block_id_tensor"), py::arg("cpu_block_id_tensor"),
py::arg("cpu_kv_stride_in_bytes"),
py::arg("cpu_layer_stride_in_bytes"),
py::arg("cpu_block_stride_in_bytes"),
py::arg("cpu_chunk_size_in_bytes"), py::arg("transfer_sms"),
py::arg("use_ce_transfer"), py::arg("num_layers"),
py::arg("layer_granularity"), py::arg("is_mla"),
py::arg("counter_id") = 0);
#ifdef FLEXKV_ENABLE_CFS
m.def("transfer_kv_blocks_remote", &transfer_kv_blocks_remote,
"Transfer KV blocks between remote and CPU memory",
Expand Down
Loading