Skip to content

Commit 0823584

Browse files
jd7-trfacebook-github-bot
authored andcommitted
Implement __torch_function__ for KeyedTensor
Summary: X-link: pytorch/pytorch#161683 1. There are a bunch of `torch.ops.aten` operations that can't handle `KeyedTensor`: The error was occurring because these ops expects a regular `Tensor` but was receiving a `KeyedTensor` object. 2. Implement `__torch_function__` for `KeyedTensor`, so when these incompatible operations are called with a `KeyedTensor`, the `__torch_function__` method automatically delegates the op to the underlying values tensor from the `KeyedTensor` and returns a new `KeyedTensor` with updated values. Reviewed By: malaybag Differential Revision: D81047278
1 parent a4832f4 commit 0823584

File tree

1 file changed

+47
-1
lines changed

1 file changed

+47
-1
lines changed

torchrec/sparse/jagged_tensor.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
import operator
1414

15-
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
15+
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
1616

1717
import torch
1818
from torch.autograd.profiler import record_function
@@ -3191,6 +3191,52 @@ class KeyedTensor(Pipelineable, metaclass=JaggedTensorMeta):
31913191
# torch.Tensor([[2, 1, 2], [2, 1, 2], [2, 1, 2]])
31923192
"""
31933193

3194+
@classmethod
3195+
# pyre-ignore
3196+
def __torch_function__(cls: Type["KeyedTensor"], func, types, args=(), kwargs=None):
3197+
"""
3198+
Enable KeyedTensor compatibility with PyTorch operations by delegating
3199+
operations to the underlying values tensor and reconstructing KeyedTensor
3200+
when appropriate.
3201+
3202+
This method allows KeyedTensor to work with various PyTorch operations
3203+
including compilation operations like torch.compile.
3204+
"""
3205+
if kwargs is None:
3206+
kwargs = {}
3207+
3208+
# Handle operations that expect regular tensors but should return KeyedTensor
3209+
tensor_ops = [
3210+
torch.ops.aten._assert_tensor_metadata.default,
3211+
torch.ops.aten.to.dtype,
3212+
torch.ops.aten.to.device,
3213+
torch.ops.aten.detach.default,
3214+
torch.ops.aten.clone.default,
3215+
]
3216+
3217+
if func in tensor_ops:
3218+
if len(args) > 0 and isinstance(args[0], cls):
3219+
keyed_tensor = args[0]
3220+
values_tensor = keyed_tensor.values()
3221+
new_args = (values_tensor,) + args[1:]
3222+
result = func(*new_args, **kwargs)
3223+
3224+
# For operations that return tensors, create new KeyedTensor with updated values
3225+
if isinstance(result, torch.Tensor):
3226+
return cls(
3227+
keys=keyed_tensor.keys(),
3228+
length_per_key=keyed_tensor.length_per_key(),
3229+
values=result,
3230+
key_dim=keyed_tensor.key_dim(),
3231+
offset_per_key=keyed_tensor._offset_per_key,
3232+
index_per_key=keyed_tensor._index_per_key,
3233+
)
3234+
3235+
return result
3236+
3237+
# For all other operations, return NotImplemented to allow normal handling
3238+
return NotImplementedError(f"{func} cannot be applied to KeyedTensor.")
3239+
31943240
def __init__(
31953241
self,
31963242
keys: List[str],

0 commit comments

Comments
 (0)