Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions core/raiden_manager_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,18 @@ absl::Status RaidenManagerBase::PullWeightsChunk(
dst_shard_idx, dst_offset_bytes, size_bytes);
}

absl::Status RaidenManagerBase::PushWeightsChunk(const std::string& peer,
size_t dst_shard_idx,
size_t dst_offset_bytes,
const uint8_t* data_ptr,
size_t size_bytes) {
if (!server_) {
return absl::FailedPreconditionError("Transport server is not running");
}
return server_->PushWeightsChunk(peer, dst_shard_idx, dst_offset_bytes,
data_ptr, size_bytes);
}

size_t RaidenManagerBase::bytes_per_block() const {
return block_size_ * slice_byte_size_;
}
Expand Down
4 changes: 4 additions & 0 deletions core/raiden_manager_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ class RaidenManagerBase : public tpu_raiden::transport::BlockTransportDelegate {
size_t src_offset_bytes, size_t dst_shard_idx,
size_t dst_offset_bytes, size_t size_bytes);

absl::Status PushWeightsChunk(const std::string& peer, size_t dst_shard_idx,
size_t dst_offset_bytes,
const uint8_t* data_ptr, size_t size_bytes);

std::optional<int> local_port() const;

uint8_t* GetHostPointer(size_t layer_idx, size_t shard_idx) override;
Expand Down
272 changes: 259 additions & 13 deletions rpc/raiden_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,11 @@

import asyncio
import dataclasses
import math
import socket
import threading
import time
from typing import Optional, Protocol, runtime_checkable
from typing import Optional, Protocol, Union, runtime_checkable

from weight_sync import weight_synchronizer_service_pb2

Expand Down Expand Up @@ -95,14 +96,18 @@ class TransferPlan:
# index, and the n-dimensional slice offsets for the shard index.
plan: dict[RaidenId, list[list[tuple[RaidenId, int, list[NDSlice]]]]]

# NEW: Maps every RaidenId in the plan to its physical Control-Plane RPC
# address!
shard_push_schedules: dict[
RaidenId, dict[int, list[tuple[str, int, int, int, int]]]
] = dataclasses.field(default_factory=dict)

# Maps every RaidenId in the plan to its physical Control-Plane RPC
# address
worker_rpc_addresses: dict[RaidenId, str] = dataclasses.field(
default_factory=dict
)

# NEW: Maps every RaidenId in the plan to its physical Data TCP socket
# endpoints (e.g. ['10.0.0.2:8500'])
# Maps every RaidenId in the plan to its physical Data TCP socket
# endpoints
worker_data_addresses: dict[RaidenId, list[str]] = dataclasses.field(
default_factory=dict
)
Expand Down Expand Up @@ -361,6 +366,35 @@ def _encode_start_transfer(
command=weight_synchronizer_service_pb2.ControlRequest.COMMAND_START_TRANSFER,
peers=peers,
)

start_req = weight_synchronizer_service_pb2.StartTransferRequest(
src_units=[_raiden_id_to_proto(u) for u in transfer_plan.src_units],
dst_units=[_raiden_id_to_proto(u) for u in transfer_plan.dst_units],
)

if transfer_plan.shard_push_schedules:
push_schedules = transfer_plan.shard_push_schedules.get(target_id)
if push_schedules:
for shard_idx, entries in push_schedules.items():
schedule_proto = (
weight_synchronizer_service_pb2.ShardPushScheduleProto()
)
for (
dst_peer,
dst_shard_idx,
dst_offset,
src_offset,
size,
) in entries:
entry_proto = schedule_proto.entries.add()
entry_proto.dst_peer = dst_peer
entry_proto.dst_shard_idx = dst_shard_idx
entry_proto.dst_offset_bytes = dst_offset
entry_proto.src_offset_bytes = src_offset
entry_proto.size_bytes = size
start_req.shard_push_schedules[shard_idx].CopyFrom(schedule_proto)

req.start_transfer_request.CopyFrom(start_req)
return req.SerializeToString()

def _verify_response(self, resp_bytes: bytes) -> None:
Expand Down Expand Up @@ -409,6 +443,116 @@ def done(self) -> bool:
return self._completed


def intersect_nd_slices(
slice1: list[tuple[int, int]], slice2: list[tuple[int, int]]
) -> Optional[list[tuple[int, int]]]:
"""Computes the precise N-dimensional intersection bounding box between two multi-dimensional slices.

Each slice is represented as a list of coordinate bounds (start, end) for
each dimension.

Args:
slice1: First N-dimensional slice bounding box.
slice2: Second N-dimensional slice bounding box.

Returns:
A list of (start, end) coordinate bounds representing the intersecting
subgrid, or None if the slices do not overlap in any dimension.
"""
result = []
for (s1, e1), (s2, e2) in zip(slice1, slice2):
start = max(s1, s2)
end = min(e1, e2)
if start >= end:
return None
result.append((start, end))
return result


def generate_1d_copy_chunks(
src_shard_slice: list[tuple[int, int]],
dst_shard_slice: list[tuple[int, int]],
intersection_slice: list[tuple[int, int]],
itemsize: int,
) -> list[tuple[int, int, int]]:
"""Translates an N-dimensional grid intersection into non-contiguous 1D memory copy byte chunks.

When linearizing multi-dimensional arrays, an intersecting subgrid is often
non-contiguous in memory. This function computes the exact linear source and
destination byte offsets and chunk sizes needed to transmit the non-adjacent
strided minor rows over a flat 1D data stream.

Args:
src_shard_slice: Global multi-dimensional bounding box of the source shard.
dst_shard_slice: Global multi-dimensional bounding box of the destination
shard.
intersection_slice: Global multi-dimensional bounding box of the overlapping
subgrid.
itemsize: Byte size of a single array scalar element (e.g., 4 for float32).

Returns:
A list of (src_offset_bytes, dst_offset_bytes, size_bytes) tuples defining
the concrete 1D linear memory copy chunk operations.
"""
rank = len(src_shard_slice)
src_shape = [e - s for s, e in src_shard_slice]
dst_shape = [e - s for s, e in dst_shard_slice]
int_shape = [e - s for s, e in intersection_slice]

src_local_int_slice = [
(int_s - src_s, int_e - src_s)
for (src_s, _), (int_s, int_e) in zip(src_shard_slice, intersection_slice)
]
dst_local_int_slice = [
(int_s - dst_s, int_e - dst_s)
for (dst_s, _), (int_s, int_e) in zip(dst_shard_slice, intersection_slice)
]

src_strides = [1] * rank
for i in range(rank - 2, -1, -1):
src_strides[i] = src_strides[i + 1] * src_shape[i + 1]

dst_strides = [1] * rank
for i in range(rank - 2, -1, -1):
dst_strides[i] = dst_strides[i + 1] * dst_shape[i + 1]

minor_dim_size = int_shape[-1]
contiguous_bytes = minor_dim_size * itemsize

chunks = []
outer_shape = int_shape[:-1]
num_outer_elements = math.prod(outer_shape) if outer_shape else 1

for i in range(num_outer_elements):
multi_index = []
temp = i
for dim_size in reversed(outer_shape):
multi_index.append(temp % dim_size)
temp //= dim_size
multi_index.reverse()

src_offset_items = 0
dst_offset_items = 0

for d in range(rank - 1):
src_idx = src_local_int_slice[d][0] + multi_index[d]
src_offset_items += src_idx * src_strides[d]

dst_idx = dst_local_int_slice[d][0] + multi_index[d]
dst_offset_items += dst_idx * dst_strides[d]

src_offset_items += src_local_int_slice[-1][0] * src_strides[-1]
dst_offset_items += dst_local_int_slice[-1][0] * dst_strides[-1]

chunks.append((
src_offset_items * itemsize,
dst_offset_items * itemsize,
contiguous_bytes,
))

return chunks


class RaidenController:
"""High-level transfer controller managing active transfers and generating transfer plans."""

Expand All @@ -418,6 +562,10 @@ def __init__(
self.port = port
self._active_transfers: dict[str, TransferPlan] = {}
self._registered_shards: dict[RaidenId, list[str]] = {}
self._registered_shard_slices: dict[
RaidenId, list[list[tuple[int, int]]]
] = {}
self._registered_itemsizes: dict[RaidenId, int] = {}
self._lock = threading.Lock()
self.worker_rpc_client = worker_rpc_client or WorkerRpcClient()

Expand All @@ -426,10 +574,25 @@ def register_work_unit(
unit: RaidenId,
shards: list[str],
control_plane_rpc_address: Optional[str] = None,
shard_nd_slices: Optional[
Union[
list[list[tuple[int, int]]],
list[weight_synchronizer_service_pb2.NDSliceProto],
]
] = None,
itemsize: Optional[int] = None,
) -> None:
"""Registers physical worker shard Data addresses and optional Control-Plane RPC endpoint."""
if shard_nd_slices is not None and itemsize is None:
raise ValueError(
"itemsize must not be None if shard_nd_slices is provided."
)
with self._lock:
self._registered_shards[unit] = shards
if shard_nd_slices:
self._registered_shard_slices[unit] = shard_nd_slices
if itemsize:
self._registered_itemsizes[unit] = itemsize
if control_plane_rpc_address and hasattr(
self.worker_rpc_client, "register_worker_endpoint"
):
Expand Down Expand Up @@ -528,6 +691,50 @@ def start_transfer(

plan[selected_src] = src_plan

src_slices = self._registered_shard_slices.get(selected_src)
dst_slices = {}
for d in dst_units:
d_slices = self._registered_shard_slices.get(d)
if d_slices:
dst_slices[d] = d_slices

itemsize = self._registered_itemsizes.get(selected_src)
if src_slices is not None and itemsize is None:
raise ValueError(
"itemsize must be registered if shard_nd_slices is provided."
)

shard_push_schedules: dict[
RaidenId, dict[int, list[tuple[str, int, int, int, int]]]
] = {}

if src_slices and len(dst_slices) == len(dst_units):
push_schedules: dict[int, list[tuple[str, int, int, int, int]]] = {}
for src_shard_idx, src_slice in enumerate(src_slices):
shard_entries = []
for dst_unit in dst_units:
d_slices = dst_slices[dst_unit]
dst_shards = self._resolve_shards(dst_unit)
for dst_shard_idx, dst_slice in enumerate(d_slices):
intersection = intersect_nd_slices(src_slice, dst_slice)
if intersection:
dst_peer = dst_shards[dst_shard_idx % len(dst_shards)]
chunks = generate_1d_copy_chunks(
src_slice, dst_slice, intersection, itemsize
)
for src_offset, dst_offset, size in chunks:
shard_entries.append((
dst_peer,
dst_shard_idx,
dst_offset,
src_offset,
size,
))
if shard_entries:
push_schedules[src_shard_idx] = shard_entries

shard_push_schedules[selected_src] = push_schedules

rpc_addresses = {}
if hasattr(self.worker_rpc_client, "get_worker_endpoints"):
rpc_addresses = self.worker_rpc_client.get_worker_endpoints()
Expand All @@ -538,6 +745,7 @@ def start_transfer(
src_units=[selected_src],
dst_units=dst_units,
plan=plan,
shard_push_schedules=shard_push_schedules,
worker_rpc_addresses=rpc_addresses,
worker_data_addresses=data_addresses,
)
Expand Down Expand Up @@ -660,7 +868,18 @@ def _handle_conn(
if reg.control_plane_rpc_address
else None
)
self._controller.register_work_unit(unit, shards, ctrl_addr)
shard_slices = []
for nd_slice_proto in reg.shard_nd_slices:
dims = []
for dim_proto in nd_slice_proto.dimensions:
dims.append((dim_proto.start, dim_proto.end))
shard_slices.append(dims)

itemsize = reg.itemsize if reg.itemsize > 0 else None

self._controller.register_work_unit(
unit, shards, ctrl_addr, shard_slices, itemsize
)
resp.success = True
elif (
req.command
Expand Down Expand Up @@ -764,6 +983,13 @@ def register_work_unit(
unit: RaidenId,
shards: list[str],
control_plane_rpc_address: Optional[str] = None,
shard_nd_slices: Optional[
Union[
list[list[tuple[int, int]]],
list[weight_synchronizer_service_pb2.NDSliceProto],
]
] = None,
itemsize: Optional[int] = None,
) -> None:
"""Sends remote RPC to register a physical worker entity with the central RaidenControllerServer.

Expand All @@ -772,16 +998,36 @@ def register_work_unit(
shards: list of physical Data TCP addresses (e.g. 'IP:Port').
control_plane_rpc_address: Optional worker Control-Plane RPC servicer
endpoint coordinate.
shard_nd_slices: Optional bounding boxes for each shard.
itemsize: Optional item size in bytes.
"""
if shard_nd_slices is not None and itemsize is None:
raise ValueError(
"itemsize must not be None if shard_nd_slices is provided."
)
reg_req = weight_synchronizer_service_pb2.RegisterWorkUnitRequest(
unit=_raiden_id_to_proto(unit),
shards=shards,
control_plane_rpc_address=(
control_plane_rpc_address if control_plane_rpc_address else ""
),
)
if shard_nd_slices:
for nd_slice in shard_nd_slices:
if isinstance(nd_slice, weight_synchronizer_service_pb2.NDSliceProto):
reg_req.shard_nd_slices.add().CopyFrom(nd_slice)
else:
slice_proto = reg_req.shard_nd_slices.add()
for s, e in nd_slice:
dim_proto = slice_proto.dimensions.add()
dim_proto.start = s
dim_proto.end = e
if itemsize:
reg_req.itemsize = itemsize

req = weight_synchronizer_service_pb2.ControlRequest(
command=weight_synchronizer_service_pb2.ControlRequest.COMMAND_REGISTER_WORK_UNIT,
register_work_unit_request=weight_synchronizer_service_pb2.RegisterWorkUnitRequest(
unit=_raiden_id_to_proto(unit),
shards=shards,
control_plane_rpc_address=(
control_plane_rpc_address if control_plane_rpc_address else ""
),
),
register_work_unit_request=reg_req,
)
self._send_protobuf_rpc(req)

Expand Down
Loading