Skip to content

Commit 5a5ce67

Browse files
Joey Yangfacebook-github-bot
authored andcommitted
Store and retrieve hash_zch_runtime_meta within mc_embbedding submodules
Summary: See https://fb.workplace.com/groups/1404957374198553/permalink/1610214197006202/ Similar to D85999577, we store the `hash_zch_runtime_meta` when it is being look up in `mc_modules.py` in `raw_id_tracker` then access it in `batched_embedding_kernel.py` which will then be streamed to the inference side (see the other diffs in this stack D87810125) Differential Revision: D88623165
1 parent 5854b0e commit 5a5ce67

File tree

3 files changed

+59
-18
lines changed

3 files changed

+59
-18
lines changed

torchrec/distributed/batched_embedding_kernel.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1764,19 +1764,29 @@ def init_parameters(self) -> None:
17641764
weight_init_max,
17651765
)
17661766

1767-
def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
1768-
hash_zch_identities = self._get_hash_zch_identities(features)
1769-
if hash_zch_identities is None:
1767+
def forward(
1768+
self,
1769+
features: KeyedJaggedTensor,
1770+
) -> torch.Tensor:
1771+
forward_args: Dict[str, Any] = {}
1772+
identities_and_metadata = self._get_hash_zch_identities_and_metadata(features)
1773+
if identities_and_metadata is not None:
1774+
hash_zch_identities, hash_zch_runtime_meta = identities_and_metadata
1775+
forward_args["hash_zch_identities"] = hash_zch_identities
1776+
if hash_zch_runtime_meta is not None:
1777+
forward_args["hash_zch_runtime_meta"] = hash_zch_runtime_meta
1778+
1779+
if len(forward_args) == 0:
17701780
return self.emb_module(
17711781
indices=features.values().long(),
17721782
offsets=features.offsets().long(),
17731783
)
1774-
1775-
return self.emb_module(
1776-
indices=features.values().long(),
1777-
offsets=features.offsets().long(),
1778-
hash_zch_identities=hash_zch_identities,
1779-
)
1784+
else:
1785+
return self.emb_module(
1786+
indices=features.values().long(),
1787+
offsets=features.offsets().long(),
1788+
**forward_args,
1789+
)
17801790

17811791
# pyre-fixme[14]: `state_dict` overrides method defined in `Module` inconsistently.
17821792
def state_dict(
@@ -2841,9 +2851,12 @@ def forward(
28412851
features: KeyedJaggedTensor,
28422852
) -> torch.Tensor:
28432853
forward_args: Dict[str, Any] = {}
2844-
hash_zch_identities = self._get_hash_zch_identities(features)
2845-
if hash_zch_identities is not None:
2854+
identities_and_metadata = self._get_hash_zch_identities_and_metadata(features)
2855+
if identities_and_metadata is not None:
2856+
hash_zch_identities, hash_zch_runtime_meta = identities_and_metadata
28462857
forward_args["hash_zch_identities"] = hash_zch_identities
2858+
if hash_zch_runtime_meta is not None:
2859+
forward_args["hash_zch_runtime_meta"] = hash_zch_runtime_meta
28472860

28482861
weights = features.weights_or_none()
28492862
if weights is not None and not torch.is_floating_point(weights):

torchrec/distributed/embedding_kernel.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -110,9 +110,9 @@ def init_raw_id_tracker(
110110
get_indexed_lookups, delete
111111
)
112112

113-
def _get_hash_zch_identities(
113+
def _get_hash_zch_identities_and_metadata(
114114
self, features: KeyedJaggedTensor
115-
) -> Optional[torch.Tensor]:
115+
) -> Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]]:
116116
if self._raw_id_tracker_wrapper is None or not isinstance(
117117
self.emb_module, SplitTableBatchedEmbeddingBagsCodegen
118118
):
@@ -131,8 +131,8 @@ def _get_hash_zch_identities(
131131
# across multiple training iterations. Current logic appends raw_ids from
132132
# all batches sequentially. This may cause misalignment with
133133
# features.values() which only contains the current batch.
134-
raw_ids_dict = raw_id_tracker_wrapper.get_indexed_lookups(
135-
table_names, emb_module.uuid
134+
indexed_lookups_dict = raw_id_tracker_wrapper.get_indexed_lookups(
135+
table_names, self.emb_module.uuid
136136
)
137137

138138
# Build hash_zch_identities by concatenating raw IDs from tracked tables.
@@ -148,11 +148,14 @@ def _get_hash_zch_identities(
148148
# raw_ids are included. If some tables lack identity while others have them,
149149
# padding with -1 may be needed to maintain alignment.
150150
all_raw_ids = []
151+
all_runtime_meta = []
151152
for table_name in table_names:
152-
if table_name in raw_ids_dict:
153-
raw_ids_list = raw_ids_dict[table_name]
153+
if table_name in indexed_lookups_dict:
154+
raw_ids_list, runtime_meta_list = indexed_lookups_dict[table_name]
154155
for raw_ids in raw_ids_list:
155156
all_raw_ids.append(raw_ids)
157+
for runtime_meta in runtime_meta_list:
158+
all_runtime_meta.append(runtime_meta)
156159

157160
if not all_raw_ids:
158161
return None
@@ -162,7 +165,12 @@ def _get_hash_zch_identities(
162165
f"hash_zch_identities row count ({hash_zch_identities.size(0)}) must match "
163166
f"features.values() length ({features.values().numel()}) to maintain 1-to-1 alignment"
164167
)
165-
return hash_zch_identities
168+
169+
if all_runtime_meta:
170+
hash_zch_runtime_meta = torch.cat(all_runtime_meta)
171+
return (hash_zch_identities, hash_zch_runtime_meta)
172+
else:
173+
return (hash_zch_identities, None)
166174

167175

168176
def create_virtual_table_local_metadata(

torchrec/distributed/mc_modules.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,7 @@ def __init__(
245245
torch.Tensor,
246246
Optional[nn.Module],
247247
Optional[torch.Tensor],
248+
Optional[torch.Tensor],
248249
],
249250
None,
250251
]
@@ -763,13 +764,22 @@ def compute(
763764
"_hash_zch_identities",
764765
):
765766
if self.post_lookup_tracker_fn is not None:
767+
runtime_meta = None
768+
if (
769+
hasattr(mcm, "_hash_zch_runtime_meta")
770+
and mcm._hash_zch_runtime_meta is not None
771+
):
772+
runtime_meta = mcm._hash_zch_runtime_meta.index_select(
773+
dim=0, index=mc_input[table].values()
774+
)
766775
self.post_lookup_tracker_fn(
767776
KeyedJaggedTensor.from_jt_dict(mc_input),
768777
torch.empty(0),
769778
None,
770779
mcm._hash_zch_identities.index_select(
771780
dim=0, index=mc_input[table].values()
772781
),
782+
runtime_meta,
773783
)
774784
values = torch.cat([jt.values() for jt in output.values()])
775785
else:
@@ -791,11 +801,20 @@ def compute(
791801
values = mc_input[table].values()
792802
if hasattr(mcm, "_hash_zch_identities"):
793803
if self.post_lookup_tracker_fn is not None:
804+
runtime_meta = None
805+
if (
806+
hasattr(mcm, "_hash_zch_runtime_meta")
807+
and mcm._hash_zch_runtime_meta is not None
808+
):
809+
runtime_meta = mcm._hash_zch_runtime_meta.index_select(
810+
dim=0, index=values
811+
)
794812
self.post_lookup_tracker_fn(
795813
KeyedJaggedTensor.from_jt_dict(mc_input),
796814
torch.empty(0),
797815
None,
798816
mcm._hash_zch_identities.index_select(dim=0, index=values),
817+
runtime_meta,
799818
)
800819

801820
remapped_kjts.append(
@@ -895,6 +914,7 @@ def register_post_lookup_tracker_fn(
895914
torch.Tensor,
896915
Optional[nn.Module],
897916
Optional[torch.Tensor],
917+
Optional[torch.Tensor],
898918
],
899919
None,
900920
],

0 commit comments

Comments
 (0)