Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 94 additions & 0 deletions tests/test_tensormeta_views.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""Unit tests for TensorMeta segment views (select/slice without hydration).

Pure-CPU: refs are faked with a minimal handle protocol (.local/.shape/.dtype/
.device), no transport backend required (materialize falls back to per-handle
fetch when backend is None).
"""

import torch

from unirl.distributed.tensor.transport import TensorMeta


class _FakeHandle:
def __init__(self, t: torch.Tensor):
self.t = t
self.shape = t.shape
self.dtype = t.dtype
self.device = t.device

def local(self) -> torch.Tensor:
return self.t


def _meta(*tensors: torch.Tensor) -> TensorMeta:
return TensorMeta(
refs=[_FakeHandle(t) for t in tensors],
sizes=[int(t.shape[0]) for t in tensors],
shape=(sum(int(t.shape[0]) for t in tensors), *tensors[0].shape[1:]),
dtype=tensors[0].dtype,
device="cpu",
)


def test_select_permutation_with_ragged_pad():
t0 = torch.arange(12).reshape(3, 4).float()
t1 = torch.arange(100, 112).reshape(2, 6).float()
t2 = torch.arange(200, 210).reshape(2, 5).float()
tm = _meta(t0, t1, t2)
perm = [5, 0, 3, 6, 2, 1, 4]
v = tm.select(perm)
assert v.view_plan is not None and v.batch_size == 7
out = v.materialize(backend=None)
assert out.shape == (7, 6) # ragged refs right-padded to the max width
assert torch.equal(out[0, :5], t2[0])
assert torch.equal(out[1, :4], t0[0])
assert torch.all(out[1, 4:] == 0)


def test_view_slice_matches_materialized_rows():
t0 = torch.arange(12).reshape(3, 4).float()
t1 = torch.arange(100, 108).reshape(2, 4).float()
v = _meta(t0, t1).select([4, 0, 2, 1])
full = v.materialize(backend=None)
half = v.slice(1, 3)
assert torch.equal(half.materialize(backend=None), full[1:3])


def test_misaligned_slice_degrades_to_view():
t0 = torch.arange(12).reshape(3, 4).float()
t1 = torch.arange(100, 108).reshape(2, 4).float()
tm = _meta(t0, t1)
mid = tm.slice(1, 4) # crosses the ref boundary off-alignment
assert mid.view_plan is not None and mid.batch_size == 3
assert torch.equal(mid.materialize(backend=None), torch.cat([t0[1:], t1[:1]]))


def test_packed_segment_view():
p0 = torch.arange(10).float()
p1 = torch.arange(100, 106).float()
pm = _meta(p0, p1)
pv = pm.select_segments([(12, 16), (0, 3)]) # out-of-order token ranges
assert pv.batch_size == 7
assert torch.equal(pv.materialize(backend=None), torch.cat([p1[2:6], p0[0:3]]))


def test_with_refs_preserves_plan():
t0 = torch.arange(8).reshape(2, 4).float()
v = _meta(t0).select([1, 0])
v2 = v.with_refs(list(v.refs))
assert v2.view_plan == v.view_plan


def test_empty_selection():
p0 = torch.arange(10).float()
e = _meta(p0).select_segments([])
assert e.batch_size == 0
assert e.materialize(backend=None).numel() == 0


def test_assemble_from_prefetched_parts():
t0 = torch.arange(12).reshape(3, 4).float()
t1 = torch.arange(100, 112).reshape(2, 6).float()
v = _meta(t0, t1).select([4, 0, 2])
assert torch.equal(v.materialize(backend=None), v.assemble({0: t0, 1: t1}))
4 changes: 2 additions & 2 deletions unirl/distributed/tensor/backend/gpu_store/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ def get_batch(self, metas: Dict[str, TensorMeta]) -> Dict[str, torch.Tensor]:
borrow_map = self._batch_borrow(all_handles)
out: Dict[str, torch.Tensor] = {}
for k, m in metas.items():
parts = [self._resolve_handle(h, borrow_map) for h in m.refs]
out[k] = parts[0] if len(parts) == 1 else torch.cat(parts, dim=0)
resolved = {i: self._resolve_handle(h, borrow_map) for i, h in enumerate(m.refs)}
out[k] = m.assemble(resolved)
return out

def put_batch(self, tensors: Dict[str, torch.Tensor]) -> Dict[str, TensorMeta]:
Expand Down
7 changes: 7 additions & 0 deletions unirl/distributed/tensor/backend/transfer_queue/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,13 @@ def _bs(t: Any) -> int:
def get_batch(self, metas: Dict[str, TensorMeta]) -> Dict[str, torch.Tensor]:
if not metas:
return {}
viewed = {k: m for k, m in metas.items() if getattr(m, "view_plan", None) is not None}
if viewed:
plain = {k: m for k, m in metas.items() if k not in viewed}
out = self.get_batch(plain) if plain else {}
for k, m in viewed.items():
out[k] = m.materialize(backend=self)
return out

async def _get_batch() -> Dict[str, torch.Tensor]:
# Flatten every key's refs into one handle list (each handle carries
Expand Down
7 changes: 7 additions & 0 deletions unirl/distributed/tensor/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,14 @@ def _select_packed_data(
"construct via the regular dataclass __init__ with per-sample lists."
)
if not indices:
if hasattr(value, "select_segments"):
return value.select_segments([])
return value[:0].clone()
if hasattr(value, "select_segments"):
# TensorMeta: token-range gather as a lazy segment view (no data motion).
return value.select_segments(
[(int(cu[i].item()), int(cu[i + 1].item())) for i in indices]
)
chunks = [value[int(cu[i].item()) : int(cu[i + 1].item())] for i in indices]
return torch.cat(chunks, dim=0)

Expand Down
188 changes: 174 additions & 14 deletions unirl/distributed/tensor/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,18 @@ class TensorMeta(Batch):
device: Optional[str] = shared_field(default=None)
grad: Optional["TensorMeta"] = shared_field(default=None)
retain_grad_flag: bool = shared_field(default=False)
# Row/segment VIEW over the refs (the permutation primitive): an ordered
# list of ``(ref_idx, start, end)`` segments in ref-local units (rows for
# CONCAT fields, tokens for PACKED fields). None = the legacy identity
# view (all of every ref, in ref order). A view keeps the data remote —
# selection/permutation across the dispatch boundary no longer needs
# driver-side hydration; segments are gathered lazily at materialize().
view_plan: Optional[List[Tuple[int, int, int]]] = shared_field(default=None)

@property
def batch_size(self) -> int:
if self.view_plan is not None:
return sum(int(e) - int(s) for _, s, e in self.view_plan)
return sum(self.sizes) if self.sizes else 0

@classmethod
Expand All @@ -75,8 +84,87 @@ def concat(cls, items: "list[TensorMeta]") -> "TensorMeta":
device=first.device,
)

def select(self, indices):
raise NotImplementedError("TensorMeta does not support select — hydrate first")
def select(self, indices) -> "TensorMeta":
"""Re-index along the unit axis by building a segment VIEW (no data motion)."""
idx = [int(i) for i in (indices.tolist() if hasattr(indices, "tolist") else indices)]
return self.select_units(idx)

def _identity_plan(self) -> List[Tuple[int, int, int]]:
return [(r, 0, int(n)) for r, n in enumerate(self.sizes)]

def _resolve_unit(self, g: int) -> Tuple[int, int]:
"""Global unit index (in THIS view's order) -> (ref_idx, ref-local unit)."""
plan = self.view_plan if self.view_plan is not None else self._identity_plan()
off = 0
for r, s, e in plan:
n = int(e) - int(s)
if g < off + n:
return int(r), int(s) + (g - off)
off += n
raise IndexError(f"unit {g} out of range for view of size {off}")

def select_units(self, idx: List[int]) -> "TensorMeta":
"""Arbitrary re-index (gather/permute) as a lazy segment view."""
pairs = [self._resolve_unit(int(g)) for g in idx]
plan: List[Tuple[int, int, int]] = []
for r, u in pairs:
if plan and plan[-1][0] == r and plan[-1][2] == u:
plan[-1] = (r, plan[-1][1], u + 1) # extend the run
else:
plan.append((r, u, u + 1))
return self._with_plan(plan)

def select_segments(self, segments: List[Tuple[int, int]]) -> "TensorMeta":
"""Re-index by global (start, end) unit ranges (PACKED token ranges)."""
plan: List[Tuple[int, int, int]] = []
for g0, g1 in segments:
g0, g1 = int(g0), int(g1)
while g0 < g1:
r, u = self._resolve_unit(g0)
# extend within the same source segment as far as possible
src_plan = self.view_plan if self.view_plan is not None else self._identity_plan()
# find the containing segment's end in ref-local units
off = 0
seg_end = None
for rr, ss, ee in src_plan:
n = int(ee) - int(ss)
if g0 < off + n:
seg_end = int(ee)
break
off += n
take = min(g1 - g0, seg_end - u)
if plan and plan[-1][0] == r and plan[-1][2] == u:
plan[-1] = (r, plan[-1][1], u + take)
else:
plan.append((r, u, u + take))
g0 += take
return self._with_plan(plan)

def _with_plan(self, plan: List[Tuple[int, int, int]]) -> "TensorMeta":
total = sum(int(e) - int(s) for _, s, e in plan)
return TensorMeta(
refs=list(self.refs),
sizes=list(self.sizes),
shape=(total, *self.shape[1:]) if self.shape else None,
dtype=self.dtype,
device=self.device,
view_plan=plan,
)

def with_refs(self, refs: List[Any]) -> "TensorMeta":
"""Clone with substituted (routed) refs, PRESERVING the view plan.

``localize`` rebuilds metas after routing; ``from_handles`` would drop
the plan and re-derive sizes — views must survive the trip.
"""
return TensorMeta(
refs=list(refs),
sizes=list(self.sizes),
shape=self.shape,
dtype=self.dtype,
device=self.device,
view_plan=None if self.view_plan is None else list(self.view_plan),
)

def _slice_by_refs(self, start: int, end: int) -> "TensorMeta":
"""Partition refs for the row range ``[start:end)`` — inverse of concat.
Expand All @@ -91,17 +179,19 @@ def _slice_by_refs(self, start: int, end: int) -> "TensorMeta":
hydration first.
"""
start, end = int(start), int(end)
if self.view_plan is not None:
# Views slice anywhere: just trim the segment list.
return self.select_segments([(start, end)])
offsets = [0]
for s in self.sizes:
offsets.append(offsets[-1] + int(s))
try:
i0 = offsets.index(start)
i1 = offsets.index(end)
except ValueError:
raise NotImplementedError(
f"TensorMeta slice [{start}:{end}] does not align to ref boundaries "
f"{offsets}; intra-handle slicing requires hydration first."
)
# Misaligned range on a non-view: degrade gracefully to a segment
# view instead of demanding hydration (the data stays remote).
return self.select_segments([(start, end)])
refs = list(self.refs[i0:i1])
sizes = list(self.sizes[i0:i1])
total = sum(int(s) for s in sizes)
Expand Down Expand Up @@ -140,10 +230,66 @@ def permute(self, *dims: int) -> "TensorMeta":
return self.transform(lambda t: t.permute(*dims))

def local(self) -> torch.Tensor:
backend = TensorTransportRuntime.current()
return self.materialize()

def materialize(self, backend: "Optional[TensorTransport]" = None) -> torch.Tensor:
"""Fetch + assemble this (possibly viewed) meta into a real tensor.

Legacy metas (no plan) keep the exact old path: ``backend.get(refs)``.
Views fetch each NEEDED ref once, then gather the plan's segments in
order. Trailing-dim CONTRACT for views: refs may be padded to
different widths per producing shard (e.g. per-worker prompt blocks);
segments crossing such refs are right-padded with zeros to the max
width — consumers of 2D+ per-shard-padded fields must be mask-driven
(the convention TextTokenCondition.concat already establishes).
"""
if backend is None:
raise RuntimeError("No TensorTransport backend installed")
return backend.get(self.refs)
backend = TensorTransportRuntime.current()
def fetch(ref):
if backend is not None:
return backend.get([ref])
return ref.local()
if self.view_plan is None and backend is not None:
return backend.get(self.refs)
resolved = {}
if self.view_plan is None:
resolved = {i: fetch(r) for i, r in enumerate(self.refs)}
else:
for r in sorted({r for r, _, _ in self.view_plan}):
resolved[r] = fetch(self.refs[r])
return self.assemble(resolved)

def assemble(self, resolved: "Dict[int, torch.Tensor]") -> torch.Tensor:
"""Assemble pre-fetched per-ref tensors into this meta's tensor.

``resolved`` maps ref index -> that ref's full tensor (only the refs a
view actually needs must be present). Legacy metas concatenate in ref
order. Views gather their plan segments in order; trailing-dim CONTRACT:
segments crossing refs padded to different widths are right-padded with
zeros to the max width — consumers of 2D+ per-shard-padded fields must
be mask-driven (the TextTokenCondition.concat convention).
"""
if self.view_plan is None:
parts = [resolved[i] for i in range(len(self.refs))]
return parts[0] if len(parts) == 1 else torch.cat(parts, dim=0)
if not self.view_plan:
base = next(iter(resolved.values()), None)
return base[:0] if base is not None else torch.empty(0)
parts = [resolved[r][s:e] for r, s, e in self.view_plan]
if len(parts) == 1:
return parts[0]
if parts[0].dim() >= 2:
widths = {int(t.shape[1]) for t in parts}
if len(widths) > 1:
target = max(widths)
padded = []
for t in parts:
if int(t.shape[1]) < target:
pad = t.new_zeros((t.shape[0], target - t.shape[1]) + tuple(t.shape[2:]))
t = torch.cat([t, pad], dim=1)
padded.append(t)
parts = padded
return torch.cat(parts, dim=0)

def retain_grad(self) -> "TensorMeta":
self.retain_grad_flag = True
Expand Down Expand Up @@ -323,7 +469,15 @@ def put_batch(self, tensors: Dict[str, torch.Tensor]) -> Dict[str, TensorMeta]:

def get_batch(self, metas: Dict[str, TensorMeta]) -> Dict[str, torch.Tensor]:
"""Fetch multiple named tensors. Default: iterate per key."""
return {k: self.get(m.refs) for k, m in metas.items()}
out: Dict[str, torch.Tensor] = {}
for k, m in metas.items():
if getattr(m, "view_plan", None) is not None:
# Segment views assemble themselves (plan-ordered gather with
# the documented ragged right-pad contract).
out[k] = m.materialize(backend=self)
else:
out[k] = self.get(m.refs)
return out

def transform(self, meta: TensorMeta, fn: Callable[[torch.Tensor], torch.Tensor]) -> TensorMeta:
"""Apply fn to the remote tensor, return new TensorMeta.
Expand Down Expand Up @@ -404,7 +558,7 @@ def hydrate(self, value: Any, fields: Optional[Set[str]] = None) -> Any:
*fields* are hydrated; the rest stay as ``TensorMeta``.
"""
if isinstance(value, TensorMeta):
return self.get(value.refs)
return value.materialize(backend=self)

filter_fn: Optional[Callable[[str], bool]] = None
if fields is not None:
Expand All @@ -418,7 +572,11 @@ def filter_fn(key):
if not meta_map:
return value

tensors = self.get_batch(meta_map)
viewed = {k: m for k, m in meta_map.items() if m.view_plan is not None}
plain = {k: m for k, m in meta_map.items() if m.view_plan is None}
tensors = self.get_batch(plain) if plain else {}
for key, meta in viewed.items():
tensors[key] = meta.materialize(backend=self)
for key, tensor in tensors.items():
if key in setters:
setters[key](tensor)
Expand Down Expand Up @@ -546,7 +704,9 @@ def route(ref: Any, dst_worker_id: str, dst_device_id: int) -> Any:

def unwrap(obj: Any, dst_worker_id: str, dst_device_id: int) -> Any:
if isinstance(obj, TensorMeta):
return TensorMeta.from_handles([route(h, dst_worker_id, dst_device_id) for h in obj.refs])
# with_refs (NOT from_handles): a permuted shard is a segment
# VIEW over the refs — the plan must survive routing.
return obj.with_refs([route(h, dst_worker_id, dst_device_id) for h in obj.refs])
return obj

routed: list = []
Expand Down Expand Up @@ -582,7 +742,7 @@ def leaf(o, _w=worker_ids[i], _d=device_ids[i]):

def substitute(obj: Any) -> Any:
if isinstance(obj, TensorMeta):
return TensorMeta.from_handles([subs.get(id(h), h) for h in obj.refs])
return obj.with_refs([subs.get(id(h), h) for h in obj.refs])
return obj

return [(map_tree(a, substitute), map_tree(k, substitute)) for a, k in routed]
Expand Down
Loading
Loading