diff --git a/benchmarks/benchmark_single_batch.py b/benchmarks/benchmark_single_batch.py index c629b545ea..69af2f46c7 100644 --- a/benchmarks/benchmark_single_batch.py +++ b/benchmarks/benchmark_single_batch.py @@ -133,19 +133,20 @@ def benchmark_flexkv(model_config: ModelConfig, all_tokens = 0 start_time = time.time() batch_get_ids = [] + return_masks = [] + cached_tokens = 0 for i in range(batch_size): all_tokens += len(batch_sequence_tensor[i]) - task_id, _ = kvmanager.get_match(batch_sequence_tensor[i], + task_id, return_mask = kvmanager.get_match(batch_sequence_tensor[i], token_mask=None) batch_get_ids.append(task_id) + cached_tokens += return_mask.sum().item() get_match_time = time.time() - start_time - kvmanager.launch(batch_get_ids, batch_slot_mapping) - get_result = kvmanager.wait(batch_get_ids) + batch_id_list =kvmanager.launch(batch_get_ids, batch_slot_mapping, as_batch=True, layerwise_transfer=True) + get_result = kvmanager.wait(batch_id_list) elapsed_time_get = time.time() - start_time - cached_tokens = 0 for _, response in get_result.items(): - if response.status == KVResponseStatus.SUCCESS: - cached_tokens += response.return_mask.sum().item() + assert response.status == KVResponseStatus.SUCCESS transfer_data_size_GB = cached_tokens * model_config.token_size_in_bytes / 1024 / 1024 / 1024 transfer_bandwidth_get = transfer_data_size_GB / elapsed_time_get print(f"get {cached_tokens} tokens, data_size: {transfer_data_size_GB:.3f} GB, " diff --git a/csrc/bindings.cpp b/csrc/bindings.cpp index ecdb00f3d7..f69e6494fd 100644 --- a/csrc/bindings.cpp +++ b/csrc/bindings.cpp @@ -22,6 +22,7 @@ #include "transfer.cuh" #include "transfer_ssd.h" #include "radix_tree.h" +#include "layerwise.h" namespace py = pybind11; @@ -32,7 +33,7 @@ void transfer_kv_blocks_binding( int64_t cpu_kv_stride_in_bytes, int64_t cpu_layer_stride_in_bytes, int64_t cpu_block_stride_in_bytes, int64_t chunk_size_in_bytes, int start_layer_id, int num_layers, int transfer_sms = -1, bool is_host_to_device = true, - bool use_ce_transfer = false, bool is_mla = false, int gpu_block_type = 0) { + bool use_ce_transfer = false, bool is_mla = false, int gpu_block_type = 0, bool sync = true) { int num_blocks = gpu_block_id_tensor.numel(); int64_t *gpu_block_ids = @@ -74,21 +75,21 @@ void transfer_kv_blocks_binding( num_blocks, start_layer_id, num_layers, gpu_block_ids, handler, 0, cpu_block_ids, cpu_ptr, cpu_kv_stride_in_bytes, cpu_layer_stride_in_bytes, cpu_block_stride_in_bytes, 0, chunk_size_in_bytes, stream, transfer_sms, - is_host_to_device, use_ce_transfer, is_mla); + is_host_to_device, use_ce_transfer, is_mla, sync); break; case flexkv::BackendType::TRTLLM: flexkv::transfer_kv_blocks( num_blocks, start_layer_id, num_layers, gpu_block_ids, handler, 0, cpu_block_ids, cpu_ptr, cpu_kv_stride_in_bytes, cpu_layer_stride_in_bytes, cpu_block_stride_in_bytes, 0, chunk_size_in_bytes, stream, transfer_sms, - is_host_to_device, use_ce_transfer, is_mla); + is_host_to_device, use_ce_transfer, is_mla, sync); break; case flexkv::BackendType::SGLANG: flexkv::transfer_kv_blocks( num_blocks, start_layer_id, num_layers, gpu_block_ids, handler, 0, cpu_block_ids, cpu_ptr, cpu_kv_stride_in_bytes, cpu_layer_stride_in_bytes, cpu_block_stride_in_bytes, 0, chunk_size_in_bytes, stream, transfer_sms, - is_host_to_device, use_ce_transfer, is_mla); + is_host_to_device, use_ce_transfer, is_mla, sync); break; } @@ -351,7 +352,7 @@ PYBIND11_MODULE(c_ext, m) { py::arg("chunk_size_in_bytes"), py::arg("start_layer_id"), py::arg("num_layers"), py::arg("transfer_sms") = -1, py::arg("is_host_to_device") = true, py::arg("use_ce_transfer") = false, - py::arg("is_mla") = false, py::arg("gpu_block_type") = 0); + py::arg("is_mla") = false, py::arg("gpu_block_type") = 0, py::arg("sync") = true); m.def("transfer_kv_blocks_ssd", &transfer_kv_blocks_ssd_binding, "Transfer KV blocks between SSD and CPU memory", py::arg("ioctx"), py::arg("cpu_layer_id_list"), @@ -362,6 +363,33 @@ PYBIND11_MODULE(c_ext, m) { py::arg("block_stride_in_bytes"), py::arg("is_read"), py::arg("num_blocks_per_file"), py::arg("round_robin") = 1, py::arg("num_threads_per_device") = 16, py::arg("is_mla") = false); + py::class_(m, "LayerwiseTransferGroup") + .def(py::init> &, + torch::Tensor &, std::map> &, + int, int, torch::Tensor &, torch::Tensor &, torch::Tensor &, + torch::Tensor &, int, int, torch::Tensor &, int>(), + py::arg("num_gpus"), py::arg("gpu_blocks"), py::arg("cpu_blocks"), + py::arg("ssd_files"), py::arg("dp_group_id"), py::arg("num_layers"), + py::arg("gpu_kv_strides_tensor"), + py::arg("gpu_block_strides_tensor"), + py::arg("gpu_layer_strides_tensor"), + py::arg("gpu_chunk_sizes_tensor"), py::arg("iouring_entries"), + py::arg("iouring_flags"), py::arg("layer_eventfds_tensor"), + py::arg("tp_size")) + .def("layerwise_transfer", + &flexkv::LayerwiseTransferGroup::layerwise_transfer, + py::arg("ssd_block_ids"), py::arg("cpu_block_ids_d2h"), + py::arg("ssd_layer_stride_in_bytes"), + py::arg("ssd_kv_stride_in_bytes"), py::arg("num_blocks_per_file"), + py::arg("round_robin"), py::arg("num_threads_per_device"), + py::arg("gpu_block_id_tensor"), py::arg("cpu_block_id_tensor"), + py::arg("cpu_kv_stride_in_bytes"), + py::arg("cpu_layer_stride_in_bytes"), + py::arg("cpu_block_stride_in_bytes"), + py::arg("cpu_chunk_size_in_bytes"), py::arg("transfer_sms"), + py::arg("use_ce_transfer"), py::arg("num_layers"), + py::arg("layer_granularity"), py::arg("is_mla"), + py::arg("counter_id") = 0); #ifdef FLEXKV_ENABLE_CFS m.def("transfer_kv_blocks_remote", &transfer_kv_blocks_remote, "Transfer KV blocks between remote and CPU memory", diff --git a/csrc/layerwise.cpp b/csrc/layerwise.cpp new file mode 100644 index 0000000000..728c53a6da --- /dev/null +++ b/csrc/layerwise.cpp @@ -0,0 +1,416 @@ +#include "layerwise.h" +#include +#include +#include +#include +#include +#include +#include + +namespace flexkv { + +struct LayerCallbackData { + int start_layer; + int layers_this_batch; + int num_gpus; + std::atomic *counter; + // Eventfd info for notification + bool enable_eventfd; + int tp_size; + int num_layers; + int *layer_eventfds; // Pointer to eventfds array for current counter set + // NVTX range id for CPU->GPU transfer + nvtxRangeId_t *current_range_id_ptr; // Pointer to current layer's range ID + bool is_last_batch; // Whether this is the last batch + char next_range_name[64]; // Name for next layer's range (if not last batch) + nvtxRangeId_t *next_range_id_ptr; // Pointer to next layer's range ID storage +}; + +static void CUDART_CB layer_done_host_callback(void *userData) { + LayerCallbackData *data = static_cast(userData); + int completed = data->counter->fetch_add(1) + 1; + if (completed == data->num_gpus) { + // Notify via eventfd when all GPUs complete this layer batch + if (data->enable_eventfd && data->layer_eventfds != nullptr) { + // Signal each tp_rank's eventfd for completed layers + for (int layer = data->start_layer; + layer < data->start_layer + data->layers_this_batch; ++layer) { + for (int tp_rank = 0; tp_rank < data->tp_size; ++tp_rank) { + int fd = data->layer_eventfds[tp_rank * data->num_layers + layer]; + if (fd >= 0) { + // Write 2 to support both get_key_buffer and get_value_buffer waits + uint64_t val = 2; + ssize_t ret = write(fd, &val, sizeof(val)); + } + } + } + } + // End current NVTX range when all GPUs complete + if (data->current_range_id_ptr != nullptr && *data->current_range_id_ptr != 0) { + nvtxRangeEnd(*data->current_range_id_ptr); + } + // Start next layer's NVTX range (so it begins right after current layer ends) + if (!data->is_last_batch && data->next_range_id_ptr != nullptr) { + *data->next_range_id_ptr = nvtxRangeStartA(data->next_range_name); + } + delete data->counter; + } + delete data; +} + +LayerwiseTransferGroup::LayerwiseTransferGroup( + int num_gpus, const std::vector> &gpu_blocks, + torch::Tensor &cpu_blocks, + std::map> &ssd_files, int dp_group_id, + int num_layers, torch::Tensor &gpu_kv_strides_tensor, + torch::Tensor &gpu_block_strides_tensor, + torch::Tensor &gpu_layer_strides_tensor, + torch::Tensor &gpu_chunk_sizes_tensor, int iouring_entries, + int iouring_flags, torch::Tensor &layer_eventfds_tensor, int tp_size) { + + num_gpus_ = num_gpus; + num_layers_ = num_layers; + tp_size_ = tp_size; + current_counter_id_ = 0; + + // Initialize eventfds + enable_eventfd_ = (layer_eventfds_tensor.numel() > 0); + if (enable_eventfd_) { + // layer_eventfds_tensor layout: [num_counters, tp_size, num_layers] + // Index formula: counter_id * tp_size * num_layers + tp_rank * num_layers + layer + int total_fds = layer_eventfds_tensor.numel(); + num_counters_ = total_fds / (tp_size * num_layers); + + int32_t *fds_ptr = layer_eventfds_tensor.data_ptr(); + layer_eventfds_.assign(fds_ptr, fds_ptr + total_fds); + + printf("[LayerwiseTransferGroup] Initialized with eventfds: " + "tp_size=%d, num_counters=%d, num_layers=%d, total_fds=%d\n", + tp_size_, num_counters_, num_layers_, total_fds); + } else { + num_counters_ = 0; + printf("[LayerwiseTransferGroup] Initialized without eventfds\n"); + } + + gpu_kv_strides_in_bytes_ = new int64_t[num_gpus]; + gpu_block_strides_in_bytes_ = new int64_t[num_gpus]; + gpu_layer_strides_in_bytes_ = new int64_t[num_gpus]; + gpu_chunk_sizes_in_bytes_ = new int64_t[num_gpus]; + + int64_t *kv_strides_ptr = gpu_kv_strides_tensor.data_ptr(); + int64_t *block_strides_ptr = gpu_block_strides_tensor.data_ptr(); + int64_t *layer_strides_ptr = gpu_layer_strides_tensor.data_ptr(); + int64_t *chunk_sizes_ptr = gpu_chunk_sizes_tensor.data_ptr(); + + for (int i = 0; i < num_gpus; i++) { + gpu_kv_strides_in_bytes_[i] = kv_strides_ptr[i]; + gpu_block_strides_in_bytes_[i] = block_strides_ptr[i]; + gpu_chunk_sizes_in_bytes_[i] = chunk_sizes_ptr[i]; + gpu_layer_strides_in_bytes_[i] = layer_strides_ptr[i]; + } + + num_tensors_per_gpu_ = gpu_blocks[0].size(); + cudaMallocHost((void **)&gpu_blocks_, + num_gpus_ * num_tensors_per_gpu_ * sizeof(void *)); + for (int i = 0; i < num_gpus_; ++i) { + for (int j = 0; j < num_tensors_per_gpu_; ++j) { + gpu_blocks_[i * num_tensors_per_gpu_ + j] = gpu_blocks[i][j].data_ptr(); + } + } + + if (num_tensors_per_gpu_ == 1) { + backend_type_ = BackendType::TRTLLM; + } else if (num_tensors_per_gpu_ == num_layers) { + backend_type_ = BackendType::VLLM; + } else if (num_tensors_per_gpu_ == num_layers * 2) { + backend_type_ = BackendType::SGLANG; + } else { + throw std::runtime_error("Unsupported GPU block type: " + + std::to_string(num_tensors_per_gpu_)); + } + + gpu_tensor_handlers_.reserve(num_gpus_); + for (int i = 0; i < num_gpus_; i++) { + int64_t **gpu_blocks_ptr = + reinterpret_cast(gpu_blocks_ + i * num_tensors_per_gpu_); + gpu_tensor_handlers_.emplace_back( + backend_type_, gpu_blocks_ptr, num_layers, gpu_kv_strides_in_bytes_[i], + gpu_block_strides_in_bytes_[i], gpu_layer_strides_in_bytes_[i]); + } + + cpu_blocks_ = cpu_blocks.data_ptr(); + + dp_group_id_ = dp_group_id; + + // Create CUDA streams for each GPU + streams_.resize(num_gpus_); + events_.resize(num_gpus_); + + // Get highest priority (lowest value) + int leastPriority, greatestPriority; + cudaDeviceGetStreamPriorityRange(&leastPriority, &greatestPriority); + + for (int i = 0; i < num_gpus_; i++) { + cudaSetDevice(dp_group_id * num_gpus_ + i); + cudaStreamCreateWithPriority(&streams_[i], cudaStreamNonBlocking, greatestPriority); + cudaEventCreate(&events_[i]); + } + + // Initialize SSD IO context if ssd_files is not empty + enable_ssd_ = !ssd_files.empty(); + if (enable_ssd_) { + ioctx_ = std::make_unique(ssd_files, ssd_files.size(), + iouring_entries, iouring_flags); + } +} + +LayerwiseTransferGroup::~LayerwiseTransferGroup() { + for (int i = 0; i < num_gpus_; i++) { + cudaSetDevice(dp_group_id_ * num_gpus_ + i); + cudaStreamDestroy(streams_[i]); + cudaEventDestroy(events_[i]); + } + + cudaFreeHost(gpu_blocks_); + + gpu_tensor_handlers_.clear(); + delete[] gpu_kv_strides_in_bytes_; + delete[] gpu_block_strides_in_bytes_; + delete[] gpu_layer_strides_in_bytes_; + delete[] gpu_chunk_sizes_in_bytes_; +} + +void LayerwiseTransferGroup::layer_done_callback(int start_layer, + int layers_this_batch, + nvtxRangeId_t *current_range_id_ptr, + bool is_last_batch, + const char *next_range_name, + nvtxRangeId_t *next_range_id_ptr) { + std::atomic *counter = new std::atomic(0); + + // Get eventfd pointer for current counter set + int *eventfds_ptr = nullptr; + if (enable_eventfd_ && num_counters_ > 0) { + // Offset into layer_eventfds_ for current counter set + int offset = current_counter_id_ * tp_size_ * num_layers_; + eventfds_ptr = layer_eventfds_.data() + offset; + } + + for (int i = 0; i < num_gpus_; ++i) { + LayerCallbackData *data = new LayerCallbackData{ + start_layer, layers_this_batch, num_gpus_, counter, + enable_eventfd_, tp_size_, num_layers_, eventfds_ptr, + current_range_id_ptr, is_last_batch, {0}, next_range_id_ptr}; + // Copy next range name + if (next_range_name != nullptr) { + snprintf(data->next_range_name, sizeof(data->next_range_name), "%s", next_range_name); + } + cudaLaunchHostFunc(streams_[i], layer_done_host_callback, data); + } +} + +void LayerwiseTransferGroup::layerwise_transfer( + const torch::Tensor &ssd_block_ids, const torch::Tensor &cpu_block_ids_d2h, + const int64_t ssd_layer_stride_in_bytes, + const int64_t ssd_kv_stride_in_bytes, const int num_blocks_per_file, + const int round_robin, const int num_threads_per_device, + const torch::Tensor &gpu_block_id_tensor, + const torch::Tensor &cpu_block_id_tensor, + const int64_t cpu_kv_stride_in_bytes, + const int64_t cpu_layer_stride_in_bytes, + const int64_t cpu_block_stride_in_bytes, + const int64_t cpu_chunk_size_in_bytes, const int transfer_sms, + const bool use_ce_transfer, const int num_layers, + const int layer_granularity, const bool is_mla, + const int counter_id) { + + // Set current counter ID for eventfd notification + current_counter_id_ = counter_id; + + int num_blocks = gpu_block_id_tensor.numel(); + int64_t *gpu_block_ids = + static_cast(gpu_block_id_tensor.data_ptr()); + int64_t *cpu_block_ids = + static_cast(cpu_block_id_tensor.data_ptr()); + void *cpu_ptr = cpu_blocks_; + + // Create CUDA events for timing each layer batch (on GPU 0) + int num_batches = (num_layers + layer_granularity - 1) / layer_granularity; + std::vector timing_events(num_batches + 1); // +1 for start event + std::vector batch_start_layers(num_batches); + std::vector batch_layers_count(num_batches); + + cudaSetDevice(dp_group_id_ * num_gpus_); + for (int i = 0; i <= num_batches; ++i) { + cudaEventCreate(&timing_events[i]); + } + + // Record start event + cudaEventRecord(timing_events[0], streams_[0]); + + // Allocate storage for NVTX range IDs (one per batch) + std::vector h2d_range_ids(num_batches, 0); + // Pre-generate all range names with data size info + std::vector h2d_range_names(num_batches); + for (int b = 0; b < num_batches; ++b) { + int sl = b * layer_granularity; + int ltb = std::min(layer_granularity, num_layers - sl); + // Calculate data size for this batch: chunk_size * 2 (K+V) * layers * num_blocks + int64_t bytes_this_batch = 0; + for (int g = 0; g < num_gpus_; ++g) { + bytes_this_batch += gpu_chunk_sizes_in_bytes_[g] * 2 * ltb * num_blocks; + } + double mb_this_batch = bytes_this_batch / (1024.0 * 1024.0); + char name[128]; + snprintf(name, sizeof(name), "CPU->GPU Layer[%d,%d) %.2fMB", sl, sl + ltb, mb_this_batch); + h2d_range_names[b] = name; + } + + // Start the first batch's NVTX range in main thread + if (num_batches > 0) { + h2d_range_ids[0] = nvtxRangeStartA(h2d_range_names[0].c_str()); + } + + int batch_idx = 0; + for (int start_layer = 0; start_layer < num_layers; + start_layer += layer_granularity) { + int layers_this_batch = + std::min(layer_granularity, num_layers - start_layer); + + batch_start_layers[batch_idx] = start_layer; + batch_layers_count[batch_idx] = layers_this_batch; + + // Step 1: SSD -> CPU transfer + if (enable_ssd_ && ssd_block_ids.numel() > 0) { + // Calculate SSD->CPU data size: cpu_chunk_size * 2 (K+V) * layers * num_ssd_blocks + int num_ssd_blocks = ssd_block_ids.numel(); + int64_t ssd_bytes = cpu_chunk_size_in_bytes * 2 * layers_this_batch * num_ssd_blocks; + double ssd_mb = ssd_bytes / (1024.0 * 1024.0); + char ssd_range_name[128]; + snprintf(ssd_range_name, sizeof(ssd_range_name), + "SSD->CPU Layer[%d,%d) %.2fMB", start_layer, start_layer + layers_this_batch, ssd_mb); + nvtxRangePushA(ssd_range_name); + + torch::Tensor layer_id_list = + torch::arange(start_layer, start_layer + layers_this_batch, + torch::TensorOptions().dtype(torch::kInt32)); + transfer_kv_blocks_ssd( + *ioctx_, layer_id_list, reinterpret_cast(cpu_blocks_), + ssd_block_ids, cpu_block_ids_d2h, cpu_layer_stride_in_bytes, + cpu_kv_stride_in_bytes, ssd_layer_stride_in_bytes, + ssd_kv_stride_in_bytes, cpu_chunk_size_in_bytes, + cpu_block_stride_in_bytes, + true, // is_read: SSD -> CPU + num_blocks_per_file, round_robin, num_threads_per_device, is_mla); + + nvtxRangePop(); + } + + // Step 2: CPU -> GPU transfer + // NVTX range for this batch was already started (by main thread for first batch, + // or by previous batch's callback for subsequent batches) + + for (int i = 0; i < num_gpus_; ++i) { + // TODO: support multi-instance + cudaSetDevice(dp_group_id_ * num_gpus_ + i); + int64_t cpu_startoff_inside_chunks = i * gpu_chunk_sizes_in_bytes_[i]; + if (is_mla) { + cpu_startoff_inside_chunks = 0; + } + int64_t gpu_startoff_inside_chunks = 0; + int64_t chunk_size = gpu_chunk_sizes_in_bytes_[i]; + + switch (backend_type_) { + case BackendType::VLLM: + flexkv::transfer_kv_blocks( + num_blocks, start_layer, layers_this_batch, gpu_block_ids, + gpu_tensor_handlers_[i], gpu_startoff_inside_chunks, cpu_block_ids, + cpu_ptr, cpu_kv_stride_in_bytes, cpu_layer_stride_in_bytes, + cpu_block_stride_in_bytes, cpu_startoff_inside_chunks, chunk_size, + streams_[i], transfer_sms, true, use_ce_transfer, is_mla, false); + break; + case BackendType::TRTLLM: + flexkv::transfer_kv_blocks( + num_blocks, start_layer, layers_this_batch, gpu_block_ids, + gpu_tensor_handlers_[i], gpu_startoff_inside_chunks, cpu_block_ids, + cpu_ptr, cpu_kv_stride_in_bytes, cpu_layer_stride_in_bytes, + cpu_block_stride_in_bytes, cpu_startoff_inside_chunks, chunk_size, + streams_[i], transfer_sms, true, use_ce_transfer, is_mla, false); + break; + case BackendType::SGLANG: + flexkv::transfer_kv_blocks( + num_blocks, start_layer, layers_this_batch, gpu_block_ids, + gpu_tensor_handlers_[i], gpu_startoff_inside_chunks, cpu_block_ids, + cpu_ptr, cpu_kv_stride_in_bytes, cpu_layer_stride_in_bytes, + cpu_block_stride_in_bytes, cpu_startoff_inside_chunks, chunk_size, + streams_[i], transfer_sms, true, use_ce_transfer, is_mla, false); + break; + } + } + + // Record event after this batch on GPU 0 + cudaSetDevice(dp_group_id_ * num_gpus_); + cudaEventRecord(timing_events[batch_idx + 1], streams_[0]); + + // NVTX: current range ends in callback, next range starts in callback + bool is_last_batch = (batch_idx == num_batches - 1); + const char *next_name = is_last_batch ? nullptr : h2d_range_names[batch_idx + 1].c_str(); + nvtxRangeId_t *next_id_ptr = is_last_batch ? nullptr : &h2d_range_ids[batch_idx + 1]; + + layer_done_callback(start_layer, layers_this_batch, + &h2d_range_ids[batch_idx], is_last_batch, + next_name, next_id_ptr); + batch_idx++; + } + for (int i = 0; i < num_gpus_; ++i) { + cudaError_t err = cudaStreamSynchronize(streams_[i]); + if (err != cudaSuccess) { + throw std::runtime_error("layerwise_transfer failed on GPU " + + std::to_string(i) + ": " + + cudaGetErrorString(err)); + } + } + + // Calculate and print timing for each layer batch + // chunk_size per GPU * num_gpus * 2 (K+V) * layers_this_batch * num_blocks + // fprintf(stderr, "\n[LayerwiseTransfer] CPU->GPU Transfer Timing (num_blocks=%d):\n", num_blocks); + float total_time_ms = 0.0f; + int64_t total_bytes = 0; + + for (int i = 0; i < num_batches; ++i) { + float elapsed_ms = 0.0f; + cudaEventElapsedTime(&elapsed_ms, timing_events[i], timing_events[i + 1]); + + // Calculate bytes transferred for this batch + // For each GPU: chunk_size * 2 (K+V) * layers * num_blocks + int64_t bytes_this_batch = 0; + for (int g = 0; g < num_gpus_; ++g) { + bytes_this_batch += gpu_chunk_sizes_in_bytes_[g] * 2 * batch_layers_count[i] * num_blocks; + } + + double bandwidth_gbps = (bytes_this_batch / (1024.0 * 1024.0 * 1024.0)) / (elapsed_ms / 1000.0); + + // fprintf(stderr, " Layers [%d, %d): time=%.3f ms, size=%.2f MB, bandwidth=%.2f GB/s\n", + // batch_start_layers[i], + // batch_start_layers[i] + batch_layers_count[i], + // elapsed_ms, + // bytes_this_batch / (1024.0 * 1024.0), + // bandwidth_gbps); + + total_time_ms += elapsed_ms; + total_bytes += bytes_this_batch; + } + + double total_bandwidth_gbps = (total_bytes / (1024.0 * 1024.0 * 1024.0)) / (total_time_ms / 1000.0); + // fprintf(stderr, " Total: time=%.3f ms, size=%.2f MB, avg_bandwidth=%.2f GB/s\n\n", + // total_time_ms, total_bytes / (1024.0 * 1024.0), total_bandwidth_gbps); + // fflush(stderr); + + // Cleanup timing events + cudaSetDevice(dp_group_id_ * num_gpus_); + for (int i = 0; i <= num_batches; ++i) { + cudaEventDestroy(timing_events[i]); + } +} + +} // namespace flexkv diff --git a/csrc/layerwise.h b/csrc/layerwise.h new file mode 100644 index 0000000000..e1941f589d --- /dev/null +++ b/csrc/layerwise.h @@ -0,0 +1,92 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "gtensor_handler.cuh" +#include "transfer.cuh" +#include "transfer_ssd.h" + +namespace flexkv { + +class LayerwiseTransferGroup { +public: + LayerwiseTransferGroup( + int num_gpus, const std::vector> &gpu_blocks, + torch::Tensor &cpu_blocks, + std::map> &ssd_files, int dp_group_id, + int num_layers, torch::Tensor &gpu_kv_strides_tensor, + torch::Tensor &gpu_block_strides_tensor, + torch::Tensor &gpu_layer_strides_tensor, + torch::Tensor &gpu_chunk_sizes_tensor, int iouring_entries, + int iouring_flags, torch::Tensor &layer_eventfds_tensor, int tp_size); + + ~LayerwiseTransferGroup(); + + // Layerwise transfer: SSD->CPU + CPU->GPU + void layerwise_transfer( + const torch::Tensor + &ssd_block_ids, // SSD source block ids (for disk2host) + const torch::Tensor + &cpu_block_ids_d2h, // CPU dest block ids (for disk2host) + const int64_t ssd_layer_stride_in_bytes, + const int64_t ssd_kv_stride_in_bytes, const int num_blocks_per_file, + const int round_robin, const int num_threads_per_device, + const torch::Tensor + &gpu_block_id_tensor, // GPU dest block ids (for host2device) + const torch::Tensor + &cpu_block_id_tensor, // CPU source block ids (for host2device) + const int64_t cpu_kv_stride_in_bytes, + const int64_t cpu_layer_stride_in_bytes, + const int64_t cpu_block_stride_in_bytes, + const int64_t cpu_chunk_size_in_bytes, const int transfer_sms, + const bool use_ce_transfer, const int num_layers, + const int layer_granularity, const bool is_mla, + const int counter_id = 0); // Counter set index for triple buffering + +private: + int num_gpus_; + int dp_group_id_; + void **gpu_blocks_; + void *cpu_blocks_; + int num_tensors_per_gpu_; + int64_t *gpu_kv_strides_in_bytes_; + int64_t *gpu_block_strides_in_bytes_; + int64_t *gpu_layer_strides_in_bytes_; + int64_t *gpu_chunk_sizes_in_bytes_; + + BackendType backend_type_; + std::vector gpu_tensor_handlers_; + + std::vector streams_; + std::vector events_; + + // SSD IO context + bool enable_ssd_; + std::unique_ptr ioctx_; + + // Layer eventfds for notification + // Shape: [tp_size, num_counters, num_layers] + bool enable_eventfd_; + int tp_size_; + int num_counters_; + int num_layers_; + std::vector layer_eventfds_; // Flat array + int current_counter_id_; // Current counter set index for this transfer + + void layer_done_callback(int start_layer, int layers_this_batch, + nvtxRangeId_t *current_range_id_ptr, + bool is_last_batch, + const char *next_range_name, + nvtxRangeId_t *next_range_id_ptr); +}; + +} // namespace flexkv diff --git a/csrc/transfer.cu b/csrc/transfer.cu index 9dda5d4fe3..18b1e72d47 100644 --- a/csrc/transfer.cu +++ b/csrc/transfer.cu @@ -1,6 +1,6 @@ /* - * SPDX-FileCopyrightText: Copyright (c) <2025> NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 + * SPDX-FileCopyrightText: Copyright (c) <2025> NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,72 +24,83 @@ namespace flexkv { #define FLOAT4_PTR(ptr) reinterpret_cast(ptr) // Templated CUDA kernel - backend type determined at compile time -template +template __global__ void transfer_kv_blocks_kernel( int num_blocks, int start_layer_id, int num_layers, int64_t *gpu_block_ids, - GTensorHandler gpu_handler, - int64_t gpu_startoff_inside_chunks, + GTensorHandler gpu_handler, int64_t gpu_startoff_inside_chunks, int64_t *cpu_block_ids, int64_t *cpu_ptr, int64_t cpu_kv_stride, int64_t cpu_layer_stride, int64_t cpu_block_stride, int64_t cpu_startoff_inside_chunks, int64_t copy_size, bool is_mla, - bool is_host_to_device) { - // start layer id should also be provided for gpu location calculation - // but for now, we only support full-layer transfer, so start_layer_id is always 0 + bool is_host_to_device) { int kv_dim = is_mla ? 1 : 2; int num_chunks = num_layers * kv_dim * num_blocks; int64_t copy_size_in_float4 = copy_size * sizeof(int64_t) / sizeof(float4); - for (int chunk_idx = blockIdx.x; chunk_idx < num_chunks; - chunk_idx += gridDim.x) { - int layer_idx = chunk_idx / (num_blocks * kv_dim); + // 计算warp信息 + int warp_id = threadIdx.x / 32; + int lane_id = threadIdx.x % 32; + int warps_per_block = blockDim.x / 32; + int total_warps = gridDim.x * warps_per_block; + + // 每个warp处理一个chunk + for (int chunk_idx = blockIdx.x * warps_per_block + warp_id; + chunk_idx < num_chunks; chunk_idx += total_warps) { + int layer_idx = start_layer_id + chunk_idx / (num_blocks * kv_dim); int kv_idx = (chunk_idx % (num_blocks * kv_dim)) / num_blocks; int gpu_block_idx = gpu_block_ids[chunk_idx % num_blocks]; int cpu_block_idx = cpu_block_ids[chunk_idx % num_blocks]; int64_t *cpu_chunk_ptr = - cpu_ptr + (layer_idx + start_layer_id) * cpu_layer_stride + - kv_idx * cpu_kv_stride + cpu_block_idx * cpu_block_stride + - cpu_startoff_inside_chunks; - + cpu_ptr + layer_idx * cpu_layer_stride + kv_idx * cpu_kv_stride + + cpu_block_idx * cpu_block_stride + cpu_startoff_inside_chunks; + // Use template specialization to compute gpu pointer - int64_t *gpu_ptr = ptr_at(gpu_handler, layer_idx, kv_idx, gpu_block_idx); - int64_t *gpu_chunk_ptr = reinterpret_cast(gpu_ptr) + gpu_startoff_inside_chunks; + int64_t *gpu_ptr = + ptr_at(gpu_handler, layer_idx, kv_idx, gpu_block_idx); + int64_t *gpu_chunk_ptr = + reinterpret_cast(gpu_ptr) + gpu_startoff_inside_chunks; int64_t *src_chunk_ptr = is_host_to_device ? cpu_chunk_ptr : gpu_chunk_ptr; int64_t *dst_chunk_ptr = is_host_to_device ? gpu_chunk_ptr : cpu_chunk_ptr; - for (int64_t idx = threadIdx.x; idx < copy_size_in_float4; - idx += blockDim.x) { - float4 element = __ldg(&FLOAT4_PTR(src_chunk_ptr)[idx]); - FLOAT4_PTR(dst_chunk_ptr)[idx] = element; + // warp内的线程协作拷贝数据 + for (int64_t idx = lane_id; idx < copy_size_in_float4; idx += 32) { + float4 element; + asm volatile("ld.global.nc.v4.f32 {%0,%1,%2,%3},[%4];" + : "=f"(element.x), "=f"(element.y), "=f"(element.z), "=f"(element.w) + : "l"(&FLOAT4_PTR(src_chunk_ptr)[idx]) + : "memory"); + asm volatile("st.global.cg.v4.f32 [%0],{%1,%2,%3,%4};" + :: "l"(&FLOAT4_PTR(dst_chunk_ptr)[idx]), + "f"(element.x), "f"(element.y), "f"(element.z), "f"(element.w) + : "memory"); } } } // Templated host function -template +template void transfer_kv_blocks( int num_blocks, int start_layer_id, int num_layers, int64_t *gpu_block_ids, - GTensorHandler gpu_tensor_handler, - int64_t gpu_startoff_inside_chunks, - int64_t *cpu_block_ids, void *cpu_ptr, - int64_t cpu_kv_stride_in_bytes, int64_t cpu_layer_stride_in_bytes, - int64_t cpu_block_stride_in_bytes, int64_t cpu_startoff_inside_chunks, - int64_t chunk_size_in_bytes, cudaStream_t stream, int transfer_sms, - bool is_host_to_device, bool use_ce_transfer, bool is_mla) { - - int block_size = 128; + GTensorHandler gpu_tensor_handler, int64_t gpu_startoff_inside_chunks, + int64_t *cpu_block_ids, void *cpu_ptr, int64_t cpu_kv_stride_in_bytes, + int64_t cpu_layer_stride_in_bytes, int64_t cpu_block_stride_in_bytes, + int64_t cpu_startoff_inside_chunks, int64_t chunk_size_in_bytes, + cudaStream_t stream, int transfer_sms, bool is_host_to_device, + bool use_ce_transfer, bool is_mla, bool sync) { + + int block_size = 1024; static int max_blocks_per_sm = -1; if (max_blocks_per_sm == -1) { cudaOccupancyMaxActiveBlocksPerMultiprocessor( &max_blocks_per_sm, transfer_kv_blocks_kernel, block_size, 0); } - if (transfer_sms == -1) { - transfer_sms = 4; - } + // if (transfer_sms == -1) { + // transfer_sms = 4; + // } - int block_count = transfer_sms * max_blocks_per_sm; + int block_count = transfer_sms; // transfer_sms * max_blocks_per_sm; int64_t *cpu_ptr_int64 = reinterpret_cast(cpu_ptr); int64_t cpu_kv_stride_int64 = cpu_kv_stride_in_bytes / sizeof(int64_t); @@ -103,7 +114,7 @@ void transfer_kv_blocks( dim3 blockDim(block_size); dim3 gridDim(block_count); - + // CE transfer mode (Copy Engine using cudaMemcpyAsync) if (use_ce_transfer) { int kv_dim = is_mla ? 1 : 2; @@ -112,14 +123,15 @@ void transfer_kv_blocks( for (int k = 0; k < num_blocks; k++) { int64_t gpu_block_idx = gpu_block_ids[k]; int64_t cpu_block_idx = cpu_block_ids[k]; - + int64_t *cpu_chunk_ptr = cpu_ptr_int64 + (i + start_layer_id) * cpu_layer_stride_int64 + j * cpu_kv_stride_int64 + cpu_block_idx * cpu_block_stride_int64 + cpu_startoff_inside_chunks_int64; - - int64_t *gpu_ptr = ptr_at(gpu_tensor_handler, i, j, gpu_block_idx); - int64_t *gpu_chunk_ptr = reinterpret_cast(gpu_ptr) + + + int64_t *gpu_ptr = ptr_at(gpu_tensor_handler, + i + start_layer_id, j, gpu_block_idx); + int64_t *gpu_chunk_ptr = reinterpret_cast(gpu_ptr) + gpu_startoff_inside_chunks_int64; if (is_host_to_device) { @@ -136,26 +148,32 @@ void transfer_kv_blocks( // Custom kernel transfer transfer_kv_blocks_kernel<<>>( num_blocks, start_layer_id, num_layers, gpu_block_ids, - gpu_tensor_handler, gpu_startoff_inside_chunks_int64, - cpu_block_ids, cpu_ptr_int64, cpu_kv_stride_int64, - cpu_layer_stride_int64, cpu_block_stride_int64, - cpu_startoff_inside_chunks_int64, chunk_size_in_int64, is_mla, - is_host_to_device); + gpu_tensor_handler, gpu_startoff_inside_chunks_int64, cpu_block_ids, + cpu_ptr_int64, cpu_kv_stride_int64, cpu_layer_stride_int64, + cpu_block_stride_int64, cpu_startoff_inside_chunks_int64, + chunk_size_in_int64, is_mla, is_host_to_device); + } + if (sync) { + cudaStreamSynchronize(stream); } - cudaStreamSynchronize(stream); } // Explicit template instantiations -template void transfer_kv_blocks( - int, int, int, int64_t*, GTensorHandler, int64_t, int64_t*, void*, - int64_t, int64_t, int64_t, int64_t, int64_t, cudaStream_t, int, bool, bool, bool); +template void transfer_kv_blocks(int, int, int, int64_t *, + GTensorHandler, int64_t, + int64_t *, void *, int64_t, + int64_t, int64_t, int64_t, + int64_t, cudaStream_t, int, + bool, bool, bool, bool); template void transfer_kv_blocks( - int, int, int, int64_t*, GTensorHandler, int64_t, int64_t*, void*, - int64_t, int64_t, int64_t, int64_t, int64_t, cudaStream_t, int, bool, bool, bool); + int, int, int, int64_t *, GTensorHandler, int64_t, int64_t *, void *, + int64_t, int64_t, int64_t, int64_t, int64_t, cudaStream_t, int, bool, bool, + bool, bool); template void transfer_kv_blocks( - int, int, int, int64_t*, GTensorHandler, int64_t, int64_t*, void*, - int64_t, int64_t, int64_t, int64_t, int64_t, cudaStream_t, int, bool, bool, bool); + int, int, int, int64_t *, GTensorHandler, int64_t, int64_t *, void *, + int64_t, int64_t, int64_t, int64_t, int64_t, cudaStream_t, int, bool, bool, + bool, bool); } // namespace flexkv diff --git a/csrc/transfer.cuh b/csrc/transfer.cuh index 5e834be660..9e30542cfe 100644 --- a/csrc/transfer.cuh +++ b/csrc/transfer.cuh @@ -1,6 +1,6 @@ /* - * SPDX-FileCopyrightText: Copyright (c) <2025> NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 + * SPDX-FileCopyrightText: Copyright (c) <2025> NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,21 +16,21 @@ */ #pragma once -#include #include "gtensor_handler.cuh" +#include namespace flexkv { // Template function for transfer, specialized for each backend type -template +template void transfer_kv_blocks( int num_blocks, int start_layer_id, int num_layers, int64_t *gpu_block_ids, - GTensorHandler gpu_tensor_handler, // Pass by value! - int64_t gpu_startoff_inside_chunks, - int64_t *cpu_block_ids, void *cpu_ptr, + GTensorHandler gpu_tensor_handler, // Pass by value! + int64_t gpu_startoff_inside_chunks, int64_t *cpu_block_ids, void *cpu_ptr, int64_t cpu_kv_stride_in_bytes, int64_t cpu_layer_stride_in_bytes, int64_t cpu_block_stride_in_bytes, int64_t cpu_startoff_inside_chunks, int64_t chunk_size_in_bytes, cudaStream_t stream, int transfer_sms, - bool is_host_to_device, bool use_ce_transfer, bool is_mla); + bool is_host_to_device, bool use_ce_transfer, bool is_mla, + bool sync = true); } // namespace flexkv diff --git a/flexkv/cache/cache_engine.py b/flexkv/cache/cache_engine.py index 3a06372fff..74726d8839 100644 --- a/flexkv/cache/cache_engine.py +++ b/flexkv/cache/cache_engine.py @@ -629,6 +629,7 @@ def _get_impl_local(self, fragment2_num_blocks = max(len(ssd_matched_blocks) - len(cpu_matched_blocks), 0) #early return if no blocks to transfer if fragment12_num_blocks == 0: + nvtx.end_range(nvtx_range) return self._empty_get_return(request_id) assert fragment12_num_blocks <= len(gpu_block_ids) diff --git a/flexkv/common/config.py b/flexkv/common/config.py index 6b01fd4b03..96283623c6 100644 --- a/flexkv/common/config.py +++ b/flexkv/common/config.py @@ -61,6 +61,8 @@ class CacheConfig: remote_layout_type=KVCacheLayoutType(os.getenv('FLEXKV_REMOTE_LAYOUT', 'BLOCKFIRST').upper()), gds_layout_type=KVCacheLayoutType(os.getenv('FLEXKV_GDS_LAYOUT', 'BLOCKFIRST').upper()), + enable_layerwise_transfer=bool(int(os.getenv('FLEXKV_ENABLE_LAYERWISE_TRANSFER', 1))), + use_ce_transfer_h2d=bool(int(os.getenv('FLEXKV_USE_CE_TRANSFER_H2D', 0))), use_ce_transfer_d2h=bool(int(os.getenv('FLEXKV_USE_CE_TRANSFER_D2H', 0))), transfer_sms_h2d=int(os.getenv('FLEXKV_TRANSFER_SMS_H2D', 8)), diff --git a/flexkv/common/transfer.py b/flexkv/common/transfer.py index 76ccfd1b67..669fcc41a4 100644 --- a/flexkv/common/transfer.py +++ b/flexkv/common/transfer.py @@ -5,6 +5,8 @@ import numpy as np +from flexkv.common.debug import flexkv_logger + @dataclass(frozen=True) class CompletedOp: @@ -46,6 +48,7 @@ class TransferType(Enum): # so that the op 3 will not be executed actually, but can indicate the completion of # a group of transfer ops VIRTUAL = "Virtual" + LAYERWISE = "LAYERWISE" class PartitionBlockType(Enum): ROUND_ROBIN = 0 @@ -92,6 +95,54 @@ def __post_init__(self) -> None: assert self.dst_block_ids.dtype == np.int64 self.valid_block_num = self.src_block_ids.size +@dataclass +class LayerwiseTransferOp(TransferOp): + + src_block_ids_h2d: np.ndarray = field(default_factory=lambda: np.array([], dtype=np.int64)) + dst_block_ids_h2d: np.ndarray = field(default_factory=lambda: np.array([], dtype=np.int64)) + src_block_ids_disk2h: np.ndarray = field(default_factory=lambda: np.array([], dtype=np.int64)) + dst_block_ids_disk2h: np.ndarray = field(default_factory=lambda: np.array([], dtype=np.int64)) + counter_id: int = 0 # Counter set index for triple buffering eventfd notification + + def __init__(self, + graph_id: int, + src_block_ids_h2d: np.ndarray, + dst_block_ids_h2d: np.ndarray, + src_block_ids_disk2h: np.ndarray, + dst_block_ids_disk2h: np.ndarray, + layer_id: int = 0, + layer_granularity: int = 1, + dp_id: int = 0, + counter_id: int = 0) -> None: + self.src_block_ids_h2d = src_block_ids_h2d + self.dst_block_ids_h2d = dst_block_ids_h2d + self.src_block_ids_disk2h = src_block_ids_disk2h + self.dst_block_ids_disk2h = dst_block_ids_disk2h + self.counter_id = counter_id + + super().__init__( + graph_id=graph_id, + transfer_type=TransferType.LAYERWISE, + src_block_ids=np.array([], dtype=np.int64), + dst_block_ids=np.array([], dtype=np.int64), + layer_id=layer_id, + layer_granularity=layer_granularity, + dp_id=dp_id, + ) + + def __post_init__(self) -> None: + super().__post_init__() + + if self.layer_granularity == -1: + flexkv_logger.warning("layer_granularity is not set, using default value 1") + self.layer_granularity = 1 + assert self.src_block_ids_h2d.size == self.dst_block_ids_h2d.size + assert self.src_block_ids_disk2h.size == self.dst_block_ids_disk2h.size + + assert self.src_block_ids_h2d.dtype == np.int64 + assert self.dst_block_ids_h2d.dtype == np.int64 + assert self.src_block_ids_disk2h.dtype == np.int64 + assert self.dst_block_ids_disk2h.dtype == np.int64 class TransferOpGraph: _next_graph_id = 0 @@ -289,8 +340,12 @@ def format_blocks(block_ids, max_show=4): print(result) return result -def merge_to_batch_graph(batch_id: int, transfer_graphs: List[TransferOpGraph], task_end_op_ids: List[int], - op_callback_dict: Dict[int, Callable]) -> Tuple[TransferOpGraph, int, Dict[int, Callable]]: +def merge_to_batch_graph(batch_id: int, + transfer_graphs: List[TransferOpGraph], + task_end_op_ids: List[int], + op_callback_dict: Dict[int, Callable], + layerwise_transfer: bool = False, + counter_id: int = 0) -> Tuple[TransferOpGraph, int, Dict[int, Callable]]: """ Merge multiple TransferOpGraphs into a single batch graph. @@ -306,6 +361,7 @@ def merge_to_batch_graph(batch_id: int, transfer_graphs: List[TransferOpGraph], transfer_graphs: List of graphs to merge task_end_op_ids: List of end op IDs for each task (one per graph) op_callback_dict: Dict mapping old op_id -> callback + layerwise_transfer: Whether to merge the graphs into a layerwise transfer op Returns: (merged_graph, batch_end_op_id, new_op_callback_dict) @@ -365,7 +421,6 @@ def merge_to_batch_graph(batch_id: int, transfer_graphs: List[TransferOpGraph], layer_granularity=disk2h_ops[0].layer_granularity, dp_id=disk2h_ops[0].dp_id, ) - merged_graph.add_transfer_op(merged_disk2h_op) # Attach callbacks - create a combined callback if multiple if disk2h_callbacks: @@ -394,7 +449,6 @@ def combined_callback(*args, **kwargs): layer_granularity=h2d_ops[0].layer_granularity, dp_id=h2d_ops[0].dp_id, ) - merged_graph.add_transfer_op(merged_h2d_op) # Attach callbacks - create a combined callback if multiple if h2d_callbacks: @@ -407,19 +461,45 @@ def combined_callback(*args, **kwargs): cb(*args, **kwargs) return combined_callback new_op_callback_dict[merged_h2d_op.op_id] = make_combined_callback(h2d_callbacks) - - # Add dependency: DISK2H -> H2D - if merged_disk2h_op is not None and merged_h2d_op is not None: - merged_graph.add_dependency(merged_h2d_op.op_id, merged_disk2h_op.op_id) - - # Determine the batch_end_op_id - # Priority: H2D (if exists) -> DISK2H (if exists) -> -1 - if merged_h2d_op is not None: - batch_end_op_id = merged_h2d_op.op_id - elif merged_disk2h_op is not None: - batch_end_op_id = merged_disk2h_op.op_id - else: + if layerwise_transfer: + if merged_h2d_op is not None: + layerwise_transfer_op = LayerwiseTransferOp( + graph_id=merged_graph.graph_id, + src_block_ids_h2d=merged_h2d_op.src_block_ids, + dst_block_ids_h2d=merged_h2d_op.dst_block_ids, + src_block_ids_disk2h=merged_disk2h_op.src_block_ids \ + if merged_disk2h_op is not None \ + else np.array([], dtype=np.int64), + dst_block_ids_disk2h=merged_disk2h_op.dst_block_ids \ + if merged_disk2h_op is not None \ + else np.array([], dtype=np.int64), + layer_id=0, + layer_granularity=1, + dp_id=h2d_ops[0].dp_id, + counter_id=counter_id, + ) + merged_graph.add_transfer_op(layerwise_transfer_op) batch_end_op_id = -1 + # layerwise transfer op does not need callbacks + new_op_callback_dict.clear() + else: + if merged_disk2h_op is not None: + merged_graph.add_transfer_op(merged_disk2h_op) + if merged_h2d_op is not None: + merged_graph.add_transfer_op(merged_h2d_op) + # Add dependency: DISK2H -> H2D + if merged_disk2h_op is not None and merged_h2d_op is not None: + merged_graph.add_dependency(merged_h2d_op.op_id, merged_disk2h_op.op_id) + + # Determine the batch_end_op_id + # Priority: H2D (if exists) -> DISK2H (if exists) -> -1 + if merged_h2d_op is not None: + batch_end_op_id = merged_h2d_op.op_id + elif merged_disk2h_op is not None: + batch_end_op_id = merged_disk2h_op.op_id + else: + batch_end_op_id = -1 + return merged_graph, batch_end_op_id, new_op_callback_dict diff --git a/flexkv/kvmanager.py b/flexkv/kvmanager.py index b9016a6dad..db55789fde 100644 --- a/flexkv/kvmanager.py +++ b/flexkv/kvmanager.py @@ -186,7 +186,9 @@ def prefetch_async(self, def launch(self, task_ids: Union[int, List[int]], slot_mappings: Union[np.ndarray, List[np.ndarray], torch.Tensor, List[torch.Tensor]], - as_batch: bool = False) -> List[int]: + as_batch: bool = False, + layerwise_transfer: bool = False, + counter_id: int = 0) -> List[int]: if isinstance(task_ids, int): task_ids = [task_ids] if not isinstance(slot_mappings, List): @@ -194,10 +196,16 @@ def launch(self, if isinstance(slot_mappings[0], torch.Tensor): slot_mappings = [slot_mapping.numpy() for slot_mapping in slot_mappings] if self.server_client_mode: - return self.dp_client.launch_tasks(task_ids, slot_mappings, as_batch) + return self.dp_client.launch_tasks(task_ids, slot_mappings, as_batch, layerwise_transfer, counter_id) else: - return self.kv_task_engine.launch_tasks(task_ids, slot_mappings, as_batch) - + return self.kv_task_engine.launch_tasks( + task_ids, + slot_mappings, + as_batch=as_batch, + layerwise_transfer=layerwise_transfer, + counter_id=counter_id + ) + def cancel(self, task_ids: Union[int, List[int]]) -> None: if isinstance(task_ids, int): task_ids = [task_ids] diff --git a/flexkv/kvtask.py b/flexkv/kvtask.py index c87087d41b..c5f1834746 100644 --- a/flexkv/kvtask.py +++ b/flexkv/kvtask.py @@ -12,7 +12,7 @@ import torch import numpy as np -from flexkv.common.config import CacheConfig, ModelConfig +from flexkv.common.config import CacheConfig, ModelConfig, GLOBAL_CONFIG_FROM_ENV from flexkv.common.debug import flexkv_logger from flexkv.common.transfer import TransferOpGraph, merge_to_batch_graph, get_nvtx_default_color, CompletedOp from flexkv.common.tracer import FlexKVTracer @@ -735,7 +735,11 @@ def prefetch_async(self, self._launch_task(task_id) return task_id - def merge_to_batch_kvtask(self, batch_id: int, task_ids: List[int]) -> TransferOpGraph: + def merge_to_batch_kvtask(self, + batch_id: int, + task_ids: List[int], + layerwise_transfer: bool = False, + counter_id: int = 0) -> TransferOpGraph: op_callback_dict = {} task_end_op_ids = [] callbacks = [] @@ -750,11 +754,12 @@ def merge_to_batch_kvtask(self, batch_id: int, task_ids: List[int]) -> TransferO task_end_op_ids.append(self.tasks[task_id].task_end_op_id) callbacks.append(self.tasks[task_id].callback) return_masks.append(self.tasks[task_id].return_mask) - batch_task_graph, task_end_op_id, op_callback_dict = merge_to_batch_graph(batch_id, transfer_graphs, task_end_op_ids, - op_callback_dict) + op_callback_dict, + layerwise_transfer, + counter_id) self.tasks[batch_id] = KVTask( task_id=batch_id, token_ids=np.concatenate([self.tasks[task_id].token_ids for task_id in task_ids]), @@ -775,13 +780,16 @@ def merge_to_batch_kvtask(self, batch_id: int, task_ids: List[int]) -> TransferO for task_id in task_ids: self.graph_to_task.pop(self.tasks[task_id].graph.graph_id, None) self.tasks.pop(task_id, None) + batch_task_graph = self.check_task_ready(batch_id) return batch_task_graph def launch_tasks(self, task_ids: List[int], slot_mappings: List[np.ndarray], as_batch: bool = False, - batch_id: int = -1) -> List[int]: + batch_id: int = -1, + layerwise_transfer: bool = False, + counter_id: int = 0) -> List[int]: assert isinstance(slot_mappings[0], np.ndarray) # trace launch tasks self.tracer.trace_launch_tasks(task_ids, slot_mappings, as_batch) @@ -790,10 +798,21 @@ def launch_tasks(self, # Batch optimization: collect all transfer graphs first nvtx_range = nvtx.start_range(message=f"KVTaskEngine.launch_tasks batch={len(task_ids)}", color="blue") - if len(task_ids) > 1 and as_batch: + if as_batch: if batch_id == -1: batch_id = self._gen_task_id() - transfer_graphs = [self.merge_to_batch_kvtask(batch_id, task_ids)] + if layerwise_transfer: + if not GLOBAL_CONFIG_FROM_ENV.enable_layerwise_transfer: + flexkv_logger.warning("layerwise transfer is not enabled") + layerwise_transfer = False + else: + for task_id in task_ids: + if self.tasks[task_id].task_type != TaskType.GET: + flexkv_logger.warning("only support layerwise get") + layerwise_transfer = False + break + batch_task_graph = self.merge_to_batch_kvtask(batch_id, task_ids, layerwise_transfer, counter_id) + transfer_graphs = [batch_task_graph] task_ids = [batch_id] else: transfer_graphs = [] diff --git a/flexkv/server/client.py b/flexkv/server/client.py index cf96a08a38..988a0ddc68 100644 --- a/flexkv/server/client.py +++ b/flexkv/server/client.py @@ -168,11 +168,13 @@ def launch_tasks( task_ids: List[int], slot_mappings: List[np.ndarray], as_batch: bool = False, + layerwise_transfer: bool = False, + counter_id: int = 0, ) -> List[int]: batch_id = -1 if as_batch: batch_id = self._get_task_id() - req = LaunchTaskRequest(self.dp_client_id, task_ids, slot_mappings, as_batch, batch_id) + req = LaunchTaskRequest(self.dp_client_id, task_ids, slot_mappings, as_batch, batch_id, layerwise_transfer, counter_id) self.send_to_server.send_pyobj(req) return [batch_id] if as_batch else task_ids diff --git a/flexkv/server/request.py b/flexkv/server/request.py index 43ecdaba57..df49ac5abe 100644 --- a/flexkv/server/request.py +++ b/flexkv/server/request.py @@ -73,6 +73,8 @@ class LaunchTaskRequest: slot_mappings: List[np.ndarray] as_batch: bool = False batch_id: int = -1 + layerwise_transfer: bool = False + counter_id: int = 0 # Counter set index for triple buffering eventfd notification @dataclass class CancelTaskRequest: diff --git a/flexkv/server/server.py b/flexkv/server/server.py index 1a01519445..529d16530c 100644 --- a/flexkv/server/server.py +++ b/flexkv/server/server.py @@ -354,7 +354,12 @@ def _handle_prefetch_request(self, req: PrefetchRequest) -> None: def _handle_launch_task_request(self, req: LaunchTaskRequest) -> None: """Handle LaunchTask request""" - self.kv_task_engine.launch_tasks(req.task_ids, req.slot_mappings, req.as_batch, req.batch_id) + self.kv_task_engine.launch_tasks(req.task_ids, + req.slot_mappings, + req.as_batch, + req.batch_id, + req.layerwise_transfer, + req.counter_id) def _handle_cancel_task_request(self, req: CancelTaskRequest) -> None: """Handle CancelTask request""" diff --git a/flexkv/transfer/layerwise.py b/flexkv/transfer/layerwise.py new file mode 100644 index 0000000000..12c63224d5 --- /dev/null +++ b/flexkv/transfer/layerwise.py @@ -0,0 +1,305 @@ +import copy +import torch.multiprocessing as mp +import threading +import time +import os +import socket +import struct +from abc import ABC, abstractmethod +from dataclasses import dataclass +from torch.multiprocessing import Queue as MPQueue, Pipe as MPPipe +from multiprocessing.connection import Connection +from threading import Thread +from typing import List, Any, Dict, Union, Optional, Tuple + +import ctypes +import numpy as np +import nvtx +import torch + +from flexkv import c_ext + +from flexkv.c_ext import LayerwiseTransferGroup +from flexkv.common.debug import flexkv_logger +from flexkv.common.memory_handle import TensorSharedHandle +from flexkv.common.storage import KVCacheLayout, KVCacheLayoutType +from flexkv.common.transfer import TransferOp, TransferType, PartitionBlockType +from flexkv.common.transfer import get_nvtx_range_color +from flexkv.common.config import CacheConfig, GLOBAL_CONFIG_FROM_ENV + +from flexkv.transfer.worker_op import WorkerTransferOp, WorkerLayerwiseTransferOp +from flexkv.transfer.worker import TransferWorkerBase, cudaHostRegister + + +def _recv_fds(sock: socket.socket, num_fds: int) -> Tuple[List[int], bytes]: + """Receive multiple fds + extra_data via Unix domain socket (SCM_RIGHTS).""" + data_buf = bytearray(256) + anc_buf_size = socket.CMSG_SPACE(num_fds * struct.calcsize("i")) + + nbytes, ancdata, flags, addr = sock.recvmsg_into([data_buf], anc_buf_size, anc_buf_size) + data = bytes(data_buf[:nbytes]) + + fds = [] + for level, ctype, cdata in ancdata: + if level == socket.SOL_SOCKET and ctype == socket.SCM_RIGHTS: + num_received = len(cdata) // struct.calcsize("i") + fds = list(struct.unpack(f"{num_received}i", cdata[:num_received * struct.calcsize("i")])) + break + if not fds: + raise RuntimeError("did not receive fds via SCM_RIGHTS") + return fds, data + +class LayerwiseTransferWorker(TransferWorkerBase): + def __init__(self, + worker_id: int, + transfer_conn: Connection, + finished_ops_queue: MPQueue, + op_buffer_tensor: torch.Tensor, + gpu_blocks: List[List[TensorSharedHandle]], + cpu_blocks: torch.Tensor, + ssd_files: Dict[int, List[str]], + gpu_kv_layouts: List[KVCacheLayout], + cpu_kv_layout: KVCacheLayout, + ssd_kv_layout: KVCacheLayout, + dtype: torch.dtype, + tp_group_size: int, + dp_group_id: int, + num_blocks_per_file: int, + use_ce_transfer_h2d: bool = False, + use_ce_transfer_d2h: bool = False, + transfer_sms_h2d: int = 8, + transfer_sms_d2h: int = 8) -> None: + super().__init__(worker_id, transfer_conn, finished_ops_queue, op_buffer_tensor) + assert len(gpu_blocks) == tp_group_size, f"len(gpu_blocks) = {len(gpu_blocks)}, tp_group_size = {tp_group_size}" + imported_gpu_blocks = [] + for handles_in_one_gpu in gpu_blocks: + blocks_in_one_gpu = [] + for handle in handles_in_one_gpu: + blocks_in_one_gpu.append(handle.get_tensor()) + imported_gpu_blocks.append(blocks_in_one_gpu) + self.gpu_blocks = imported_gpu_blocks + self.dtype = dtype # note this should be quantized data type + self.is_mla = gpu_kv_layouts[0].is_mla + + self.num_gpus = len(self.gpu_blocks) + self.tp_group_size = tp_group_size + self.dp_group_id = dp_group_id + + # initialize GPU storage + self.num_layers = gpu_kv_layouts[0].num_layer + # here the chunk size doesn't include the layer info + self.gpu_chunk_sizes_in_bytes = [gpu_kv_layout.get_chunk_size() * self.dtype.itemsize \ + for gpu_kv_layout in gpu_kv_layouts] + self.gpu_kv_strides_in_bytes = [gpu_kv_layout.get_kv_stride() * self.dtype.itemsize \ + for gpu_kv_layout in gpu_kv_layouts] + self.gpu_block_strides_in_bytes = [gpu_kv_layout.get_block_stride() * self.dtype.itemsize \ + for gpu_kv_layout in gpu_kv_layouts] + self.gpu_layer_strides_in_bytes = [gpu_kv_layout.get_layer_stride() * self.dtype.itemsize \ + for gpu_kv_layout in gpu_kv_layouts] + + num_blocks_first_gpu = len(imported_gpu_blocks[0]) if imported_gpu_blocks else 0 + if num_blocks_first_gpu == 1: + self.gpu_block_type_ = 1 # TRTLLM + elif num_blocks_first_gpu == self.num_layers: + self.gpu_block_type_ = 0 # VLLM + elif num_blocks_first_gpu == self.num_layers * 2: + self.gpu_block_type_ = 2 # SGLANG + else: + raise ValueError(f"Invalid GPU block type: {num_blocks_first_gpu}") + + # initialize CPU storage + flexkv_logger.info(f"Pinning CPU Memory: " + f"{cpu_blocks.numel() * cpu_blocks.element_size() / (1024 ** 3):.2f} GB") + cudaHostRegister(cpu_blocks) + self.cpu_blocks = cpu_blocks + + self.cpu_chunk_size_in_bytes = cpu_kv_layout.get_chunk_size() * self.dtype.itemsize + self.cpu_kv_stride_in_bytes = cpu_kv_layout.get_kv_stride() * self.dtype.itemsize + self.cpu_block_stride_in_bytes = cpu_kv_layout.get_block_stride() * self.dtype.itemsize + self.cpu_layer_stride_in_bytes = cpu_kv_layout.get_layer_stride() * self.dtype.itemsize + + self.use_ce_transfer_h2d = use_ce_transfer_h2d + self.use_ce_transfer_d2h = use_ce_transfer_d2h + self.transfer_sms_h2d = transfer_sms_h2d + self.transfer_sms_d2h = transfer_sms_d2h + + # initialize SSD storage + self.enable_ssd = len(ssd_files) > 0 + self.ssd_files = ssd_files + if self.enable_ssd: + self.num_blocks_per_file = num_blocks_per_file + self.num_files = sum(len(file_list) for file_list in ssd_files.values()) + self.round_robin = 1 + + ssd_kv_layout_per_file = ssd_kv_layout.div_block(self.num_files, padding=True) + self.ssd_kv_stride_in_bytes = ssd_kv_layout_per_file.get_kv_stride() * self.dtype.itemsize + self.ssd_layer_stride_in_bytes = ssd_kv_layout_per_file.get_layer_stride() * self.dtype.itemsize + self.ssd_block_stride_in_bytes = ssd_kv_layout_per_file.get_block_stride() * self.dtype.itemsize + else: + self.num_blocks_per_file = 0 + self.round_robin = 1 + self.ssd_kv_stride_in_bytes = 0 + self.ssd_layer_stride_in_bytes = 0 + self.ssd_block_stride_in_bytes = 0 + + gpu_kv_strides_tensor = torch.tensor(self.gpu_kv_strides_in_bytes, dtype=torch.int64) + gpu_block_strides_tensor = torch.tensor(self.gpu_block_strides_in_bytes, dtype=torch.int64) + gpu_chunk_sizes_tensor = torch.tensor(self.gpu_chunk_sizes_in_bytes, dtype=torch.int64) + gpu_layer_strides_tensor = torch.tensor(self.gpu_layer_strides_in_bytes, dtype=torch.int64) + + layer_eventfds_tensor = self._receive_eventfds_from_sglang(tp_group_size) + + # Create LayerwiseTransferGroup which handles both SSD->CPU and CPU->GPU transfers + self.layerwise_transfer_group = LayerwiseTransferGroup( + self.num_gpus, self.gpu_blocks, cpu_blocks, ssd_files, + dp_group_id, self.num_layers, + gpu_kv_strides_tensor, gpu_block_strides_tensor, + gpu_layer_strides_tensor, gpu_chunk_sizes_tensor, + GLOBAL_CONFIG_FROM_ENV.iouring_entries, + GLOBAL_CONFIG_FROM_ENV.iouring_flags, + layer_eventfds_tensor, tp_group_size) + + def _receive_eventfds_from_sglang(self, tp_group_size: int, + max_retries: int = 180, + retry_interval: float = 1.0) -> torch.Tensor: + """Receive eventfds from SGLang via Unix socket (FlexKV as server).""" + socket_path = os.environ.get('FLEXKV_LAYERWISE_EVENTFD_SOCKET', '/tmp/flexkv_layerwise_eventfd.sock') + + def cleanup_socket(): + try: + if os.path.exists(socket_path): + os.unlink(socket_path) + except OSError: + pass + + cleanup_socket() + server_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + server_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + + try: + server_sock.bind(socket_path) + server_sock.listen(tp_group_size) + os.chmod(socket_path, 0o777) + flexkv_logger.info(f"[LayerwiseWorker] Listening on {socket_path}, waiting for {tp_group_size} connections") + except Exception as e: + flexkv_logger.error(f"[LayerwiseWorker] Failed to bind/listen: {e}") + server_sock.close() + return torch.empty(0, dtype=torch.int32) + + server_sock.settimeout(max_retries * retry_interval) + all_rank_eventfds: Dict[int, Dict[int, List[int]]] = {} + num_layers, num_counters = self.num_layers, 3 + + try: + for conn_idx in range(tp_group_size): + try: + conn, _ = server_sock.accept() + except socket.timeout: + flexkv_logger.warning(f"[LayerwiseWorker] Timeout, received {conn_idx}/{tp_group_size}") + break + + with conn: + metadata = conn.recv(16) + if len(metadata) < 16: + flexkv_logger.error(f"[LayerwiseWorker] Incomplete metadata: {len(metadata)} bytes") + continue + + tp_rank, _, recv_num_layers, recv_num_counters = struct.unpack("iiii", metadata) + if conn_idx == 0: + num_layers, num_counters = recv_num_layers, recv_num_counters + + rank_eventfds = {} + for _ in range(recv_num_counters): + fds, extra_data = _recv_fds(conn, recv_num_layers) + counter_id = struct.unpack("i", extra_data[:4])[0] + rank_eventfds[counter_id] = fds + + all_rank_eventfds[tp_rank] = rank_eventfds + flexkv_logger.info(f"[LayerwiseWorker] Received eventfds from tp_rank={tp_rank}") + except Exception as e: + flexkv_logger.error(f"[LayerwiseWorker] Error in accept loop: {e}") + finally: + server_sock.close() + cleanup_socket() + + if not all_rank_eventfds: + flexkv_logger.warning("[LayerwiseWorker] No connections received") + return torch.empty(0, dtype=torch.int32) + + # Build tensor: [num_counters, tp_size, num_layers] + eventfds_list = [] + for counter_id in range(num_counters): + for tp_rank in range(tp_group_size): + fds = all_rank_eventfds.get(tp_rank, {}).get(counter_id, [-1] * num_layers) + eventfds_list.extend(fds) + + tensor = torch.tensor(eventfds_list, dtype=torch.int32) + flexkv_logger.info(f"[LayerwiseWorker] Eventfds tensor: {tensor.shape}, counters={num_counters}, tp={tp_group_size}, layers={num_layers}") + return tensor + + def _transfer_impl(self, + src_block_ids_h2d: torch.Tensor, + dst_block_ids_h2d: torch.Tensor, + src_block_ids_disk2h: Optional[torch.Tensor], + dst_block_ids_disk2h: Optional[torch.Tensor], + layer_granularity: int, + counter_id: int = 0, + **kwargs: Any) -> None: + assert src_block_ids_h2d.dtype == torch.int64 + assert dst_block_ids_h2d.dtype == torch.int64 + assert len(src_block_ids_h2d) == len(dst_block_ids_h2d) + if src_block_ids_disk2h is not None: + assert src_block_ids_disk2h.dtype == torch.int64 + assert dst_block_ids_disk2h.dtype == torch.int64 + assert len(src_block_ids_disk2h) == len(dst_block_ids_disk2h) + + # Use unified layerwise transfer C++ interface + ssd_block_ids = src_block_ids_disk2h if src_block_ids_disk2h is not None else torch.empty(0, dtype=torch.int64) + cpu_block_ids_d2h = dst_block_ids_disk2h if dst_block_ids_disk2h is not None \ + else torch.empty(0, dtype=torch.int64) + + self.layerwise_transfer_group.layerwise_transfer( + ssd_block_ids, + cpu_block_ids_d2h, + self.ssd_layer_stride_in_bytes, + self.ssd_kv_stride_in_bytes, + self.num_blocks_per_file, + self.round_robin, + 32, # num_threads_per_device + dst_block_ids_h2d, + src_block_ids_h2d, + self.cpu_kv_stride_in_bytes, + self.cpu_layer_stride_in_bytes, + self.cpu_block_stride_in_bytes, + self.cpu_chunk_size_in_bytes, + self.transfer_sms_h2d, + self.use_ce_transfer_h2d, + self.num_layers, + layer_granularity, + self.is_mla, + counter_id, + ) + + def launch_transfer(self, transfer_op: WorkerLayerwiseTransferOp) -> None: + layer_granularity = transfer_op.layer_granularity + if layer_granularity == -1: + layer_granularity = self.num_layers + + src_block_ids_h2d = torch.from_numpy(transfer_op.src_block_ids_h2d).to(dtype=torch.int64).pin_memory() + dst_block_ids_h2d = torch.from_numpy(transfer_op.dst_block_ids_h2d).to(dtype=torch.int64).pin_memory() + + if transfer_op.src_block_ids_disk2h.size > 0: + src_block_ids_disk2h = torch.from_numpy(transfer_op.src_block_ids_disk2h).to(dtype=torch.int64) + dst_block_ids_disk2h = torch.from_numpy(transfer_op.dst_block_ids_disk2h).to(dtype=torch.int64) + else: + src_block_ids_disk2h = None + dst_block_ids_disk2h = None + + self._transfer_impl( + src_block_ids_h2d, + dst_block_ids_h2d, + src_block_ids_disk2h, + dst_block_ids_disk2h, + layer_granularity, + transfer_op.counter_id, + ) diff --git a/flexkv/transfer/transfer_engine.py b/flexkv/transfer/transfer_engine.py index 3efe771be1..369822fd2c 100644 --- a/flexkv/transfer/transfer_engine.py +++ b/flexkv/transfer/transfer_engine.py @@ -40,11 +40,14 @@ GDSTransferWorker, tpGDSTransferWorker, ) +from flexkv.transfer.layerwise import LayerwiseTransferWorker from flexkv.common.config import CacheConfig, ModelConfig, GLOBAL_CONFIG_FROM_ENV from flexkv.common.ring_buffer import SharedOpPool def register_op_to_buffer(op: TransferOp, pin_buffer: SharedOpPool) -> None: + if op.transfer_type == TransferType.LAYERWISE: + return op.src_slot_id = pin_buffer.allocate_slot(op.src_block_ids) op.dst_slot_id = pin_buffer.allocate_slot(op.dst_block_ids) @@ -245,6 +248,35 @@ def _init_workers(self) -> None: # GDS workers handle DISK2D/D2DISK operations using the GDS transfer path self._worker_map[TransferType.DISK2D] = self.gds_workers self._worker_map[TransferType.D2DISK] = self.gds_workers + if GLOBAL_CONFIG_FROM_ENV.enable_layerwise_transfer: + ssd_files = {} if self._ssd_handle is None else self._ssd_handle.get_file_list() + ssd_kv_layout = None if self._ssd_handle is None else self._ssd_handle.kv_layout + num_blocks_per_file = 0 if self._ssd_handle is None else self._ssd_handle.num_blocks_per_file + self.layerwise_workers = [ + LayerwiseTransferWorker.create_worker( + mp_ctx=self.mp_ctx, + finished_ops_queue=self.finished_ops_queue, + op_buffer_tensor=self.pin_buffer.get_buffer(), + gpu_blocks=[self.gpu_handles[j].get_tensor_handle_list() \ + for j in range(i * self.tp_size, (i + 1) * self.tp_size)], + cpu_blocks=self._cpu_handle.get_tensor(), + ssd_files=ssd_files, + gpu_kv_layouts=[self.gpu_handles[i].kv_layout \ + for i in range(i * self.tp_size, (i + 1) * self.tp_size)], + cpu_kv_layout=self._cpu_handle.kv_layout, + ssd_kv_layout=ssd_kv_layout, + dtype=self.gpu_handles[i].dtype, + tp_group_size=self.tp_size, + dp_group_id=i, + num_blocks_per_file=num_blocks_per_file, + use_ce_transfer_h2d=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_h2d, + use_ce_transfer_d2h=GLOBAL_CONFIG_FROM_ENV.use_ce_transfer_d2h, + transfer_sms_h2d=GLOBAL_CONFIG_FROM_ENV.transfer_sms_h2d, + transfer_sms_d2h=GLOBAL_CONFIG_FROM_ENV.transfer_sms_d2h, + ) + for i in range(self.dp_size) + ] + self._worker_map[TransferType.LAYERWISE] = self.layerwise_workers if len(self._worker_map) == 0: raise ValueError("No workers initialized, please check the config") # Wait for all workers to ready diff --git a/flexkv/transfer/worker.py b/flexkv/transfer/worker.py index f4d695ec1d..3230c16c83 100644 --- a/flexkv/transfer/worker.py +++ b/flexkv/transfer/worker.py @@ -29,8 +29,9 @@ from flexkv.common.memory_handle import TensorSharedHandle from flexkv.common.storage import KVCacheLayout, KVCacheLayoutType from flexkv.common.transfer import TransferOp, TransferType, PartitionBlockType -from flexkv.common.transfer import get_nvtx_range_color +from flexkv.common.transfer import get_nvtx_range_color, LayerwiseTransferOp from flexkv.common.config import CacheConfig, GLOBAL_CONFIG_FROM_ENV +from flexkv.transfer.worker_op import WorkerTransferOp, WorkerLayerwiseTransferOp try: from flexkv.c_ext import transfer_kv_blocks_remote @@ -54,36 +55,6 @@ def cudaHostUnregister(tensor: torch.Tensor) -> None: size = tensor.numel() * tensor.element_size() ret = cudart.cudaHostUnregister(ctypes.c_void_p(ptr)) -@dataclass -class WorkerTransferOp: - transfer_op_id: int - transfer_graph_id: int - transfer_type: TransferType - layer_id: int - layer_granularity: int - src_slot_id: int - dst_slot_id: int - valid_block_num: int - src_block_ids: np.ndarray - dst_block_ids: np.ndarray - # successors: List[int] - - def __init__(self, transfer_op: TransferOp): - self.transfer_op_id = transfer_op.op_id - self.transfer_graph_id = transfer_op.graph_id - self.transfer_type = transfer_op.transfer_type - self.layer_id = transfer_op.layer_id - self.layer_granularity = transfer_op.layer_granularity - self.src_slot_id = transfer_op.src_slot_id - self.dst_slot_id = transfer_op.dst_slot_id - self.valid_block_num = transfer_op.valid_block_num - if self.src_slot_id == -1: - self.src_block_ids = transfer_op.src_block_ids - self.dst_block_ids = transfer_op.dst_block_ids - else: - self.src_block_ids = np.empty(0) - self.dst_block_ids = np.empty(0) - # self.successors = list(transfer_op.successors) # for nvtx class TransferWorkerBase(ABC): _worker_id_counter = 0 @@ -227,14 +198,13 @@ def run(self) -> None: for op in batch_ops: try: nvtx.push_range(f"launch {op.transfer_type.name} op_id: {op.transfer_op_id}, " - f"graph_id: {op.transfer_graph_id}, " - f"num_blocks: {op.valid_block_num}", + f"graph_id: {op.transfer_graph_id}", color=get_nvtx_range_color(op.transfer_graph_id)) self.launch_transfer(op) nvtx.pop_range() except Exception as e: flexkv_logger.error(f"Error launching transfer: {e}\n" - f"Failed transfer op: {op}") + f"Failed transfer op: {op.transfer_op_id}") self.finished_ops_queue.put(op.transfer_op_id) else: continue @@ -253,8 +223,12 @@ def __init__(self, worker_id: int, transfer_conn: Connection, process: mp.Proces self.process = process self.ready_event = ready_event - def submit_transfer(self, op: TransferOp) -> None: - self.transfer_conn.send(WorkerTransferOp(op)) + def submit_transfer(self, op: Union[TransferOp, LayerwiseTransferOp]) -> None: + if isinstance(op, LayerwiseTransferOp): + worker_op = WorkerLayerwiseTransferOp(op) + else: + worker_op = WorkerTransferOp(op) + self.transfer_conn.send(worker_op) def shutdown(self) -> None: try: diff --git a/flexkv/transfer/worker_op.py b/flexkv/transfer/worker_op.py new file mode 100644 index 0000000000..329e0a21ae --- /dev/null +++ b/flexkv/transfer/worker_op.py @@ -0,0 +1,62 @@ +from dataclasses import dataclass + +import numpy as np + +from flexkv.common.transfer import TransferOp, TransferType, LayerwiseTransferOp + + +@dataclass +class WorkerTransferOp: + transfer_op_id: int + transfer_graph_id: int + transfer_type: TransferType + layer_id: int + layer_granularity: int + src_slot_id: int + dst_slot_id: int + valid_block_num: int + src_block_ids: np.ndarray + dst_block_ids: np.ndarray + + def __init__(self, transfer_op: TransferOp): + self.transfer_op_id = transfer_op.op_id + self.transfer_graph_id = transfer_op.graph_id + self.transfer_type = transfer_op.transfer_type + self.layer_id = transfer_op.layer_id + self.layer_granularity = transfer_op.layer_granularity + self.src_slot_id = transfer_op.src_slot_id + self.dst_slot_id = transfer_op.dst_slot_id + self.valid_block_num = transfer_op.valid_block_num + if self.src_slot_id == -1: + self.src_block_ids = transfer_op.src_block_ids + self.dst_block_ids = transfer_op.dst_block_ids + else: + self.src_block_ids = np.empty(0) + self.dst_block_ids = np.empty(0) + + +@dataclass +class WorkerLayerwiseTransferOp: + transfer_op_id: int + transfer_graph_id: int + transfer_type: TransferType + layer_id: int + layer_granularity: int + src_block_ids_h2d: np.ndarray + dst_block_ids_h2d: np.ndarray + src_block_ids_disk2h: np.ndarray + dst_block_ids_disk2h: np.ndarray + counter_id: int # Counter set index for triple buffering eventfd notification + + def __init__(self, transfer_op: LayerwiseTransferOp): + self.transfer_op_id = transfer_op.op_id + self.transfer_graph_id = transfer_op.graph_id + assert transfer_op.transfer_type == TransferType.LAYERWISE + self.transfer_type = transfer_op.transfer_type + self.layer_id = transfer_op.layer_id + self.layer_granularity = transfer_op.layer_granularity + self.src_block_ids_h2d = transfer_op.src_block_ids_h2d + self.dst_block_ids_h2d = transfer_op.dst_block_ids_h2d + self.src_block_ids_disk2h = transfer_op.src_block_ids_disk2h + self.dst_block_ids_disk2h = transfer_op.dst_block_ids_disk2h + self.counter_id = transfer_op.counter_id diff --git a/setup.py b/setup.py index d99ef02f1a..4b5a90944b 100755 --- a/setup.py +++ b/setup.py @@ -30,6 +30,7 @@ def get_version(): "csrc/tp_transfer_thread_group.cpp", "csrc/transfer_ssd.cpp", "csrc/radix_tree.cpp", + "csrc/layerwise.cpp" ] hpp_sources = [ @@ -37,6 +38,7 @@ def get_version(): "csrc/tp_transfer_thread_group.h", "csrc/transfer_ssd.h", "csrc/radix_tree.h", + "csrc/layerwise.h", ] extra_link_args = ["-lcuda", "-lxxhash", "-lpthread", "-lrt", "-luring"] @@ -168,4 +170,3 @@ def copy_shared_libraries(self): }, python_requires=">=3.8", ) - diff --git a/tests/test_kvmanager.py b/tests/test_kvmanager.py index 3f2e29cef8..55e7bb456e 100644 --- a/tests/test_kvmanager.py +++ b/tests/test_kvmanager.py @@ -333,19 +333,19 @@ def test_kvmanager(model_config, cache_config, test_config, gpu_layout_type): # =============== Test batched launched get =============== if not enable_gds: print("\n========== Testing batched launched get ==========") - + # Use the first few request_pairs that were written in initial phase batch_size = 6 - + batched_get_task_ids = [] batched_slot_mappings = [] batched_req_info = [] # Store (token_ids, block_ids) for verification - + # Create multiple get_match requests for i in range(batch_size): token_ids, block_ids, dp_id = request_pairs[random.randint(0, num_requests - 1)] slot_mapping = block_ids_2_slot_mapping(block_ids, tokens_per_block) - + request_id, return_mask = kvmanager.get_match( token_ids=token_ids, layer_granularity=-1, @@ -356,7 +356,7 @@ def test_kvmanager(model_config, cache_config, test_config, gpu_layout_type): batched_slot_mappings.append(slot_mapping) batched_req_info.append((token_ids, block_ids, request_id)) print(f"Created get_match request {request_id} for request_pair[{i}]") - + # Launch all get requests as a batch print(f"Launching {len(batched_get_task_ids)} get requests as batch...") batch_id = kvmanager.launch( @@ -365,12 +365,12 @@ def test_kvmanager(model_config, cache_config, test_config, gpu_layout_type): as_batch=True )[0] print(f"Returned task_ids after batch launch: {batch_id}") - + # Wait for the batched get to complete # When as_batch=True, launch returns [batch_id], we need to wait on batch_id batch_results = kvmanager.wait(batch_id, completely=True) print(f"Batch wait returned {len(batch_results)} results") - + # Verify results batched_cache_hit = 0 batched_cache_miss = 0 @@ -381,7 +381,7 @@ def test_kvmanager(model_config, cache_config, test_config, gpu_layout_type): batched_cache_hit += return_mask.sum().item() batched_cache_miss += len(return_mask) - return_mask.sum().item() print(f"Task {batch_id}: cache_hit={batched_cache_hit}, cache_miss={batched_cache_miss}") - + # GPU KV cache verification for batched get if gpu_kv_verifier is not None: for idx, (token_ids, block_ids, req_id) in enumerate(batched_req_info): @@ -395,9 +395,9 @@ def test_kvmanager(model_config, cache_config, test_config, gpu_layout_type): token_ids[:valid_fetched_tokens], block_ids[:valid_fetched_tokens // tokens_per_block] ) - + print(f"Batched get test completed: hit={batched_cache_hit}, miss={batched_cache_miss}") - + # Since we read data that was written before, cache hit should be high if enable_cpu and num_cpu_blocks >= num_gpu_blocks: assert batched_cache_miss == 0, \