From bb7fded2d89ddb02e79e0d6bf4f04084fbceddf4 Mon Sep 17 00:00:00 2001 From: Mengjiao Zhou Date: Tue, 11 Nov 2025 10:47:55 -0800 Subject: [PATCH] sync opt_in metric fix Differential Revision: D86741254 --- torchrec/modules/hash_mc_metrics.py | 50 ++++++++++++++++++++++++++--- torchrec/modules/hash_mc_modules.py | 7 ++-- 2 files changed, 51 insertions(+), 6 deletions(-) diff --git a/torchrec/modules/hash_mc_metrics.py b/torchrec/modules/hash_mc_metrics.py index b1fc73655..54ce55fbf 100644 --- a/torchrec/modules/hash_mc_metrics.py +++ b/torchrec/modules/hash_mc_metrics.py @@ -42,6 +42,9 @@ def __init__( zch_size: int, frequency: int, start_bucket: int, + num_buckets_per_rank: int, + num_reserved_slots_per_bucket: int, + device: torch.device, disable_fallback: bool, log_file_path: str = "", ) -> None: @@ -57,8 +60,32 @@ def __init__( self._zch_size: int = zch_size self._frequency: int = frequency self._start_bucket: int = start_bucket + self._num_buckets_per_rank: int = num_buckets_per_rank + self._num_reserved_slots_per_bucket: int = num_reserved_slots_per_bucket + self._device: torch.device = device self._disable_fallback: bool = disable_fallback + assert ( + self._zch_size % self._num_buckets_per_rank == 0 + ), f"{self._zch_size} must be divisible by {self._num_buckets_per_rank}" + indice_per_bucket = torch.tensor( + [ + (self._zch_size // self._num_buckets_per_rank) * bucket + for bucket in range(1, self._num_buckets_per_rank + 1) + ], + dtype=torch.int64, + device=self._device, + ) + + self._opt_in_ranges: torch.Tensor = torch.sub( + indice_per_bucket, + ( + self._num_reserved_slots_per_bucket + if self._num_reserved_slots_per_bucket > 0 + else 0 + ), + ) + self._dtype_checked: bool = False self._total_cnt: int = 0 self._hit_cnt: int = 0 @@ -77,6 +104,13 @@ def __init__( ) # initialize file handler self.logger.addHandler(file_handler) # add file handler to logger + self.logger.info( + f"ScalarLogger: {self._name=}, {self._device=}, " + f"{self._zch_size=}, {self._frequency=}, {self._start_bucket=}, " + f"{self._num_buckets_per_rank=}, {self._num_reserved_slots_per_bucket=}, " + f"{self._opt_in_ranges=}, {self._disable_fallback=}" + ) + def should_report(self) -> bool: # We only need to report metrics from rank0 (start_bucket = 0) @@ -95,9 +129,9 @@ def update( identities_1: torch.Tensor, values: torch.Tensor, remapped_ids: torch.Tensor, + hit_indices: torch.Tensor, evicted_emb_indices: Optional[torch.Tensor], metadata: Optional[torch.Tensor], - num_reserved_slots: int, eviction_config: Optional[HashZchEvictionConfig] = None, ) -> None: if not self._dtype_checked: @@ -125,9 +159,17 @@ def update( self._hit_cnt += hit_cnt self._collision_cnt += values.numel() - hit_cnt - insert_cnt - opt_in_range = self._zch_size - num_reserved_slots - opt_in_ids = torch.lt(remapped_ids, opt_in_range) - self._opt_in_cnt += int(torch.sum(opt_in_ids).item()) + if self._disable_fallback: + hit_values = values[hit_indices] + train_buckets = hit_values % self._num_buckets_per_rank + else: + train_buckets = values % self._num_buckets_per_rank + + opt_in_ranges = self._opt_in_ranges.index_select(dim=0, index=train_buckets) + opt_in_ids = torch.lt(remapped_ids, opt_in_ranges) + opt_in_ids_cnt = int(torch.sum(opt_in_ids).item()) + # opt_in_cnt: # of ids assigned to opt-in block + self._opt_in_cnt += opt_in_ids_cnt if evicted_emb_indices is not None and evicted_emb_indices.numel() > 0: deduped_evicted_indices = torch.unique(evicted_emb_indices) diff --git a/torchrec/modules/hash_mc_modules.py b/torchrec/modules/hash_mc_modules.py index 9b6df6bb1..c617287e0 100644 --- a/torchrec/modules/hash_mc_modules.py +++ b/torchrec/modules/hash_mc_modules.py @@ -286,6 +286,9 @@ def __init__( zch_size=self._zch_size, frequency=self._tb_logging_frequency, start_bucket=self._start_bucket, + num_buckets_per_rank=self._end_bucket - self._start_bucket, + num_reserved_slots_per_bucket=self.get_reserved_slots_per_bucket(), + device=self._device, disable_fallback=self._disable_fallback, ) else: @@ -542,9 +545,9 @@ def remap(self, features: Dict[str, JaggedTensor]) -> Dict[str, JaggedTensor]: # record the on-device remapped ids self.table_name_on_device_remapped_ids_dict[name] = remapped_ids.clone() lengths: torch.Tensor = feature.lengths() + hit_indices = remapped_ids != -1 if self._disable_fallback: # Only works on GPU when read only is true. - hit_indices = remapped_ids != -1 remapped_ids = remapped_ids[hit_indices] lengths = torch.masked_fill(lengths, ~hit_indices, 0) if self._scalar_logger is not None: @@ -554,9 +557,9 @@ def remap(self, features: Dict[str, JaggedTensor]) -> Dict[str, JaggedTensor]: identities_1=self._hash_zch_identities, values=values, remapped_ids=remapped_ids, + hit_indices=hit_indices, evicted_emb_indices=evictions, metadata=metadata, - num_reserved_slots=num_reserved_slots, eviction_config=self._eviction_config, )