@@ -185,9 +185,11 @@ def record_lookup(
185185 states : torch .Tensor ,
186186 emb_module : Optional [nn .Module ] = None ,
187187 raw_ids : Optional [torch .Tensor ] = None ,
188+ runtime_meta : Optional [torch .Tensor ] = None ,
188189 ) -> None :
189190 per_table_ids : Dict [str , List [torch .Tensor ]] = {}
190191 per_table_raw_ids : Dict [str , List [torch .Tensor ]] = {}
192+ per_table_runtime_meta : Dict [str , List [torch .Tensor ]] = {}
191193
192194 # Skip storing invalid input or raw ids
193195 if (
@@ -197,28 +199,50 @@ def record_lookup(
197199 ):
198200 return
199201
200- embeddings_2d = raw_ids .view (kjt .values ().numel (), - 1 )
202+ # Skip storing if runtime_meta is provided but has invalid shape
203+ if runtime_meta is not None and not (
204+ runtime_meta .numel () % kjt .values ().numel () == 0
205+ ):
206+ return
207+
208+ raw_ids_2d = raw_ids .view (kjt .values ().numel (), - 1 )
209+ runtime_meta_2d = None
210+ # It is possible that runtime_meta is None while raw_ids is not None so we will proceed
211+ if runtime_meta is not None :
212+ runtime_meta_2d = runtime_meta .view (kjt .values ().numel (), - 1 )
201213
202214 offset : int = 0
203215 for key in kjt .keys ():
204216 table_fqn = self .table_to_fqn [key ]
205217 ids_list : List [torch .Tensor ] = per_table_ids .get (table_fqn , [])
206- emb_list : List [torch .Tensor ] = per_table_raw_ids .get (table_fqn , [])
218+ raw_ids_list : List [torch .Tensor ] = per_table_raw_ids .get (table_fqn , [])
219+ runtime_meta_list : List [torch .Tensor ] = per_table_runtime_meta .get (
220+ table_fqn , []
221+ )
207222
208223 ids = kjt [key ].values ()
209224 ids_list .append (ids )
210- emb_list .append (embeddings_2d [offset : offset + ids .numel ()])
225+ raw_ids_list .append (raw_ids_2d [offset : offset + ids .numel ()])
226+ if runtime_meta_2d is not None :
227+ runtime_meta_list .append (runtime_meta_2d [offset : offset + ids .numel ()])
211228 offset += ids .numel ()
212229
213230 per_table_ids [table_fqn ] = ids_list
214- per_table_raw_ids [table_fqn ] = emb_list
231+ per_table_raw_ids [table_fqn ] = raw_ids_list
232+ if runtime_meta_2d is not None :
233+ per_table_runtime_meta [table_fqn ] = runtime_meta_list
215234
216235 for table_fqn , ids_list in per_table_ids .items ():
217236 self .store .append (
218237 batch_idx = self .curr_batch_idx ,
219238 fqn = table_fqn ,
220239 ids = torch .cat (ids_list ),
221240 raw_ids = torch .cat (per_table_raw_ids [table_fqn ]),
241+ runtime_meta = (
242+ torch .cat (per_table_runtime_meta [table_fqn ])
243+ if table_fqn in per_table_runtime_meta
244+ else None
245+ ),
222246 )
223247
224248 def _clean_fqn_fn (self , fqn : str ) -> str :
@@ -277,8 +301,8 @@ def get_indexed_lookups(
277301 self ,
278302 tables : List [str ],
279303 consumer : Optional [str ] = None ,
280- ) -> Dict [str , List [torch .Tensor ]]:
281- raw_id_per_table : Dict [str , List [torch .Tensor ]] = {}
304+ ) -> Dict [str , Tuple [ List [torch .Tensor ], List [ torch . Tensor ] ]]:
305+ result : Dict [str , Tuple [ List [torch .Tensor ], List [ torch . Tensor ] ]] = {}
282306 consumer = consumer or self .DEFAULT_CONSUMER
283307 assert (
284308 consumer in self .per_consumer_batch_idx
@@ -293,17 +317,23 @@ def get_indexed_lookups(
293317
294318 for table in tables :
295319 raw_ids_list = []
320+ runtime_meta_list = []
296321 fqn = self .table_to_fqn [table ]
297322 if fqn in indexed_lookups :
298323 for indexed_lookup in indexed_lookups [fqn ]:
299324 if indexed_lookup .raw_ids is not None :
300325 raw_ids_list .append (indexed_lookup .raw_ids )
301- raw_id_per_table [table ] = raw_ids_list
326+ if indexed_lookup .runtime_meta is not None :
327+ runtime_meta_list .append (indexed_lookup .runtime_meta )
328+ if (
329+ raw_ids_list
330+ ): # if raw_ids doesn't exist runtime_meta will not exist so no need to check for runtime_meta
331+ result [table ] = (raw_ids_list , runtime_meta_list )
302332
303333 if self ._delete_on_read :
304334 self .store .delete (up_to_idx = min (self .per_consumer_batch_idx .values ()))
305335
306- return raw_id_per_table
336+ return result
307337
308338 def delete (self , up_to_idx : Optional [int ]) -> None :
309339 self .store .delete (up_to_idx )
0 commit comments