|
12 | 12 |
|
13 | 13 | import operator
|
14 | 14 |
|
15 |
| -from typing import Any, Callable, Dict, List, Optional, Tuple, Union |
| 15 | +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union |
16 | 16 |
|
17 | 17 | import torch
|
18 | 18 | from torch.autograd.profiler import record_function
|
@@ -3191,6 +3191,52 @@ class KeyedTensor(Pipelineable, metaclass=JaggedTensorMeta):
|
3191 | 3191 | # torch.Tensor([[2, 1, 2], [2, 1, 2], [2, 1, 2]])
|
3192 | 3192 | """
|
3193 | 3193 |
|
| 3194 | + @classmethod |
| 3195 | + # pyre-ignore |
| 3196 | + def __torch_function__(cls: Type["KeyedTensor"], func, types, args=(), kwargs=None): |
| 3197 | + """ |
| 3198 | + Enable KeyedTensor compatibility with PyTorch operations by delegating |
| 3199 | + operations to the underlying values tensor and reconstructing KeyedTensor |
| 3200 | + when appropriate. |
| 3201 | +
|
| 3202 | + This method allows KeyedTensor to work with various PyTorch operations |
| 3203 | + including compilation operations like torch.compile. |
| 3204 | + """ |
| 3205 | + if kwargs is None: |
| 3206 | + kwargs = {} |
| 3207 | + |
| 3208 | + # Handle operations that expect regular tensors but should return KeyedTensor |
| 3209 | + tensor_ops = [ |
| 3210 | + torch.ops.aten._assert_tensor_metadata.default, |
| 3211 | + torch.ops.aten.to.dtype, |
| 3212 | + torch.ops.aten.to.device, |
| 3213 | + torch.ops.aten.detach.default, |
| 3214 | + torch.ops.aten.clone.default, |
| 3215 | + ] |
| 3216 | + |
| 3217 | + if func in tensor_ops: |
| 3218 | + if len(args) > 0 and isinstance(args[0], cls): |
| 3219 | + keyed_tensor = args[0] |
| 3220 | + values_tensor = keyed_tensor.values() |
| 3221 | + new_args = (values_tensor,) + args[1:] |
| 3222 | + result = func(*new_args, **kwargs) |
| 3223 | + |
| 3224 | + # For operations that return tensors, create new KeyedTensor with updated values |
| 3225 | + if isinstance(result, torch.Tensor): |
| 3226 | + return cls( |
| 3227 | + keys=keyed_tensor.keys(), |
| 3228 | + length_per_key=keyed_tensor.length_per_key(), |
| 3229 | + values=result, |
| 3230 | + key_dim=keyed_tensor.key_dim(), |
| 3231 | + offset_per_key=keyed_tensor._offset_per_key, |
| 3232 | + index_per_key=keyed_tensor._index_per_key, |
| 3233 | + ) |
| 3234 | + |
| 3235 | + return result |
| 3236 | + |
| 3237 | + # For all other operations, return NotImplemented to allow normal handling |
| 3238 | + return NotImplementedError(f"{func} cannot be applied to KeyedTensor.") |
| 3239 | + |
3194 | 3240 | def __init__(
|
3195 | 3241 | self,
|
3196 | 3242 | keys: List[str],
|
|
0 commit comments