Skip to content

Commit c22f2a3

Browse files
committed
Fix bug in BaseMCRTensor:__torch_function__()
There is a bug when args is a collection of collections instead of a plain tuple. In this case, the old args parsing was unable to search for block_size in nested structures containing BaseMCRTensor. For instance, let `a` and `b` be BaseMCRTensor variables. Calling `torch.stack((a, b))` results in error, as the `args` received in `__torch_function__()` is a nested tuple.
1 parent d634328 commit c22f2a3

File tree

1 file changed

+14
-1
lines changed

1 file changed

+14
-1
lines changed

torchhd/tensors/basemcr.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,20 @@ def __torch_function__(cls, func, types, args=(), kwargs=None):
359359
if kwargs is None:
360360
kwargs = {}
361361

362-
block_sizes = set(a.block_size for a in args if hasattr(a, "block_size"))
362+
def _parse_container_for_attr(container, attr):
363+
s = set()
364+
for a in container:
365+
if type(a) is tuple or type(a) is list:
366+
s |= _parse_container_for_attr(a, attr)
367+
else:
368+
if hasattr(a, attr):
369+
s.add(a.block_size)
370+
return s
371+
372+
# Args is a tuple that can contain other tuples or lists. Parse it
373+
# reccursively to find any BaseMCRTensor object
374+
block_sizes = _parse_container_for_attr(args, "block_size")
375+
363376
if len(block_sizes) != 1:
364377
raise RuntimeError(
365378
f"Call to {func} must contain exactly one block size, got {list(block_sizes)}"

0 commit comments

Comments
 (0)