Skip to content

Commit 5854b0e

Browse files
Joey Yangfacebook-github-bot
authored andcommitted
Extend raw_id_tracker to support hash_zch_runtime_meta (#3598)
Summary: See https://fb.workplace.com/groups/1404957374198553/permalink/1610214197006202/ This diff extends `raw_id_tracker` to store `hash_zch_runtime_meta` which will be alongside with `hash_zch_identities` when presented. Note that it is possible that a mpzch table only has `hash_zch_identities` without `hash_zch_runtime_meta` but is not true vice versa. Differential Revision: D88600497
1 parent 3a1d5f3 commit 5854b0e

File tree

3 files changed

+45
-9
lines changed

3 files changed

+45
-9
lines changed

torchrec/distributed/model_tracker/delta_store.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def append(
9393
ids: torch.Tensor,
9494
states: Optional[torch.Tensor] = None,
9595
raw_ids: Optional[torch.Tensor] = None,
96+
runtime_meta: Optional[torch.Tensor] = None,
9697
) -> None:
9798
"""
9899
Append a batch of ids and states to the store for a specific table.
@@ -165,6 +166,7 @@ def append(
165166
ids: torch.Tensor,
166167
states: Optional[torch.Tensor] = None,
167168
raw_ids: Optional[torch.Tensor] = None,
169+
runtime_meta: Optional[torch.Tensor] = None,
168170
) -> None:
169171
table_fqn_lookup = self.per_fqn_lookups.get(fqn, [])
170172
table_fqn_lookup.append(
@@ -284,10 +286,13 @@ def append(
284286
ids: torch.Tensor,
285287
states: Optional[torch.Tensor] = None,
286288
raw_ids: Optional[torch.Tensor] = None,
289+
runtime_meta: Optional[torch.Tensor] = None,
287290
) -> None:
288291
table_fqn_lookup = self.per_fqn_lookups.get(fqn, [])
289292
table_fqn_lookup.append(
290-
RawIndexedLookup(batch_idx=batch_idx, ids=ids, raw_ids=raw_ids)
293+
RawIndexedLookup(
294+
batch_idx=batch_idx, ids=ids, raw_ids=raw_ids, runtime_meta=runtime_meta
295+
)
291296
)
292297
self.per_fqn_lookups[fqn] = table_fqn_lookup
293298

torchrec/distributed/model_tracker/trackers/raw_id_tracker.py

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -185,9 +185,11 @@ def record_lookup(
185185
states: torch.Tensor,
186186
emb_module: Optional[nn.Module] = None,
187187
raw_ids: Optional[torch.Tensor] = None,
188+
runtime_meta: Optional[torch.Tensor] = None,
188189
) -> None:
189190
per_table_ids: Dict[str, List[torch.Tensor]] = {}
190191
per_table_raw_ids: Dict[str, List[torch.Tensor]] = {}
192+
per_table_runtime_meta: Dict[str, List[torch.Tensor]] = {}
191193

192194
# Skip storing invalid input or raw ids
193195
if (
@@ -197,28 +199,50 @@ def record_lookup(
197199
):
198200
return
199201

200-
embeddings_2d = raw_ids.view(kjt.values().numel(), -1)
202+
# Skip storing if runtime_meta is provided but has invalid shape
203+
if runtime_meta is not None and not (
204+
runtime_meta.numel() % kjt.values().numel() == 0
205+
):
206+
return
207+
208+
raw_ids_2d = raw_ids.view(kjt.values().numel(), -1)
209+
runtime_meta_2d = None
210+
# It is possible that runtime_meta is None while raw_ids is not None so we will proceed
211+
if runtime_meta is not None:
212+
runtime_meta_2d = runtime_meta.view(kjt.values().numel(), -1)
201213

202214
offset: int = 0
203215
for key in kjt.keys():
204216
table_fqn = self.table_to_fqn[key]
205217
ids_list: List[torch.Tensor] = per_table_ids.get(table_fqn, [])
206-
emb_list: List[torch.Tensor] = per_table_raw_ids.get(table_fqn, [])
218+
raw_ids_list: List[torch.Tensor] = per_table_raw_ids.get(table_fqn, [])
219+
runtime_meta_list: List[torch.Tensor] = per_table_runtime_meta.get(
220+
table_fqn, []
221+
)
207222

208223
ids = kjt[key].values()
209224
ids_list.append(ids)
210-
emb_list.append(embeddings_2d[offset : offset + ids.numel()])
225+
raw_ids_list.append(raw_ids_2d[offset : offset + ids.numel()])
226+
if runtime_meta_2d is not None:
227+
runtime_meta_list.append(runtime_meta_2d[offset : offset + ids.numel()])
211228
offset += ids.numel()
212229

213230
per_table_ids[table_fqn] = ids_list
214-
per_table_raw_ids[table_fqn] = emb_list
231+
per_table_raw_ids[table_fqn] = raw_ids_list
232+
if runtime_meta_2d is not None:
233+
per_table_runtime_meta[table_fqn] = runtime_meta_list
215234

216235
for table_fqn, ids_list in per_table_ids.items():
217236
self.store.append(
218237
batch_idx=self.curr_batch_idx,
219238
fqn=table_fqn,
220239
ids=torch.cat(ids_list),
221240
raw_ids=torch.cat(per_table_raw_ids[table_fqn]),
241+
runtime_meta=(
242+
torch.cat(per_table_runtime_meta[table_fqn])
243+
if table_fqn in per_table_runtime_meta
244+
else None
245+
),
222246
)
223247

224248
def _clean_fqn_fn(self, fqn: str) -> str:
@@ -277,8 +301,8 @@ def get_indexed_lookups(
277301
self,
278302
tables: List[str],
279303
consumer: Optional[str] = None,
280-
) -> Dict[str, List[torch.Tensor]]:
281-
raw_id_per_table: Dict[str, List[torch.Tensor]] = {}
304+
) -> Dict[str, Tuple[List[torch.Tensor], List[torch.Tensor]]]:
305+
result: Dict[str, Tuple[List[torch.Tensor], List[torch.Tensor]]] = {}
282306
consumer = consumer or self.DEFAULT_CONSUMER
283307
assert (
284308
consumer in self.per_consumer_batch_idx
@@ -293,17 +317,23 @@ def get_indexed_lookups(
293317

294318
for table in tables:
295319
raw_ids_list = []
320+
runtime_meta_list = []
296321
fqn = self.table_to_fqn[table]
297322
if fqn in indexed_lookups:
298323
for indexed_lookup in indexed_lookups[fqn]:
299324
if indexed_lookup.raw_ids is not None:
300325
raw_ids_list.append(indexed_lookup.raw_ids)
301-
raw_id_per_table[table] = raw_ids_list
326+
if indexed_lookup.runtime_meta is not None:
327+
runtime_meta_list.append(indexed_lookup.runtime_meta)
328+
if (
329+
raw_ids_list
330+
): # if raw_ids doesn't exist runtime_meta will not exist so no need to check for runtime_meta
331+
result[table] = (raw_ids_list, runtime_meta_list)
302332

303333
if self._delete_on_read:
304334
self.store.delete(up_to_idx=min(self.per_consumer_batch_idx.values()))
305335

306-
return raw_id_per_table
336+
return result
307337

308338
def delete(self, up_to_idx: Optional[int]) -> None:
309339
self.store.delete(up_to_idx)

torchrec/distributed/model_tracker/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ class RawIndexedLookup:
3535
batch_idx: int
3636
ids: torch.Tensor
3737
raw_ids: Optional[torch.Tensor] = None
38+
runtime_meta: Optional[torch.Tensor] = None
3839

3940

4041
@dataclass

0 commit comments

Comments
 (0)