diff --git a/docs/source/python-api-reference/mooncake-store.md b/docs/source/python-api-reference/mooncake-store.md index 5ecc8bbc6..e8cac8cf7 100644 --- a/docs/source/python-api-reference/mooncake-store.md +++ b/docs/source/python-api-reference/mooncake-store.md @@ -1004,6 +1004,91 @@ def batch_get_tensor_with_tp(self, base_keys: List[str], tp_rank: int = 0, tp_si --- +### PyTorch Tensor Operations (Zero Copy) + +These methods provide direct support for storing and retrieving PyTorch tensors. They automatically handle serialization and metadata, and include built-in support for **Tensor Parallelism (TP)** by automatically splitting and reconstructing tensor shards. + +⚠️ **Note**: These methods require `torch` to be installed and available in the environment. + +#### get_tensor_into() + +Get a PyTorch tensor from the store directly into a pre-allocated buffer. + +```python +def get_tensor_with_tp(self, key: str, buffer_ptr: int, size: int) -> torch.Tensor +``` + +**Parameters:** + + - `key` (str): Base identifier of the tensor. + - `buffer_ptr` (int): The buffer pointer pre-allocated for tensor, and the buffer should be registered. + - `size` (int): The size of buffer. + +**Returns:** + + - `torch.Tensor`: The retrieved tensor (or shard). Returns `None` if not found. + +#### batch_get_tensor() + +Get a batch of PyTorch tensor from the store directly into a pre-allocated buffer. + +```python +def batch_get_tensor_with_tp(self, base_keys: List[str], buffer_ptrs: List[int], sizes: List[int]) -> List[torch.Tensor] +``` + +**Parameters:** + + - `base_keys` (List[str]): List of base identifiers. + - `buffer_ptrs` (List[int]): List of the buffers pointer pre-allocated for tensor, and the buffers should be registered. + - `sizes` (List[int]): List of the size of buffers. + +**Returns:** + + - `List[torch.Tensor]`: List of retrieved tensors (or shards). Contains `None` for missing keys. + +#### get_tensor_into_with_tp() + +Get a PyTorch tensor from the store, specifically retrieving the shard corresponding to the given Tensor Parallel rank, directly into the pre-allocated buffer. + +```python +def get_tensor_with_tp(self, key: str, buffer_ptr: int, size: int, tp_rank: int = 0, tp_size: int = 1, split_dim: int = 0) -> torch.Tensor +``` + +**Parameters:** + + - `key` (str): Base identifier of the tensor. + - `buffer_ptr` (int): The buffer pointer pre-allocated for tensor, and the buffer should be registered. + - `size` (int): The size of buffer. + - `tp_rank` (int): The tensor parallel rank to retrieve (default: 0). Fetches key `key_tp_{rank}` if `tp_size > 1`. + - `tp_size` (int): Total tensor parallel size (default: 1). + - `split_dim` (int): The dimension used during splitting (default: 0). + +**Returns:** + + - `torch.Tensor`: The retrieved tensor (or shard). Returns `None` if not found. + +#### batch_get_tensor_with_tp() + +Get a batch of PyTorch tensor shards from the store for a given Tensor Parallel rank, directly into the pre-allocated buffer. + +```python +def batch_get_tensor_with_tp(self, base_keys: List[str], buffer_ptrs: List[int], sizes: List[int], tp_rank: int = 0, tp_size: int = 1) -> List[torch.Tensor] +``` + +**Parameters:** + + - `base_keys` (List[str]): List of base identifiers. + - `buffer_ptrs` (List[int]): List of the buffers pointer pre-allocated for tensor, and the buffers should be registered. + - `sizes` (List[int]): List of the size of buffers. + - `tp_rank` (int): The tensor parallel rank to retrieve (default: 0). + - `tp_size` (int): Total tensor parallel size (default: 1). + +**Returns:** + + - `List[torch.Tensor]`: List of retrieved tensors (or shards). Contains `None` for missing keys. + +--- + ### Batch Zero-Copy Operations #### batch_put_from() diff --git a/mooncake-integration/integration_utils.h b/mooncake-integration/integration_utils.h index dfea8638a..8309787a6 100644 --- a/mooncake-integration/integration_utils.h +++ b/mooncake-integration/integration_utils.h @@ -62,6 +62,35 @@ static const std::array array_creators = {{ create_typed_array, // FLOAT8_E5M2 = 14 (using uint8_t as storage) }}; +template +py::array create_typed_array_view(char *data_ptr, size_t offset, + size_t total_length) { + return py::array_t({static_cast(total_length / sizeof(T))}, + (T *)(data_ptr + offset), py::none()); +} + +static const std::array array_creators_view = {{ + create_typed_array_view, // FLOAT32 = 0 + create_typed_array_view, // FLOAT64 = 1 + create_typed_array_view, // INT8 = 2 + create_typed_array_view, // UINT8 = 3 + create_typed_array_view, // INT16 = 4 + create_typed_array_view, // UINT16 = 5 + create_typed_array_view, // INT32 = 6 + create_typed_array_view, // UINT32 = 7 + create_typed_array_view, // INT64 = 8 + create_typed_array_view, // UINT64 = 9 + create_typed_array_view, // BOOL = 10 + create_typed_array_view, // FLOAT16 = 11 (using uint16_t as + // storage) + create_typed_array_view, // BFLOAT16 = 12 (using uint16_t as + // storage) + create_typed_array_view, // FLOAT8_E4M3 = 13 (using uint8_t as + // storage) + create_typed_array_view, // FLOAT8_E5M2 = 14 (using uint8_t as + // storage) +}}; + inline TensorDtype get_tensor_dtype(py::object dtype_obj) { if (dtype_obj.is_none()) { return TensorDtype::UNKNOWN; diff --git a/mooncake-integration/store/store_py.cpp b/mooncake-integration/store/store_py.cpp index 4b0aaf28b..9bbc12255 100644 --- a/mooncake-integration/store/store_py.cpp +++ b/mooncake-integration/store/store_py.cpp @@ -320,6 +320,297 @@ class MooncakeStorePyWrapper { return results_list; } + pybind11::object get_tensor_into(const std::string &key, + uintptr_t buffer_ptr, size_t size) { + void *buffer = reinterpret_cast(buffer_ptr); + if (!is_client_initialized()) { + LOG(ERROR) << "Client is not initialized"; + return pybind11::none(); + } + + if (use_dummy_client_) { + LOG(ERROR) << "get_tensor is not supported for dummy client now"; + return pybind11::none(); + } + + try { + // Section with GIL released + py::gil_scoped_release release_gil; + auto total_length = store_->get_into(key, buffer, size); + if (total_length <= 0) { + py::gil_scoped_acquire acquire_gil; + return pybind11::none(); + } + + TensorMetadata metadata; + // Copy data from buffer to contiguous memory + memcpy(&metadata, static_cast(buffer), + sizeof(TensorMetadata)); + + if (metadata.ndim < 0 || metadata.ndim > 4) { + py::gil_scoped_acquire acquire_gil; + LOG(ERROR) << "Invalid tensor metadata: ndim=" << metadata.ndim; + return pybind11::none(); + } + + TensorDtype dtype_enum = static_cast(metadata.dtype); + if (dtype_enum == TensorDtype::UNKNOWN) { + py::gil_scoped_acquire acquire_gil; + LOG(ERROR) << "Unknown tensor dtype!"; + return pybind11::none(); + } + + size_t tensor_size = total_length - sizeof(TensorMetadata); + if (tensor_size == 0) { + py::gil_scoped_acquire acquire_gil; + LOG(ERROR) << "Invalid data format: no tensor data found"; + return pybind11::none(); + } + + py::gil_scoped_acquire acquire_gil; + // Convert bytes to tensor using torch.from_numpy + pybind11::object np_array; + int dtype_index = static_cast(dtype_enum); + if (dtype_index >= 0 && + dtype_index < static_cast(array_creators.size())) { + np_array = array_creators_view[dtype_index]( + static_cast(buffer), sizeof(TensorMetadata), + tensor_size); + } else { + LOG(ERROR) << "Unsupported dtype enum: " << dtype_index; + return pybind11::none(); + } + + if (metadata.ndim > 0) { + std::vector shape_vec; + for (int i = 0; i < metadata.ndim; i++) { + shape_vec.push_back(metadata.shape[i]); + } + py::tuple shape_tuple = py::cast(shape_vec); + np_array = np_array.attr("reshape")(shape_tuple); + } + pybind11::object tensor = + torch_module().attr("from_numpy")(np_array); + // Handle BFloat16/Float16 view checks + if (dtype_enum == TensorDtype::BFLOAT16) { + tensor = tensor.attr("view")(torch_module().attr("bfloat16")); + } else if (dtype_enum == TensorDtype::FLOAT16) { + tensor = tensor.attr("view")(torch_module().attr("float16")); + } else if (dtype_enum == TensorDtype::FLOAT8_E4M3) { + tensor = + tensor.attr("view")(torch_module().attr("float8_e4m3fn")); + } else if (dtype_enum == TensorDtype::FLOAT8_E5M2) { + tensor = + tensor.attr("view")(torch_module().attr("float8_e5m2")); + } + return tensor; + + } catch (const pybind11::error_already_set &e) { + LOG(ERROR) << "Failed to get tensor data: " << e.what(); + return pybind11::none(); + } + } + + pybind11::list batch_get_tensor_into( + const std::vector &keys, + const std::vector &buffer_ptrs, + const std::vector &sizes) { + std::vector buffers; + buffers.reserve(buffer_ptrs.size()); + for (uintptr_t ptr : buffer_ptrs) { + buffers.push_back(reinterpret_cast(ptr)); + } + + if (!is_client_initialized()) { + LOG(ERROR) << "Client is not initialized"; + py::list empty_list; + for (size_t i = 0; i < keys.size(); ++i) { + empty_list.append(to_py_ret(ErrorCode::INVALID_PARAMS)); + } + return empty_list; + } + + if (use_dummy_client_) { + LOG(ERROR) << "batch_get_tensor is not supported for dummy client " + "now"; + py::list empty_list; + for (size_t i = 0; i < keys.size(); ++i) { + empty_list.append(to_py_ret(ErrorCode::INVALID_PARAMS)); + } + return empty_list; + } + + // Phase 1: Batch Get Buffers (GIL Released) + std::vector total_lengths; + { + py::gil_scoped_release release_gil; + // This internal call already handles logging for query failures + total_lengths = store_->batch_get_into(keys, buffers, sizes); + } + + py::list results_list; + try { + py::gil_scoped_acquire acquire_gil; + auto torch = torch_module(); + + for (size_t i = 0; i < total_lengths.size(); i++) { + const auto &buffer = buffers[i]; + if (!buffer) { + results_list.append(to_py_ret(ErrorCode::INVALID_PARAMS)); + continue; + } + + auto total_length = total_lengths[i]; + if (total_length <= 0) { + LOG(ERROR) << "Invalid data format: insufficient data for" + "metadata"; + results_list.append(to_py_ret(ErrorCode::INVALID_PARAMS)); + continue; + } + if (total_length <= static_cast(sizeof(TensorMetadata))) { + LOG(ERROR) << "Invalid data format: insufficient data for " + "metadata"; + results_list.append(to_py_ret(ErrorCode::INVALID_PARAMS)); + continue; + } + + TensorMetadata metadata; + memcpy(&metadata, static_cast(buffer), + sizeof(TensorMetadata)); + + if (metadata.ndim < 0 || metadata.ndim > 4) { + LOG(ERROR) + << "Invalid tensor metadata: ndim=" << metadata.ndim; + results_list.append(to_py_ret(ErrorCode::INVALID_PARAMS)); + continue; + } + + TensorDtype dtype_enum = + static_cast(metadata.dtype); + if (dtype_enum == TensorDtype::UNKNOWN) { + LOG(ERROR) << "Unknown tensor dtype!"; + results_list.append(to_py_ret(ErrorCode::INVALID_PARAMS)); + continue; + } + + size_t tensor_size = total_length - sizeof(TensorMetadata); + if (tensor_size == 0) { + LOG(ERROR) << "Invalid data format: no tensor data found"; + results_list.append(to_py_ret(ErrorCode::INVALID_PARAMS)); + continue; + } + + pybind11::object np_array; + int dtype_index = static_cast(dtype_enum); + if (dtype_index >= 0 && + dtype_index < static_cast(array_creators.size())) { + // This call MUST take ownership of exported_data + np_array = array_creators_view[dtype_index]( + static_cast(buffer), sizeof(TensorMetadata), + tensor_size); + } else { + LOG(ERROR) << "Unsupported dtype enum: " << dtype_index; + results_list.append(to_py_ret(ErrorCode::INVALID_PARAMS)); + continue; + } + + if (metadata.ndim > 0) { + std::vector shape_vec; + for (int i = 0; i < metadata.ndim; i++) { + shape_vec.push_back(metadata.shape[i]); + } + py::tuple shape_tuple = py::cast(shape_vec); + np_array = np_array.attr("reshape")(shape_tuple); + } + pybind11::object tensor = torch.attr("from_numpy")(np_array); + // Handle BFloat16/Float16 view checks + if (dtype_enum == TensorDtype::BFLOAT16) { + tensor = + tensor.attr("view")(torch_module().attr("bfloat16")); + } else if (dtype_enum == TensorDtype::FLOAT16) { + tensor = + tensor.attr("view")(torch_module().attr("float16")); + } else if (dtype_enum == TensorDtype::FLOAT8_E4M3) { + tensor = tensor.attr("view")( + torch_module().attr("float8_e4m3fn")); + } else if (dtype_enum == TensorDtype::FLOAT8_E5M2) { + tensor = + tensor.attr("view")(torch_module().attr("float8_e5m2")); + } + results_list.append(tensor); + } + } catch (const pybind11::error_already_set &e) { + LOG(ERROR) << "Failed during batch tensor deserialization: " + << e.what(); + } + return results_list; + } + + pybind11::object get_tensor_into_with_tp(const std::string &key, + uintptr_t buffer_ptr, size_t size, + int tp_rank = 0, int tp_size = 1, + int split_dim = 0) { + if (!is_client_initialized()) { + LOG(ERROR) << "Client is not initialized"; + return pybind11::none(); + } + + if (use_dummy_client_) { + LOG(ERROR) + << "get_tensor_into is not supported for dummy client now"; + return pybind11::none(); + } + + if (tp_size <= 1) { + return get_tensor_into(key, buffer_ptr, size); + } + + // Construct the specific key for this rank: e.g., "key_tp_0" + std::string tp_key = get_tp_key_name(key, tp_rank); + + // Delegate to the standard get_tensor_into method + return get_tensor_into(tp_key, buffer_ptr, size); + } + + pybind11::list batch_get_tensor_into_with_tp( + const std::vector &base_keys, + const std::vector &buffer_ptrs, + const std::vector &sizes, int tp_rank = 0, int tp_size = 1) { + if (!is_client_initialized()) { + LOG(ERROR) << "Client is not initialized"; + py::list empty_list; + for (size_t i = 0; i < base_keys.size(); ++i) { + empty_list.append(py::none()); + } + return empty_list; + } + + if (use_dummy_client_) { + LOG(ERROR) << "batch_get_tensor_into_with_tp is not supported for " + "dummy client"; + py::list empty_list; + for (size_t i = 0; i < base_keys.size(); ++i) { + empty_list.append(py::none()); + } + return empty_list; + } + + // If tp_size is 1, it's just a normal batch_get_tensor_into + if (tp_size <= 1) { + return batch_get_tensor_into(base_keys, buffer_ptrs, sizes); + } + + // Generate the specific shard keys for the given tp_rank + std::vector shard_keys; + shard_keys.reserve(base_keys.size()); + for (const auto &key : base_keys) { + shard_keys.push_back(get_tp_key_name(key, tp_rank)); + } + + // Use the existing batch_get_tensor_into to fetch all shards at once + return batch_get_tensor_into(shard_keys, buffer_ptrs, sizes); + } + int put_tensor_impl(const std::string &key, pybind11::object tensor, const ReplicateConfig &config) { // Validation & Metadata extraction (GIL Held) @@ -927,6 +1218,33 @@ PYBIND11_MODULE(store, m) { .def("pub_tensor", &MooncakeStorePyWrapper::pub_tensor, py::arg("key"), py::arg("tensor"), py::arg("config") = ReplicateConfig{}, "Publish a PyTorch tensor with configurable replication settings") + .def("get_tensor_into", &MooncakeStorePyWrapper::get_tensor_into, + py::arg("key"), py::arg("buffer_ptr"), py::arg("size"), + "Get tensor directly into a pre-allocated buffer") + .def("batch_get_tensor_into", + &MooncakeStorePyWrapper::batch_get_tensor_into, py::arg("keys"), + py::arg("buffer_ptrs"), py::arg("sizes"), + "Get tensors directly into pre-allocated buffers for multiple keys") + .def( + "get_tensor_into_with_tp", + &MooncakeStorePyWrapper::get_tensor_into_with_tp, py::arg("key"), + py::arg("buffer_ptr"), py::arg("size"), py::arg("tp_rank") = 0, + py::arg("tp_size") = 1, py::arg("split_dim") = 0, + "Get a PyTorch tensor from the store directly into a pre-allocated" + "buffer, optionally sliced for Tensor Parallelism.\n" + "Args:\n" + " key: The key of the tensor.\n" + " buffer_ptr: The buffer pointer pre-allocated for tensor.\n" + " size: The size of buffer.\n" + " tp_rank: The current tensor parallel rank (default 0).\n" + " tp_size: The total tensor parallel size (default 1).\n" + " split_dim: The dimension to split the tensor along (default 0).") + .def("batch_get_tensor_into_with_tp", + &MooncakeStorePyWrapper::batch_get_tensor_into_with_tp, + py::arg("base_keys"), py::arg("buffer_ptrs"), py::arg("sizes"), + py::arg("tp_rank") = 0, py::arg("tp_size") = 1, + "Get a batch of PyTorch tensor shards from the store directly into" + "pre-allocated buffers for a given Tensor Parallel rank.") .def( "register_buffer", [](MooncakeStorePyWrapper &self, uintptr_t buffer_ptr, diff --git a/mooncake-store/include/dummy_client.h b/mooncake-store/include/dummy_client.h index 7f3a56986..fde0089ea 100644 --- a/mooncake-store/include/dummy_client.h +++ b/mooncake-store/include/dummy_client.h @@ -228,4 +228,4 @@ class DummyClient : public PyClient { volatile bool connected_ = false; }; -} // namespace mooncake \ No newline at end of file +} // namespace mooncake diff --git a/mooncake-store/src/dummy_client.cpp b/mooncake-store/src/dummy_client.cpp index b7df019a9..41c113a66 100644 --- a/mooncake-store/src/dummy_client.cpp +++ b/mooncake-store/src/dummy_client.cpp @@ -755,4 +755,4 @@ void DummyClient::ping_thread_main() { } } -} // namespace mooncake \ No newline at end of file +} // namespace mooncake diff --git a/scripts/test_tensor_api.py b/scripts/test_tensor_api.py index 61d65337d..1f410c2bc 100644 --- a/scripts/test_tensor_api.py +++ b/scripts/test_tensor_api.py @@ -1,3 +1,4 @@ +import ctypes import os import sys import json @@ -20,11 +21,76 @@ GLOBAL_CONFIG = None DEFAULT_MOONCAKE_CONFIG_PATH_ENV = "MOONCAKE_CONFIG_PATH" -DEFAULT_GLOBAL_SEGMENT_SIZE = 4 * 1024 * 1024 * 1024 # 4 GiB -DEFAULT_LOCAL_BUFFER_SIZE = 2 * 1024 * 1024 * 1024 # 2 GB +DEFAULT_GLOBAL_SEGMENT_SIZE = 16 * 1024 * 1024 * 1024 # 16 GiB +DEFAULT_LOCAL_BUFFER_SIZE = 8 * 1024 * 1024 * 1024 # 8 GB DEFAULT_MASTER_METRICS_PORT = 9003 DEFAULT_CHECK_SERVER = False +DTYPE_MAP = { + 0: np.float32, # FLOAT32 + 1: np.float64, # FLOAT64 + 2: np.int8, # INT8 + 3: np.uint8, # UINT8 + 4: np.int16, # INT16 + 5: np.uint16, # UINT16 + 6: np.int32, # INT32 + 7: np.uint32, # UINT32 + 8: np.int64, # INT64 + 9: np.uint64, # UINT64 + 10: np.bool_, # BOOL + 11: np.float16, # FLOAT16 + # Note: BFLOAT16 (12), FLOAT8 (13,14), W8A8 (15) not supported in NumPy +} + +def verify_tensor_equality(original, received, rtol=0, atol=0, verbose=True): + """ + compare two tensors. + """ + def to_numpy(x): + if isinstance(x, torch.Tensor): + if x.is_cuda: + x = x.cpu() + return x.detach().numpy() + elif isinstance(x, np.ndarray): + return x + else: + raise TypeError(f"Unsupported tensor type: {type(x)}") + + try: + orig_np = to_numpy(original) + recv_np = to_numpy(received) + except Exception as e: + if verbose: + print(f"❌ Error converting tensors: {e}") + return False + + if orig_np.shape != recv_np.shape: + if verbose: + print(f"❌ Shape mismatch: original {orig_np.shape} vs received {recv_np.shape}") + return False + + if orig_np.dtype != recv_np.dtype: + if verbose: + print(f"❌ Dtype mismatch: original {orig_np.dtype} vs received {recv_np.dtype}") + return False + + if np.array_equal(orig_np, recv_np): + + return True + else: + diff_mask = orig_np != recv_np + diff_indices = np.where(diff_mask) + if len(diff_indices[0]) > 0: + first_diff_idx = tuple(idx[0] for idx in diff_indices) + orig_val = orig_np[first_diff_idx] + recv_val = recv_np[first_diff_idx] + if verbose: + print(f"❌ Tensors differ at index {first_diff_idx}") + print(f" Original: {orig_val}") + print(f" Received: {recv_val}") + print(f" Difference: {abs(orig_val - recv_val)}") + return False + def parse_global_segment_size(value) -> int: """Parse human-readable size strings (e.g., '4GB') into bytes.""" if isinstance(value, int): return value @@ -91,7 +157,7 @@ def generate_tensors(num_tensors, size_mb): num_elements = size_bytes // element_size dim = int(np.sqrt(num_elements)) dim = (dim // 8) * 8 # Adjust dimension to be divisible by common TP sizes (2, 4, 8) - + # Use random data and ensure the tensor is contiguous in memory tensors = [torch.randn(dim, dim, dtype=torch.float32).contiguous() for _ in range(num_tensors)] # Use timestamp to prevent key collision in rare edge cases (though we remove_all anyway) @@ -130,10 +196,10 @@ def setUp(self): # 1. Access the global connection if GLOBAL_STORE is None: self.skipTest("Store not initialized") - + self.store = GLOBAL_STORE self.config = GLOBAL_CONFIG - + # 2. [Critical] Clean environment before the test starts # This ensures no stale data from previous tests affects the current one self.store.remove_all() @@ -147,12 +213,12 @@ def test_01_basic_put_get(self): """Verify basic put and get functionality.""" key = "func_test_single" tensor = torch.randn(1024, 1024, dtype=torch.float32) - + # Perform Put rc = self.store.put_tensor(key, tensor) self.assertEqual(rc, 0, f"put_tensor failed with rc={rc}") self.assertTrue(self.store.is_exist(key), "Key not found after put") - + # Perform Get retrieved = self.store.get_tensor(key) self.assertIsNotNone(retrieved, "Get returned None") @@ -163,7 +229,7 @@ def test_02_tp_single_tensor(self): tp_size = 4 split_dim = 1 key = "func_test_tp_single" - + # Create a small tensor (e.g., 16MB) _, tensors = generate_tensors(1, 16) target_tensor = tensors[0] @@ -195,7 +261,7 @@ def test_03_tp_batch(self): split_dim = 0 num_tensors = 4 keys, tensors = generate_tensors(num_tensors, 8) # Small size for functional testing - + # 1. Batch Put with TP results = self.store.batch_put_tensor_with_tp(keys, tensors, tp_size=tp_size, split_dim=split_dim) self.assertTrue(all(r == 0 for r in results), f"Batch put failed. Results: {results}") @@ -212,13 +278,13 @@ def test_03_tp_batch(self): original = tensors[i] expected_chunks = original.chunk(tp_size, split_dim) reconstruction_parts = [] - + for rank in range(tp_size): shard = all_shards[rank][i] self.assertTrue(torch.equal(shard, expected_chunks[rank]), f"Tensor {i} Rank {rank} data mismatch") reconstruction_parts.append(shard) - + recon = torch.cat(reconstruction_parts, dim=split_dim) self.assertTrue(torch.equal(recon, original), f"Tensor {i} final reconstruction mismatch") @@ -232,6 +298,237 @@ def test_04_tp_consistency(self): self.assertTrue(tmp_tensor_0.sum() == chunked_tensors[0].sum()) self.assertTrue(tmp_tensor_1.sum() == chunked_tensors[1].sum()) + buffer_spacing = 1 * 1024 * 1024 + buffer_2 = (ctypes.c_ubyte * buffer_spacing)() + buffer_3 = (ctypes.c_ubyte * buffer_spacing)() + buffer_ptr_2 = ctypes.addressof(buffer_2) + buffer_ptr_3 = ctypes.addressof(buffer_3) + self.store.register_buffer(buffer_ptr_2, buffer_spacing) + self.store.register_buffer(buffer_ptr_3, buffer_spacing) + tmp_tensor_2 = self.store.batch_get_tensor_into_with_tp(['key'], [buffer_ptr_2], [buffer_spacing], tp_rank=0, tp_size=tp_size)[0] + tmp_tensor_3 = self.store.batch_get_tensor_into_with_tp(['key'], [buffer_ptr_3], [buffer_spacing], tp_rank=1, tp_size=tp_size)[0] + self.assertTrue(tmp_tensor_2.sum() == chunked_tensors[0].sum()) + self.assertTrue(tmp_tensor_3.sum() == chunked_tensors[1].sum()) + self.store.unregister_buffer(buffer_ptr_2) + self.store.unregister_buffer(buffer_ptr_3) + + def test_05_put_get_into(self): + """Verify basic put and get into functionality.""" + key = "get_into_test" + tensor = torch.randn(1024, 1024, dtype=torch.float32) + buffer_spacing = 64 * 1024 * 1024 + total_buffer_size = buffer_spacing + + # Allocate large contiguous buffer + large_buffer = (ctypes.c_ubyte * total_buffer_size)() + large_buffer_ptr = ctypes.addressof(large_buffer) + + res = self.store.register_buffer(large_buffer_ptr, total_buffer_size) + self.assertEqual(res, 0, "Buffer registration should succeed") + + # Perform Put + rc = self.store.put_tensor(key, tensor) + self.assertEqual(rc, 0, f"put_tensor failed with rc={rc}") + self.assertTrue(self.store.is_exist(key), "Key not found after put") + + # Perform Get + retrieved = self.store.get_tensor_into(key, large_buffer_ptr, total_buffer_size) + self.assertIsNotNone(retrieved, "Get returned None") + self.assertTrue(torch.equal(tensor, retrieved), f"Data mismatch between original and retrieved tensor, tensor: {tensor}, retrieved: {retrieved}") + # Unregister buffer + self.assertEqual( + self.store.unregister_buffer(large_buffer_ptr), + 0, + "Buffer unregistration should succeed" + ) + + def test_06_batch_put_get_into(self): + """Zero copy Batch Get.""" + num_tensors = 4 + keys, tensors = generate_tensors(num_tensors, 8) + buffer_spacing = 64 * 1024 * 1024 # 64MB per tensor slot + batch_size = len(keys) + total_buffer_size = buffer_spacing * batch_size + + # Allocate large contiguous buffer + large_buffer = (ctypes.c_ubyte * total_buffer_size)() + large_buffer_ptr = ctypes.addressof(large_buffer) + + # Prepare pointers (only addresses, no ctypes array objects) + buffer_ptrs = [] + buffer_sizes = [] + for i in range(batch_size): + offset = i * buffer_spacing + ptr = large_buffer_ptr + offset + buffer_ptrs.append(ptr) + buffer_sizes.append(buffer_spacing) + + # Register the entire buffer with the store + res = self.store.register_buffer(large_buffer_ptr, total_buffer_size) + self.assertEqual(res, 0, "Buffer registration should succeed") + + results = self.store.batch_put_tensor(keys, tensors) + self.assertTrue(all(r == 0 for r in results), f"Batch put failed. Results: {results}") + + res = self.store.batch_get_tensor_into(keys, buffer_ptrs, buffer_sizes) + self.assertEqual(len(res), len(tensors)) + for j in range(batch_size): + self.assertTrue( + verify_tensor_equality(tensors[j], res[j]), + f"Tensor {j} content mismatch, tensor: {tensors[j]}, res: {res[j]}" + ) + # Unregister buffer + self.assertEqual( + self.store.unregister_buffer(large_buffer_ptr), + 0, + "Buffer unregistration should succeed" + ) + + def test_07_put_get_into_with_tp(self): + """Zero copy Batch Get with TP — each rank has its own buffer.""" + tp_size = 4 + split_dim = 0 + key = "get_into_with_tp_test" + tensor = torch.randn(1024, 1024, dtype=torch.float32) + + # Step 1: Put full tensors with TP splitting + result = self.store.put_tensor_with_tp( + key, tensor, tp_size=tp_size, split_dim=split_dim + ) + self.assertEqual(result, 0, f"Put failed. Result: {result}") + + # Step 2: For each TP rank, allocate its own buffer and get shards + all_shards = [] + registered_buffers = [] # Keep track of (ptr, size) for cleanup + + for rank in range(tp_size): + buffer_spacing = 64 * 1024 * 1024 # 64MB per tensor slot + total_buffer_size = buffer_spacing + + # Allocate buffer for this rank + large_buffer = (ctypes.c_ubyte * total_buffer_size)() + large_buffer_ptr = ctypes.addressof(large_buffer) + + # Register buffer for this rank + res = self.store.register_buffer(large_buffer_ptr, total_buffer_size) + self.assertEqual(res, 0, f"Buffer registration failed for rank {rank}") + + # Keep buffer alive and record for cleanup + # We store (large_buffer, large_buffer_ptr, total_buffer_size) + # large_buffer must be kept alive until unregister! + registered_buffers.append((large_buffer, large_buffer_ptr, total_buffer_size)) + + # Get shard for this rank + shard = self.store.get_tensor_into_with_tp( + key, large_buffer_ptr, total_buffer_size, + tp_rank=rank, tp_size=tp_size + ) + all_shards.append(shard) + + # Step 3: Validate reconstruction + original = tensor + expected_chunks = original.chunk(tp_size, split_dim) + reconstruction_parts = [] + + for rank in range(tp_size): + shard = all_shards[rank] + self.assertTrue( + torch.equal(shard, expected_chunks[rank]), + f"Tensor Rank {rank} data mismatch" + ) + reconstruction_parts.append(shard) + + recon = torch.cat(reconstruction_parts, dim=split_dim) + self.assertTrue( + torch.equal(recon, original), + f"Tensor final reconstruction mismatch" + ) + + # Step 4: Unregister all buffers + for large_buffer, ptr, size in registered_buffers: + res = self.store.unregister_buffer(ptr) + self.assertEqual(res, 0, f"Buffer unregistration failed for buffer at {ptr}") + # large_buffer will be GC'd after this; no need to explicitly delete + + def test_08_batch_put_get_into_with_tp(self): + """Zero copy Batch Get with TP — each rank has its own buffer.""" + tp_size = 4 + split_dim = 0 + num_tensors = 4 + keys, tensors = generate_tensors(num_tensors, 8) + + # Step 1: Put full tensors with TP splitting + results = self.store.batch_put_tensor_with_tp( + keys, tensors, tp_size=tp_size, split_dim=split_dim + ) + self.assertTrue(all(r == 0 for r in results), f"Batch put failed. Results: {results}") + + # Step 2: For each TP rank, allocate its own buffer and get shards + all_shards = [] + registered_buffers = [] # Keep track of (ptr, size) for cleanup + + for rank in range(tp_size): + batch_size = len(keys) + buffer_spacing = 64 * 1024 * 1024 # 64MB per tensor slot + total_buffer_size = buffer_spacing * batch_size + + # Allocate buffer for this rank + large_buffer = (ctypes.c_ubyte * total_buffer_size)() + large_buffer_ptr = ctypes.addressof(large_buffer) + + # Prepare buffer pointers for this rank + buffer_ptrs = [] + buffer_sizes = [] + for i in range(batch_size): + offset = i * buffer_spacing + ptr = large_buffer_ptr + offset + buffer_ptrs.append(ptr) + buffer_sizes.append(buffer_spacing) + + # Register buffer for this rank + res = self.store.register_buffer(large_buffer_ptr, total_buffer_size) + self.assertEqual(res, 0, f"Buffer registration failed for rank {rank}") + + # Keep buffer alive and record for cleanup + # We store (large_buffer, large_buffer_ptr, total_buffer_size) + # large_buffer must be kept alive until unregister! + registered_buffers.append((large_buffer, large_buffer_ptr, total_buffer_size)) + + # Get shards for this rank + shards = self.store.batch_get_tensor_into_with_tp( + keys, buffer_ptrs, buffer_sizes, + tp_rank=rank, tp_size=tp_size + ) + self.assertEqual(len(shards), num_tensors) + all_shards.append(shards) + + # Step 3: Validate reconstruction + for i in range(num_tensors): + original = tensors[i] + expected_chunks = original.chunk(tp_size, split_dim) + reconstruction_parts = [] + + for rank in range(tp_size): + shard = all_shards[rank][i] + self.assertTrue( + torch.equal(shard, expected_chunks[rank]), + f"Tensor {i} Rank {rank} data mismatch" + ) + reconstruction_parts.append(shard) + + recon = torch.cat(reconstruction_parts, dim=split_dim) + self.assertTrue(torch.equal(recon, original), f"Tensor {i} final reconstruction mismatch") + self.assertTrue( + torch.equal(recon, original), + f"Tensor {i} final reconstruction mismatch" + ) + + # Step 4: Unregister all buffers + for large_buffer, ptr, size in registered_buffers: + res = self.store.unregister_buffer(ptr) + self.assertEqual(res, 0, f"Buffer unregistration failed for buffer at {ptr}") + # large_buffer will be GC'd after this; no need to explicitly delete + # ========================================== # Performance/Benchmark Tests # ========================================== @@ -239,19 +536,19 @@ def test_04_tp_consistency(self): class TestMooncakeBenchmark(MooncakeTestBase): # Benchmark Settings BENCH_ITERATIONS = 5 - TENSOR_SIZE_MB = 64 - TOTAL_SIZE_GB = 1 + TENSOR_SIZE_MB = 16 + TOTAL_SIZE_MB = 256 def setUp(self): """Benchmark-specific setUp.""" # 1. Call parent setUp to clean the store (remove_all) super().setUp() - + # 2. Generate test data - total_bytes = self.TOTAL_SIZE_GB * 1024**3 + total_bytes = int(self.TOTAL_SIZE_MB * 1024**2) tensor_bytes = self.TENSOR_SIZE_MB * 1024**2 self.num_tensors = max(1, total_bytes // tensor_bytes) - + print(f"\n[Gen] Generating {self.num_tensors} tensors (~{self.TENSOR_SIZE_MB}MB each)...") self.keys, self.tensors = generate_tensors(self.num_tensors, self.TENSOR_SIZE_MB) self.total_bits = (tensor_bytes * self.num_tensors) * 8 @@ -270,7 +567,7 @@ def test_benchmark_01_batch_put_get(self): for i in range(self.BENCH_ITERATIONS): # Clean store before each iteration for "cold" writes self.store.remove_all() - + # Measure Put t0 = time.perf_counter() self.store.batch_put_tensor(self.keys, self.tensors) @@ -311,6 +608,147 @@ def test_benchmark_02_tp_batch(self): self._print_perf(f"TP Batch Put (TP={tp_size})", put_times) self._print_perf(f"TP Batch Get (TP={tp_size})", get_times) + def test_benchmark_03_batch_put_get_into(self): + """Benchmark: Zero copy Batch Get.""" + self.store.remove_all() + buffer_spacing = 300 * 1024 * 1024 # 300MB per tensor slot + batch_size = len(self.keys) + total_buffer_size = buffer_spacing * batch_size + + # Allocate large contiguous buffer + large_buffer = (ctypes.c_ubyte * total_buffer_size)() + large_buffer_ptr = ctypes.addressof(large_buffer) + + # Prepare pointers (only addresses, no ctypes array objects) + buffer_ptrs = [] + buffer_sizes = [] + for i in range(batch_size): + offset = i * buffer_spacing + ptr = large_buffer_ptr + offset + buffer_ptrs.append(ptr) + buffer_sizes.append(buffer_spacing) + + # Register the entire buffer with the store + res = self.store.register_buffer(large_buffer_ptr, total_buffer_size) + self.assertEqual(res, 0, "Buffer registration should succeed") + + print(f"--- Running zero copy Batch Benchmark ({self.BENCH_ITERATIONS} iters) ---") + put_times = [] + get_times = [] + + for i in range(self.BENCH_ITERATIONS): + # Clean store before each iteration for "cold" writes + self.store.remove_all() + + # Measure Put + t0 = time.perf_counter() + self.store.batch_put_tensor(self.keys, self.tensors) + put_times.append(time.perf_counter() - t0) + + # Measure Get + t0 = time.perf_counter() + res = self.store.batch_get_tensor_into(self.keys, buffer_ptrs, buffer_sizes) + get_times.append(time.perf_counter() - t0) + self.assertEqual(len(res), len(self.tensors)) + for j in range(batch_size): + self.assertTrue( + verify_tensor_equality(self.tensors[j], res[j]), + f"Tensor {j} content mismatch" + ) + + self._print_perf("Standard Batch Put", put_times) + self._print_perf("Zero copy Batch Get", get_times) + + # Unregister buffer + self.assertEqual( + self.store.unregister_buffer(large_buffer_ptr), + 0, + "Buffer unregistration should succeed" + ) + + def test_benchmark_04_batch_put_get_into_with_tp(self): + """Benchmark: Zero copy Batch Get with tp.""" + tp_size = 4 + split_dim = 0 + batch_size = len(self.keys) + self.store.remove_all() + buffer_spacing = 64 * 1024 * 1024 # 64MB per tensor slot + + # Allocate and register a separate buffer for each TP rank + rank_buffers = [] # Store metadata for cleanup + + for rank in range(tp_size): + total_buffer_size = buffer_spacing * batch_size + large_buffer = (ctypes.c_ubyte * total_buffer_size)() + large_buffer_ptr = ctypes.addressof(large_buffer) + + buffer_ptrs = [] + buffer_sizes = [] + for i in range(batch_size): + offset = i * buffer_spacing + ptr = large_buffer_ptr + offset + buffer_ptrs.append(ptr) + buffer_sizes.append(buffer_spacing) + + # Register buffer for this rank + res = self.store.register_buffer(large_buffer_ptr, total_buffer_size) + self.assertEqual(res, 0, f"Buffer registration failed for rank {rank}") + + rank_buffers.append({ + 'buffer_obj': large_buffer, # keep alive + 'ptrs': buffer_ptrs, + 'sizes': buffer_sizes, + 'base_ptr': large_buffer_ptr, + 'total_size': total_buffer_size + }) + + print(f"--- Running zero copy Batch Benchmark (TP={tp_size}, {self.BENCH_ITERATIONS} iters) ---") + put_times = [] + get_times = [] + + for i in range(self.BENCH_ITERATIONS): + # Clean store before each iteration for "cold" writes + self.store.remove_all() + + # Measure Put + t0 = time.perf_counter() + self.store.batch_put_tensor_with_tp( + self.keys, self.tensors, tp_size=tp_size, split_dim=split_dim + ) + put_times.append(time.perf_counter() - t0) + + # Measure Get: each rank uses its own buffer + t0 = time.perf_counter() + all_res = [] + for rank in range(tp_size): + res = self.store.batch_get_tensor_into_with_tp( + self.keys, + rank_buffers[rank]['ptrs'], + rank_buffers[rank]['sizes'], + tp_rank=rank, + tp_size=tp_size + ) + self.assertEqual(len(res), batch_size) + all_res.append(res) + get_times.append(time.perf_counter() - t0) + + # Verify correctness using rank 0's result + for j in range(batch_size): + original = self.tensors[j] + expected_shard = original.chunk(tp_size, split_dim)[0] # rank 0 shard + actual = all_res[0][j] + self.assertTrue( + torch.equal(actual, expected_shard), + f"Tensor {j} content mismatch on rank 0" + ) + + self._print_perf(f"Standard Batch Put with tp (TP={tp_size})", put_times) + self._print_perf(f"Zero copy Batch Get with tp (TP={tp_size})", get_times) + + # Unregister all buffers + for buf_info in rank_buffers: + res = self.store.unregister_buffer(buf_info['base_ptr']) + self.assertEqual(res, 0, f"Buffer unregistration failed for buffer at {buf_info['base_ptr']}") # ========================================== # Stress/Concurrency Tests @@ -488,8 +926,39 @@ def _test_dtype_roundtrip(self, dtype, name, expected_enum_name=None): print(f" [Fail] {name:<15} Data content mismatch") self.fail(f"Data content mismatch for {name}") + buffer_spacing = 1 * 1024 * 1024 + buffer = (ctypes.c_ubyte * buffer_spacing)() + buffer_ptr = ctypes.addressof(buffer) + self.store.register_buffer(buffer_ptr, buffer_spacing) + retrieved = self.store.get_tensor_into(key, buffer_ptr, buffer_spacing) + if retrieved is None: + print(f" [Fail] {name:<15} Get returned None") + self.fail(f"Get returned None for {name}") + + # We expect the retrieved tensor to have the same dtype as input + if original.dtype != retrieved.dtype: + msg = f"Dtype mismatch for {name}! Input: {original.dtype}, Output: {retrieved.dtype}" + print(f" [Fail] {name:<15} {msg}") + self.fail(msg) + + # Use byte-view comparison for robustness (especially for FP8/BF16 on CPU) + try: + # Cast to untyped storage byte view (or uint8 view) + t1_bytes = original.view(torch.uint8) if original.element_size() > 0 else original + t2_bytes = retrieved.view(torch.uint8) if retrieved.element_size() > 0 else retrieved + is_equal = torch.equal(t1_bytes, t2_bytes) + except Exception: + # Fallback for types that might fail view() or equal() + is_equal = torch.equal(original.cpu(), retrieved.cpu()) + + if not is_equal: + print(f" [Fail] {name:<15} Data content mismatch") + self.fail(f"Data content mismatch for {name}") + self.store.unregister_buffer(buffer_ptr) + print(f" [Pass] {name:<15} {str(dtype)}") + def test_all_dtypes(self): print("\n--- Testing All Supported PyTorch Data Types ---") @@ -566,5 +1035,5 @@ def test_fp8_types(self): runner = unittest.TextTestRunner(verbosity=2) result = runner.run(suite) - - sys.exit(not result.wasSuccessful()) \ No newline at end of file + + sys.exit(not result.wasSuccessful())