Skip to content

Commit bfe3db3

Browse files
Joey Yangfacebook-github-bot
authored andcommitted
Store and retrieve hash_zch_runtime_meta within mc_embbedding submodules (#3599)
Summary: See https://fb.workplace.com/groups/1404957374198553/permalink/1610214197006202/ Similar to D85999577, we store the `hash_zch_runtime_meta` when it is being look up in `mc_modules.py` in `raw_id_tracker` then access it in `batched_embedding_kernel.py` which will then be streamed to the inference side (see the other diffs in this stack D87810125) Reviewed By: chouxi Differential Revision: D88623165
1 parent 9421c38 commit bfe3db3

File tree

3 files changed

+75
-38
lines changed

3 files changed

+75
-38
lines changed

torchrec/distributed/batched_embedding_kernel.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1758,19 +1758,29 @@ def init_parameters(self) -> None:
17581758
weight_init_max,
17591759
)
17601760

1761-
def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
1762-
hash_zch_identities = self._get_hash_zch_identities(features)
1763-
if hash_zch_identities is None:
1761+
def forward(
1762+
self,
1763+
features: KeyedJaggedTensor,
1764+
) -> torch.Tensor:
1765+
forward_args: Dict[str, Any] = {}
1766+
identities_and_metadata = self._get_hash_zch_identities_and_metadata(features)
1767+
if identities_and_metadata is not None:
1768+
hash_zch_identities, hash_zch_runtime_meta = identities_and_metadata
1769+
forward_args["hash_zch_identities"] = hash_zch_identities
1770+
if hash_zch_runtime_meta is not None:
1771+
forward_args["hash_zch_runtime_meta"] = hash_zch_runtime_meta
1772+
1773+
if len(forward_args) == 0:
17641774
return self.emb_module(
17651775
indices=features.values().long(),
17661776
offsets=features.offsets().long(),
17671777
)
1768-
1769-
return self.emb_module(
1770-
indices=features.values().long(),
1771-
offsets=features.offsets().long(),
1772-
hash_zch_identities=hash_zch_identities,
1773-
)
1778+
else:
1779+
return self.emb_module(
1780+
indices=features.values().long(),
1781+
offsets=features.offsets().long(),
1782+
**forward_args,
1783+
)
17741784

17751785
# pyre-fixme[14]: `state_dict` overrides method defined in `Module` inconsistently.
17761786
def state_dict(
@@ -2832,9 +2842,12 @@ def forward(
28322842
features: KeyedJaggedTensor,
28332843
) -> torch.Tensor:
28342844
forward_args: Dict[str, Any] = {}
2835-
hash_zch_identities = self._get_hash_zch_identities(features)
2836-
if hash_zch_identities is not None:
2845+
identities_and_metadata = self._get_hash_zch_identities_and_metadata(features)
2846+
if identities_and_metadata is not None:
2847+
hash_zch_identities, hash_zch_runtime_meta = identities_and_metadata
28372848
forward_args["hash_zch_identities"] = hash_zch_identities
2849+
if hash_zch_runtime_meta is not None:
2850+
forward_args["hash_zch_runtime_meta"] = hash_zch_runtime_meta
28382851

28392852
weights = features.weights_or_none()
28402853
if weights is not None and not torch.is_floating_point(weights):

torchrec/distributed/embedding_kernel.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -110,9 +110,9 @@ def init_raw_id_tracker(
110110
get_indexed_lookups, delete
111111
)
112112

113-
def _get_hash_zch_identities(
113+
def _get_hash_zch_identities_and_metadata(
114114
self, features: KeyedJaggedTensor
115-
) -> Optional[torch.Tensor]:
115+
) -> Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]]:
116116
if self._raw_id_tracker_wrapper is None or not isinstance(
117117
self.emb_module, SplitTableBatchedEmbeddingBagsCodegen
118118
):
@@ -131,7 +131,7 @@ def _get_hash_zch_identities(
131131
# across multiple training iterations. Current logic appends raw_ids from
132132
# all batches sequentially. This may cause misalignment with
133133
# features.values() which only contains the current batch.
134-
raw_ids_dict = raw_id_tracker_wrapper.get_indexed_lookups(
134+
indexed_lookups_dict = raw_id_tracker_wrapper.get_indexed_lookups(
135135
table_names, emb_module.uuid
136136
)
137137

@@ -148,11 +148,14 @@ def _get_hash_zch_identities(
148148
# raw_ids are included. If some tables lack identity while others have them,
149149
# padding with -1 may be needed to maintain alignment.
150150
all_raw_ids = []
151+
all_runtime_meta = []
151152
for table_name in table_names:
152-
if table_name in raw_ids_dict:
153-
raw_ids_list = raw_ids_dict[table_name]
153+
if table_name in indexed_lookups_dict:
154+
raw_ids_list, runtime_meta_list = indexed_lookups_dict[table_name]
154155
for raw_ids in raw_ids_list:
155156
all_raw_ids.append(raw_ids)
157+
for runtime_meta in runtime_meta_list:
158+
all_runtime_meta.append(runtime_meta)
156159

157160
if not all_raw_ids:
158161
return None
@@ -162,7 +165,16 @@ def _get_hash_zch_identities(
162165
f"hash_zch_identities row count ({hash_zch_identities.size(0)}) must match "
163166
f"features.values() length ({features.values().numel()}) to maintain 1-to-1 alignment"
164167
)
165-
return hash_zch_identities
168+
169+
if all_runtime_meta:
170+
hash_zch_runtime_meta = torch.cat(all_runtime_meta)
171+
assert hash_zch_runtime_meta.size(0) == hash_zch_identities.size(0), (
172+
f"hash_zch_runtime_meta row count ({hash_zch_runtime_meta.size(0)}) must match "
173+
f"hash_zch_identities length ({hash_zch_identities.size(0)}) to maintain 1-to-1 alignment"
174+
)
175+
return (hash_zch_identities, hash_zch_runtime_meta)
176+
else:
177+
return (hash_zch_identities, None)
166178

167179

168180
def create_virtual_table_local_metadata(

torchrec/distributed/mc_modules.py

Lines changed: 33 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,7 @@ def __init__(
245245
torch.Tensor,
246246
Optional[nn.Module],
247247
Optional[torch.Tensor],
248+
Optional[torch.Tensor],
248249
],
249250
None,
250251
]
@@ -716,6 +717,31 @@ def global_to_local_index(
716717
jt._values = jt.values() - self._table_to_offset[table]
717718
return jt_dict
718719

720+
def _retrieve_and_track_hash_zch_identities_and_metadata(
721+
self,
722+
mcm: nn.Module,
723+
mc_input: Dict[str, JaggedTensor],
724+
indices: torch.Tensor,
725+
) -> None:
726+
if self.post_lookup_tracker_fn is None:
727+
return
728+
if not hasattr(mcm, "_hash_zch_identities"):
729+
return
730+
# _hash_zch_identities should always exist but _hash_zch_runtime_meta is optional
731+
runtime_meta = None
732+
if (
733+
hasattr(mcm, "_hash_zch_runtime_meta")
734+
and mcm._hash_zch_runtime_meta is not None
735+
):
736+
runtime_meta = mcm._hash_zch_runtime_meta.index_select(dim=0, index=indices)
737+
self.post_lookup_tracker_fn(
738+
KeyedJaggedTensor.from_jt_dict(mc_input),
739+
torch.empty(0),
740+
None,
741+
mcm._hash_zch_identities.index_select(dim=0, index=indices),
742+
runtime_meta,
743+
)
744+
719745
def compute(
720746
self,
721747
ctx: ManagedCollisionCollectionContext,
@@ -758,19 +784,9 @@ def compute(
758784
mc_input = mcm.remap(mc_input)
759785
mc_input = self.global_to_local_index(mc_input)
760786
output.update(mc_input)
761-
if hasattr(
762-
mcm,
763-
"_hash_zch_identities",
764-
):
765-
if self.post_lookup_tracker_fn is not None:
766-
self.post_lookup_tracker_fn(
767-
KeyedJaggedTensor.from_jt_dict(mc_input),
768-
torch.empty(0),
769-
None,
770-
mcm._hash_zch_identities.index_select(
771-
dim=0, index=mc_input[table].values()
772-
),
773-
)
787+
self._retrieve_and_track_hash_zch_identities_and_metadata(
788+
mcm, mc_input, mc_input[table].values()
789+
)
774790
values = torch.cat([jt.values() for jt in output.values()])
775791
else:
776792
table: str = tables[0]
@@ -789,14 +805,9 @@ def compute(
789805
mc_input = mcm.remap(mc_input)
790806
mc_input = self.global_to_local_index(mc_input)
791807
values = mc_input[table].values()
792-
if hasattr(mcm, "_hash_zch_identities"):
793-
if self.post_lookup_tracker_fn is not None:
794-
self.post_lookup_tracker_fn(
795-
KeyedJaggedTensor.from_jt_dict(mc_input),
796-
torch.empty(0),
797-
None,
798-
mcm._hash_zch_identities.index_select(dim=0, index=values),
799-
)
808+
self._retrieve_and_track_hash_zch_identities_and_metadata(
809+
mcm, mc_input, values
810+
)
800811

801812
remapped_kjts.append(
802813
KeyedJaggedTensor(
@@ -895,6 +906,7 @@ def register_post_lookup_tracker_fn(
895906
torch.Tensor,
896907
Optional[nn.Module],
897908
Optional[torch.Tensor],
909+
Optional[torch.Tensor],
898910
],
899911
None,
900912
],

0 commit comments

Comments
 (0)