Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 46 additions & 4 deletions torchrec/modules/hash_mc_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ def __init__(
zch_size: int,
frequency: int,
start_bucket: int,
num_buckets_per_rank: int,
num_reserved_slots_per_bucket: int,
device: torch.device,
disable_fallback: bool,
log_file_path: str = "",
) -> None:
Expand All @@ -57,8 +60,32 @@ def __init__(
self._zch_size: int = zch_size
self._frequency: int = frequency
self._start_bucket: int = start_bucket
self._num_buckets_per_rank: int = num_buckets_per_rank
self._num_reserved_slots_per_bucket: int = num_reserved_slots_per_bucket
self._device: torch.device = device
self._disable_fallback: bool = disable_fallback

assert (
self._zch_size % self._num_buckets_per_rank == 0
), f"{self._zch_size} must be divisible by {self._num_buckets_per_rank}"
indice_per_bucket = torch.tensor(
[
(self._zch_size // self._num_buckets_per_rank) * bucket
for bucket in range(1, self._num_buckets_per_rank + 1)
],
dtype=torch.int64,
device=self._device,
)

self._opt_in_ranges: torch.Tensor = torch.sub(
indice_per_bucket,
(
self._num_reserved_slots_per_bucket
if self._num_reserved_slots_per_bucket > 0
else 0
),
)

self._dtype_checked: bool = False
self._total_cnt: int = 0
self._hit_cnt: int = 0
Expand All @@ -77,6 +104,13 @@ def __init__(
) # initialize file handler
self.logger.addHandler(file_handler) # add file handler to logger

self.logger.info(
f"ScalarLogger: {self._name=}, {self._device=}, "
f"{self._zch_size=}, {self._frequency=}, {self._start_bucket=}, "
f"{self._num_buckets_per_rank=}, {self._num_reserved_slots_per_bucket=}, "
f"{self._opt_in_ranges=}, {self._disable_fallback=}"
)

def should_report(self) -> bool:
# We only need to report metrics from rank0 (start_bucket = 0)

Expand All @@ -95,9 +129,9 @@ def update(
identities_1: torch.Tensor,
values: torch.Tensor,
remapped_ids: torch.Tensor,
hit_indices: torch.Tensor,
evicted_emb_indices: Optional[torch.Tensor],
metadata: Optional[torch.Tensor],
num_reserved_slots: int,
eviction_config: Optional[HashZchEvictionConfig] = None,
) -> None:
if not self._dtype_checked:
Expand Down Expand Up @@ -125,9 +159,17 @@ def update(
self._hit_cnt += hit_cnt
self._collision_cnt += values.numel() - hit_cnt - insert_cnt

opt_in_range = self._zch_size - num_reserved_slots
opt_in_ids = torch.lt(remapped_ids, opt_in_range)
self._opt_in_cnt += int(torch.sum(opt_in_ids).item())
if self._disable_fallback:
hit_values = values[hit_indices]
train_buckets = hit_values % self._num_buckets_per_rank
else:
train_buckets = values % self._num_buckets_per_rank

opt_in_ranges = self._opt_in_ranges.index_select(dim=0, index=train_buckets)
opt_in_ids = torch.lt(remapped_ids, opt_in_ranges)
opt_in_ids_cnt = int(torch.sum(opt_in_ids).item())
# opt_in_cnt: # of ids assigned to opt-in block
self._opt_in_cnt += opt_in_ids_cnt

if evicted_emb_indices is not None and evicted_emb_indices.numel() > 0:
deduped_evicted_indices = torch.unique(evicted_emb_indices)
Expand Down
7 changes: 5 additions & 2 deletions torchrec/modules/hash_mc_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,9 @@ def __init__(
zch_size=self._zch_size,
frequency=self._tb_logging_frequency,
start_bucket=self._start_bucket,
num_buckets_per_rank=self._end_bucket - self._start_bucket,
num_reserved_slots_per_bucket=self.get_reserved_slots_per_bucket(),
device=self._device,
disable_fallback=self._disable_fallback,
)
else:
Expand Down Expand Up @@ -542,9 +545,9 @@ def remap(self, features: Dict[str, JaggedTensor]) -> Dict[str, JaggedTensor]:
# record the on-device remapped ids
self.table_name_on_device_remapped_ids_dict[name] = remapped_ids.clone()
lengths: torch.Tensor = feature.lengths()
hit_indices = remapped_ids != -1
if self._disable_fallback:
# Only works on GPU when read only is true.
hit_indices = remapped_ids != -1
remapped_ids = remapped_ids[hit_indices]
lengths = torch.masked_fill(lengths, ~hit_indices, 0)
if self._scalar_logger is not None:
Expand All @@ -554,9 +557,9 @@ def remap(self, features: Dict[str, JaggedTensor]) -> Dict[str, JaggedTensor]:
identities_1=self._hash_zch_identities,
values=values,
remapped_ids=remapped_ids,
hit_indices=hit_indices,
evicted_emb_indices=evictions,
metadata=metadata,
num_reserved_slots=num_reserved_slots,
eviction_config=self._eviction_config,
)

Expand Down
Loading