diff --git a/torchrec/distributed/dist_data.py b/torchrec/distributed/dist_data.py index b58ca1b12..99ed9e049 100644 --- a/torchrec/distributed/dist_data.py +++ b/torchrec/distributed/dist_data.py @@ -1240,7 +1240,9 @@ def forward(self, tensors: List[torch.Tensor]) -> List[torch.Tensor]: Awaitable[torch.Tensor]: awaitable of the merged pooled embeddings. """ - assert len(tensors) == self._world_size + assert ( + len(tensors) == self._world_size + ), f"length of input tensor {len(tensors)} must match with world size {self._world_size}" return torch.ops.fbgemm.all_to_one_device( tensors, self._device, diff --git a/torchrec/distributed/quant_embedding.py b/torchrec/distributed/quant_embedding.py index 9f47745ba..8d58931d8 100644 --- a/torchrec/distributed/quant_embedding.py +++ b/torchrec/distributed/quant_embedding.py @@ -305,37 +305,49 @@ def _construct_jagged_tensors_tw( storage_device_type: str, ) -> Dict[str, JaggedTensor]: ret: Dict[str, JaggedTensor] = {} + index = 0 for i in range(len(embedding_names_per_rank)): - embeddings_i = embeddings[i] - features_i: KeyedJaggedTensor = features[i] - if storage_device_type in ["ssd", "cpu"]: - embeddings_i = _get_batching_hinted_output( - _fx_trec_get_feature_length(features_i, embedding_names_per_rank[i]), - embeddings_i, - ) - - lengths = features_i.lengths().view(-1, features_i.stride()) - values = features_i.values() - embeddings_list = _fx_split_embeddings_per_feature_length( - embeddings_i, features_i - ) - stride = features_i.stride() - lengths_tuple = torch.unbind(lengths.view(-1, stride), dim=0) - if need_indices: - values_list = _fx_split_embeddings_per_feature_length(values, features_i) - for j, key in enumerate(embedding_names_per_rank[i]): - ret[key] = JaggedTensor( - lengths=lengths_tuple[j], - values=embeddings_list[j], - weights=values_list[j], + if len(embedding_names_per_rank[i]) > 0: + embeddings_i = embeddings[index] + features_i: KeyedJaggedTensor = features[i] + if storage_device_type in ["ssd", "cpu"]: + embeddings_i = _get_batching_hinted_output( + _fx_trec_get_feature_length( + features_i, embedding_names_per_rank[index] + ), + embeddings_i, ) - else: - for j, key in enumerate(embedding_names_per_rank[i]): - ret[key] = JaggedTensor( - lengths=lengths_tuple[j], - values=embeddings_list[j], - weights=None, + + lengths = features_i.lengths().view(-1, features_i.stride()) + values = features_i.values() + embeddings_list = _fx_split_embeddings_per_feature_length( + embeddings_i, features_i + ) + stride = features_i.stride() + lengths_tuple = torch.unbind(lengths.view(-1, stride), dim=0) + if need_indices: + values_list = _fx_split_embeddings_per_feature_length( + values, features_i ) + for j, key in enumerate(embedding_names_per_rank[i]): + ret[key] = JaggedTensor( + lengths=lengths_tuple[j], + values=embeddings_list[j], + weights=values_list[j], + ) + else: + for j, key in enumerate(embedding_names_per_rank[i]): + ret[key] = JaggedTensor( + lengths=lengths_tuple[j], + values=embeddings_list[j], + weights=None, + ) + index += 1 + # for cuda storage device, empty embeddding per rank is already skipped + # as part of tw_sequence_sharding output dist before executing + # SeqEmbeddingsAllToOne (for cpu / ssd SeqEmbeddingsAllToOne is not required) + elif storage_device_type in ["cpu", "ssd"]: + index += 1 return ret diff --git a/torchrec/distributed/sharding/tw_sequence_sharding.py b/torchrec/distributed/sharding/tw_sequence_sharding.py index 2a8d482f0..1a7b517bc 100644 --- a/torchrec/distributed/sharding/tw_sequence_sharding.py +++ b/torchrec/distributed/sharding/tw_sequence_sharding.py @@ -175,12 +175,21 @@ def __init__( device: torch.device, world_size: int, storage_device_type_from_sharding_infos: Optional[str] = None, + embedding_names_per_rank: Optional[List[List[str]]] = None, ) -> None: super().__init__() - self._dist: SeqEmbeddingsAllToOne = SeqEmbeddingsAllToOne(device, world_size) + self._adjusted_world_size: int = ( + world_size + if embedding_names_per_rank is None + else sum(1 for sublist in embedding_names_per_rank if len(sublist) > 0) + ) + self._dist: SeqEmbeddingsAllToOne = SeqEmbeddingsAllToOne( + device, self._adjusted_world_size + ) self._storage_device_type_from_sharding_infos: Optional[str] = ( storage_device_type_from_sharding_infos ) + self._embedding_names_per_rank = embedding_names_per_rank def forward( self, @@ -216,6 +225,8 @@ def forward( local_emb, ) for i, local_emb in enumerate(local_embs) + if self._embedding_names_per_rank is not None + and len(self._embedding_names_per_rank[i]) > 0 ] return self._dist(local_embs) else: @@ -269,4 +280,5 @@ def create_output_dist( device, self._world_size, self._storage_device_type_from_sharding_infos, + self.embedding_names_per_rank(), )