From 5b6141c75fc8a0445829b1e7674855f83e0356b2 Mon Sep 17 00:00:00 2001 From: Cruz Zhao Date: Thu, 11 Dec 2025 22:30:24 +0800 Subject: [PATCH 01/24] add get_tensor_into() and batch_get_tensor_into() Signed-off-by: Cruz Zhao --- mooncake-integration/store/store_py.cpp | 179 ++++++++++++++++++++++++ mooncake-store/include/dummy_client.h | 10 +- mooncake-store/include/pyclient.h | 8 ++ mooncake-store/src/dummy_client.cpp | 25 +++- 4 files changed, 216 insertions(+), 6 deletions(-) diff --git a/mooncake-integration/store/store_py.cpp b/mooncake-integration/store/store_py.cpp index 4b0aaf28b..cdc130839 100644 --- a/mooncake-integration/store/store_py.cpp +++ b/mooncake-integration/store/store_py.cpp @@ -320,6 +320,176 @@ class MooncakeStorePyWrapper { return results_list; } + int64_t get_tensor_into(const std::string &key, uintptr_t buffer_ptr, + size_t size) { + void *buffer = reinterpret_cast(buffer_ptr); + if (!is_client_initialized()) { + LOG(ERROR) << "Client is not initialized"; + return to_py_ret(ErrorCode::INVALID_PARAMS); + } + + if (use_dummy_client_) { + LOG(ERROR) << "get_tensor is not supported for dummy client now"; + return to_py_ret(ErrorCode::INVALID_PARAMS); + } + + try { + // Section with GIL released + py::gil_scoped_release release_gil; + auto total_length = store_->get_into_internal(key, buffer, size); + if (!total_length.has_value()) { + py::gil_scoped_acquire acquire_gil; + return to_py_ret(ErrorCode::INVALID_PARAMS); + } + + TensorMetadata metadata; + // Copy data from buffer to contiguous memory + memcpy(&metadata, static_cast(buffer), + sizeof(TensorMetadata)); + + if (metadata.ndim < 0 || metadata.ndim > 4) { + py::gil_scoped_acquire acquire_gil; + LOG(ERROR) << "Invalid tensor metadata: ndim=" << metadata.ndim; + return to_py_ret(ErrorCode::INVALID_PARAMS); + } + + TensorDtype dtype_enum = static_cast(metadata.dtype); + if (dtype_enum == TensorDtype::UNKNOWN) { + py::gil_scoped_acquire acquire_gil; + LOG(ERROR) << "Unknown tensor dtype!"; + return to_py_ret(ErrorCode::INVALID_PARAMS); + } + + size_t tensor_size = total_length.value() - sizeof(TensorMetadata); + if (tensor_size == 0) { + py::gil_scoped_acquire acquire_gil; + LOG(ERROR) << "Invalid data format: no tensor data found"; + return to_py_ret(ErrorCode::INVALID_PARAMS); + } + + py::gil_scoped_acquire acquire_gil; + // Convert bytes to tensor using torch.from_numpy + pybind11::object np_array; + int dtype_index = static_cast(dtype_enum); + if (dtype_index < 0 || + dtype_index >= static_cast(array_creators.size())) { + LOG(ERROR) << "Unsupported dtype enum: " << dtype_index; + return to_py_ret(ErrorCode::INVALID_PARAMS); + } + + return total_length.value(); + + } catch (const pybind11::error_already_set &e) { + LOG(ERROR) << "Failed to get tensor data: " << e.what(); + return to_py_ret(ErrorCode::INVALID_PARAMS); + } + } + + pybind11::list batch_get_tensor_into(const std::vector &keys, + const std::vector &buffer_ptrs, + const std::vector &sizes) { + std::vector buffers; + buffers.reserve(buffer_ptrs.size()); + for (uintptr_t ptr : buffer_ptrs) { + buffers.push_back(reinterpret_cast(ptr)); + } + + if (!is_client_initialized()) { + LOG(ERROR) << "Client is not initialized"; + py::list empty_list; + for (size_t i = 0; i < keys.size(); ++i) { + empty_list.append(to_py_ret(ErrorCode::INVALID_PARAMS)); + } + return empty_list; + } + + if (use_dummy_client_) { + LOG(ERROR) << "batch_get_tensor is not supported for dummy client " + "now"; + py::list empty_list; + for (size_t i = 0; i < keys.size(); ++i) { + empty_list.append(to_py_ret(ErrorCode::INVALID_PARAMS)); + } + return empty_list; + } + + // Phase 1: Batch Get Buffers (GIL Released) + py::gil_scoped_release release_gil; + // This internal call already handles logging for query failures + auto total_lengths = + store_->batch_get_into_internal(keys, buffers, sizes); + + py::list results_list; + try { + py::gil_scoped_acquire acquire_gil; + auto torch = torch_module(); + + for (size_t i = 0; i < total_lengths.size(); i++) { + const auto &buffer = buffers[i]; + if (!buffer) { + results_list.append(to_py_ret(ErrorCode::INVALID_PARAMS)); + continue; + } + + auto total_length = total_lengths[i]; + if (!total_length.has_value()) { + LOG(ERROR) << "Invalid data format: insufficient data for" + "metadata"; + results_list.append(to_py_ret(ErrorCode::INVALID_PARAMS)); + continue; + } + if (total_length.value() <= + static_cast(sizeof(TensorMetadata))) { + LOG(ERROR) << "Invalid data format: insufficient data for " + "metadata"; + results_list.append(to_py_ret(ErrorCode::INVALID_PARAMS)); + continue; + } + + TensorMetadata metadata; + memcpy(&metadata, static_cast(buffer), + sizeof(TensorMetadata)); + + if (metadata.ndim < 0 || metadata.ndim > 4) { + LOG(ERROR) + << "Invalid tensor metadata: ndim=" << metadata.ndim; + results_list.append(to_py_ret(ErrorCode::INVALID_PARAMS)); + continue; + } + + TensorDtype dtype_enum = + static_cast(metadata.dtype); + if (dtype_enum == TensorDtype::UNKNOWN) { + LOG(ERROR) << "Unknown tensor dtype!"; + results_list.append(to_py_ret(ErrorCode::INVALID_PARAMS)); + continue; + } + + size_t tensor_size = + total_length.value() - sizeof(TensorMetadata); + if (tensor_size == 0) { + LOG(ERROR) << "Invalid data format: no tensor data found"; + results_list.append(to_py_ret(ErrorCode::INVALID_PARAMS)); + continue; + } + + int dtype_index = static_cast(dtype_enum); + if (dtype_index < 0 || + dtype_index >= static_cast(array_creators.size())) { + LOG(ERROR) << "Unsupported dtype enum: " << dtype_index; + results_list.append(to_py_ret(ErrorCode::INVALID_PARAMS)); + continue; + } + + results_list.append(total_length.value()); + } + } catch (const pybind11::error_already_set &e) { + LOG(ERROR) << "Failed during batch tensor deserialization: " + << e.what(); + } + return results_list; + } + int put_tensor_impl(const std::string &key, pybind11::object tensor, const ReplicateConfig &config) { // Validation & Metadata extraction (GIL Held) @@ -927,6 +1097,15 @@ PYBIND11_MODULE(store, m) { .def("pub_tensor", &MooncakeStorePyWrapper::pub_tensor, py::arg("key"), py::arg("tensor"), py::arg("config") = ReplicateConfig{}, "Publish a PyTorch tensor with configurable replication settings") + .def("get_tensor_into", &MooncakeStorePyWrapper::get_tensor_into, + py::arg("key"), py::arg("buffer_ptr"), py::arg("size"), + "Get tensor directly into a pre-allocated buffer") + .def("batch_get_tensor_into", + &MooncakeStorePyWrapper::batch_get_tensor_into, py::arg("keys"), + py::arg("buffer_ptrs"), py::arg("sizes"), + "Get tensors directly into pre-allocated buffers for " + "multiple " + "keys") .def( "register_buffer", [](MooncakeStorePyWrapper &self, uintptr_t buffer_ptr, diff --git a/mooncake-store/include/dummy_client.h b/mooncake-store/include/dummy_client.h index 7f3a56986..fc83035d0 100644 --- a/mooncake-store/include/dummy_client.h +++ b/mooncake-store/include/dummy_client.h @@ -82,8 +82,16 @@ class DummyClient : public PyClient { int unregister_buffer(void *buffer); + tl::expected get_into_internal(const std::string &key, + void *buffer, + size_t size); + int64_t get_into(const std::string &key, void *buffer, size_t size); + std::vector> batch_get_into_internal( + const std::vector &keys, + const std::vector &buffers, const std::vector &sizes); + std::vector batch_get_into(const std::vector &keys, const std::vector &buffers, const std::vector &sizes); @@ -228,4 +236,4 @@ class DummyClient : public PyClient { volatile bool connected_ = false; }; -} // namespace mooncake \ No newline at end of file +} // namespace mooncake diff --git a/mooncake-store/include/pyclient.h b/mooncake-store/include/pyclient.h index 0ec6e56fa..542c92209 100644 --- a/mooncake-store/include/pyclient.h +++ b/mooncake-store/include/pyclient.h @@ -54,9 +54,17 @@ class PyClient { virtual int unregister_buffer(void *buffer) = 0; + virtual tl::expected get_into_internal( + const std::string &key, void *buffer, size_t size) = 0; + virtual int64_t get_into(const std::string &key, void *buffer, size_t size) = 0; + virtual std::vector> + batch_get_into_internal(const std::vector &keys, + const std::vector &buffers, + const std::vector &sizes) = 0; + virtual std::vector batch_get_into( const std::vector &keys, const std::vector &buffers, diff --git a/mooncake-store/src/dummy_client.cpp b/mooncake-store/src/dummy_client.cpp index b7df019a9..6b6cbec9f 100644 --- a/mooncake-store/src/dummy_client.cpp +++ b/mooncake-store/src/dummy_client.cpp @@ -563,6 +563,12 @@ std::vector> DummyClient::batch_get_buffer( return std::vector>(); } +tl::expected DummyClient::get_into_internal( + const std::string& key, void* buffer, size_t size) { + // TODO: implement this function + return tl::unexpected(ErrorCode::INVALID_PARAMS); +} + int64_t DummyClient::get_into(const std::string& key, void* buffer, size_t size) { // TODO: implement this function @@ -600,16 +606,25 @@ int DummyClient::put_from(const std::string& key, void* buffer, size_t size, return -1; } -std::vector DummyClient::batch_get_into( - const std::vector& keys, const std::vector& buffer_ptrs, - const std::vector& sizes) { +std::vector> +DummyClient::batch_get_into_internal(const std::vector& keys, + const std::vector& buffer_ptrs, + const std::vector& sizes) { std::vector buffers; for (auto ptr : buffer_ptrs) { buffers.push_back(reinterpret_cast(ptr)); } - auto internal_results = + auto results = invoke_batch_rpc<&RealClient::batch_get_into_dummy_helper, int64_t>( keys.size(), keys, buffers, sizes, client_id_); + + return results; +} + +std::vector DummyClient::batch_get_into( + const std::vector& keys, const std::vector& buffer_ptrs, + const std::vector& sizes) { + auto internal_results = batch_get_into_internal(keys, buffer_ptrs, sizes); std::vector results; results.reserve(internal_results.size()); @@ -755,4 +770,4 @@ void DummyClient::ping_thread_main() { } } -} // namespace mooncake \ No newline at end of file +} // namespace mooncake From 1cb3fd9ed95d3c26b432b163add472daf2bf9c72 Mon Sep 17 00:00:00 2001 From: Cruz Zhao Date: Sun, 14 Dec 2025 10:24:04 +0800 Subject: [PATCH 02/24] fix coredump Signed-off-by: Cruz Zhao --- mooncake-integration/store/store_py.cpp | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/mooncake-integration/store/store_py.cpp b/mooncake-integration/store/store_py.cpp index cdc130839..3f265c6ac 100644 --- a/mooncake-integration/store/store_py.cpp +++ b/mooncake-integration/store/store_py.cpp @@ -414,10 +414,12 @@ class MooncakeStorePyWrapper { } // Phase 1: Batch Get Buffers (GIL Released) - py::gil_scoped_release release_gil; - // This internal call already handles logging for query failures - auto total_lengths = - store_->batch_get_into_internal(keys, buffers, sizes); + std::vector> total_lengths; + { + py::gil_scoped_release release_gil; + // This internal call already handles logging for query failures + total_lengths = store_->batch_get_into_internal(keys, buffers, sizes); + } py::list results_list; try { From 55f2b35daacb872d54276d9b1b8f3ccba757cb88 Mon Sep 17 00:00:00 2001 From: Cruz Zhao Date: Sun, 14 Dec 2025 12:25:00 +0800 Subject: [PATCH 03/24] add test case for batch_get_tensor_into Signed-off-by: Cruz Zhao --- scripts/test_tensor_api.py | 187 ++++++++++++++++++++++++++++++++++++- 1 file changed, 186 insertions(+), 1 deletion(-) diff --git a/scripts/test_tensor_api.py b/scripts/test_tensor_api.py index 61d65337d..9b9ede59d 100644 --- a/scripts/test_tensor_api.py +++ b/scripts/test_tensor_api.py @@ -1,3 +1,4 @@ +import ctypes import os import sys import json @@ -5,6 +6,7 @@ import argparse import unittest import torch +import struct import numpy as np from dataclasses import dataclass from mooncake.store import MooncakeDistributedStore @@ -25,6 +27,120 @@ DEFAULT_MASTER_METRICS_PORT = 9003 DEFAULT_CHECK_SERVER = False +DTYPE_MAP = { + 0: np.float32, # FLOAT32 + 1: np.float64, # FLOAT64 + 2: np.int8, # INT8 + 3: np.uint8, # UINT8 + 4: np.int16, # INT16 + 5: np.uint16, # UINT16 + 6: np.int32, # INT32 + 7: np.uint32, # UINT32 + 8: np.int64, # INT64 + 9: np.uint64, # UINT64 + 10: np.bool_, # BOOL + 11: np.float16, # FLOAT16 + # Note: BFLOAT16 (12), FLOAT8 (13,14), W8A8 (15) not supported in NumPy +} + +def parse_tensor_from_buffer(buffer_view): + """ + 从 memoryview 中解析 tensor。 + 假设格式: [TensorMetadata (40 bytes)][raw data] + TensorMetadata layout (C++): + int32_t dtype; // offset 0 + int32_t ndim; // offset 4 + uint64_t shape[4]; // offsets 8,16,24,32 + """ + if len(buffer_view) < 40: + raise ValueError(f"Buffer too small for TensorMetadata (got {len(buffer_view)} bytes, need >=40)") + + # 解析 metadata + dtype_enum = struct.unpack_from(' 4: + raise ValueError(f"Invalid ndim: {ndim}") + + shape = [] + for i in range(4): + dim = struct.unpack_from(' 0 else () + + # 映射 dtype + if dtype_enum not in DTYPE_MAP: + raise ValueError(f"Unsupported or unknown TensorDtype enum: {dtype_enum}") + np_dtype = DTYPE_MAP[dtype_enum] + + # 计算数据部分 + data_start = 40 + if len(buffer_view) <= data_start: + raise ValueError("No tensor data found after metadata") + + raw_data = buffer_view[data_start:] # memoryview slice → still bytes-like + + # 构造 NumPy 数组(零拷贝) + try: + arr = np.frombuffer(raw_data, dtype=np_dtype) + if arr.size == 0 and np.prod(actual_shape) != 0: + raise ValueError("Data size mismatch") + tensor = torch.from_numpy(arr.reshape(actual_shape)) + return tensor + except Exception as e: + raise ValueError(f"Failed to construct tensor from buffer: {e}") + + +def verify_tensor_equality(original, received, rtol=0, atol=0, verbose=True): + """ + 验证两个张量是否完全一致(逐元素精确比较)。 + """ + def to_numpy(x): + if isinstance(x, torch.Tensor): + if x.is_cuda: + x = x.cpu() + return x.detach().numpy() + elif isinstance(x, np.ndarray): + return x + else: + raise TypeError(f"Unsupported tensor type: {type(x)}") + + try: + orig_np = to_numpy(original) + recv_np = to_numpy(received) + except Exception as e: + if verbose: + print(f"❌ Error converting tensors: {e}") + return False + + if orig_np.shape != recv_np.shape: + if verbose: + print(f"❌ Shape mismatch: original {orig_np.shape} vs received {recv_np.shape}") + return False + + if orig_np.dtype != recv_np.dtype: + if verbose: + print(f"❌ Dtype mismatch: original {orig_np.dtype} vs received {recv_np.dtype}") + return False + + if np.array_equal(orig_np, recv_np): +# if verbose: +# print("✅ Tensors are identical!") + return True + else: + diff_mask = orig_np != recv_np + diff_indices = np.where(diff_mask) + if len(diff_indices[0]) > 0: + first_diff_idx = tuple(idx[0] for idx in diff_indices) + orig_val = orig_np[first_diff_idx] + recv_val = recv_np[first_diff_idx] + if verbose: + print(f"❌ Tensors differ at index {first_diff_idx}") + print(f" Original: {orig_val}") + print(f" Received: {recv_val}") + print(f" Difference: {abs(orig_val - recv_val)}") + return False + def parse_global_segment_size(value) -> int: """Parse human-readable size strings (e.g., '4GB') into bytes.""" if isinstance(value, int): return value @@ -311,6 +427,75 @@ def test_benchmark_02_tp_batch(self): self._print_perf(f"TP Batch Put (TP={tp_size})", put_times) self._print_perf(f"TP Batch Get (TP={tp_size})", get_times) + def test_benchmark_03_batch_put_get_into(self): + """Benchmark: Standard Batch Put/Get.""" + buffer_spacing = 1024 * 1024 * 1024 # 1GB per tensor slot + batch_size = len(self.keys) + total_buffer_size = buffer_spacing * batch_size + + # Allocate large contiguous buffer + large_buffer = (ctypes.c_ubyte * total_buffer_size)() + large_buffer_ptr = ctypes.addressof(large_buffer) + + # Prepare pointers (only addresses, no ctypes array objects) + buffer_ptrs = [] + buffer_sizes = [] + for i in range(batch_size): + offset = i * buffer_spacing + ptr = large_buffer_ptr + offset + buffer_ptrs.append(ptr) + buffer_sizes.append(buffer_spacing) + + # Register the entire buffer with the store + res = self.store.register_buffer(large_buffer_ptr, total_buffer_size) + self.assertEqual(res, 0, "Buffer registration should succeed") + + print(f"--- Running Standard Batch Benchmark ({self.BENCH_ITERATIONS} iters) ---") + put_times = [] + get_times = [] + + for i in range(self.BENCH_ITERATIONS): + self.store.remove_all() + + # Measure Put + t0 = time.perf_counter() + self.store.batch_put_tensor(self.keys, self.tensors) + put_times.append(time.perf_counter() - t0) + + # Measure Get + t0 = time.perf_counter() + bytes_read_list = self.store.batch_get_tensor_into(self.keys, buffer_ptrs, buffer_sizes) + get_times.append(time.perf_counter() - t0) + + # Validate results + self.assertEqual(len(bytes_read_list), batch_size) + for j in range(batch_size): + bytes_read = bytes_read_list[j] + self.assertGreater(bytes_read, 0, f"Tensor {j} read failed (bytes={bytes_read})") + + # ✅ Create memoryview slice for this tensor only + offset = j * buffer_spacing + tensor_mv = memoryview(large_buffer)[offset : offset + bytes_read] + + try: + reconstructed_tensor = parse_tensor_from_buffer(tensor_mv) + except Exception as e: + self.fail(f"Failed to parse tensor {j}: {e}") + + self.assertTrue( + verify_tensor_equality(self.tensors[j], reconstructed_tensor), + f"Tensor {j} content mismatch" + ) + + self._print_perf("Standard Batch Put", put_times) + self._print_perf("Standard Batch Get", get_times) + + # Unregister buffer + self.assertEqual( + self.store.unregister_buffer(large_buffer_ptr), + 0, + "Buffer unregistration should succeed" + ) # ========================================== # Stress/Concurrency Tests @@ -567,4 +752,4 @@ def test_fp8_types(self): runner = unittest.TextTestRunner(verbosity=2) result = runner.run(suite) - sys.exit(not result.wasSuccessful()) \ No newline at end of file + sys.exit(not result.wasSuccessful()) From 2d1f66ec521b09983d4098810ed2247377826615 Mon Sep 17 00:00:00 2001 From: Cruz Zhao Date: Sun, 14 Dec 2025 12:28:26 +0800 Subject: [PATCH 04/24] fix discription Signed-off-by: Cruz Zhao --- scripts/test_tensor_api.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/scripts/test_tensor_api.py b/scripts/test_tensor_api.py index 9b9ede59d..4bc293012 100644 --- a/scripts/test_tensor_api.py +++ b/scripts/test_tensor_api.py @@ -428,7 +428,7 @@ def test_benchmark_02_tp_batch(self): self._print_perf(f"TP Batch Get (TP={tp_size})", get_times) def test_benchmark_03_batch_put_get_into(self): - """Benchmark: Standard Batch Put/Get.""" + """Benchmark: Zero copy Batch Get.""" buffer_spacing = 1024 * 1024 * 1024 # 1GB per tensor slot batch_size = len(self.keys) total_buffer_size = buffer_spacing * batch_size @@ -450,7 +450,7 @@ def test_benchmark_03_batch_put_get_into(self): res = self.store.register_buffer(large_buffer_ptr, total_buffer_size) self.assertEqual(res, 0, "Buffer registration should succeed") - print(f"--- Running Standard Batch Benchmark ({self.BENCH_ITERATIONS} iters) ---") + print(f"--- Running zero copy Batch Benchmark ({self.BENCH_ITERATIONS} iters) ---") put_times = [] get_times = [] @@ -488,7 +488,7 @@ def test_benchmark_03_batch_put_get_into(self): ) self._print_perf("Standard Batch Put", put_times) - self._print_perf("Standard Batch Get", get_times) + self._print_perf("Zero copy Batch Get", get_times) # Unregister buffer self.assertEqual( From 4a602dd08e9beb3576eab6069b8eb8b1db0f291b Mon Sep 17 00:00:00 2001 From: Cruz Zhao Date: Sun, 14 Dec 2025 12:29:17 +0800 Subject: [PATCH 05/24] fix format Signed-off-by: Cruz Zhao --- mooncake-integration/store/store_py.cpp | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/mooncake-integration/store/store_py.cpp b/mooncake-integration/store/store_py.cpp index 3f265c6ac..685012bb0 100644 --- a/mooncake-integration/store/store_py.cpp +++ b/mooncake-integration/store/store_py.cpp @@ -321,8 +321,8 @@ class MooncakeStorePyWrapper { } int64_t get_tensor_into(const std::string &key, uintptr_t buffer_ptr, - size_t size) { - void *buffer = reinterpret_cast(buffer_ptr); + size_t size) { + void *buffer = reinterpret_cast(buffer_ptr); if (!is_client_initialized()) { LOG(ERROR) << "Client is not initialized"; return to_py_ret(ErrorCode::INVALID_PARAMS); @@ -385,9 +385,10 @@ class MooncakeStorePyWrapper { } } - pybind11::list batch_get_tensor_into(const std::vector &keys, - const std::vector &buffer_ptrs, - const std::vector &sizes) { + pybind11::list batch_get_tensor_into( + const std::vector &keys, + const std::vector &buffer_ptrs, + const std::vector &sizes) { std::vector buffers; buffers.reserve(buffer_ptrs.size()); for (uintptr_t ptr : buffer_ptrs) { @@ -414,12 +415,13 @@ class MooncakeStorePyWrapper { } // Phase 1: Batch Get Buffers (GIL Released) - std::vector> total_lengths; - { + std::vector> total_lengths; + { py::gil_scoped_release release_gil; // This internal call already handles logging for query failures - total_lengths = store_->batch_get_into_internal(keys, buffers, sizes); - } + total_lengths = + store_->batch_get_into_internal(keys, buffers, sizes); + } py::list results_list; try { From 77106051d530505381dadd527ce07ade96ec9ff0 Mon Sep 17 00:00:00 2001 From: Cruz Zhao Date: Mon, 15 Dec 2025 15:11:25 +0800 Subject: [PATCH 06/24] fix Signed-off-by: Cruz Zhao --- mooncake-integration/store/store_py.cpp | 22 ++++++++++------------ mooncake-store/include/dummy_client.h | 8 -------- mooncake-store/include/pyclient.h | 8 -------- mooncake-store/src/dummy_client.cpp | 23 ++++------------------- scripts/test_tensor_api.py | 12 +++++------- 5 files changed, 19 insertions(+), 54 deletions(-) diff --git a/mooncake-integration/store/store_py.cpp b/mooncake-integration/store/store_py.cpp index 685012bb0..f4b99cc77 100644 --- a/mooncake-integration/store/store_py.cpp +++ b/mooncake-integration/store/store_py.cpp @@ -336,8 +336,8 @@ class MooncakeStorePyWrapper { try { // Section with GIL released py::gil_scoped_release release_gil; - auto total_length = store_->get_into_internal(key, buffer, size); - if (!total_length.has_value()) { + auto total_length = store_->get_into(key, buffer, size); + if (total_length <= 0) { py::gil_scoped_acquire acquire_gil; return to_py_ret(ErrorCode::INVALID_PARAMS); } @@ -360,7 +360,7 @@ class MooncakeStorePyWrapper { return to_py_ret(ErrorCode::INVALID_PARAMS); } - size_t tensor_size = total_length.value() - sizeof(TensorMetadata); + size_t tensor_size = total_length - sizeof(TensorMetadata); if (tensor_size == 0) { py::gil_scoped_acquire acquire_gil; LOG(ERROR) << "Invalid data format: no tensor data found"; @@ -377,7 +377,7 @@ class MooncakeStorePyWrapper { return to_py_ret(ErrorCode::INVALID_PARAMS); } - return total_length.value(); + return total_length; } catch (const pybind11::error_already_set &e) { LOG(ERROR) << "Failed to get tensor data: " << e.what(); @@ -415,12 +415,11 @@ class MooncakeStorePyWrapper { } // Phase 1: Batch Get Buffers (GIL Released) - std::vector> total_lengths; + std::vector total_lengths; { py::gil_scoped_release release_gil; // This internal call already handles logging for query failures - total_lengths = - store_->batch_get_into_internal(keys, buffers, sizes); + total_lengths = store_->batch_get_into(keys, buffers, sizes); } py::list results_list; @@ -436,13 +435,13 @@ class MooncakeStorePyWrapper { } auto total_length = total_lengths[i]; - if (!total_length.has_value()) { + if (total_length <= 0) { LOG(ERROR) << "Invalid data format: insufficient data for" "metadata"; results_list.append(to_py_ret(ErrorCode::INVALID_PARAMS)); continue; } - if (total_length.value() <= + if (total_length <= static_cast(sizeof(TensorMetadata))) { LOG(ERROR) << "Invalid data format: insufficient data for " "metadata"; @@ -469,8 +468,7 @@ class MooncakeStorePyWrapper { continue; } - size_t tensor_size = - total_length.value() - sizeof(TensorMetadata); + size_t tensor_size = total_length - sizeof(TensorMetadata); if (tensor_size == 0) { LOG(ERROR) << "Invalid data format: no tensor data found"; results_list.append(to_py_ret(ErrorCode::INVALID_PARAMS)); @@ -485,7 +483,7 @@ class MooncakeStorePyWrapper { continue; } - results_list.append(total_length.value()); + results_list.append(total_length); } } catch (const pybind11::error_already_set &e) { LOG(ERROR) << "Failed during batch tensor deserialization: " diff --git a/mooncake-store/include/dummy_client.h b/mooncake-store/include/dummy_client.h index fc83035d0..fde0089ea 100644 --- a/mooncake-store/include/dummy_client.h +++ b/mooncake-store/include/dummy_client.h @@ -82,16 +82,8 @@ class DummyClient : public PyClient { int unregister_buffer(void *buffer); - tl::expected get_into_internal(const std::string &key, - void *buffer, - size_t size); - int64_t get_into(const std::string &key, void *buffer, size_t size); - std::vector> batch_get_into_internal( - const std::vector &keys, - const std::vector &buffers, const std::vector &sizes); - std::vector batch_get_into(const std::vector &keys, const std::vector &buffers, const std::vector &sizes); diff --git a/mooncake-store/include/pyclient.h b/mooncake-store/include/pyclient.h index 542c92209..0ec6e56fa 100644 --- a/mooncake-store/include/pyclient.h +++ b/mooncake-store/include/pyclient.h @@ -54,17 +54,9 @@ class PyClient { virtual int unregister_buffer(void *buffer) = 0; - virtual tl::expected get_into_internal( - const std::string &key, void *buffer, size_t size) = 0; - virtual int64_t get_into(const std::string &key, void *buffer, size_t size) = 0; - virtual std::vector> - batch_get_into_internal(const std::vector &keys, - const std::vector &buffers, - const std::vector &sizes) = 0; - virtual std::vector batch_get_into( const std::vector &keys, const std::vector &buffers, diff --git a/mooncake-store/src/dummy_client.cpp b/mooncake-store/src/dummy_client.cpp index 6b6cbec9f..41c113a66 100644 --- a/mooncake-store/src/dummy_client.cpp +++ b/mooncake-store/src/dummy_client.cpp @@ -563,12 +563,6 @@ std::vector> DummyClient::batch_get_buffer( return std::vector>(); } -tl::expected DummyClient::get_into_internal( - const std::string& key, void* buffer, size_t size) { - // TODO: implement this function - return tl::unexpected(ErrorCode::INVALID_PARAMS); -} - int64_t DummyClient::get_into(const std::string& key, void* buffer, size_t size) { // TODO: implement this function @@ -606,25 +600,16 @@ int DummyClient::put_from(const std::string& key, void* buffer, size_t size, return -1; } -std::vector> -DummyClient::batch_get_into_internal(const std::vector& keys, - const std::vector& buffer_ptrs, - const std::vector& sizes) { +std::vector DummyClient::batch_get_into( + const std::vector& keys, const std::vector& buffer_ptrs, + const std::vector& sizes) { std::vector buffers; for (auto ptr : buffer_ptrs) { buffers.push_back(reinterpret_cast(ptr)); } - auto results = + auto internal_results = invoke_batch_rpc<&RealClient::batch_get_into_dummy_helper, int64_t>( keys.size(), keys, buffers, sizes, client_id_); - - return results; -} - -std::vector DummyClient::batch_get_into( - const std::vector& keys, const std::vector& buffer_ptrs, - const std::vector& sizes) { - auto internal_results = batch_get_into_internal(keys, buffer_ptrs, sizes); std::vector results; results.reserve(internal_results.size()); diff --git a/scripts/test_tensor_api.py b/scripts/test_tensor_api.py index 4bc293012..d750b4cf6 100644 --- a/scripts/test_tensor_api.py +++ b/scripts/test_tensor_api.py @@ -45,8 +45,8 @@ def parse_tensor_from_buffer(buffer_view): """ - 从 memoryview 中解析 tensor。 - 假设格式: [TensorMetadata (40 bytes)][raw data] + parse tensor from memoryview + format: [TensorMetadata (40 bytes)][raw data] TensorMetadata layout (C++): int32_t dtype; // offset 0 int32_t ndim; // offset 4 @@ -55,7 +55,7 @@ def parse_tensor_from_buffer(buffer_view): if len(buffer_view) < 40: raise ValueError(f"Buffer too small for TensorMetadata (got {len(buffer_view)} bytes, need >=40)") - # 解析 metadata + # parse metadata dtype_enum = struct.unpack_from(' 0 else () - # 映射 dtype + # map dtype if dtype_enum not in DTYPE_MAP: raise ValueError(f"Unsupported or unknown TensorDtype enum: {dtype_enum}") np_dtype = DTYPE_MAP[dtype_enum] - # 计算数据部分 data_start = 40 if len(buffer_view) <= data_start: raise ValueError("No tensor data found after metadata") raw_data = buffer_view[data_start:] # memoryview slice → still bytes-like - # 构造 NumPy 数组(零拷贝) try: arr = np.frombuffer(raw_data, dtype=np_dtype) if arr.size == 0 and np.prod(actual_shape) != 0: @@ -93,7 +91,7 @@ def parse_tensor_from_buffer(buffer_view): def verify_tensor_equality(original, received, rtol=0, atol=0, verbose=True): """ - 验证两个张量是否完全一致(逐元素精确比较)。 + compare two tensors。 """ def to_numpy(x): if isinstance(x, torch.Tensor): From 469da73eb754614da46a1da75b7b65a9b33d788e Mon Sep 17 00:00:00 2001 From: Cruz Zhao Date: Mon, 15 Dec 2025 15:13:50 +0800 Subject: [PATCH 07/24] fix format Signed-off-by: Cruz Zhao --- mooncake-integration/store/store_py.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/mooncake-integration/store/store_py.cpp b/mooncake-integration/store/store_py.cpp index f4b99cc77..c53a84f5f 100644 --- a/mooncake-integration/store/store_py.cpp +++ b/mooncake-integration/store/store_py.cpp @@ -415,7 +415,7 @@ class MooncakeStorePyWrapper { } // Phase 1: Batch Get Buffers (GIL Released) - std::vector total_lengths; + std::vector total_lengths; { py::gil_scoped_release release_gil; // This internal call already handles logging for query failures @@ -441,8 +441,7 @@ class MooncakeStorePyWrapper { results_list.append(to_py_ret(ErrorCode::INVALID_PARAMS)); continue; } - if (total_length <= - static_cast(sizeof(TensorMetadata))) { + if (total_length <= static_cast(sizeof(TensorMetadata))) { LOG(ERROR) << "Invalid data format: insufficient data for " "metadata"; results_list.append(to_py_ret(ErrorCode::INVALID_PARAMS)); From 915b40f14966879f92f616e4235448485b8b0474 Mon Sep 17 00:00:00 2001 From: Cruz Zhao Date: Mon, 15 Dec 2025 19:11:28 +0800 Subject: [PATCH 08/24] fix return value Signed-off-by: Cruz Zhao --- mooncake-integration/integration_utils.h | 29 ++++++++++ mooncake-integration/store/store_py.cpp | 61 ++++++++++++++------ scripts/test_tensor_api.py | 71 ++---------------------- 3 files changed, 80 insertions(+), 81 deletions(-) diff --git a/mooncake-integration/integration_utils.h b/mooncake-integration/integration_utils.h index dfea8638a..c0c8bba24 100644 --- a/mooncake-integration/integration_utils.h +++ b/mooncake-integration/integration_utils.h @@ -62,6 +62,35 @@ static const std::array array_creators = {{ create_typed_array, // FLOAT8_E5M2 = 14 (using uint8_t as storage) }}; +template +py::array create_typed_array_without_free(char *data_ptr, size_t offset, + size_t total_length) { + return py::array_t({static_cast(total_length / sizeof(T))}, + (T *)(data_ptr + offset), py::none()); +} + +static const std::array array_creators_without_free = {{ + create_typed_array_without_free, // FLOAT32 = 0 + create_typed_array_without_free, // FLOAT64 = 1 + create_typed_array_without_free, // INT8 = 2 + create_typed_array_without_free, // UINT8 = 3 + create_typed_array_without_free, // INT16 = 4 + create_typed_array_without_free, // UINT16 = 5 + create_typed_array_without_free, // INT32 = 6 + create_typed_array_without_free, // UINT32 = 7 + create_typed_array_without_free, // INT64 = 8 + create_typed_array_without_free, // UINT64 = 9 + create_typed_array_without_free, // BOOL = 10 + create_typed_array_without_free, // FLOAT16 = 11 (using uint16_t + // as storage) + create_typed_array_without_free, // BFLOAT16 = 12 (using uint16_t + // as storage) + create_typed_array_without_free, // FLOAT8_E4M3 = 13 (using + // uint8_t as storage) + create_typed_array_without_free, // FLOAT8_E5M2 = 14 (using + // uint8_t as storage) +}}; + inline TensorDtype get_tensor_dtype(py::object dtype_obj) { if (dtype_obj.is_none()) { return TensorDtype::UNKNOWN; diff --git a/mooncake-integration/store/store_py.cpp b/mooncake-integration/store/store_py.cpp index c53a84f5f..85945a0e4 100644 --- a/mooncake-integration/store/store_py.cpp +++ b/mooncake-integration/store/store_py.cpp @@ -320,17 +320,17 @@ class MooncakeStorePyWrapper { return results_list; } - int64_t get_tensor_into(const std::string &key, uintptr_t buffer_ptr, - size_t size) { + pybind11::object get_tensor_into(const std::string &key, + uintptr_t buffer_ptr, size_t size) { void *buffer = reinterpret_cast(buffer_ptr); if (!is_client_initialized()) { LOG(ERROR) << "Client is not initialized"; - return to_py_ret(ErrorCode::INVALID_PARAMS); + return pybind11::none(); } if (use_dummy_client_) { LOG(ERROR) << "get_tensor is not supported for dummy client now"; - return to_py_ret(ErrorCode::INVALID_PARAMS); + return pybind11::none(); } try { @@ -339,7 +339,7 @@ class MooncakeStorePyWrapper { auto total_length = store_->get_into(key, buffer, size); if (total_length <= 0) { py::gil_scoped_acquire acquire_gil; - return to_py_ret(ErrorCode::INVALID_PARAMS); + return pybind11::none(); } TensorMetadata metadata; @@ -350,38 +350,52 @@ class MooncakeStorePyWrapper { if (metadata.ndim < 0 || metadata.ndim > 4) { py::gil_scoped_acquire acquire_gil; LOG(ERROR) << "Invalid tensor metadata: ndim=" << metadata.ndim; - return to_py_ret(ErrorCode::INVALID_PARAMS); + return pybind11::none(); } TensorDtype dtype_enum = static_cast(metadata.dtype); if (dtype_enum == TensorDtype::UNKNOWN) { py::gil_scoped_acquire acquire_gil; LOG(ERROR) << "Unknown tensor dtype!"; - return to_py_ret(ErrorCode::INVALID_PARAMS); + return pybind11::none(); } size_t tensor_size = total_length - sizeof(TensorMetadata); if (tensor_size == 0) { py::gil_scoped_acquire acquire_gil; LOG(ERROR) << "Invalid data format: no tensor data found"; - return to_py_ret(ErrorCode::INVALID_PARAMS); + return pybind11::none(); } py::gil_scoped_acquire acquire_gil; // Convert bytes to tensor using torch.from_numpy pybind11::object np_array; int dtype_index = static_cast(dtype_enum); - if (dtype_index < 0 || - dtype_index >= static_cast(array_creators.size())) { + if (dtype_index >= 0 && + dtype_index < static_cast(array_creators.size())) { + np_array = array_creators_without_free[dtype_index]( + static_cast(buffer), sizeof(TensorMetadata), + tensor_size); + } else { LOG(ERROR) << "Unsupported dtype enum: " << dtype_index; - return to_py_ret(ErrorCode::INVALID_PARAMS); + return pybind11::none(); } - return total_length; + if (metadata.ndim > 0) { + std::vector shape_vec; + for (int i = 0; i < metadata.ndim; i++) { + shape_vec.push_back(metadata.shape[i]); + } + py::tuple shape_tuple = py::cast(shape_vec); + np_array = np_array.attr("reshape")(shape_tuple); + } + pybind11::object tensor = + torch_module().attr("from_numpy")(np_array); + return tensor; } catch (const pybind11::error_already_set &e) { LOG(ERROR) << "Failed to get tensor data: " << e.what(); - return to_py_ret(ErrorCode::INVALID_PARAMS); + return pybind11::none(); } } @@ -474,15 +488,30 @@ class MooncakeStorePyWrapper { continue; } + pybind11::object np_array; int dtype_index = static_cast(dtype_enum); - if (dtype_index < 0 || - dtype_index >= static_cast(array_creators.size())) { + if (dtype_index >= 0 && + dtype_index < static_cast(array_creators.size())) { + // This call MUST take ownership of exported_data + np_array = array_creators_without_free[dtype_index]( + static_cast(buffer), sizeof(TensorMetadata), + tensor_size); + } else { LOG(ERROR) << "Unsupported dtype enum: " << dtype_index; results_list.append(to_py_ret(ErrorCode::INVALID_PARAMS)); continue; } - results_list.append(total_length); + if (metadata.ndim > 0) { + std::vector shape_vec; + for (int i = 0; i < metadata.ndim; i++) { + shape_vec.push_back(metadata.shape[i]); + } + py::tuple shape_tuple = py::cast(shape_vec); + np_array = np_array.attr("reshape")(shape_tuple); + } + pybind11::object tensor = torch.attr("from_numpy")(np_array); + results_list.append(tensor); } } catch (const pybind11::error_already_set &e) { LOG(ERROR) << "Failed during batch tensor deserialization: " diff --git a/scripts/test_tensor_api.py b/scripts/test_tensor_api.py index d750b4cf6..2d48ddae3 100644 --- a/scripts/test_tensor_api.py +++ b/scripts/test_tensor_api.py @@ -43,52 +43,6 @@ # Note: BFLOAT16 (12), FLOAT8 (13,14), W8A8 (15) not supported in NumPy } -def parse_tensor_from_buffer(buffer_view): - """ - parse tensor from memoryview - format: [TensorMetadata (40 bytes)][raw data] - TensorMetadata layout (C++): - int32_t dtype; // offset 0 - int32_t ndim; // offset 4 - uint64_t shape[4]; // offsets 8,16,24,32 - """ - if len(buffer_view) < 40: - raise ValueError(f"Buffer too small for TensorMetadata (got {len(buffer_view)} bytes, need >=40)") - - # parse metadata - dtype_enum = struct.unpack_from(' 4: - raise ValueError(f"Invalid ndim: {ndim}") - - shape = [] - for i in range(4): - dim = struct.unpack_from(' 0 else () - - # map dtype - if dtype_enum not in DTYPE_MAP: - raise ValueError(f"Unsupported or unknown TensorDtype enum: {dtype_enum}") - np_dtype = DTYPE_MAP[dtype_enum] - - data_start = 40 - if len(buffer_view) <= data_start: - raise ValueError("No tensor data found after metadata") - - raw_data = buffer_view[data_start:] # memoryview slice → still bytes-like - - try: - arr = np.frombuffer(raw_data, dtype=np_dtype) - if arr.size == 0 and np.prod(actual_shape) != 0: - raise ValueError("Data size mismatch") - tensor = torch.from_numpy(arr.reshape(actual_shape)) - return tensor - except Exception as e: - raise ValueError(f"Failed to construct tensor from buffer: {e}") - - def verify_tensor_equality(original, received, rtol=0, atol=0, verbose=True): """ compare two tensors。 @@ -360,7 +314,7 @@ def setUp(self): """Benchmark-specific setUp.""" # 1. Call parent setUp to clean the store (remove_all) super().setUp() - + # 2. Generate test data total_bytes = self.TOTAL_SIZE_GB * 1024**3 tensor_bytes = self.TENSOR_SIZE_MB * 1024**2 @@ -384,7 +338,7 @@ def test_benchmark_01_batch_put_get(self): for i in range(self.BENCH_ITERATIONS): # Clean store before each iteration for "cold" writes self.store.remove_all() - + # Measure Put t0 = time.perf_counter() self.store.batch_put_tensor(self.keys, self.tensors) @@ -453,6 +407,7 @@ def test_benchmark_03_batch_put_get_into(self): get_times = [] for i in range(self.BENCH_ITERATIONS): + # Clean store before each iteration for "cold" writes self.store.remove_all() # Measure Put @@ -462,26 +417,12 @@ def test_benchmark_03_batch_put_get_into(self): # Measure Get t0 = time.perf_counter() - bytes_read_list = self.store.batch_get_tensor_into(self.keys, buffer_ptrs, buffer_sizes) + res = self.store.batch_get_tensor_into(self.keys, buffer_ptrs, buffer_sizes) get_times.append(time.perf_counter() - t0) - - # Validate results - self.assertEqual(len(bytes_read_list), batch_size) + self.assertEqual(len(res), len(self.tensors)) for j in range(batch_size): - bytes_read = bytes_read_list[j] - self.assertGreater(bytes_read, 0, f"Tensor {j} read failed (bytes={bytes_read})") - - # ✅ Create memoryview slice for this tensor only - offset = j * buffer_spacing - tensor_mv = memoryview(large_buffer)[offset : offset + bytes_read] - - try: - reconstructed_tensor = parse_tensor_from_buffer(tensor_mv) - except Exception as e: - self.fail(f"Failed to parse tensor {j}: {e}") - self.assertTrue( - verify_tensor_equality(self.tensors[j], reconstructed_tensor), + verify_tensor_equality(self.tensors[j], res[j]), f"Tensor {j} content mismatch" ) From 9daaa9794504b0c6c6e4eae26e121c34f06dde90 Mon Sep 17 00:00:00 2001 From: Cruz Zhao Date: Tue, 16 Dec 2025 11:32:11 +0800 Subject: [PATCH 09/24] introduce get_tensor_into_with_tp() and batch_get_tensor_into_with_tp() Signed-off-by: Cruz Zhao --- mooncake-integration/store/store_py.cpp | 85 +++++++++++++++++++++++++ scripts/test_tensor_api.py | 61 ++++++++++++++++++ 2 files changed, 146 insertions(+) diff --git a/mooncake-integration/store/store_py.cpp b/mooncake-integration/store/store_py.cpp index 85945a0e4..9eb885393 100644 --- a/mooncake-integration/store/store_py.cpp +++ b/mooncake-integration/store/store_py.cpp @@ -520,6 +520,71 @@ class MooncakeStorePyWrapper { return results_list; } + pybind11::object get_tensor_into_with_tp(const std::string &key, + uintptr_t buffer_ptr, size_t size, + int tp_rank = 0, int tp_size = 1, + int split_dim = 0) { + if (!is_client_initialized()) { + LOG(ERROR) << "Client is not initialized"; + return pybind11::none(); + } + + if (use_dummy_client_) { + LOG(ERROR) + << "get_tensor_into is not supported for dummy client now"; + return pybind11::none(); + } + + if (tp_size <= 1) { + return get_tensor_into(key, buffer_ptr, size); + } + + // Construct the specific key for this rank: e.g., "key_tp_0" + std::string tp_key = get_tp_key_name(key, tp_rank); + + // Delegate to the standard get_tensor_into method + return get_tensor_into(tp_key, buffer_ptr, size); + } + + pybind11::list batch_get_tensor_into_with_tp( + const std::vector &base_keys, + const std::vector &buffer_ptrs, + const std::vector &sizes, int tp_rank = 0, int tp_size = 1) { + if (!is_client_initialized()) { + LOG(ERROR) << "Client is not initialized"; + py::list empty_list; + for (size_t i = 0; i < base_keys.size(); ++i) { + empty_list.append(py::none()); + } + return empty_list; + } + + if (use_dummy_client_) { + LOG(ERROR) << "batch_get_tensor_into_with_tp is not supported for " + "dummy client"; + py::list empty_list; + for (size_t i = 0; i < base_keys.size(); ++i) { + empty_list.append(py::none()); + } + return empty_list; + } + + // If tp_size is 1, it's just a normal batch_get_tensor_into + if (tp_size <= 1) { + return batch_get_tensor_into(base_keys, buffer_ptrs, sizes); + } + + // Generate the specific shard keys for the given tp_rank + std::vector shard_keys; + shard_keys.reserve(base_keys.size()); + for (const auto &key : base_keys) { + shard_keys.push_back(get_tp_key_name(key, tp_rank)); + } + + // Use the existing batch_get_tensor_into to fetch all shards at once + return batch_get_tensor_into(shard_keys, buffer_ptrs, sizes); + } + int put_tensor_impl(const std::string &key, pybind11::object tensor, const ReplicateConfig &config) { // Validation & Metadata extraction (GIL Held) @@ -1136,6 +1201,26 @@ PYBIND11_MODULE(store, m) { "Get tensors directly into pre-allocated buffers for " "multiple " "keys") + .def( + "get_tensor_into_with_tp", + &MooncakeStorePyWrapper::get_tensor_into_with_tp, py::arg("key"), + py::arg("buffer_ptr"), py::arg("size"), py::arg("tp_rank") = 0, + py::arg("tp_size") = 1, py::arg("split_dim") = 0, + "Get a PyTorch tensor from the store directly into a pre-allocated" + "buffer, optionally sliced for Tensor Parallelism.\n" + "Args:\n" + " key: The key of the tensor.\n" + " buffer_ptr: The buffer pointer pre-allocated for tensor.\n" + " size: The size of buffer.\n" + " tp_rank: The current tensor parallel rank (default 0).\n" + " tp_size: The total tensor parallel size (default 1).\n" + " split_dim: The dimension to split the tensor along (default 0).") + .def("batch_get_tensor_into_with_tp", + &MooncakeStorePyWrapper::batch_get_tensor_into_with_tp, + py::arg("base_keys"), py::arg("buffer_ptrs"), py::arg("sizes"), + py::arg("tp_rank") = 0, py::arg("tp_size") = 1, + "Get a batch of PyTorch tensor shards from the store directly into" + "pre-allocated buffers for a given Tensor Parallel rank.") .def( "register_buffer", [](MooncakeStorePyWrapper &self, uintptr_t buffer_ptr, diff --git a/scripts/test_tensor_api.py b/scripts/test_tensor_api.py index 2d48ddae3..e8b3e9a6b 100644 --- a/scripts/test_tensor_api.py +++ b/scripts/test_tensor_api.py @@ -436,6 +436,67 @@ def test_benchmark_03_batch_put_get_into(self): "Buffer unregistration should succeed" ) + def test_benchmark_04_batch_put_get_into_with_tp(self): + """Benchmark: Zero copy Batch Get with tp.""" + buffer_spacing = 1024 * 1024 * 1024 # 1GB per tensor slot + batch_size = len(self.keys) + total_buffer_size = buffer_spacing * batch_size + + # Allocate large contiguous buffer + large_buffer = (ctypes.c_ubyte * total_buffer_size)() + large_buffer_ptr = ctypes.addressof(large_buffer) + + # Prepare pointers (only addresses, no ctypes array objects) + buffer_ptrs = [] + buffer_sizes = [] + for i in range(batch_size): + offset = i * buffer_spacing + ptr = large_buffer_ptr + offset + buffer_ptrs.append(ptr) + buffer_sizes.append(buffer_spacing) + + # Register the entire buffer with the store + res = self.store.register_buffer(large_buffer_ptr, total_buffer_size) + self.assertEqual(res, 0, "Buffer registration should succeed") + + tp_size = 4 + print(f"--- Running zero copy Batch Benchmark (TP={tp_size}, {self.BENCH_ITERATIONS} iters) ---") + split_dim = 0 + put_times = [] + get_times = [] + + for i in range(self.BENCH_ITERATIONS): + # Clean store before each iteration for "cold" writes + self.store.remove_all() + + # Measure Put + t0 = time.perf_counter() + self.store.batch_put_tensor_with_tp(self.keys, self.tensors, tp_size=tp_size, split_dim=split_dim) + put_times.append(time.perf_counter() - t0) + + # Measure Get + t0 = time.perf_counter() + for rank in range(tp_size): + res = self.store.batch_get_tensor_into_with_tp(self.keys, buffer_ptrs, buffer_sizes, tp_rank=rank, tp_size=tp_size) + self.assertEqual(len(res), len(self.tensors)) + get_times.append(time.perf_counter() - t0) + self.assertEqual(len(res), len(self.tensors)) + for j in range(batch_size): + self.assertTrue( + verify_tensor_equality(self.tensors[j], res[j]), + f"Tensor {j} content mismatch" + ) + + self._print_perf("Standard Batch Put with tp (TP={tp_size})", put_times) + self._print_perf("Zero copy Batch Get with tp (TP={tp_size})", get_times) + + # Unregister buffer + self.assertEqual( + self.store.unregister_buffer(large_buffer_ptr), + 0, + "Buffer unregistration should succeed" + ) + # ========================================== # Stress/Concurrency Tests # ========================================== From 360e5c8f016739382c99a83b60b5ab92cc770849 Mon Sep 17 00:00:00 2001 From: Cruz Zhao Date: Tue, 16 Dec 2025 15:05:12 +0800 Subject: [PATCH 10/24] add testcases Signed-off-by: Cruz Zhao --- scripts/test_tensor_api.py | 128 +++++++++++++++++++++++++++++++++++++ 1 file changed, 128 insertions(+) diff --git a/scripts/test_tensor_api.py b/scripts/test_tensor_api.py index e8b3e9a6b..3e1a22ad3 100644 --- a/scripts/test_tensor_api.py +++ b/scripts/test_tensor_api.py @@ -300,6 +300,134 @@ def test_04_tp_consistency(self): self.assertTrue(tmp_tensor_0.sum() == chunked_tensors[0].sum()) self.assertTrue(tmp_tensor_1.sum() == chunked_tensors[1].sum()) + def test_05_put_get_into(self): + """Verify basic put and get into functionality.""" + key = "get_into_test" + tensor = torch.randn(4, 4, dtype=torch.float32) + buffer_spacing = 64 * 1024 * 1024 # 1GB per tensor slot + total_buffer_size = buffer_spacing + + # Allocate large contiguous buffer + large_buffer = (ctypes.c_ubyte * total_buffer_size)() + large_buffer_ptr = ctypes.addressof(large_buffer) + + res = self.store.register_buffer(large_buffer_ptr, total_buffer_size) + self.assertEqual(res, 0, "Buffer registration should succeed") + + # Perform Put + rc = self.store.put_tensor(key, tensor) + self.assertEqual(rc, 0, f"put_tensor failed with rc={rc}") + self.assertTrue(self.store.is_exist(key), "Key not found after put") + + # Perform Get + retrieved = self.store.get_tensor_into(key, large_buffer_ptr, total_buffer_size) + self.assertIsNotNone(retrieved, "Get returned None") + self.assertTrue(torch.equal(tensor, retrieved), f"Data mismatch between original and retrieved tensor, tensor: {tensor}, retrieved: {retrieved}") + # Unregister buffer + self.assertEqual( + self.store.unregister_buffer(large_buffer_ptr), + 0, + "Buffer unregistration should succeed" + ) + + def test_06_batch_put_get_into(self): + """Benchmark: Zero copy Batch Get.""" + num_tensors = 4 + keys, tensors = generate_tensors(num_tensors, 8) + buffer_spacing = 1024 * 1024 * 1024 # 1GB per tensor slot + batch_size = len(keys) + total_buffer_size = buffer_spacing * batch_size + + # Allocate large contiguous buffer + large_buffer = (ctypes.c_ubyte * total_buffer_size)() + large_buffer_ptr = ctypes.addressof(large_buffer) + + # Prepare pointers (only addresses, no ctypes array objects) + buffer_ptrs = [] + buffer_sizes = [] + for i in range(batch_size): + offset = i * buffer_spacing + ptr = large_buffer_ptr + offset + buffer_ptrs.append(ptr) + buffer_sizes.append(buffer_spacing) + + # Register the entire buffer with the store + res = self.store.register_buffer(large_buffer_ptr, total_buffer_size) + self.assertEqual(res, 0, "Buffer registration should succeed") + + results = self.store.batch_put_tensor(keys, tensors) + self.assertTrue(all(r == 0 for r in results), f"Batch put failed. Results: {results}") + + res = self.store.batch_get_tensor_into(keys, buffer_ptrs, buffer_sizes) + self.assertEqual(len(res), len(tensors)) + for j in range(batch_size): + self.assertTrue( + verify_tensor_equality(tensors[j], res[j]), + f"Tensor {j} content mismatch, tensor: {tensors[j]}, res: {res[j]}" + ) + # Unregister buffer + self.assertEqual( + self.store.unregister_buffer(large_buffer_ptr), + 0, + "Buffer unregistration should succeed" + ) + + def test_07_batch_put_get_into_with_tp(self): + """Benchmark: Zero copy Batch Get with tp.""" + tp_size = 4 + split_dim = 0 + num_tensors = 4 + keys, tensors = generate_tensors(num_tensors, 8) + buffer_spacing = 1024 * 1024 * 1024 # 1GB per tensor slot + batch_size = len(keys) + total_buffer_size = buffer_spacing * batch_size + + # Allocate large contiguous buffer + large_buffer = (ctypes.c_ubyte * total_buffer_size)() + large_buffer_ptr = ctypes.addressof(large_buffer) + + # Prepare pointers (only addresses, no ctypes array objects) + buffer_ptrs = [] + buffer_sizes = [] + for i in range(batch_size): + offset = i * buffer_spacing + ptr = large_buffer_ptr + offset + buffer_ptrs.append(ptr) + buffer_sizes.append(buffer_spacing) + + # Register the entire buffer with the store + res = self.store.register_buffer(large_buffer_ptr, total_buffer_size) + self.assertEqual(res, 0, "Buffer registration should succeed") + + results = self.store.batch_put_tensor_with_tp(keys, tensors, tp_size = tp_size, split_dim=split_dim) + self.assertTrue(all(r == 0 for r in results), f"Batch put failed. Results: {results}") + + all_shards = [] + for rank in range(tp_size): + shards = self.store.batch_get_tensor_into_with_tp(keys, buffer_ptrs, buffer_sizes, tp_rank=rank, tp_size=tp_size) + self.assertEqual(len(shards), num_tensors) + all_shards.append(shards) + + for i in range(num_tensors): + original = tensors[i] + expected_chunks = original.chunk(tp_size, split_dim) + reconstruction_parts = [] + + for rank in range(tp_size): + shard = all_shards[rank][i] + self.assertTrue(torch.equal(shard, expected_chunks[rank]), + f"Tensor {i} Rank {rank} data mismatch") + reconstruction_parts.append(shard) + + recon = torch.cat(reconstruction_parts, dim=split_dim) + self.assertTrue(torch.equal(recon, original), f"Tensor {i} final reconstruction mismatch") + # Unregister buffer + self.assertEqual( + self.store.unregister_buffer(large_buffer_ptr), + 0, + "Buffer unregistration should succeed" + ) + # ========================================== # Performance/Benchmark Tests # ========================================== From 6e373f26763e6540ce6d657369a5ec7ac860379e Mon Sep 17 00:00:00 2001 From: Cruz Zhao Date: Tue, 16 Dec 2025 15:40:26 +0800 Subject: [PATCH 11/24] fix test cases Signed-off-by: Cruz Zhao --- scripts/test_tensor_api.py | 278 ++++++++++++++++++++++++++----------- 1 file changed, 194 insertions(+), 84 deletions(-) diff --git a/scripts/test_tensor_api.py b/scripts/test_tensor_api.py index 3e1a22ad3..4afe0692c 100644 --- a/scripts/test_tensor_api.py +++ b/scripts/test_tensor_api.py @@ -22,8 +22,8 @@ GLOBAL_CONFIG = None DEFAULT_MOONCAKE_CONFIG_PATH_ENV = "MOONCAKE_CONFIG_PATH" -DEFAULT_GLOBAL_SEGMENT_SIZE = 4 * 1024 * 1024 * 1024 # 4 GiB -DEFAULT_LOCAL_BUFFER_SIZE = 2 * 1024 * 1024 * 1024 # 2 GB +DEFAULT_GLOBAL_SEGMENT_SIZE = 16 * 1024 * 1024 * 1024 # 4 GiB +DEFAULT_LOCAL_BUFFER_SIZE = 8 * 1024 * 1024 * 1024 # 2 GB DEFAULT_MASTER_METRICS_PORT = 9003 DEFAULT_CHECK_SERVER = False @@ -159,7 +159,7 @@ def generate_tensors(num_tensors, size_mb): num_elements = size_bytes // element_size dim = int(np.sqrt(num_elements)) dim = (dim // 8) * 8 # Adjust dimension to be divisible by common TP sizes (2, 4, 8) - + # Use random data and ensure the tensor is contiguous in memory tensors = [torch.randn(dim, dim, dtype=torch.float32).contiguous() for _ in range(num_tensors)] # Use timestamp to prevent key collision in rare edge cases (though we remove_all anyway) @@ -198,10 +198,10 @@ def setUp(self): # 1. Access the global connection if GLOBAL_STORE is None: self.skipTest("Store not initialized") - + self.store = GLOBAL_STORE self.config = GLOBAL_CONFIG - + # 2. [Critical] Clean environment before the test starts # This ensures no stale data from previous tests affects the current one self.store.remove_all() @@ -215,12 +215,12 @@ def test_01_basic_put_get(self): """Verify basic put and get functionality.""" key = "func_test_single" tensor = torch.randn(1024, 1024, dtype=torch.float32) - + # Perform Put rc = self.store.put_tensor(key, tensor) self.assertEqual(rc, 0, f"put_tensor failed with rc={rc}") self.assertTrue(self.store.is_exist(key), "Key not found after put") - + # Perform Get retrieved = self.store.get_tensor(key) self.assertIsNotNone(retrieved, "Get returned None") @@ -231,7 +231,7 @@ def test_02_tp_single_tensor(self): tp_size = 4 split_dim = 1 key = "func_test_tp_single" - + # Create a small tensor (e.g., 16MB) _, tensors = generate_tensors(1, 16) target_tensor = tensors[0] @@ -263,7 +263,7 @@ def test_03_tp_batch(self): split_dim = 0 num_tensors = 4 keys, tensors = generate_tensors(num_tensors, 8) # Small size for functional testing - + # 1. Batch Put with TP results = self.store.batch_put_tensor_with_tp(keys, tensors, tp_size=tp_size, split_dim=split_dim) self.assertTrue(all(r == 0 for r in results), f"Batch put failed. Results: {results}") @@ -280,13 +280,13 @@ def test_03_tp_batch(self): original = tensors[i] expected_chunks = original.chunk(tp_size, split_dim) reconstruction_parts = [] - + for rank in range(tp_size): shard = all_shards[rank][i] self.assertTrue(torch.equal(shard, expected_chunks[rank]), f"Tensor {i} Rank {rank} data mismatch") reconstruction_parts.append(shard) - + recon = torch.cat(reconstruction_parts, dim=split_dim) self.assertTrue(torch.equal(recon, original), f"Tensor {i} final reconstruction mismatch") @@ -303,14 +303,14 @@ def test_04_tp_consistency(self): def test_05_put_get_into(self): """Verify basic put and get into functionality.""" key = "get_into_test" - tensor = torch.randn(4, 4, dtype=torch.float32) + tensor = torch.randn(1024, 1024, dtype=torch.float32) buffer_spacing = 64 * 1024 * 1024 # 1GB per tensor slot total_buffer_size = buffer_spacing # Allocate large contiguous buffer large_buffer = (ctypes.c_ubyte * total_buffer_size)() large_buffer_ptr = ctypes.addressof(large_buffer) - + res = self.store.register_buffer(large_buffer_ptr, total_buffer_size) self.assertEqual(res, 0, "Buffer registration should succeed") @@ -318,7 +318,7 @@ def test_05_put_get_into(self): rc = self.store.put_tensor(key, tensor) self.assertEqual(rc, 0, f"put_tensor failed with rc={rc}") self.assertTrue(self.store.is_exist(key), "Key not found after put") - + # Perform Get retrieved = self.store.get_tensor_into(key, large_buffer_ptr, total_buffer_size) self.assertIsNotNone(retrieved, "Get returned None") @@ -331,7 +331,7 @@ def test_05_put_get_into(self): ) def test_06_batch_put_get_into(self): - """Benchmark: Zero copy Batch Get.""" + """Zero copy Batch Get.""" num_tensors = 4 keys, tensors = generate_tensors(num_tensors, 8) buffer_spacing = 1024 * 1024 * 1024 # 1GB per tensor slot @@ -372,61 +372,150 @@ def test_06_batch_put_get_into(self): "Buffer unregistration should succeed" ) - def test_07_batch_put_get_into_with_tp(self): - """Benchmark: Zero copy Batch Get with tp.""" + def test_07_put_get_into_with_tp(self): + """Zero copy Batch Get with TP — each rank has its own buffer.""" tp_size = 4 split_dim = 0 - num_tensors = 4 - keys, tensors = generate_tensors(num_tensors, 8) - buffer_spacing = 1024 * 1024 * 1024 # 1GB per tensor slot - batch_size = len(keys) - total_buffer_size = buffer_spacing * batch_size + key = "get_into_with_tp_test" + tensor = torch.randn(1024, 1024, dtype=torch.float32) - # Allocate large contiguous buffer - large_buffer = (ctypes.c_ubyte * total_buffer_size)() - large_buffer_ptr = ctypes.addressof(large_buffer) + # Step 1: Put full tensors with TP splitting + result = self.store.put_tensor_with_tp( + key, tensor, tp_size=tp_size, split_dim=split_dim + ) + self.assertEqual(result, 0, f"Put failed. Result: {result}") - # Prepare pointers (only addresses, no ctypes array objects) - buffer_ptrs = [] - buffer_sizes = [] - for i in range(batch_size): - offset = i * buffer_spacing - ptr = large_buffer_ptr + offset - buffer_ptrs.append(ptr) - buffer_sizes.append(buffer_spacing) + # Step 2: For each TP rank, allocate its own buffer and get shards + all_shards = [] + registered_buffers = [] # Keep track of (ptr, size) for cleanup - # Register the entire buffer with the store - res = self.store.register_buffer(large_buffer_ptr, total_buffer_size) - self.assertEqual(res, 0, "Buffer registration should succeed") + for rank in range(tp_size): + buffer_spacing = 1024 * 1024 * 1024 # 1GB per tensor slot + total_buffer_size = buffer_spacing + + # Allocate buffer for this rank + large_buffer = (ctypes.c_ubyte * total_buffer_size)() + large_buffer_ptr = ctypes.addressof(large_buffer) + + # Register buffer for this rank + res = self.store.register_buffer(large_buffer_ptr, total_buffer_size) + self.assertEqual(res, 0, f"Buffer registration failed for rank {rank}") + + # Keep buffer alive and record for cleanup + # We store (large_buffer, large_buffer_ptr, total_buffer_size) + # large_buffer must be kept alive until unregister! + registered_buffers.append((large_buffer, large_buffer_ptr, total_buffer_size)) + + # Get shard for this rank + shard = self.store.get_tensor_into_with_tp( + key, large_buffer_ptr, total_buffer_size, + tp_rank=rank, tp_size=tp_size + ) + all_shards.append(shard) + + # Step 3: Validate reconstruction + original = tensor + expected_chunks = original.chunk(tp_size, split_dim) + reconstruction_parts = [] - results = self.store.batch_put_tensor_with_tp(keys, tensors, tp_size = tp_size, split_dim=split_dim) + for rank in range(tp_size): + shard = all_shards[rank] + self.assertTrue( + torch.equal(shard, expected_chunks[rank]), + f"Tensor Rank {rank} data mismatch" + ) + reconstruction_parts.append(shard) + + recon = torch.cat(reconstruction_parts, dim=split_dim) + self.assertTrue( + torch.equal(recon, original), + f"Tensor final reconstruction mismatch" + ) + + # Step 4: Unregister all buffers + for large_buffer, ptr, size in registered_buffers: + res = self.store.unregister_buffer(ptr) + self.assertEqual(res, 0, f"Buffer unregistration failed for buffer at {ptr}") + # large_buffer will be GC'd after this; no need to explicitly delete + + def test_08_batch_put_get_into_with_tp(self): + """Zero copy Batch Get with TP — each rank has its own buffer.""" + tp_size = 4 + split_dim = 0 + num_tensors = 4 + keys, tensors = generate_tensors(num_tensors, 8) + + # Step 1: Put full tensors with TP splitting + results = self.store.batch_put_tensor_with_tp( + keys, tensors, tp_size=tp_size, split_dim=split_dim + ) self.assertTrue(all(r == 0 for r in results), f"Batch put failed. Results: {results}") + # Step 2: For each TP rank, allocate its own buffer and get shards all_shards = [] + registered_buffers = [] # Keep track of (ptr, size) for cleanup + for rank in range(tp_size): - shards = self.store.batch_get_tensor_into_with_tp(keys, buffer_ptrs, buffer_sizes, tp_rank=rank, tp_size=tp_size) + batch_size = len(keys) + buffer_spacing = 1024 * 1024 * 1024 # 1GB per tensor slot + total_buffer_size = buffer_spacing * batch_size + + # Allocate buffer for this rank + large_buffer = (ctypes.c_ubyte * total_buffer_size)() + large_buffer_ptr = ctypes.addressof(large_buffer) + + # Prepare buffer pointers for this rank + buffer_ptrs = [] + buffer_sizes = [] + for i in range(batch_size): + offset = i * buffer_spacing + ptr = large_buffer_ptr + offset + buffer_ptrs.append(ptr) + buffer_sizes.append(buffer_spacing) + + # Register buffer for this rank + res = self.store.register_buffer(large_buffer_ptr, total_buffer_size) + self.assertEqual(res, 0, f"Buffer registration failed for rank {rank}") + + # Keep buffer alive and record for cleanup + # We store (large_buffer, large_buffer_ptr, total_buffer_size) + # large_buffer must be kept alive until unregister! + registered_buffers.append((large_buffer, large_buffer_ptr, total_buffer_size)) + + # Get shards for this rank + shards = self.store.batch_get_tensor_into_with_tp( + keys, buffer_ptrs, buffer_sizes, + tp_rank=rank, tp_size=tp_size + ) self.assertEqual(len(shards), num_tensors) all_shards.append(shards) + # Step 3: Validate reconstruction for i in range(num_tensors): original = tensors[i] expected_chunks = original.chunk(tp_size, split_dim) reconstruction_parts = [] - + for rank in range(tp_size): shard = all_shards[rank][i] - self.assertTrue(torch.equal(shard, expected_chunks[rank]), - f"Tensor {i} Rank {rank} data mismatch") + self.assertTrue( + torch.equal(shard, expected_chunks[rank]), + f"Tensor {i} Rank {rank} data mismatch" + ) reconstruction_parts.append(shard) - + recon = torch.cat(reconstruction_parts, dim=split_dim) self.assertTrue(torch.equal(recon, original), f"Tensor {i} final reconstruction mismatch") - # Unregister buffer - self.assertEqual( - self.store.unregister_buffer(large_buffer_ptr), - 0, - "Buffer unregistration should succeed" - ) + self.assertTrue( + torch.equal(recon, original), + f"Tensor {i} final reconstruction mismatch" + ) + + # Step 4: Unregister all buffers + for large_buffer, ptr, size in registered_buffers: + res = self.store.unregister_buffer(ptr) + self.assertEqual(res, 0, f"Buffer unregistration failed for buffer at {ptr}") + # large_buffer will be GC'd after this; no need to explicitly delete # ========================================== # Performance/Benchmark Tests @@ -447,7 +536,7 @@ def setUp(self): total_bytes = self.TOTAL_SIZE_GB * 1024**3 tensor_bytes = self.TENSOR_SIZE_MB * 1024**2 self.num_tensors = max(1, total_bytes // tensor_bytes) - + print(f"\n[Gen] Generating {self.num_tensors} tensors (~{self.TENSOR_SIZE_MB}MB each)...") self.keys, self.tensors = generate_tensors(self.num_tensors, self.TENSOR_SIZE_MB) self.total_bits = (tensor_bytes * self.num_tensors) * 8 @@ -566,30 +655,40 @@ def test_benchmark_03_batch_put_get_into(self): def test_benchmark_04_batch_put_get_into_with_tp(self): """Benchmark: Zero copy Batch Get with tp.""" - buffer_spacing = 1024 * 1024 * 1024 # 1GB per tensor slot + tp_size = 4 + split_dim = 0 batch_size = len(self.keys) - total_buffer_size = buffer_spacing * batch_size - - # Allocate large contiguous buffer - large_buffer = (ctypes.c_ubyte * total_buffer_size)() - large_buffer_ptr = ctypes.addressof(large_buffer) + buffer_spacing = 1024 * 1024 * 1024 # 1GB per tensor slot - # Prepare pointers (only addresses, no ctypes array objects) - buffer_ptrs = [] - buffer_sizes = [] - for i in range(batch_size): - offset = i * buffer_spacing - ptr = large_buffer_ptr + offset - buffer_ptrs.append(ptr) - buffer_sizes.append(buffer_spacing) + # Allocate and register a separate buffer for each TP rank + rank_buffers = [] # Store metadata for cleanup - # Register the entire buffer with the store - res = self.store.register_buffer(large_buffer_ptr, total_buffer_size) - self.assertEqual(res, 0, "Buffer registration should succeed") + for rank in range(tp_size): + total_buffer_size = buffer_spacing * batch_size + large_buffer = (ctypes.c_ubyte * total_buffer_size)() + large_buffer_ptr = ctypes.addressof(large_buffer) + + buffer_ptrs = [] + buffer_sizes = [] + for i in range(batch_size): + offset = i * buffer_spacing + ptr = large_buffer_ptr + offset + buffer_ptrs.append(ptr) + buffer_sizes.append(buffer_spacing) + + # Register buffer for this rank + res = self.store.register_buffer(large_buffer_ptr, total_buffer_size) + self.assertEqual(res, 0, f"Buffer registration failed for rank {rank}") + + rank_buffers.append({ + 'buffer_obj': large_buffer, # keep alive + 'ptrs': buffer_ptrs, + 'sizes': buffer_sizes, + 'base_ptr': large_buffer_ptr, + 'total_size': total_buffer_size + }) - tp_size = 4 print(f"--- Running zero copy Batch Benchmark (TP={tp_size}, {self.BENCH_ITERATIONS} iters) ---") - split_dim = 0 put_times = [] get_times = [] @@ -599,31 +698,42 @@ def test_benchmark_04_batch_put_get_into_with_tp(self): # Measure Put t0 = time.perf_counter() - self.store.batch_put_tensor_with_tp(self.keys, self.tensors, tp_size=tp_size, split_dim=split_dim) + self.store.batch_put_tensor_with_tp( + self.keys, self.tensors, tp_size=tp_size, split_dim=split_dim + ) put_times.append(time.perf_counter() - t0) - # Measure Get + # Measure Get: each rank uses its own buffer t0 = time.perf_counter() for rank in range(tp_size): - res = self.store.batch_get_tensor_into_with_tp(self.keys, buffer_ptrs, buffer_sizes, tp_rank=rank, tp_size=tp_size) - self.assertEqual(len(res), len(self.tensors)) + res = self.store.batch_get_tensor_into_with_tp( + self.keys, + rank_buffers[rank]['ptrs'], + rank_buffers[rank]['sizes'], + tp_rank=rank, + tp_size=tp_size + ) + self.assertEqual(len(res), batch_size) + all_res.append(res) get_times.append(time.perf_counter() - t0) - self.assertEqual(len(res), len(self.tensors)) + + # Verify correctness using rank 0's result for j in range(batch_size): + original = self.tensors[j] + expected_shard = original.chunk(tp_size, split_dim)[0] # rank 0 shard + actual = all_res[0][j] self.assertTrue( - verify_tensor_equality(self.tensors[j], res[j]), - f"Tensor {j} content mismatch" + torch.equal(actual, expected_shard), + f"Tensor {j} content mismatch on rank 0" ) - self._print_perf("Standard Batch Put with tp (TP={tp_size})", put_times) - self._print_perf("Zero copy Batch Get with tp (TP={tp_size})", get_times) + self._print_perf(f"Standard Batch Put with tp (TP={tp_size})", put_times) + self._print_perf(f"Zero copy Batch Get with tp (TP={tp_size})", get_times) - # Unregister buffer - self.assertEqual( - self.store.unregister_buffer(large_buffer_ptr), - 0, - "Buffer unregistration should succeed" - ) + # Unregister all buffers + for buf_info in rank_buffers: + res = self.store.unregister_buffer(buf_info['base_ptr']) + self.assertEqual(res, 0, f"Buffer unregistration failed for buffer at {buf_info['base_ptr']}") # ========================================== # Stress/Concurrency Tests @@ -879,5 +989,5 @@ def test_fp8_types(self): runner = unittest.TextTestRunner(verbosity=2) result = runner.run(suite) - + sys.exit(not result.wasSuccessful()) From 41befb1f45a01d08f4b1e0e5d7466972645a4870 Mon Sep 17 00:00:00 2001 From: Cruz Zhao Date: Tue, 16 Dec 2025 16:25:51 +0800 Subject: [PATCH 12/24] fix testcase Signed-off-by: Cruz Zhao --- scripts/test_tensor_api.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/scripts/test_tensor_api.py b/scripts/test_tensor_api.py index 4afe0692c..b3de3160e 100644 --- a/scripts/test_tensor_api.py +++ b/scripts/test_tensor_api.py @@ -524,8 +524,8 @@ def test_08_batch_put_get_into_with_tp(self): class TestMooncakeBenchmark(MooncakeTestBase): # Benchmark Settings BENCH_ITERATIONS = 5 - TENSOR_SIZE_MB = 64 - TOTAL_SIZE_GB = 1 + TENSOR_SIZE_MB = 16 + TOTAL_SIZE_GB = 256 def setUp(self): """Benchmark-specific setUp.""" @@ -598,7 +598,7 @@ def test_benchmark_02_tp_batch(self): def test_benchmark_03_batch_put_get_into(self): """Benchmark: Zero copy Batch Get.""" - buffer_spacing = 1024 * 1024 * 1024 # 1GB per tensor slot + buffer_spacing = 512 * 1024 * 1024 # 1GB per tensor slot batch_size = len(self.keys) total_buffer_size = buffer_spacing * batch_size @@ -658,7 +658,7 @@ def test_benchmark_04_batch_put_get_into_with_tp(self): tp_size = 4 split_dim = 0 batch_size = len(self.keys) - buffer_spacing = 1024 * 1024 * 1024 # 1GB per tensor slot + buffer_spacing = 512 * 1024 * 1024 # 1GB per tensor slot # Allocate and register a separate buffer for each TP rank rank_buffers = [] # Store metadata for cleanup From ea678736ef5e36d930ea3531896874228a172634 Mon Sep 17 00:00:00 2001 From: Cruz Zhao Date: Tue, 16 Dec 2025 16:46:49 +0800 Subject: [PATCH 13/24] fix test case Signed-off-by: Cruz Zhao --- scripts/test_tensor_api.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/test_tensor_api.py b/scripts/test_tensor_api.py index b3de3160e..bb85a6268 100644 --- a/scripts/test_tensor_api.py +++ b/scripts/test_tensor_api.py @@ -525,7 +525,7 @@ class TestMooncakeBenchmark(MooncakeTestBase): # Benchmark Settings BENCH_ITERATIONS = 5 TENSOR_SIZE_MB = 16 - TOTAL_SIZE_GB = 256 + TOTAL_SIZE_MB = 256 def setUp(self): """Benchmark-specific setUp.""" @@ -533,7 +533,7 @@ def setUp(self): super().setUp() # 2. Generate test data - total_bytes = self.TOTAL_SIZE_GB * 1024**3 + total_bytes = int(self.TOTAL_SIZE_MB * 1024**2) tensor_bytes = self.TENSOR_SIZE_MB * 1024**2 self.num_tensors = max(1, total_bytes // tensor_bytes) From eb530ff2478c8e80c93f478dc41177ca9b1807f0 Mon Sep 17 00:00:00 2001 From: Cruz Zhao Date: Tue, 16 Dec 2025 17:53:22 +0800 Subject: [PATCH 14/24] fix test case Signed-off-by: Cruz Zhao --- scripts/test_tensor_api.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/scripts/test_tensor_api.py b/scripts/test_tensor_api.py index bb85a6268..56d1f2c47 100644 --- a/scripts/test_tensor_api.py +++ b/scripts/test_tensor_api.py @@ -304,7 +304,7 @@ def test_05_put_get_into(self): """Verify basic put and get into functionality.""" key = "get_into_test" tensor = torch.randn(1024, 1024, dtype=torch.float32) - buffer_spacing = 64 * 1024 * 1024 # 1GB per tensor slot + buffer_spacing = 64 * 1024 * 1024 total_buffer_size = buffer_spacing # Allocate large contiguous buffer @@ -334,7 +334,7 @@ def test_06_batch_put_get_into(self): """Zero copy Batch Get.""" num_tensors = 4 keys, tensors = generate_tensors(num_tensors, 8) - buffer_spacing = 1024 * 1024 * 1024 # 1GB per tensor slot + buffer_spacing = 64 * 1024 * 1024 # 1GB per tensor slot batch_size = len(keys) total_buffer_size = buffer_spacing * batch_size @@ -390,7 +390,7 @@ def test_07_put_get_into_with_tp(self): registered_buffers = [] # Keep track of (ptr, size) for cleanup for rank in range(tp_size): - buffer_spacing = 1024 * 1024 * 1024 # 1GB per tensor slot + buffer_spacing = 64 * 1024 * 1024 # 1GB per tensor slot total_buffer_size = buffer_spacing # Allocate buffer for this rank @@ -457,7 +457,7 @@ def test_08_batch_put_get_into_with_tp(self): for rank in range(tp_size): batch_size = len(keys) - buffer_spacing = 1024 * 1024 * 1024 # 1GB per tensor slot + buffer_spacing = 64 * 1024 * 1024 # 1GB per tensor slot total_buffer_size = buffer_spacing * batch_size # Allocate buffer for this rank @@ -598,7 +598,8 @@ def test_benchmark_02_tp_batch(self): def test_benchmark_03_batch_put_get_into(self): """Benchmark: Zero copy Batch Get.""" - buffer_spacing = 512 * 1024 * 1024 # 1GB per tensor slot + self.store.remove_all() + buffer_spacing = 300 * 1024 * 1024 # 1GB per tensor slot batch_size = len(self.keys) total_buffer_size = buffer_spacing * batch_size @@ -658,7 +659,8 @@ def test_benchmark_04_batch_put_get_into_with_tp(self): tp_size = 4 split_dim = 0 batch_size = len(self.keys) - buffer_spacing = 512 * 1024 * 1024 # 1GB per tensor slot + self.store.remove_all() + buffer_spacing = 300 * 1024 * 1024 # 1GB per tensor slot # Allocate and register a separate buffer for each TP rank rank_buffers = [] # Store metadata for cleanup From cfe61f0f42ddcf49fd54ead418cb5af498f1d36d Mon Sep 17 00:00:00 2001 From: Cruz Zhao Date: Wed, 17 Dec 2025 10:33:38 +0800 Subject: [PATCH 15/24] fix test case Signed-off-by: Cruz Zhao --- scripts/test_tensor_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/test_tensor_api.py b/scripts/test_tensor_api.py index 56d1f2c47..4feeab38a 100644 --- a/scripts/test_tensor_api.py +++ b/scripts/test_tensor_api.py @@ -660,7 +660,7 @@ def test_benchmark_04_batch_put_get_into_with_tp(self): split_dim = 0 batch_size = len(self.keys) self.store.remove_all() - buffer_spacing = 300 * 1024 * 1024 # 1GB per tensor slot + buffer_spacing = 64 * 1024 * 1024 # 1GB per tensor slot # Allocate and register a separate buffer for each TP rank rank_buffers = [] # Store metadata for cleanup From 1c8ad1be74a4165bf4ecfcd5bfd812b4a4159592 Mon Sep 17 00:00:00 2001 From: Cruz Zhao Date: Wed, 17 Dec 2025 15:44:37 +0800 Subject: [PATCH 16/24] add docs Signed-off-by: Cruz Zhao --- .../python-api-reference/mooncake-store.md | 85 +++++++++++++++++++ 1 file changed, 85 insertions(+) diff --git a/docs/source/python-api-reference/mooncake-store.md b/docs/source/python-api-reference/mooncake-store.md index 5ecc8bbc6..e8cac8cf7 100644 --- a/docs/source/python-api-reference/mooncake-store.md +++ b/docs/source/python-api-reference/mooncake-store.md @@ -1004,6 +1004,91 @@ def batch_get_tensor_with_tp(self, base_keys: List[str], tp_rank: int = 0, tp_si --- +### PyTorch Tensor Operations (Zero Copy) + +These methods provide direct support for storing and retrieving PyTorch tensors. They automatically handle serialization and metadata, and include built-in support for **Tensor Parallelism (TP)** by automatically splitting and reconstructing tensor shards. + +⚠️ **Note**: These methods require `torch` to be installed and available in the environment. + +#### get_tensor_into() + +Get a PyTorch tensor from the store directly into a pre-allocated buffer. + +```python +def get_tensor_with_tp(self, key: str, buffer_ptr: int, size: int) -> torch.Tensor +``` + +**Parameters:** + + - `key` (str): Base identifier of the tensor. + - `buffer_ptr` (int): The buffer pointer pre-allocated for tensor, and the buffer should be registered. + - `size` (int): The size of buffer. + +**Returns:** + + - `torch.Tensor`: The retrieved tensor (or shard). Returns `None` if not found. + +#### batch_get_tensor() + +Get a batch of PyTorch tensor from the store directly into a pre-allocated buffer. + +```python +def batch_get_tensor_with_tp(self, base_keys: List[str], buffer_ptrs: List[int], sizes: List[int]) -> List[torch.Tensor] +``` + +**Parameters:** + + - `base_keys` (List[str]): List of base identifiers. + - `buffer_ptrs` (List[int]): List of the buffers pointer pre-allocated for tensor, and the buffers should be registered. + - `sizes` (List[int]): List of the size of buffers. + +**Returns:** + + - `List[torch.Tensor]`: List of retrieved tensors (or shards). Contains `None` for missing keys. + +#### get_tensor_into_with_tp() + +Get a PyTorch tensor from the store, specifically retrieving the shard corresponding to the given Tensor Parallel rank, directly into the pre-allocated buffer. + +```python +def get_tensor_with_tp(self, key: str, buffer_ptr: int, size: int, tp_rank: int = 0, tp_size: int = 1, split_dim: int = 0) -> torch.Tensor +``` + +**Parameters:** + + - `key` (str): Base identifier of the tensor. + - `buffer_ptr` (int): The buffer pointer pre-allocated for tensor, and the buffer should be registered. + - `size` (int): The size of buffer. + - `tp_rank` (int): The tensor parallel rank to retrieve (default: 0). Fetches key `key_tp_{rank}` if `tp_size > 1`. + - `tp_size` (int): Total tensor parallel size (default: 1). + - `split_dim` (int): The dimension used during splitting (default: 0). + +**Returns:** + + - `torch.Tensor`: The retrieved tensor (or shard). Returns `None` if not found. + +#### batch_get_tensor_with_tp() + +Get a batch of PyTorch tensor shards from the store for a given Tensor Parallel rank, directly into the pre-allocated buffer. + +```python +def batch_get_tensor_with_tp(self, base_keys: List[str], buffer_ptrs: List[int], sizes: List[int], tp_rank: int = 0, tp_size: int = 1) -> List[torch.Tensor] +``` + +**Parameters:** + + - `base_keys` (List[str]): List of base identifiers. + - `buffer_ptrs` (List[int]): List of the buffers pointer pre-allocated for tensor, and the buffers should be registered. + - `sizes` (List[int]): List of the size of buffers. + - `tp_rank` (int): The tensor parallel rank to retrieve (default: 0). + - `tp_size` (int): Total tensor parallel size (default: 1). + +**Returns:** + + - `List[torch.Tensor]`: List of retrieved tensors (or shards). Contains `None` for missing keys. + +--- + ### Batch Zero-Copy Operations #### batch_put_from() From de71fe29141bcb06a3bd813ba240c68d8c3b059b Mon Sep 17 00:00:00 2001 From: Cruz Zhao Date: Wed, 17 Dec 2025 17:46:31 +0800 Subject: [PATCH 17/24] fix test cases Signed-off-by: Cruz Zhao --- scripts/test_tensor_api.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/scripts/test_tensor_api.py b/scripts/test_tensor_api.py index 4feeab38a..f077db056 100644 --- a/scripts/test_tensor_api.py +++ b/scripts/test_tensor_api.py @@ -6,7 +6,6 @@ import argparse import unittest import torch -import struct import numpy as np from dataclasses import dataclass from mooncake.store import MooncakeDistributedStore @@ -22,8 +21,8 @@ GLOBAL_CONFIG = None DEFAULT_MOONCAKE_CONFIG_PATH_ENV = "MOONCAKE_CONFIG_PATH" -DEFAULT_GLOBAL_SEGMENT_SIZE = 16 * 1024 * 1024 * 1024 # 4 GiB -DEFAULT_LOCAL_BUFFER_SIZE = 8 * 1024 * 1024 * 1024 # 2 GB +DEFAULT_GLOBAL_SEGMENT_SIZE = 16 * 1024 * 1024 * 1024 # 16 GiB +DEFAULT_LOCAL_BUFFER_SIZE = 8 * 1024 * 1024 * 1024 # 8 GB DEFAULT_MASTER_METRICS_PORT = 9003 DEFAULT_CHECK_SERVER = False @@ -45,7 +44,7 @@ def verify_tensor_equality(original, received, rtol=0, atol=0, verbose=True): """ - compare two tensors。 + compare two tensors. """ def to_numpy(x): if isinstance(x, torch.Tensor): @@ -334,7 +333,7 @@ def test_06_batch_put_get_into(self): """Zero copy Batch Get.""" num_tensors = 4 keys, tensors = generate_tensors(num_tensors, 8) - buffer_spacing = 64 * 1024 * 1024 # 1GB per tensor slot + buffer_spacing = 64 * 1024 * 1024 # 64MB per tensor slot batch_size = len(keys) total_buffer_size = buffer_spacing * batch_size @@ -390,7 +389,7 @@ def test_07_put_get_into_with_tp(self): registered_buffers = [] # Keep track of (ptr, size) for cleanup for rank in range(tp_size): - buffer_spacing = 64 * 1024 * 1024 # 1GB per tensor slot + buffer_spacing = 64 * 1024 * 1024 # 64MB per tensor slot total_buffer_size = buffer_spacing # Allocate buffer for this rank @@ -457,7 +456,7 @@ def test_08_batch_put_get_into_with_tp(self): for rank in range(tp_size): batch_size = len(keys) - buffer_spacing = 64 * 1024 * 1024 # 1GB per tensor slot + buffer_spacing = 64 * 1024 * 1024 # 64MB per tensor slot total_buffer_size = buffer_spacing * batch_size # Allocate buffer for this rank @@ -599,7 +598,7 @@ def test_benchmark_02_tp_batch(self): def test_benchmark_03_batch_put_get_into(self): """Benchmark: Zero copy Batch Get.""" self.store.remove_all() - buffer_spacing = 300 * 1024 * 1024 # 1GB per tensor slot + buffer_spacing = 300 * 1024 * 1024 # 300MB per tensor slot batch_size = len(self.keys) total_buffer_size = buffer_spacing * batch_size @@ -660,7 +659,7 @@ def test_benchmark_04_batch_put_get_into_with_tp(self): split_dim = 0 batch_size = len(self.keys) self.store.remove_all() - buffer_spacing = 64 * 1024 * 1024 # 1GB per tensor slot + buffer_spacing = 64 * 1024 * 1024 # 64MB per tensor slot # Allocate and register a separate buffer for each TP rank rank_buffers = [] # Store metadata for cleanup @@ -707,6 +706,7 @@ def test_benchmark_04_batch_put_get_into_with_tp(self): # Measure Get: each rank uses its own buffer t0 = time.perf_counter() + all_res = [] for rank in range(tp_size): res = self.store.batch_get_tensor_into_with_tp( self.keys, From 3ac809a0a3d8acf95102b62201b76857389bb59c Mon Sep 17 00:00:00 2001 From: Cruz Zhao Date: Wed, 17 Dec 2025 17:48:26 +0800 Subject: [PATCH 18/24] rename Signed-off-by: Cruz Zhao --- mooncake-integration/integration_utils.h | 44 ++++++++++++------------ mooncake-integration/store/store_py.cpp | 4 +-- 2 files changed, 24 insertions(+), 24 deletions(-) diff --git a/mooncake-integration/integration_utils.h b/mooncake-integration/integration_utils.h index c0c8bba24..8309787a6 100644 --- a/mooncake-integration/integration_utils.h +++ b/mooncake-integration/integration_utils.h @@ -63,32 +63,32 @@ static const std::array array_creators = {{ }}; template -py::array create_typed_array_without_free(char *data_ptr, size_t offset, - size_t total_length) { +py::array create_typed_array_view(char *data_ptr, size_t offset, + size_t total_length) { return py::array_t({static_cast(total_length / sizeof(T))}, (T *)(data_ptr + offset), py::none()); } -static const std::array array_creators_without_free = {{ - create_typed_array_without_free, // FLOAT32 = 0 - create_typed_array_without_free, // FLOAT64 = 1 - create_typed_array_without_free, // INT8 = 2 - create_typed_array_without_free, // UINT8 = 3 - create_typed_array_without_free, // INT16 = 4 - create_typed_array_without_free, // UINT16 = 5 - create_typed_array_without_free, // INT32 = 6 - create_typed_array_without_free, // UINT32 = 7 - create_typed_array_without_free, // INT64 = 8 - create_typed_array_without_free, // UINT64 = 9 - create_typed_array_without_free, // BOOL = 10 - create_typed_array_without_free, // FLOAT16 = 11 (using uint16_t - // as storage) - create_typed_array_without_free, // BFLOAT16 = 12 (using uint16_t - // as storage) - create_typed_array_without_free, // FLOAT8_E4M3 = 13 (using - // uint8_t as storage) - create_typed_array_without_free, // FLOAT8_E5M2 = 14 (using - // uint8_t as storage) +static const std::array array_creators_view = {{ + create_typed_array_view, // FLOAT32 = 0 + create_typed_array_view, // FLOAT64 = 1 + create_typed_array_view, // INT8 = 2 + create_typed_array_view, // UINT8 = 3 + create_typed_array_view, // INT16 = 4 + create_typed_array_view, // UINT16 = 5 + create_typed_array_view, // INT32 = 6 + create_typed_array_view, // UINT32 = 7 + create_typed_array_view, // INT64 = 8 + create_typed_array_view, // UINT64 = 9 + create_typed_array_view, // BOOL = 10 + create_typed_array_view, // FLOAT16 = 11 (using uint16_t as + // storage) + create_typed_array_view, // BFLOAT16 = 12 (using uint16_t as + // storage) + create_typed_array_view, // FLOAT8_E4M3 = 13 (using uint8_t as + // storage) + create_typed_array_view, // FLOAT8_E5M2 = 14 (using uint8_t as + // storage) }}; inline TensorDtype get_tensor_dtype(py::object dtype_obj) { diff --git a/mooncake-integration/store/store_py.cpp b/mooncake-integration/store/store_py.cpp index 9eb885393..ba121a01b 100644 --- a/mooncake-integration/store/store_py.cpp +++ b/mooncake-integration/store/store_py.cpp @@ -373,7 +373,7 @@ class MooncakeStorePyWrapper { int dtype_index = static_cast(dtype_enum); if (dtype_index >= 0 && dtype_index < static_cast(array_creators.size())) { - np_array = array_creators_without_free[dtype_index]( + np_array = array_creators_view[dtype_index]( static_cast(buffer), sizeof(TensorMetadata), tensor_size); } else { @@ -493,7 +493,7 @@ class MooncakeStorePyWrapper { if (dtype_index >= 0 && dtype_index < static_cast(array_creators.size())) { // This call MUST take ownership of exported_data - np_array = array_creators_without_free[dtype_index]( + np_array = array_creators_view[dtype_index]( static_cast(buffer), sizeof(TensorMetadata), tensor_size); } else { From 68db9164cb1b20fc198800271fa52a1aeae88fd3 Mon Sep 17 00:00:00 2001 From: Cruz Zhao Date: Wed, 17 Dec 2025 22:40:55 +0800 Subject: [PATCH 19/24] add testcases Signed-off-by: Cruz Zhao --- scripts/test_tensor_api.py | 47 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/scripts/test_tensor_api.py b/scripts/test_tensor_api.py index f077db056..66e8d8813 100644 --- a/scripts/test_tensor_api.py +++ b/scripts/test_tensor_api.py @@ -299,6 +299,20 @@ def test_04_tp_consistency(self): self.assertTrue(tmp_tensor_0.sum() == chunked_tensors[0].sum()) self.assertTrue(tmp_tensor_1.sum() == chunked_tensors[1].sum()) + buffer_spacing = 1 * 1024 * 1024 + buffer_2 = (ctypes.c_ubyte * buffer_spacing)() + buffer_3 = (ctypes.c_ubyte * buffer_spacing)() + buffer_ptr_2 = ctypes.addressof(buffer_2) + buffer_ptr_3 = ctypes.addressof(buffer_3) + self.store.register_buffer(buffer_ptr_2, buffer_spacing) + self.store.register_buffer(buffer_ptr_3, buffer_spacing) + tmp_tensor_2 = self.store.batch_get_tensor_into_with_tp(['key'], [buffer_ptr_2], [buffer_spacing], tp_rank=0, tp_size=tp_size)[0] + tmp_tensor_3 = self.store.batch_get_tensor_into_with_tp(['key'], [buffer_ptr_3], [buffer_spacing], tp_rank=1, tp_size=tp_size)[0] + self.assertTrue(tmp_tensor_2.sum() == chunked_tensors[0].sum()) + self.assertTrue(tmp_tensor_3.sum() == chunked_tensors[1].sum()) + self.store.unregister_buffer(buffer_ptr_2) + self.store.unregister_buffer(buffer_ptr_3) + def test_05_put_get_into(self): """Verify basic put and get into functionality.""" key = "get_into_test" @@ -915,6 +929,39 @@ def _test_dtype_roundtrip(self, dtype, name, expected_enum_name=None): print(f" [Pass] {name:<15} {str(dtype)}") + buffer_spacing = 1 * 1024 * 1024 + buffer = (ctypes.c_ubyte * buffer_spacing)() + buffer_ptr = ctypes.addressof(buffer) + self.store.register_buffer(buffer_ptr, buffer_spacing) + retrieved = self.store.get_tensor_into(key, buffer_ptr, buffer_spacing) + if retrieved is None: + print(f" [Fail] {name:<15} Get returned None") + self.fail(f"Get returned None for {name}") + + # We expect the retrieved tensor to have the same dtype as input + if original.dtype != retrieved.dtype: + msg = f"Dtype mismatch for {name}! Input: {original.dtype}, Output: {retrieved.dtype}" + print(f" [Fail] {name:<15} {msg}") + self.fail(msg) + + # Use byte-view comparison for robustness (especially for FP8/BF16 on CPU) + try: + # Cast to untyped storage byte view (or uint8 view) + t1_bytes = original.view(torch.uint8) if original.element_size() > 0 else original + t2_bytes = retrieved.view(torch.uint8) if retrieved.element_size() > 0 else retrieved + is_equal = torch.equal(t1_bytes, t2_bytes) + except Exception: + # Fallback for types that might fail view() or equal() + is_equal = torch.equal(original.cpu(), retrieved.cpu()) + + if not is_equal: + print(f" [Fail] {name:<15} Data content mismatch") + self.fail(f"Data content mismatch for {name}") + self.store.unregister_buffer(buffer_ptr) + + print(f" [Pass] {name:<15} {str(dtype)}") + + def test_all_dtypes(self): print("\n--- Testing All Supported PyTorch Data Types ---") From 3cfa936e51ec2e4ba8e3d038cca72cccbe4b48c4 Mon Sep 17 00:00:00 2001 From: Cruz Zhao Date: Wed, 17 Dec 2025 22:41:13 +0800 Subject: [PATCH 20/24] fix data type Signed-off-by: Cruz Zhao --- mooncake-integration/store/store_py.cpp | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/mooncake-integration/store/store_py.cpp b/mooncake-integration/store/store_py.cpp index ba121a01b..5e3318478 100644 --- a/mooncake-integration/store/store_py.cpp +++ b/mooncake-integration/store/store_py.cpp @@ -391,6 +391,16 @@ class MooncakeStorePyWrapper { } pybind11::object tensor = torch_module().attr("from_numpy")(np_array); + // Handle BFloat16/Float16 view checks + if (dtype_enum == TensorDtype::BFLOAT16) { + tensor = tensor.attr("view")(torch_module().attr("bfloat16")); + } else if (dtype_enum == TensorDtype::FLOAT16) { + tensor = tensor.attr("view")(torch_module().attr("float16")); + } else if (dtype_enum == TensorDtype::FLOAT8_E4M3) { + tensor = tensor.attr("view")(torch_module().attr("float8_e4m3fn")); + } else if (dtype_enum == TensorDtype::FLOAT8_E5M2) { + tensor = tensor.attr("view")(torch_module().attr("float8_e5m2")); + } return tensor; } catch (const pybind11::error_already_set &e) { @@ -511,8 +521,18 @@ class MooncakeStorePyWrapper { np_array = np_array.attr("reshape")(shape_tuple); } pybind11::object tensor = torch.attr("from_numpy")(np_array); - results_list.append(tensor); - } + // Handle BFloat16/Float16 view checks + if (dtype_enum == TensorDtype::BFLOAT16) { + tensor = tensor.attr("view")(torch_module().attr("bfloat16")); + } else if (dtype_enum == TensorDtype::FLOAT16) { + tensor = tensor.attr("view")(torch_module().attr("float16")); + } else if (dtype_enum == TensorDtype::FLOAT8_E4M3) { + tensor = tensor.attr("view")(torch_module().attr("float8_e4m3fn")); + } else if (dtype_enum == TensorDtype::FLOAT8_E5M2) { + tensor = tensor.attr("view")(torch_module().attr("float8_e5m2")); + } + results_list.append(tensor); + } } catch (const pybind11::error_already_set &e) { LOG(ERROR) << "Failed during batch tensor deserialization: " << e.what(); From 6b80b412127e929dec405f27d03853ae5fc03cc2 Mon Sep 17 00:00:00 2001 From: Cruz Zhao Date: Wed, 17 Dec 2025 22:42:22 +0800 Subject: [PATCH 21/24] test case Signed-off-by: Cruz Zhao --- scripts/test_tensor_api.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/scripts/test_tensor_api.py b/scripts/test_tensor_api.py index 66e8d8813..867c60740 100644 --- a/scripts/test_tensor_api.py +++ b/scripts/test_tensor_api.py @@ -927,8 +927,6 @@ def _test_dtype_roundtrip(self, dtype, name, expected_enum_name=None): print(f" [Fail] {name:<15} Data content mismatch") self.fail(f"Data content mismatch for {name}") - print(f" [Pass] {name:<15} {str(dtype)}") - buffer_spacing = 1 * 1024 * 1024 buffer = (ctypes.c_ubyte * buffer_spacing)() buffer_ptr = ctypes.addressof(buffer) From e21faca01f7ba881bdc556a50522d55781b460b8 Mon Sep 17 00:00:00 2001 From: Cruz Zhao Date: Thu, 18 Dec 2025 10:06:47 +0800 Subject: [PATCH 22/24] fix format Signed-off-by: Cruz Zhao --- mooncake-integration/store/store_py.cpp | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/mooncake-integration/store/store_py.cpp b/mooncake-integration/store/store_py.cpp index 5e3318478..74686b604 100644 --- a/mooncake-integration/store/store_py.cpp +++ b/mooncake-integration/store/store_py.cpp @@ -397,9 +397,11 @@ class MooncakeStorePyWrapper { } else if (dtype_enum == TensorDtype::FLOAT16) { tensor = tensor.attr("view")(torch_module().attr("float16")); } else if (dtype_enum == TensorDtype::FLOAT8_E4M3) { - tensor = tensor.attr("view")(torch_module().attr("float8_e4m3fn")); + tensor = + tensor.attr("view")(torch_module().attr("float8_e4m3fn")); } else if (dtype_enum == TensorDtype::FLOAT8_E5M2) { - tensor = tensor.attr("view")(torch_module().attr("float8_e5m2")); + tensor = + tensor.attr("view")(torch_module().attr("float8_e5m2")); } return tensor; @@ -523,16 +525,20 @@ class MooncakeStorePyWrapper { pybind11::object tensor = torch.attr("from_numpy")(np_array); // Handle BFloat16/Float16 view checks if (dtype_enum == TensorDtype::BFLOAT16) { - tensor = tensor.attr("view")(torch_module().attr("bfloat16")); + tensor = + tensor.attr("view")(torch_module().attr("bfloat16")); } else if (dtype_enum == TensorDtype::FLOAT16) { - tensor = tensor.attr("view")(torch_module().attr("float16")); + tensor = + tensor.attr("view")(torch_module().attr("float16")); } else if (dtype_enum == TensorDtype::FLOAT8_E4M3) { - tensor = tensor.attr("view")(torch_module().attr("float8_e4m3fn")); + tensor = tensor.attr("view")( + torch_module().attr("float8_e4m3fn")); } else if (dtype_enum == TensorDtype::FLOAT8_E5M2) { - tensor = tensor.attr("view")(torch_module().attr("float8_e5m2")); - } - results_list.append(tensor); + tensor = + tensor.attr("view")(torch_module().attr("float8_e5m2")); } + results_list.append(tensor); + } } catch (const pybind11::error_already_set &e) { LOG(ERROR) << "Failed during batch tensor deserialization: " << e.what(); From 3e63ac1b3587819cd4379d02ee1ac843fa3888ed Mon Sep 17 00:00:00 2001 From: Xuchun Shang Date: Thu, 18 Dec 2025 11:29:46 +0800 Subject: [PATCH 23/24] Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- scripts/test_tensor_api.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/scripts/test_tensor_api.py b/scripts/test_tensor_api.py index 867c60740..1f410c2bc 100644 --- a/scripts/test_tensor_api.py +++ b/scripts/test_tensor_api.py @@ -75,8 +75,7 @@ def to_numpy(x): return False if np.array_equal(orig_np, recv_np): -# if verbose: -# print("✅ Tensors are identical!") + return True else: diff_mask = orig_np != recv_np From 40f1bc6b0dc5b3ee7c69b659c93b56e98e142637 Mon Sep 17 00:00:00 2001 From: Xuchun Shang Date: Thu, 18 Dec 2025 11:38:14 +0800 Subject: [PATCH 24/24] Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- mooncake-integration/store/store_py.cpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/mooncake-integration/store/store_py.cpp b/mooncake-integration/store/store_py.cpp index 74686b604..9bbc12255 100644 --- a/mooncake-integration/store/store_py.cpp +++ b/mooncake-integration/store/store_py.cpp @@ -1224,9 +1224,7 @@ PYBIND11_MODULE(store, m) { .def("batch_get_tensor_into", &MooncakeStorePyWrapper::batch_get_tensor_into, py::arg("keys"), py::arg("buffer_ptrs"), py::arg("sizes"), - "Get tensors directly into pre-allocated buffers for " - "multiple " - "keys") + "Get tensors directly into pre-allocated buffers for multiple keys") .def( "get_tensor_into_with_tp", &MooncakeStorePyWrapper::get_tensor_into_with_tp, py::arg("key"),