diff --git a/torchrec/distributed/keyed_jagged_tensor_pool.py b/torchrec/distributed/keyed_jagged_tensor_pool.py index 3c605c943..8afe34b04 100644 --- a/torchrec/distributed/keyed_jagged_tensor_pool.py +++ b/torchrec/distributed/keyed_jagged_tensor_pool.py @@ -608,7 +608,7 @@ def create_context(self) -> ObjectPoolShardingContext: def _lookup_ids_dist( self, ids: torch.Tensor, - ) -> Tuple[List[torch.Tensor], torch.Tensor]: + ) -> Tuple[List[torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]: return self._lookup_ids_dist_impl(ids) # pyre-ignore @@ -630,7 +630,7 @@ def _lookup_values_dist( # pyre-ignore def forward(self, ids: torch.Tensor) -> KeyedJaggedTensor: - dist_input, unbucketize_permute = self._lookup_ids_dist(ids) + dist_input, unbucketize_permute, _, _ = self._lookup_ids_dist(ids) lookup = self._lookup_local(dist_input) # Here we are playing a trick to workaround a fx tracing issue, # as proxy is not iteratable. diff --git a/torchrec/distributed/quant_embedding.py b/torchrec/distributed/quant_embedding.py index 8d58931d8..2e86053bb 100644 --- a/torchrec/distributed/quant_embedding.py +++ b/torchrec/distributed/quant_embedding.py @@ -84,6 +84,7 @@ from torchrec.modules.utils import ( _fx_trec_get_feature_length, _get_batching_hinted_output, + _get_unbucketize_tensor_via_length_alignment, ) from torchrec.quant.embedding_modules import ( EmbeddingCollection as QuantEmbeddingCollection, @@ -96,6 +97,7 @@ torch.fx.wrap("len") torch.fx.wrap("_get_batching_hinted_output") torch.fx.wrap("_fx_trec_get_feature_length") +torch.fx.wrap("_get_unbucketize_tensor_via_length_alignment") try: torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") @@ -278,16 +280,6 @@ def _fx_trec_wrap_length_tolist(length: torch.Tensor) -> List[int]: return length.long().tolist() -@torch.fx.wrap -def _get_unbucketize_tensor_via_length_alignment( - lengths: torch.Tensor, - bucketize_length: torch.Tensor, - bucketize_permute_tensor: torch.Tensor, - bucket_mapping_tensor: torch.Tensor, -) -> torch.Tensor: - return bucketize_permute_tensor - - @torch.fx.wrap def _fx_split_embeddings_per_feature_length( embeddings: torch.Tensor, diff --git a/torchrec/distributed/sharding/rw_pool_sharding.py b/torchrec/distributed/sharding/rw_pool_sharding.py index 2e0823dd1..e5bff5c73 100644 --- a/torchrec/distributed/sharding/rw_pool_sharding.py +++ b/torchrec/distributed/sharding/rw_pool_sharding.py @@ -166,6 +166,9 @@ class InferRwObjectPoolInputDist(torch.nn.Module): block_size (torch.Tensor): tensor containing block sizes for each rank. e.g. if block_size=torch.tensor(100), then IDs 0-99 will be assigned to rank 0, 100-199 to rank 1, and so on. + block_bucketize_row_pos (torch.Tensor]): tensor containing shard/row offsets for each + rank in case of uneven sharding of the tensor pool across ranks. If not provided, + then block_size will be used to permute the IDs across ranks. Example: device = torch.device("cpu") @@ -179,22 +182,27 @@ class InferRwObjectPoolInputDist(torch.nn.Module): _world_size: int _device: torch.device _block_size: torch.Tensor + _block_bucketize_row_pos: list[torch.Tensor] def __init__( self, env: ShardingEnv, device: torch.device, block_size: torch.Tensor, + block_bucketize_row_pos: Optional[list[torch.Tensor]] = None, ) -> None: super().__init__() self._world_size = env.world_size self._device = device self._block_size = block_size + self._block_bucketize_row_pos: list[torch.Tensor] = ( + [] if block_bucketize_row_pos is None else block_bucketize_row_pos + ) def forward( self, ids: torch.Tensor, - ) -> Tuple[List[torch.Tensor], torch.Tensor]: + ) -> Tuple[List[torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]: """ Bucketizes ids tensor into a list of tensors each containing ids for the corresponding rank. Places each tensor on the appropriate device. @@ -203,24 +211,34 @@ def forward( ids (torch.Tensor): Tensor with ids Returns: - Tuple[List[torch.Tensor], torch.Tensor]: Tuple containing list of ids tensors - for each rank given the bucket sizes, and the tensor containing indices - to permute the ids to get the original order before bucketization. + Tuple[List[torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]: + Tuple containing + 1. list of ids tensors for each rank given the bucket sizes + 2. the tensor containing indices to permute the ids to get the original order before bucketization. + 3. the tensor containing the bucket mapping for each id + 4. the tensor containing the bucketized lengths """ ( bucketized_lengths, bucketized_indices, - _bucketized_weights, - _bucketize_permute, + _, # bucketized_weights + _, # _bucketize_permute unbucketize_permute, - ) = torch.ops.fbgemm.block_bucketize_sparse_features( - _get_bucketize_shape(ids, ids.device), - ids.long(), + bucket_mapping, + ) = torch.ops.fbgemm.block_bucketize_sparse_features_inference( + lengths=_get_bucketize_shape(ids, ids.device), + indices=ids.long(), bucketize_pos=False, sequence=True, block_sizes=self._block_size.long(), my_size=self._world_size, weights=None, + block_bucketize_pos=( + self._block_bucketize_row_pos + if len(self._block_bucketize_row_pos) > 0 + else None + ), + return_bucket_mapping=True, ) id_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(bucketized_lengths) @@ -236,7 +254,13 @@ def forward( ) assert unbucketize_permute is not None, "unbucketize permute must not be None" - return dist_ids, unbucketize_permute + assert bucket_mapping is not None, "bucket mapping must not be None" + return ( + dist_ids, + unbucketize_permute, + bucket_mapping, + bucketized_lengths, + ) def update( self, @@ -270,6 +294,11 @@ def update( block_sizes=self._block_size.long(), my_size=self._world_size, weights=None, + block_bucketize_pos=( + self._block_bucketize_row_pos + if len(self._block_bucketize_row_pos) > 0 + else None + ), ) id_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(bucketized_lengths) diff --git a/torchrec/distributed/sharding/rw_tensor_pool_sharding.py b/torchrec/distributed/sharding/rw_tensor_pool_sharding.py index 63ea13fc4..942c36172 100644 --- a/torchrec/distributed/sharding/rw_tensor_pool_sharding.py +++ b/torchrec/distributed/sharding/rw_tensor_pool_sharding.py @@ -213,6 +213,8 @@ class InferRwTensorPoolOutputDist(torch.nn.Module): vals = torch.Tensor([1,2,3,4,5,6], device=device) """ + __annotations__ = {"_device": Optional[torch.device]} + def __init__( self, env: ShardingEnv, @@ -224,6 +226,11 @@ def __init__( self._cat_dim = 0 self._placeholder: torch.Tensor = torch.ones(1, device=device) + @torch.jit.export + def set_device(self, device_str: str) -> None: + self._device = torch.device(device_str) + self._placeholder = torch.ones(1, device=self._device) + def forward( self, lookups: List[torch.Tensor], @@ -256,12 +263,16 @@ def __init__( pool_size: int, env: ShardingEnv, device: torch.device, + memory_capacity_per_rank: Optional[list[int]] = None, ) -> None: - super().__init__(pool_size, env, device) + super().__init__(pool_size, env, device, memory_capacity_per_rank) def create_lookup_ids_dist(self) -> InferRwObjectPoolInputDist: return InferRwObjectPoolInputDist( - self._env, device=self._device, block_size=self._block_size_t + self._env, + device=self._device, + block_size=self._block_size_t, + block_bucketize_row_pos=self._block_bucketize_row_pos, ) def create_lookup_values_dist( diff --git a/torchrec/distributed/tensor_pool.py b/torchrec/distributed/tensor_pool.py index fffda2a10..80d0abbd3 100644 --- a/torchrec/distributed/tensor_pool.py +++ b/torchrec/distributed/tensor_pool.py @@ -32,7 +32,14 @@ ) from torchrec.modules.object_pool_lookups import TensorLookup, TensorPoolLookup from torchrec.modules.tensor_pool import TensorPool -from torchrec.modules.utils import deterministic_dedup +from torchrec.modules.utils import ( + _get_batching_hinted_output, + _get_unbucketize_tensor_via_length_alignment, + deterministic_dedup, +) + +torch.fx.wrap("_get_unbucketize_tensor_via_length_alignment") +torch.fx.wrap("_get_batching_hinted_output") @torch.fx.wrap @@ -44,6 +51,17 @@ def index_select_view( return output[unbucketize_permute].view(-1, dim) +@torch.fx.wrap +def _fx_item_unwrap_optional_tensor(optional: Optional[torch.Tensor]) -> torch.Tensor: + assert optional is not None, "Expected optional to be non-None Tensor" + return optional + + +@torch.fx.wrap +def _get_id_length_sharded_tensor_pool(ids: torch.Tensor) -> torch.Tensor: + return torch.tensor([ids.size(dim=0)], device=ids.device, dtype=torch.long) + + class TensorPoolAwaitable(LazyAwaitable[torch.Tensor]): def __init__( self, @@ -271,6 +289,8 @@ class LocalShardPool(torch.nn.Module): # out is tensor([1,2,3]) i.e. first row of the shard """ + current_device: torch.device + def __init__( self, shard: torch.Tensor, @@ -280,6 +300,12 @@ def __init__( shard, requires_grad=False, ) + self.current_device = self._shard.device + + @torch.jit.export + def set_device(self, device_str: str) -> None: + self.current_device = torch.device(device_str) + self._shard.to(self.current_device) def forward(self, rank_ids: torch.Tensor) -> torch.Tensor: """ @@ -291,7 +317,7 @@ def forward(self, rank_ids: torch.Tensor) -> torch.Tensor: Returns: torch.Tensor: Tensor of values corresponding to the given rank ids. """ - return self._shard[rank_ids] + return self._shard[rank_ids.to(self.current_device)] def update(self, rank_ids: torch.Tensor, values: torch.Tensor) -> None: _ = update(self._shard, rank_ids, values) @@ -337,6 +363,11 @@ def __init__( env=self._sharding_env, device=self._device, pool_size=self._pool_size, + memory_capacity_per_rank=( + self._sharding_plan.memory_capacity_per_rank + if self._sharding_plan.memory_capacity_per_rank is not None + else None + ), ) else: raise NotImplementedError( @@ -356,6 +387,7 @@ def __init__( if device == torch.device("cpu") else torch.device("cuda", rank) ) + self._local_shard_pools.append( LocalShardPool( torch.empty( @@ -409,7 +441,7 @@ def create_context(self) -> ObjectPoolShardingContext: def _lookup_ids_dist( self, ids: torch.Tensor, - ) -> Tuple[List[torch.Tensor], torch.Tensor]: + ) -> Tuple[List[torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]: return self._lookup_ids_dist_impl(ids) # pyre-ignore @@ -439,18 +471,54 @@ def _lookup_values_dist( # pyre-ignore def forward(self, ids: torch.Tensor) -> torch.Tensor: - dist_input, unbucketize_permute = self._lookup_ids_dist(ids) + dist_input, unbucketize_permute, bucket_mapping, bucketized_lengths = ( + self._lookup_ids_dist(ids) + ) + unbucketize_permute_non_opt = _fx_item_unwrap_optional_tensor( + unbucketize_permute + ) + lookup = self._lookup_local(dist_input) # Here we are playing a trick to workaround a fx tracing issue, # as proxy is not iteratable. lookup_list = [] - for i in range(self._world_size): - lookup_list.append(lookup[i]) + # In case of non-heterogenous even sharding keeping the behavior + # consistent with existing logic to ensure that additional fx wrappers + # do not impact the model split logic during inference in anyway + if self._sharding_plan.memory_capacity_per_rank is None: + for i in range(self._world_size): + lookup_list.append(lookup[i]) + else: + # Adding fx wrappers in case of uneven heterogenous sharding to + # make it compatible with model split boundaries during inference + for i in range(self._world_size): + lookup_list.append( + _get_batching_hinted_output( + _get_id_length_sharded_tensor_pool(dist_input[i]), lookup[i] + ) + ) + + features_before_input_dist_length = _get_id_length_sharded_tensor_pool(ids) + bucketized_lengths_col_view = bucketized_lengths.view(self._world_size, -1) + unbucketize_permute_non_opt = _fx_item_unwrap_optional_tensor( + unbucketize_permute + ) + bucket_mapping_non_opt = _fx_item_unwrap_optional_tensor(bucket_mapping) + unbucketize_permute_non_opt = _get_unbucketize_tensor_via_length_alignment( + features_before_input_dist_length, + bucketized_lengths_col_view, + unbucketize_permute_non_opt, + bucket_mapping_non_opt, + ) output = self._lookup_values_dist(lookup_list) - return index_select_view(output, unbucketize_permute, self._dim) + return index_select_view( + output, + unbucketize_permute_non_opt.to(device=output.device), + self._dim, + ) # pyre-ignore def _update_values_dist(self, ctx: ObjectPoolShardingContext, values: torch.Tensor): diff --git a/torchrec/distributed/tensor_sharding.py b/torchrec/distributed/tensor_sharding.py index 6a7ca0715..8df207108 100644 --- a/torchrec/distributed/tensor_sharding.py +++ b/torchrec/distributed/tensor_sharding.py @@ -102,6 +102,7 @@ def __init__( pool_size: int, env: ShardingEnv, device: torch.device, + memory_capacity_per_rank: Optional[list[int]] = None, ) -> None: self._pool_size = pool_size self._env = env @@ -117,13 +118,40 @@ def __init__( self._last_block_size: int = self._pool_size - self._block_size * ( self._world_size - 1 ) - self.local_pool_size_per_rank: List[int] = [self._block_size] * ( - self._world_size - 1 - ) + [self._last_block_size] - + # only used for uneven sharding case when memory_capacity_per_rank is provided + row_offset_per_rank = [] + + if memory_capacity_per_rank is None: + self.local_pool_size_per_rank: List[int] = [self._block_size] * ( + self._world_size - 1 + ) + [self._last_block_size] + else: + row_offset_per_rank = [0] + self.local_pool_size_per_rank: List[int] = [] + row_offset = 0 + assert ( + len(memory_capacity_per_rank) == self._world_size + ), "If memory_capacity_per_rank is provided for sharded tensor pool, it must have the same length as world_size" + total_mem_cap = sum(memory_capacity_per_rank) + for cap in memory_capacity_per_rank[:-1]: + rows_per_shard = int(cap / total_mem_cap * self._pool_size) + self.local_pool_size_per_rank.append(rows_per_shard) + row_offset += rows_per_shard + row_offset_per_rank.append(row_offset) + self.local_pool_size_per_rank.append( + self._pool_size - sum(self.local_pool_size_per_rank) + ) + row_offset_per_rank.append(self._pool_size) self._block_size_t: torch.Tensor = torch.tensor( [self._block_size], device=self._device, dtype=torch.long ) + # for uneven sharding case, we get the row offsets for each rank to + # enable input_dist and lookup of ids to correct rank + self._block_bucketize_row_pos: Optional[List[torch.Tensor]] = ( + None + if memory_capacity_per_rank is None + else [torch.tensor(row_offset_per_rank, device=self._device)] + ) @abstractmethod def create_lookup_ids_dist(self) -> torch.nn.Module: diff --git a/torchrec/distributed/tests/test_tensor_pool_rw_sharding.py b/torchrec/distributed/tests/test_tensor_pool_rw_sharding.py index 0a6992c0d..dc56684c1 100644 --- a/torchrec/distributed/tests/test_tensor_pool_rw_sharding.py +++ b/torchrec/distributed/tests/test_tensor_pool_rw_sharding.py @@ -10,7 +10,10 @@ import unittest import torch -from torchrec.distributed.sharding.rw_tensor_pool_sharding import TensorPoolRwSharding +from torchrec.distributed.sharding.rw_tensor_pool_sharding import ( + InferRwTensorPoolSharding, + TensorPoolRwSharding, +) from torchrec.distributed.tensor_sharding import TensorPoolRwShardingContext from torchrec.distributed.test_utils.multi_process import ( MultiProcessContext, @@ -200,3 +203,344 @@ def test_lookup( ) -> None: world_size = 2 self._run_multi_process_test(callable=self._test_lookup, world_size=world_size) + + +class TestInferRwTensorPoolSharding(unittest.TestCase): + def test_uneven_sharding_with_memory_capacity_per_rank(self) -> None: + # Setup: create a sharding configuration with uneven memory capacity per rank + # Rank 0 gets 60% capacity, Rank 1 gets 20%, Rank 2 gets 20% + pool_size = 1000 + world_size = 3 + memory_capacity_per_rank = [600, 200, 200] # Uneven distribution + device = torch.device("cpu") + + # Create a mock ShardingEnv + class MockShardingEnv: + def __init__(self, world_size: int, rank: int) -> None: + self.world_size = world_size + self.rank = rank + self.process_group = None + + env = MockShardingEnv(world_size=world_size, rank=0) + + # Execute: create InferRwTensorPoolSharding with memory_capacity_per_rank + sharding = InferRwTensorPoolSharding( + pool_size=pool_size, + # pyre-fixme: Incompatible parameter type [6] + env=env, + device=device, + memory_capacity_per_rank=memory_capacity_per_rank, + ) + + # Assert: verify the local pool size per rank is computed based on memory capacity + # Expected: rank 0 gets 600 rows, rank 1 gets 200 rows, rank 2 gets 200 rows + expected_local_pool_size_per_rank = [600, 200, 200] + self.assertEqual( + sharding.local_pool_size_per_rank, expected_local_pool_size_per_rank + ) + + # Assert: verify block_bucketize_row_pos is set for uneven sharding + self.assertIsNotNone(sharding._block_bucketize_row_pos) + self.assertEqual(len(sharding._block_bucketize_row_pos), 1) + + # Assert: verify the row offsets are correct [0, 600, 800, 1000] + expected_row_offsets = torch.tensor([0, 600, 800, 1000], device=device) + torch.testing.assert_close( + # pyre-fixme: Undefined attribute [16] + sharding._block_bucketize_row_pos[0], + expected_row_offsets, + ) + + def test_uneven_sharding_with_different_capacities(self) -> None: + # Setup: create a sharding configuration with different memory capacities + # Rank 0 gets 50% capacity, Rank 1 gets 30%, Rank 2 gets 20% + pool_size = 500 + world_size = 3 + memory_capacity_per_rank = [500, 300, 200] # Different distribution + device = torch.device("cpu") + + # Create a mock ShardingEnv + class MockShardingEnv: + def __init__(self, world_size: int, rank: int) -> None: + self.world_size = world_size + self.rank = rank + self.process_group = None + + env = MockShardingEnv(world_size=world_size, rank=1) + + # Execute: create InferRwTensorPoolSharding with memory_capacity_per_rank + sharding = InferRwTensorPoolSharding( + pool_size=pool_size, + # pyre-fixme: Incompatible parameter type [6] + env=env, + device=device, + memory_capacity_per_rank=memory_capacity_per_rank, + ) + + # Assert: verify the local pool size per rank is computed proportionally + # Total capacity = 1000 + # Rank 0: 500/1000 * 500 = 250 rows + # Rank 1: 300/1000 * 500 = 150 rows + # Rank 2: remaining = 500 - 250 - 150 = 100 rows + expected_local_pool_size_per_rank = [250, 150, 100] + self.assertEqual( + sharding.local_pool_size_per_rank, expected_local_pool_size_per_rank + ) + + # Assert: verify the row offsets are correct [0, 250, 400, 500] + expected_row_offsets = torch.tensor([0, 250, 400, 500], device=device) + torch.testing.assert_close( + # pyre-fixme: Undefined attribute [16] + sharding._block_bucketize_row_pos[0], + expected_row_offsets, + ) + + def test_uneven_sharding_total_rows_equals_pool_size(self) -> None: + # Setup: verify that the sum of local pool sizes equals the pool size + pool_size = 1234 + world_size = 4 + memory_capacity_per_rank = [100, 200, 300, 400] + device = torch.device("cpu") + + # Create a mock ShardingEnv + class MockShardingEnv: + def __init__(self, world_size: int, rank: int) -> None: + self.world_size = world_size + self.rank = rank + self.process_group = None + + env = MockShardingEnv(world_size=world_size, rank=0) + + # Execute: create InferRwTensorPoolSharding with memory_capacity_per_rank + sharding = InferRwTensorPoolSharding( + pool_size=pool_size, + # pyre-fixme: Incompatible parameter type [6] + env=env, + device=device, + memory_capacity_per_rank=memory_capacity_per_rank, + ) + + # Assert: verify the sum of local pool sizes equals the total pool size + total_rows = sum(sharding.local_pool_size_per_rank) + self.assertEqual(total_rows, pool_size) + + # Assert: verify the last row offset equals the pool size + # pyre-fixme: Undefined attribute [16] + self.assertEqual(sharding._block_bucketize_row_pos[0][-1].item(), pool_size) + + def test_even_sharding_without_memory_capacity_per_rank(self) -> None: + # Setup: create a sharding configuration without memory_capacity_per_rank + # This should result in even sharding + pool_size = 1000 + world_size = 4 + device = torch.device("cpu") + + # Create a mock ShardingEnv + class MockShardingEnv: + def __init__(self, world_size: int, rank: int) -> None: + self.world_size = world_size + self.rank = rank + self.process_group = None + + env = MockShardingEnv(world_size=world_size, rank=0) + + # Execute: create InferRwTensorPoolSharding without memory_capacity_per_rank + sharding = InferRwTensorPoolSharding( + pool_size=pool_size, + # pyre-fixme: Incompatible parameter type [6] + env=env, + device=device, + memory_capacity_per_rank=None, + ) + + # Assert: verify the local pool size per rank is evenly distributed + # block_size = (1000 + 4 - 1) // 4 = 250 + # Expected: [250, 250, 250, 250] + expected_local_pool_size_per_rank = [250, 250, 250, 250] + self.assertEqual( + sharding.local_pool_size_per_rank, expected_local_pool_size_per_rank + ) + + # Assert: verify block_bucketize_row_pos is None for even sharding + self.assertIsNone(sharding._block_bucketize_row_pos) + + def test_lookup_ids_dist_uses_block_bucketize_row_pos(self) -> None: + # Setup: create a sharding configuration with uneven memory capacity + pool_size = 1000 + world_size = 3 + memory_capacity_per_rank = [600, 200, 200] + device = torch.device("cpu") + + # Create a mock ShardingEnv + class MockShardingEnv: + def __init__(self, world_size: int, rank: int) -> None: + self.world_size = world_size + self.rank = rank + self.process_group = None + + env = MockShardingEnv(world_size=world_size, rank=0) + + sharding = InferRwTensorPoolSharding( + pool_size=pool_size, + # pyre-fixme: Incompatible parameter type [6] + env=env, + device=device, + memory_capacity_per_rank=memory_capacity_per_rank, + ) + + # Execute: create the lookup ids distribution module + lookup_ids_dist = sharding.create_lookup_ids_dist() + + # Assert: verify the lookup_ids_dist has the correct block_bucketize_row_pos + self.assertIsNotNone(lookup_ids_dist._block_bucketize_row_pos) + self.assertEqual(len(lookup_ids_dist._block_bucketize_row_pos), 1) + + # Assert: verify the row offsets match the sharding configuration + expected_row_offsets = torch.tensor([0, 600, 800, 1000], device=device) + torch.testing.assert_close( + lookup_ids_dist._block_bucketize_row_pos[0], expected_row_offsets + ) + + def test_lookup_with_uneven_sharding_bucketizes_ids_correctly(self) -> None: + # Setup: create a sharding configuration with uneven memory capacity + # Rank 0 gets rows 0-599 (60% capacity) + # Rank 1 gets rows 600-799 (20% capacity) + # Rank 2 gets rows 800-999 (20% capacity) + pool_size = 1000 + world_size = 3 + memory_capacity_per_rank = [600, 200, 200] + device = torch.device("cpu") + + # Create a mock ShardingEnv + class MockShardingEnv: + def __init__(self, world_size: int, rank: int) -> None: + self.world_size = world_size + self.rank = rank + self.process_group = None + + env = MockShardingEnv(world_size=world_size, rank=0) + + sharding = InferRwTensorPoolSharding( + pool_size=pool_size, + # pyre-fixme: Incompatible parameter type [6] + env=env, + device=device, + memory_capacity_per_rank=memory_capacity_per_rank, + ) + + # Execute: create the lookup ids distribution and test with various IDs + # IDs 0, 100, 599 should go to rank 0 + # IDs 600, 700 should go to rank 1 + # IDs 800, 900, 999 should go to rank 2 + lookup_ids_dist = sharding.create_lookup_ids_dist() + test_ids = torch.tensor([0, 100, 599, 600, 700, 800, 900, 999], device=device) + + dist_ids, unbucketize_permute, bucket_mapping, bucketized_lengths = ( + lookup_ids_dist(test_ids) + ) + + # Assert: verify the number of IDs per rank matches expected distribution + # Rank 0 should receive 3 IDs (0, 100, 599) + # Rank 1 should receive 2 IDs (600, 700) + # Rank 2 should receive 3 IDs (800, 900, 999) + expected_lengths = torch.tensor([3, 2, 3], device=device) + torch.testing.assert_close(bucketized_lengths, expected_lengths) + + # Assert: verify IDs are correctly distributed to each rank + # Note: IDs are stored as local offsets within each rank's pool + # Rank 0 IDs: 0, 100, 599 (no offset needed, rank 0 starts at 0) + torch.testing.assert_close( + dist_ids[0], torch.tensor([0, 100, 599], device=device) + ) + # Rank 1 IDs: 600, 700 become 0, 100 (offset by 600, rank 1 starts at 600) + torch.testing.assert_close(dist_ids[1], torch.tensor([0, 100], device=device)) + # Rank 2 IDs: 800, 900, 999 become 0, 100, 199 (offset by 800, rank 2 starts at 800) + torch.testing.assert_close( + dist_ids[2], torch.tensor([0, 100, 199], device=device) + ) + + # Assert: verify bucket_mapping assigns IDs to correct ranks + # IDs 0, 100, 599 -> rank 0 + # IDs 600, 700 -> rank 1 + # IDs 800, 900, 999 -> rank 2 + expected_bucket_mapping = torch.tensor([0, 0, 0, 1, 1, 2, 2, 2], device=device) + torch.testing.assert_close(bucket_mapping, expected_bucket_mapping) + + def test_lookup_output_with_uneven_sharding(self) -> None: + # Setup: create a sharding configuration with uneven memory capacity + # This test validates the complete lookup flow including output distribution + pool_size = 10 + world_size = 3 + # Rank 0: 60% -> 6 rows (IDs 0-5) + # Rank 1: 20% -> 2 rows (IDs 6-7) + # Rank 2: 20% -> 2 rows (IDs 8-9) + memory_capacity_per_rank = [600, 200, 200] + device = torch.device("cpu") + dim = 3 + + # Create a mock ShardingEnv + class MockShardingEnv: + def __init__(self, world_size: int, rank: int) -> None: + self.world_size = world_size + self.rank = rank + self.process_group = None + + env = MockShardingEnv(world_size=world_size, rank=0) + + sharding = InferRwTensorPoolSharding( + pool_size=pool_size, + # pyre-fixme: Incompatible parameter type [6] + env=env, + device=device, + memory_capacity_per_rank=memory_capacity_per_rank, + ) + + # Execute: test lookup with IDs from different ranks + lookup_ids_dist = sharding.create_lookup_ids_dist() + lookup_values_dist = sharding.create_lookup_values_dist() + + # Test IDs: 0, 5 (rank 0), 6, 7 (rank 1), 8 (rank 2) + test_ids = torch.tensor([0, 5, 6, 7, 8], device=device) + + dist_ids, unbucketize_permute, bucket_mapping, bucketized_lengths = ( + lookup_ids_dist(test_ids) + ) + + # Simulate lookup values from each rank's local pool + # Rank 0 values for IDs 0, 5 + rank_0_values = torch.tensor( + [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=torch.float32, device=device + ) + # Rank 1 values for IDs 6, 7 + rank_1_values = torch.tensor( + [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]], + dtype=torch.float32, + device=device, + ) + # Rank 2 values for ID 8 + rank_2_values = torch.tensor( + [[13.0, 14.0, 15.0]], dtype=torch.float32, device=device + ) + + # Execute: merge the lookup values from all ranks + lookups = [rank_0_values, rank_1_values, rank_2_values] + merged_values = lookup_values_dist(lookups) + + # Assert: verify the merged values have the correct shape + self.assertEqual(merged_values.shape, (5, dim)) + + # Assert: verify values are correctly merged + # The merged tensor should contain values in the order they were bucketized + # After unbucketize_permute, values should be in original order + expected_merged = torch.tensor( + [ + [1.0, 2.0, 3.0], # ID 0 from rank 0 + [4.0, 5.0, 6.0], # ID 5 from rank 0 + [7.0, 8.0, 9.0], # ID 6 from rank 1 + [10.0, 11.0, 12.0], # ID 7 from rank 1 + [13.0, 14.0, 15.0], # ID 8 from rank 2 + ], + dtype=torch.float32, + device=device, + ) + torch.testing.assert_close(merged_values, expected_merged) diff --git a/torchrec/distributed/types.py b/torchrec/distributed/types.py index 5bac4e396..70b10c53b 100644 --- a/torchrec/distributed/types.py +++ b/torchrec/distributed/types.py @@ -1326,12 +1326,22 @@ class ObjectPoolShardingType(Enum): class ObjectPoolShardingPlan(ModuleShardingPlan): sharding_type: ObjectPoolShardingType inference: bool = False + # Currently used to propagate the metadata used to shard the tensor pool + # unevenly across different devices (HBM/DRAM) based on the memory capacity + # of each device type + memory_capacity_per_rank: Optional[list[int]] = None def _serialize(self) -> dict[str, Any]: - return { + output = { "sharding_type": self.sharding_type.name, "inference": self.inference, } + if self.memory_capacity_per_rank is not None: + output["memory_capacity_per_rank"] = ( + # pyre-fixme: Incompatible parameter type [6] + self.memory_capacity_per_rank + ) + return output @dataclass diff --git a/torchrec/modules/utils.py b/torchrec/modules/utils.py index 95b082b4a..0da67adaf 100644 --- a/torchrec/modules/utils.py +++ b/torchrec/modules/utils.py @@ -434,3 +434,13 @@ def _fx_trec_get_feature_length( "embedding output and features mismatch", ) return features.lengths() + + +@torch.fx.wrap +def _get_unbucketize_tensor_via_length_alignment( + lengths: torch.Tensor, + bucketize_length: torch.Tensor, + bucketize_permute_tensor: torch.Tensor, + bucket_mapping_tensor: torch.Tensor, +) -> torch.Tensor: + return bucketize_permute_tensor