Skip to content

Commit d871e47

Browse files
jd7-trfacebook-github-bot
authored andcommitted
[APS] Implement __torch_function__ for KeyedTensor (pytorch#161683)
Summary: X-link: pytorch/torchrec#3329 1. There are a bunch of `torch.ops.aten` operations that can't handle `KeyedTensor`: The error was occurring because these ops expects a regular `Tensor` but was receiving a `KeyedTensor` object. 2. Implement `__torch_function__` for `KeyedTensor`, so when these incompatible operations are called with a `KeyedTensor`, the `__torch_function__` method automatically delegates the op to the underlying values tensor from the `KeyedTensor` and returns a new `KeyedTensor` with updated values. Test Plan: ``` buck2 run mode/opt fbcode//aps_models/ads/gmp:launcher_with_publish mode=mtml_mobile_cvr_model/managed/Y2025Q2/local_mode_mtml_mobile_cvr_model_733415799_v0_fork +training.ir_serializer=manifold ``` MAST job: https://fburl.com/mlhub/pp937uxf Rollback Plan: Reviewed By: malaybag Differential Revision: D81047278
1 parent 4fd761f commit d871e47

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

torch/_subclasses/fake_impls.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -817,10 +817,16 @@ def assert_tensor_metadata(
817817
f"Tensor layout mismatch! Expected: {layout}, Got: {t.layout()}"
818818
)
819819
if device is not None:
820-
assert t.device == device, (
821-
f"Tensor device mismatch! Expected: {device}, Got: {t.device}"
820+
assert t.device.type == device.type, (
821+
f"Tensor device type mismatch! Expected: {device.type}, Got: {t.device.type}"
822+
)
823+
assert (
824+
(not device.index)
825+
or (not t.device.index)
826+
or (t.device.index == device.index)
827+
), (
828+
f"Tensor device index mismatch! Expected: {device.type}, Got: {t.device.type}"
822829
)
823-
824830

825831
# NB: this must be ordered after local_scalar_dense
826832
@register_op_impl(lambda func: torch.Tag.data_dependent_output in func.tags)

0 commit comments

Comments
 (0)