From 57776d107611d1a0e74a10302b22d7fb75a06f08 Mon Sep 17 00:00:00 2001 From: datenglin Date: Mon, 15 Jun 2026 10:11:55 -0700 Subject: [PATCH] Implemented resharding for the push model with the Raiden controller. PiperOrigin-RevId: 932528741 --- core/raiden_manager_base.cc | 12 + core/raiden_manager_base.h | 4 + rpc/raiden_controller.py | 272 +++++++++++++++++- rpc/raiden_controller_test.py | 26 ++ tpu_raiden/frameworks/jax/BUILD | 5 +- .../frameworks/jax/resharding_planner.py | 54 ++++ .../frameworks/jax/resharding_planner_test.py | 24 ++ transport/block_transport.cc | 69 +++++ transport/block_transport.h | 29 +- weight_sync/weight_synchronizer_base.cc | 38 +++ weight_sync/weight_synchronizer_base.h | 19 ++ .../weight_synchronizer_control_service.cc | 26 +- ...eight_synchronizer_control_service_test.cc | 101 +++++++ weight_sync/weight_synchronizer_service.proto | 50 ++++ 14 files changed, 708 insertions(+), 21 deletions(-) diff --git a/core/raiden_manager_base.cc b/core/raiden_manager_base.cc index 6d25452..58ffb36 100644 --- a/core/raiden_manager_base.cc +++ b/core/raiden_manager_base.cc @@ -160,6 +160,18 @@ absl::Status RaidenManagerBase::PullWeightsChunk( dst_shard_idx, dst_offset_bytes, size_bytes); } +absl::Status RaidenManagerBase::PushWeightsChunk(const std::string& peer, + size_t dst_shard_idx, + size_t dst_offset_bytes, + const uint8_t* data_ptr, + size_t size_bytes) { + if (!server_) { + return absl::FailedPreconditionError("Transport server is not running"); + } + return server_->PushWeightsChunk(peer, dst_shard_idx, dst_offset_bytes, + data_ptr, size_bytes); +} + size_t RaidenManagerBase::bytes_per_block() const { return block_size_ * slice_byte_size_; } diff --git a/core/raiden_manager_base.h b/core/raiden_manager_base.h index dfe9b51..e4e9a6d 100644 --- a/core/raiden_manager_base.h +++ b/core/raiden_manager_base.h @@ -80,6 +80,10 @@ class RaidenManagerBase : public tpu_raiden::transport::BlockTransportDelegate { size_t src_offset_bytes, size_t dst_shard_idx, size_t dst_offset_bytes, size_t size_bytes); + absl::Status PushWeightsChunk(const std::string& peer, size_t dst_shard_idx, + size_t dst_offset_bytes, + const uint8_t* data_ptr, size_t size_bytes); + std::optional local_port() const; uint8_t* GetHostPointer(size_t layer_idx, size_t shard_idx) override; diff --git a/rpc/raiden_controller.py b/rpc/raiden_controller.py index 98355a1..e9b3069 100644 --- a/rpc/raiden_controller.py +++ b/rpc/raiden_controller.py @@ -29,10 +29,11 @@ import asyncio import dataclasses +import math import socket import threading import time -from typing import Optional, Protocol, runtime_checkable +from typing import Optional, Protocol, Union, runtime_checkable from weight_sync import weight_synchronizer_service_pb2 @@ -95,14 +96,18 @@ class TransferPlan: # index, and the n-dimensional slice offsets for the shard index. plan: dict[RaidenId, list[list[tuple[RaidenId, int, list[NDSlice]]]]] - # NEW: Maps every RaidenId in the plan to its physical Control-Plane RPC - # address! + shard_push_schedules: dict[ + RaidenId, dict[int, list[tuple[str, int, int, int, int]]] + ] = dataclasses.field(default_factory=dict) + + # Maps every RaidenId in the plan to its physical Control-Plane RPC + # address worker_rpc_addresses: dict[RaidenId, str] = dataclasses.field( default_factory=dict ) - # NEW: Maps every RaidenId in the plan to its physical Data TCP socket - # endpoints (e.g. ['10.0.0.2:8500']) + # Maps every RaidenId in the plan to its physical Data TCP socket + # endpoints worker_data_addresses: dict[RaidenId, list[str]] = dataclasses.field( default_factory=dict ) @@ -361,6 +366,35 @@ def _encode_start_transfer( command=weight_synchronizer_service_pb2.ControlRequest.COMMAND_START_TRANSFER, peers=peers, ) + + start_req = weight_synchronizer_service_pb2.StartTransferRequest( + src_units=[_raiden_id_to_proto(u) for u in transfer_plan.src_units], + dst_units=[_raiden_id_to_proto(u) for u in transfer_plan.dst_units], + ) + + if transfer_plan.shard_push_schedules: + push_schedules = transfer_plan.shard_push_schedules.get(target_id) + if push_schedules: + for shard_idx, entries in push_schedules.items(): + schedule_proto = ( + weight_synchronizer_service_pb2.ShardPushScheduleProto() + ) + for ( + dst_peer, + dst_shard_idx, + dst_offset, + src_offset, + size, + ) in entries: + entry_proto = schedule_proto.entries.add() + entry_proto.dst_peer = dst_peer + entry_proto.dst_shard_idx = dst_shard_idx + entry_proto.dst_offset_bytes = dst_offset + entry_proto.src_offset_bytes = src_offset + entry_proto.size_bytes = size + start_req.shard_push_schedules[shard_idx].CopyFrom(schedule_proto) + + req.start_transfer_request.CopyFrom(start_req) return req.SerializeToString() def _verify_response(self, resp_bytes: bytes) -> None: @@ -409,6 +443,116 @@ def done(self) -> bool: return self._completed +def intersect_nd_slices( + slice1: list[tuple[int, int]], slice2: list[tuple[int, int]] +) -> Optional[list[tuple[int, int]]]: + """Computes the precise N-dimensional intersection bounding box between two multi-dimensional slices. + + Each slice is represented as a list of coordinate bounds (start, end) for + each dimension. + + Args: + slice1: First N-dimensional slice bounding box. + slice2: Second N-dimensional slice bounding box. + + Returns: + A list of (start, end) coordinate bounds representing the intersecting + subgrid, or None if the slices do not overlap in any dimension. + """ + result = [] + for (s1, e1), (s2, e2) in zip(slice1, slice2): + start = max(s1, s2) + end = min(e1, e2) + if start >= end: + return None + result.append((start, end)) + return result + + +def generate_1d_copy_chunks( + src_shard_slice: list[tuple[int, int]], + dst_shard_slice: list[tuple[int, int]], + intersection_slice: list[tuple[int, int]], + itemsize: int, +) -> list[tuple[int, int, int]]: + """Translates an N-dimensional grid intersection into non-contiguous 1D memory copy byte chunks. + + When linearizing multi-dimensional arrays, an intersecting subgrid is often + non-contiguous in memory. This function computes the exact linear source and + destination byte offsets and chunk sizes needed to transmit the non-adjacent + strided minor rows over a flat 1D data stream. + + Args: + src_shard_slice: Global multi-dimensional bounding box of the source shard. + dst_shard_slice: Global multi-dimensional bounding box of the destination + shard. + intersection_slice: Global multi-dimensional bounding box of the overlapping + subgrid. + itemsize: Byte size of a single array scalar element (e.g., 4 for float32). + + Returns: + A list of (src_offset_bytes, dst_offset_bytes, size_bytes) tuples defining + the concrete 1D linear memory copy chunk operations. + """ + rank = len(src_shard_slice) + src_shape = [e - s for s, e in src_shard_slice] + dst_shape = [e - s for s, e in dst_shard_slice] + int_shape = [e - s for s, e in intersection_slice] + + src_local_int_slice = [ + (int_s - src_s, int_e - src_s) + for (src_s, _), (int_s, int_e) in zip(src_shard_slice, intersection_slice) + ] + dst_local_int_slice = [ + (int_s - dst_s, int_e - dst_s) + for (dst_s, _), (int_s, int_e) in zip(dst_shard_slice, intersection_slice) + ] + + src_strides = [1] * rank + for i in range(rank - 2, -1, -1): + src_strides[i] = src_strides[i + 1] * src_shape[i + 1] + + dst_strides = [1] * rank + for i in range(rank - 2, -1, -1): + dst_strides[i] = dst_strides[i + 1] * dst_shape[i + 1] + + minor_dim_size = int_shape[-1] + contiguous_bytes = minor_dim_size * itemsize + + chunks = [] + outer_shape = int_shape[:-1] + num_outer_elements = math.prod(outer_shape) if outer_shape else 1 + + for i in range(num_outer_elements): + multi_index = [] + temp = i + for dim_size in reversed(outer_shape): + multi_index.append(temp % dim_size) + temp //= dim_size + multi_index.reverse() + + src_offset_items = 0 + dst_offset_items = 0 + + for d in range(rank - 1): + src_idx = src_local_int_slice[d][0] + multi_index[d] + src_offset_items += src_idx * src_strides[d] + + dst_idx = dst_local_int_slice[d][0] + multi_index[d] + dst_offset_items += dst_idx * dst_strides[d] + + src_offset_items += src_local_int_slice[-1][0] * src_strides[-1] + dst_offset_items += dst_local_int_slice[-1][0] * dst_strides[-1] + + chunks.append(( + src_offset_items * itemsize, + dst_offset_items * itemsize, + contiguous_bytes, + )) + + return chunks + + class RaidenController: """High-level transfer controller managing active transfers and generating transfer plans.""" @@ -418,6 +562,10 @@ def __init__( self.port = port self._active_transfers: dict[str, TransferPlan] = {} self._registered_shards: dict[RaidenId, list[str]] = {} + self._registered_shard_slices: dict[ + RaidenId, list[list[tuple[int, int]]] + ] = {} + self._registered_itemsizes: dict[RaidenId, int] = {} self._lock = threading.Lock() self.worker_rpc_client = worker_rpc_client or WorkerRpcClient() @@ -426,10 +574,25 @@ def register_work_unit( unit: RaidenId, shards: list[str], control_plane_rpc_address: Optional[str] = None, + shard_nd_slices: Optional[ + Union[ + list[list[tuple[int, int]]], + list[weight_synchronizer_service_pb2.NDSliceProto], + ] + ] = None, + itemsize: Optional[int] = None, ) -> None: """Registers physical worker shard Data addresses and optional Control-Plane RPC endpoint.""" + if shard_nd_slices is not None and itemsize is None: + raise ValueError( + "itemsize must not be None if shard_nd_slices is provided." + ) with self._lock: self._registered_shards[unit] = shards + if shard_nd_slices: + self._registered_shard_slices[unit] = shard_nd_slices + if itemsize: + self._registered_itemsizes[unit] = itemsize if control_plane_rpc_address and hasattr( self.worker_rpc_client, "register_worker_endpoint" ): @@ -528,6 +691,50 @@ def start_transfer( plan[selected_src] = src_plan + src_slices = self._registered_shard_slices.get(selected_src) + dst_slices = {} + for d in dst_units: + d_slices = self._registered_shard_slices.get(d) + if d_slices: + dst_slices[d] = d_slices + + itemsize = self._registered_itemsizes.get(selected_src) + if src_slices is not None and itemsize is None: + raise ValueError( + "itemsize must be registered if shard_nd_slices is provided." + ) + + shard_push_schedules: dict[ + RaidenId, dict[int, list[tuple[str, int, int, int, int]]] + ] = {} + + if src_slices and len(dst_slices) == len(dst_units): + push_schedules: dict[int, list[tuple[str, int, int, int, int]]] = {} + for src_shard_idx, src_slice in enumerate(src_slices): + shard_entries = [] + for dst_unit in dst_units: + d_slices = dst_slices[dst_unit] + dst_shards = self._resolve_shards(dst_unit) + for dst_shard_idx, dst_slice in enumerate(d_slices): + intersection = intersect_nd_slices(src_slice, dst_slice) + if intersection: + dst_peer = dst_shards[dst_shard_idx % len(dst_shards)] + chunks = generate_1d_copy_chunks( + src_slice, dst_slice, intersection, itemsize + ) + for src_offset, dst_offset, size in chunks: + shard_entries.append(( + dst_peer, + dst_shard_idx, + dst_offset, + src_offset, + size, + )) + if shard_entries: + push_schedules[src_shard_idx] = shard_entries + + shard_push_schedules[selected_src] = push_schedules + rpc_addresses = {} if hasattr(self.worker_rpc_client, "get_worker_endpoints"): rpc_addresses = self.worker_rpc_client.get_worker_endpoints() @@ -538,6 +745,7 @@ def start_transfer( src_units=[selected_src], dst_units=dst_units, plan=plan, + shard_push_schedules=shard_push_schedules, worker_rpc_addresses=rpc_addresses, worker_data_addresses=data_addresses, ) @@ -660,7 +868,18 @@ def _handle_conn( if reg.control_plane_rpc_address else None ) - self._controller.register_work_unit(unit, shards, ctrl_addr) + shard_slices = [] + for nd_slice_proto in reg.shard_nd_slices: + dims = [] + for dim_proto in nd_slice_proto.dimensions: + dims.append((dim_proto.start, dim_proto.end)) + shard_slices.append(dims) + + itemsize = reg.itemsize if reg.itemsize > 0 else None + + self._controller.register_work_unit( + unit, shards, ctrl_addr, shard_slices, itemsize + ) resp.success = True elif ( req.command @@ -764,6 +983,13 @@ def register_work_unit( unit: RaidenId, shards: list[str], control_plane_rpc_address: Optional[str] = None, + shard_nd_slices: Optional[ + Union[ + list[list[tuple[int, int]]], + list[weight_synchronizer_service_pb2.NDSliceProto], + ] + ] = None, + itemsize: Optional[int] = None, ) -> None: """Sends remote RPC to register a physical worker entity with the central RaidenControllerServer. @@ -772,16 +998,36 @@ def register_work_unit( shards: list of physical Data TCP addresses (e.g. 'IP:Port'). control_plane_rpc_address: Optional worker Control-Plane RPC servicer endpoint coordinate. + shard_nd_slices: Optional bounding boxes for each shard. + itemsize: Optional item size in bytes. """ + if shard_nd_slices is not None and itemsize is None: + raise ValueError( + "itemsize must not be None if shard_nd_slices is provided." + ) + reg_req = weight_synchronizer_service_pb2.RegisterWorkUnitRequest( + unit=_raiden_id_to_proto(unit), + shards=shards, + control_plane_rpc_address=( + control_plane_rpc_address if control_plane_rpc_address else "" + ), + ) + if shard_nd_slices: + for nd_slice in shard_nd_slices: + if isinstance(nd_slice, weight_synchronizer_service_pb2.NDSliceProto): + reg_req.shard_nd_slices.add().CopyFrom(nd_slice) + else: + slice_proto = reg_req.shard_nd_slices.add() + for s, e in nd_slice: + dim_proto = slice_proto.dimensions.add() + dim_proto.start = s + dim_proto.end = e + if itemsize: + reg_req.itemsize = itemsize + req = weight_synchronizer_service_pb2.ControlRequest( command=weight_synchronizer_service_pb2.ControlRequest.COMMAND_REGISTER_WORK_UNIT, - register_work_unit_request=weight_synchronizer_service_pb2.RegisterWorkUnitRequest( - unit=_raiden_id_to_proto(unit), - shards=shards, - control_plane_rpc_address=( - control_plane_rpc_address if control_plane_rpc_address else "" - ), - ), + register_work_unit_request=reg_req, ) self._send_protobuf_rpc(req) diff --git a/rpc/raiden_controller_test.py b/rpc/raiden_controller_test.py index 9417229..fa57265 100644 --- a/rpc/raiden_controller_test.py +++ b/rpc/raiden_controller_test.py @@ -167,6 +167,32 @@ async def start_transfer(self, orchestrator_id, plan) -> None: ], ) + def test_enforce_itemsize_when_shard_nd_slices_provided(self): + controller = raiden_controller.RaidenController(port=10003) + unit = raiden_controller.RaidenId( + job_name="trainer", job_replica_id="0", data_name="weights" + ) + + # 1. Verification during register_work_unit + with self.assertRaisesWithPredicateMatch( + ValueError, lambda e: "itemsize must not be None" in str(e) + ): + controller.register_work_unit( + unit, ["10.0.0.1:8000"], shard_nd_slices=[[[(0, 2)]]] + ) + + # 2. Verification during start_transfer if bypassed + controller._registered_shard_slices[unit] = [[[(0, 2)]]] + controller._registered_itemsizes.pop(unit, None) + dst = raiden_controller.RaidenId( + job_name="sampler", job_replica_id="0", data_name="weights" + ) + + with self.assertRaisesWithPredicateMatch( + ValueError, lambda e: "itemsize must be registered" in str(e) + ): + controller.start_transfer(src_units=[unit], dst_units=[dst]) + if __name__ == "__main__": diff --git a/tpu_raiden/frameworks/jax/BUILD b/tpu_raiden/frameworks/jax/BUILD index 5fb349e..ab58cc7 100644 --- a/tpu_raiden/frameworks/jax/BUILD +++ b/tpu_raiden/frameworks/jax/BUILD @@ -526,7 +526,10 @@ py_library( name = "resharding_planner", srcs = ["resharding_planner.py"], visibility = ["//visibility:public"], - deps = ["@jax//jax"], + deps = [ + "//weight_sync:weight_synchronizer_service_py_pb2", + "@jax//jax", + ], ) py_library( diff --git a/tpu_raiden/frameworks/jax/resharding_planner.py b/tpu_raiden/frameworks/jax/resharding_planner.py index e871111..98fa7dc 100644 --- a/tpu_raiden/frameworks/jax/resharding_planner.py +++ b/tpu_raiden/frameworks/jax/resharding_planner.py @@ -28,8 +28,10 @@ """Lightweight generalized 2D block resharding planner for TPU Raiden.""" from dataclasses import dataclass +import itertools from typing import List, Tuple import jax +from weight_sync import weight_synchronizer_service_pb2 @dataclass @@ -239,3 +241,55 @@ def make_resharding_plan_from_metadata( ) return plan + + +def compute_nd_shard_slices( + global_shape: Tuple[int, ...], + mesh_shape: Tuple[int, ...], +) -> List[weight_synchronizer_service_pb2.NDSliceProto]: + """Computes N-dimensional logical tensor bounding boxes for a sharded grid. + + This function derives the exact coordinate intervals along every dimension + for every logical accelerator shard in canonical row-major order. + + Args: + global_shape: The global multi-dimensional shape of the tensor. + mesh_shape: The sharding grid configuration (number of devices per + dimension). + + Returns: + A list of NDSliceProto messages containing the multi-dimensional bounding + box for each logical device shard. + """ + if len(global_shape) != len(mesh_shape): + raise ValueError( + f"Tensor rank ({len(global_shape)}) and sharding mesh rank" + f" ({len(mesh_shape)}) must match exactly." + ) + + rank = len(global_shape) + tile_sizes = [] + for d in range(rank): + if mesh_shape[d] <= 0: + raise ValueError(f"Mesh shape at dimension {d} must be positive.") + tile_sizes.append(global_shape[d] // mesh_shape[d]) + + # Generate all multi-dimensional device coordinates in row-major sequence + coordinate_ranges = [range(mesh_shape[d]) for d in range(rank)] + + shard_slices = [] + for device_coord in itertools.product(*coordinate_ranges): + slice_proto = weight_synchronizer_service_pb2.NDSliceProto() + for d in range(rank): + c = device_coord[d] + start = c * tile_sizes[d] + # Ensure any exact remainders land nicely in the last physical mesh shard boundary + end = ( + (c + 1) * tile_sizes[d] + if c < mesh_shape[d] - 1 + else global_shape[d] + ) + slice_proto.dimensions.add(start=start, end=end) + shard_slices.append(slice_proto) + + return shard_slices diff --git a/tpu_raiden/frameworks/jax/resharding_planner_test.py b/tpu_raiden/frameworks/jax/resharding_planner_test.py index 44d88fa..eb63be0 100644 --- a/tpu_raiden/frameworks/jax/resharding_planner_test.py +++ b/tpu_raiden/frameworks/jax/resharding_planner_test.py @@ -92,6 +92,30 @@ def test_reshard_axis_1_to_axis_0(self): self.assertEqual(chunk_3_7.dst_slice, (0, 16, 768, 1024)) self.assertEqual(chunk_3_7.shape, (16, 256)) + def test_compute_nd_shard_slices(self): + # Test 2D grid: 128x1024 across 2x4 mesh + slices_2d = resharding_planner.compute_nd_shard_slices((128, 1024), (2, 4)) + self.assertLen(slices_2d, 8) + # Shard 0 (0, 0) + s0 = slices_2d[0].dimensions + self.assertEqual((s0[0].start, s0[0].end), (0, 64)) + self.assertEqual((s0[1].start, s0[1].end), (0, 256)) + # Shard 7 (1, 3) + s7 = slices_2d[7].dimensions + self.assertEqual((s7[0].start, s7[0].end), (64, 128)) + self.assertEqual((s7[1].start, s7[1].end), (768, 1024)) + + # Test 3D grid: 16x32x64 across 2x1x4 mesh + slices_3d = resharding_planner.compute_nd_shard_slices( + (16, 32, 64), (2, 1, 4) + ) + self.assertLen(slices_3d, 8) + # Shard 5 (1, 0, 1) -> index 1*4 + 0*4 + 1 = 5 + s5 = slices_3d[5].dimensions + self.assertEqual((s5[0].start, s5[0].end), (8, 16)) + self.assertEqual((s5[1].start, s5[1].end), (0, 32)) + self.assertEqual((s5[2].start, s5[2].end), (16, 32)) + if __name__ == "__main__": absltest.main() diff --git a/transport/block_transport.cc b/transport/block_transport.cc index 5282502..f43e3f7 100644 --- a/transport/block_transport.cc +++ b/transport/block_transport.cc @@ -461,6 +461,32 @@ absl::Status BlockTransport::ProcessSingleRequest(int client_fd) { ack = 1; RETURN_IF_ERROR(WriteExact(client_fd, &ack, 1)); RETURN_IF_ERROR(delegate_->OnSingleBlockReceived(dst_id, size_bytes)); + } else if (header.op == + 5) { // Arbitrary Byte Slice Push (Distributed Resharding) + uint32_t dst_offset = header.remote_block_id; + uint32_t dst_shard_idx = header.local_block_id; + uint32_t size_bytes = header.num_blocks; + + if (delegate_->num_layers() == 0) { + return absl::InternalError("Server host buffers are not initialized"); + } + if (dst_shard_idx >= delegate_->num_shards()) { + return absl::InvalidArgumentError( + absl::StrCat("Invalid destination shard index: ", dst_shard_idx, + ", total shards: ", delegate_->num_shards())); + } + uint8_t* base_host_ptr = delegate_->GetHostPointer(0, dst_shard_idx); + size_t host_size = delegate_->GetHostSize(0, dst_shard_idx); + if (dst_offset + size_bytes > host_size) { + return absl::InvalidArgumentError(absl::StrCat( + "Destination out of bounds. Offset: ", dst_offset, + ", Size: ", size_bytes, ", Shard Host Size: ", host_size)); + } + uint8_t* dest_ptr = base_host_ptr + dst_offset; + RETURN_IF_ERROR(ReadExact(client_fd, dest_ptr, size_bytes)); + + uint8_t ack = 1; + RETURN_IF_ERROR(WriteExact(client_fd, &ack, 1)); } return absl::OkStatus(); } @@ -979,5 +1005,48 @@ absl::Status BlockTransport::PullWeightsChunk( return absl::OkStatus(); } +absl::Status BlockTransport::PushWeightsChunk(const std::string& peer, + size_t dst_shard_idx, + size_t dst_offset_bytes, + const uint8_t* data_ptr, + size_t size_bytes) { + if (peer.empty()) { + return absl::InvalidArgumentError( + "Destination peer address cannot be empty"); + } + + TF_ASSIGN_OR_RETURN(const int fd, AcquireConnection(peer)); + bool ok_to_pool = false; + auto fd_cleaner = absl::MakeCleanup([&] { + if (ok_to_pool) { + ReleaseConnection(peer, fd); + } else { + shutdown(fd, SHUT_RDWR); + close(fd); + } + }); + + BlockPacketHeader header = {}; + // Operation code 5 signals low-overhead streaming of arbitrary, + // non-contiguous strided byte slices directly into a remote TPU shard Host + // buffer offset. + header.op = 5; + header.remote_block_id = static_cast(dst_offset_bytes); + header.local_block_id = static_cast(dst_shard_idx); + header.num_blocks = static_cast(size_bytes); + + RETURN_IF_ERROR(WriteExact(fd, &header, sizeof(header))); + RETURN_IF_ERROR(WriteExact(fd, data_ptr, size_bytes)); + + uint8_t ack = 0; + RETURN_IF_ERROR(ReadExact(fd, &ack, 1)); + if (ack != 1) { + return absl::InternalError("PushWeightsChunk verification failed"); + } + + ok_to_pool = true; + return absl::OkStatus(); +} + } // namespace transport } // namespace tpu_raiden \ No newline at end of file diff --git a/transport/block_transport.h b/transport/block_transport.h index d643ceb..28769cd 100644 --- a/transport/block_transport.h +++ b/transport/block_transport.h @@ -88,7 +88,9 @@ class BlockTransport { public: // Binary packet header layout for H2H transfers. struct alignas(8) BlockPacketHeader { - uint8_t op; // 1 = Push, 2 = Pull, 3 = Byte-Range Pull + // 1 = Push, 2 = Pull, 3 = Byte-Range Pull, 4 = Single Block Push, 5 = + // Resharding Slice Push + uint8_t op; uint8_t major_order; // See MajorOrder. Ignored by legacy ops. uint16_t reserved = 0; uint32_t remote_block_id; @@ -125,6 +127,31 @@ class BlockTransport { absl::Status WriteBlockDirect(const std::string& peer, int remote_block_id, const uint8_t* data_ptr, size_t size_bytes); + /** + * Directly pushes an arbitrary byte slice from local Host memory into a + * specific memory offset of a remote peer's destination shard buffer. + * + * Utilizes a highly optimized request packet framing to stream raw + * bytes and blocks synchronously until an explicit TCP ACK response is + * returned, ensuring perfect end-to-end memory consistency. + * + * @param peer Direct network coordinate "ip:port" of the remote destination + * worker. + * @param dst_shard_idx Target logical shard index residing on the remote + * peer. + * @param dst_offset_bytes Exact linear byte offset into the target shard + * buffer. + * @param data_ptr Pointer to the beginning of the local staging source byte + * array. + * @param size_bytes Number of continuous bytes to push across the network + * stream. + * @return absl::OkStatus() upon verified written completion confirmed by the + * destination. + */ + absl::Status PushWeightsChunk(const std::string& peer, size_t dst_shard_idx, + size_t dst_offset_bytes, + const uint8_t* data_ptr, size_t size_bytes); + int local_port() const { return local_port_; } private: diff --git a/weight_sync/weight_synchronizer_base.cc b/weight_sync/weight_synchronizer_base.cc index 328a195..d4ec5c6 100644 --- a/weight_sync/weight_synchronizer_base.cc +++ b/weight_sync/weight_synchronizer_base.cc @@ -41,6 +41,7 @@ #include "core/raiden_manager_base.h" #include "core/raw_transfer_core.h" #include "weight_sync/weight_synchronizer_control_service.h" +#include "weight_sync/weight_synchronizer_service.pb.h" ABSL_FLAG(size_t, raiden_weight_sync_host_buffer_scratchpad_size, 256 * 1024, "Amount of scratchpad to allocate to host buffers for resharding " @@ -518,6 +519,43 @@ absl::Status WeightSynchronizerBase::PushWeights( return absl::OkStatus(); } +absl::Status WeightSynchronizerBase::PushWeightsResharded( + const tpu_raiden::weight_sync::StartTransferRequest& request) { + const auto& schedules = request.shard_push_schedules(); + TF_ASSIGN_OR_RETURN(raiden::PjRtCopyFuture d2h_future, D2h()); + TF_RETURN_IF_ERROR(d2h_future.Await().status()); + + for (size_t i = 0; i < num_shards_; ++i) { + auto it = schedules.find(static_cast(i)); + if (it == schedules.end()) { + continue; + } + const auto& schedule = it->second; + const uint8_t* base_host_ptr = GetHostPointer(0, i); + if (base_host_ptr == nullptr) { + return absl::InternalError("Host pointer is null during resharded push"); + } + size_t shard_host_size = GetHostSize(0, i); + + for (const auto& entry : schedule.entries()) { + const std::string& dst_peer = entry.dst_peer(); + size_t dst_shard_idx = entry.dst_shard_idx(); + size_t dst_offset = entry.dst_offset_bytes(); + size_t src_offset = entry.src_offset_bytes(); + size_t size = entry.size_bytes(); + + if (src_offset + size > shard_host_size) { + return absl::InvalidArgumentError("Push range out of bounds"); + } + + const uint8_t* data_ptr = base_host_ptr + src_offset; + TF_RETURN_IF_ERROR(PushWeightsChunk(dst_peer, dst_shard_idx, dst_offset, + data_ptr, size)); + } + } + return absl::OkStatus(); +} + absl::Status WeightSynchronizerBase::PullWeights(const std::string& source) { if (source.empty()) { return absl::InvalidArgumentError( diff --git a/weight_sync/weight_synchronizer_base.h b/weight_sync/weight_synchronizer_base.h index 4ddd781..ee8dd67 100644 --- a/weight_sync/weight_synchronizer_base.h +++ b/weight_sync/weight_synchronizer_base.h @@ -34,6 +34,7 @@ namespace tpu_raiden { namespace weight_sync { class WeightSynchronizerControlService; +class StartTransferRequest; class WeightSynchronizerBase : public tpu_raiden::RaidenManagerBase { public: @@ -62,6 +63,24 @@ class WeightSynchronizerBase : public tpu_raiden::RaidenManagerBase { // network push) absl::Status PushWeights(const std::vector& peers); + /** + * Executes a distributed resharding push transfer based on precise + * centralized Controller schedules. + * + * Automatically copies active local weight buffers from TPU device HBM to + * Host staging memory (via D2H), iterates over all active local shards, and + * pipelines non-contiguous byte chunks across persistent TCP connections to + * target remote peer host buffers. + * + * @param request Demarshaled StartTransferRequest protobuf containing exact + * 1D memory copy byte chunks, peer network coordinates, and + * offset schedules. + * @return absl::OkStatus() upon complete, successfully ACK-handshaked + * delivery to all remote peers. + */ + absl::Status PushWeightsResharded( + const tpu_raiden::weight_sync::StartTransferRequest& request); + // Inference server pulls current weights from the source peer E2E (network // pull + H2D) absl::Status PullWeights(const std::string& source); diff --git a/weight_sync/weight_synchronizer_control_service.cc b/weight_sync/weight_synchronizer_control_service.cc index 9326caa..6ad78a2 100644 --- a/weight_sync/weight_synchronizer_control_service.cc +++ b/weight_sync/weight_synchronizer_control_service.cc @@ -163,15 +163,29 @@ void WeightSynchronizerControlService::ConnectionWorker(int client_fd) { if (req.command() == tpu_raiden::weight_sync::ControlRequest::COMMAND_START_TRANSFER) { - std::vector peers(req.peers().begin(), req.peers().end()); - LOG(INFO) << "C++ Control Service received START_TRANSFER request to " - << peers.size() << " peers"; - if (!peers.empty()) { - absl::Status status = engine_->PushWeights(peers); + if (req.has_start_transfer_request() && + !req.start_transfer_request().shard_push_schedules().empty()) { + LOG(INFO) << "C++ Control Service received START_TRANSFER with " + "shard_push_schedules"; + const auto& start_req = req.start_transfer_request(); + absl::Status status = engine_->PushWeightsResharded(start_req); if (!status.ok()) { resp.set_success(false); resp.set_message(std::string(status.message())); - LOG(ERROR) << "PushWeights native execution failed: " << status; + LOG(ERROR) << "PushWeightsResharded native execution failed: " + << status; + } + } else { + std::vector peers(req.peers().begin(), req.peers().end()); + LOG(INFO) << "C++ Control Service received START_TRANSFER request to " + << peers.size() << " peers"; + if (!peers.empty()) { + absl::Status status = engine_->PushWeights(peers); + if (!status.ok()) { + resp.set_success(false); + resp.set_message(std::string(status.message())); + LOG(ERROR) << "PushWeights native execution failed: " << status; + } } } } else if (req.command() == diff --git a/weight_sync/weight_synchronizer_control_service_test.cc b/weight_sync/weight_synchronizer_control_service_test.cc index bad87dc..02d87a8 100644 --- a/weight_sync/weight_synchronizer_control_service_test.cc +++ b/weight_sync/weight_synchronizer_control_service_test.cc @@ -37,6 +37,7 @@ #include #include #include +#include #include #include "weight_sync/weight_synchronizer_base.h" @@ -150,6 +151,106 @@ TEST(WeightSynchronizerControlServiceTest, ShutdownCommandStopsService) { close(sock); } +TEST(WeightSynchronizerControlServiceTest, PushWeightsReshardedSuccess) { + WeightSynchronizerBase src_engine( + /*num_layers=*/1, /*num_shards=*/4, /*slice_byte_size=*/16, + /*local_port=*/0, /*host_blocks_to_allocate=*/std::nullopt, + /*parallelism=*/1, /*control_port=*/std::nullopt); + + WeightSynchronizerControlService control_service(&src_engine, + /*control_port=*/0); + ASSERT_GT(control_service.control_port(), 0); + + int sock = ConnectToControlPort(control_service.control_port()); + ASSERT_GE(sock, 0); + + WeightSynchronizerBase dst_engine( + /*num_layers=*/1, /*num_shards=*/4, /*slice_byte_size=*/16, + /*local_port=*/0, /*host_blocks_to_allocate=*/1, + /*parallelism=*/1, /*control_port=*/std::nullopt); + ASSERT_TRUE(dst_engine.local_port().has_value()); + std::string dst_peer = + "127.0.0.1:" + std::to_string(*dst_engine.local_port()); + + // Populate source buffers + std::vector> src_data = { + {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27}, + {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31}, + {32, 33, 34, 35, 40, 41, 42, 43, 48, 49, 50, 51, 56, 57, 58, 59}, + {36, 37, 38, 39, 44, 45, 46, 47, 52, 53, 54, 55, 60, 61, 62, 63}, + }; + + for (size_t i = 0; i < 4; ++i) { + uint8_t* ptr = src_engine.GetHostPointer(0, i); + ASSERT_NE(ptr, nullptr); + std::memcpy(ptr, src_data[i].data(), 16); + } + + ControlRequest req; + req.set_command(ControlRequest::COMMAND_START_TRANSFER); + + StartTransferRequest* start_req = req.mutable_start_transfer_request(); + + // Construct precise resharding push schedules for S0 and S2 pushing to D0 + auto& push_schedules = *start_req->mutable_shard_push_schedules(); + + // S0 -> D0 + ShardPushScheduleProto s0_sched; + for (int r = 0; r < 4; ++r) { + ShardPushEntryProto* e = s0_sched.add_entries(); + e->set_dst_peer(dst_peer); + e->set_dst_shard_idx(0); + e->set_src_offset_bytes(r * 4); + e->set_dst_offset_bytes(r * 2); + e->set_size_bytes(2); + } + push_schedules[0] = s0_sched; + + // S2 -> D0 + ShardPushScheduleProto s2_sched; + for (int r = 0; r < 4; ++r) { + ShardPushEntryProto* e = s2_sched.add_entries(); + e->set_dst_peer(dst_peer); + e->set_dst_shard_idx(0); + e->set_src_offset_bytes(r * 4); + e->set_dst_offset_bytes(8 + r * 2); + e->set_size_bytes(2); + } + push_schedules[2] = s2_sched; + + std::string payload; + ASSERT_TRUE(req.SerializeToString(&payload)); + uint32_t req_len = htonl(payload.size()); + + EXPECT_EQ(write(sock, &req_len, sizeof(req_len)), sizeof(req_len)); + EXPECT_EQ(write(sock, payload.data(), payload.size()), payload.size()); + + // Read response + uint32_t resp_len_net = 0; + ASSERT_EQ(read(sock, &resp_len_net, sizeof(resp_len_net)), + sizeof(resp_len_net)); + uint32_t resp_len = ntohl(resp_len_net); + + std::string resp_bytes(resp_len, '\0'); + ASSERT_EQ(read(sock, resp_bytes.data(), resp_len), resp_len); + + ControlResponse resp; + ASSERT_TRUE(resp.ParseFromString(resp_bytes)); + EXPECT_TRUE(resp.success()) << resp.message(); + + // Verify Destination Shard 0 final host memory! + uint8_t* dst_ptr = dst_engine.GetHostPointer(0, 0); + ASSERT_NE(dst_ptr, nullptr); + + std::vector expected_d0 = {0, 1, 8, 9, 16, 17, 24, 25, + 32, 33, 40, 41, 48, 49, 56, 57}; + for (size_t k = 0; k < 16; ++k) { + EXPECT_EQ(dst_ptr[k], expected_d0[k]) << "Mismatch at byte " << k; + } + + close(sock); +} + } // namespace } // namespace weight_sync } // namespace tpu_raiden diff --git a/weight_sync/weight_synchronizer_service.proto b/weight_sync/weight_synchronizer_service.proto index 0b2dd03..38a018e 100644 --- a/weight_sync/weight_synchronizer_service.proto +++ b/weight_sync/weight_synchronizer_service.proto @@ -25,15 +25,65 @@ message RaidenIdProto { string data_name = 3; } +// Coordinate intervals representing a contiguous slice along a single tensor +// dimension. +message NDSliceDimensionProto { + // Inclusive start coordinate. + int64 start = 1; + // Exclusive end coordinate. + int64 end = 2; +} + +// Slices representing a multi-dimensional bounding box inside a distributed +// tensor. +message NDSliceProto { + // Dimension intervals ordered from outer-most dimension to inner-most minor + // dimension. + repeated NDSliceDimensionProto dimensions = 1; +} + message RegisterWorkUnitRequest { RaidenIdProto unit = 1; repeated string shards = 2; string control_plane_rpc_address = 3; + // Multi-dimensional bounding boxes owned by each logical device shard in + // row-major order. + repeated NDSliceProto shard_nd_slices = 4; + // Byte size of a single array scalar element (e.g., 4 for float32). + int32 itemsize = 5; +} + +// Defines an explicit 1D linear memory push operation from a local source to a +// remote destination shard. +message ShardPushEntryProto { + // Network coordinate "ip:port" or BNS of the remote Destination Worker Data + // servicer. + string dst_peer = 1; + // Logical shard partition index managed by the target Destination Worker + // process. + int32 dst_shard_idx = 2; + // Exact linear memory byte offset to paste data into within the target + // destination buffer. + int64 dst_offset_bytes = 3; + // Exact linear memory byte offset to copy data out of within the local source + // buffer. + int64 src_offset_bytes = 4; + // Number of continuous bytes to transmit across the TCP stream. + int64 size_bytes = 5; +} + +// A collection of network push entries orchestrated by a specific source TPU +// shard. +message ShardPushScheduleProto { + repeated ShardPushEntryProto entries = 1; } message StartTransferRequest { repeated RaidenIdProto src_units = 1; repeated RaidenIdProto dst_units = 2; + // Centralized transfer execution plans mapped by local Source Worker logical + // shard partition index. + map shard_push_schedules = 3; } message ControlRequest {