You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[APS] Implement __torch_function__ for KeyedTensor (pytorch#161683)
Summary:
X-link: pytorch/torchrec#3329
1. There are a bunch of `torch.ops.aten` operations that can't handle `KeyedTensor`: The error was occurring because these ops expects a regular `Tensor` but was receiving a `KeyedTensor` object.
2. Implement `__torch_function__` for `KeyedTensor`, so when these incompatible operations are called with a `KeyedTensor`, the `__torch_function__` method automatically delegates the op to the underlying values tensor from the `KeyedTensor` and returns a new `KeyedTensor` with updated values.
Test Plan:
```
buck2 run mode/opt fbcode//aps_models/ads/gmp:launcher_with_publish mode=mtml_mobile_cvr_model/managed/Y2025Q2/local_mode_mtml_mobile_cvr_model_733415799_v0_fork +training.ir_serializer=manifold
```
MAST job: https://fburl.com/mlhub/pp937uxf
Rollback Plan:
Reviewed By: malaybag
Differential Revision: D81047278
0 commit comments