diff --git a/torchrec/distributed/quant_embeddingbag.py b/torchrec/distributed/quant_embeddingbag.py index eb5f8ea21..68dfa8e08 100644 --- a/torchrec/distributed/quant_embeddingbag.py +++ b/torchrec/distributed/quant_embeddingbag.py @@ -15,7 +15,6 @@ IntNBitTableBatchedEmbeddingBagsCodegen, ) from torch import nn - from torch.distributed._shard.sharding_spec import EnumerableShardingSpec from torchrec.distributed.embedding_lookup import EmbeddingComputeKernel from torchrec.distributed.embedding_sharding import ( @@ -53,6 +52,7 @@ from torchrec.distributed.sharding.rw_sharding import InferRwPooledEmbeddingSharding from torchrec.distributed.sharding.tw_sharding import InferTwEmbeddingSharding from torchrec.distributed.types import ( + Multistreamable, NullShardedModuleContext, NullShardingContext, ParameterSharding, @@ -168,12 +168,21 @@ def create_infer_embedding_bag_sharding( raise ValueError(f"Sharding type not supported {sharding_type}") +class ShardedQuantManagedCollisionContext(Multistreamable): + remapped_kjt: Optional[ListOfKJTList] = None + + def record_stream(self, stream: torch.Stream) -> None: + super().record_stream(stream) + if self.remapped_kjt is not None: + self.remapped_kjt.record_stream(stream) + + class ShardedQuantEmbeddingBagCollection( ShardedQuantEmbeddingModuleState[ ListOfKJTList, List[List[torch.Tensor]], KeyedTensor, - NullShardedModuleContext, + ShardedQuantManagedCollisionContext, ], ): """ @@ -329,7 +338,7 @@ def _create_output_dist(self, device: Optional[torch.device] = None) -> None: # pyre-ignore [14] # pyre-ignore def input_dist( - self, ctx: NullShardedModuleContext, features: KeyedJaggedTensor + self, ctx: ShardedQuantManagedCollisionContext, features: KeyedJaggedTensor ) -> ListOfKJTList: input_dist_outputs = self._input_dist_module(features) @@ -341,7 +350,7 @@ def input_dist( def compute( self, - ctx: NullShardedModuleContext, + ctx: ShardedQuantManagedCollisionContext, dist_input: ListOfKJTList, ) -> List[List[torch.Tensor]]: # syntax for torchscript @@ -350,26 +359,29 @@ def compute( # pyre-ignore def output_dist( self, - ctx: NullShardedModuleContext, + ctx: ShardedQuantManagedCollisionContext, output: List[List[torch.Tensor]], - ) -> KeyedTensor: - return construct_output_kt( - embeddings=[ - dist.forward(output[i]) for i, dist in enumerate(self._output_dists) - ], - embedding_dims=self._embedding_dims, - embedding_names=self._embedding_names, + ) -> Tuple[KeyedTensor, Optional[ListOfKJTList]]: + return ( + construct_output_kt( + embeddings=[ + dist.forward(output[i]) for i, dist in enumerate(self._output_dists) + ], + embedding_dims=self._embedding_dims, + embedding_names=self._embedding_names, + ), + None, ) # pyre-ignore def compute_and_output_dist( - self, ctx: NullShardedModuleContext, input: ListOfKJTList - ) -> KeyedTensor: + self, ctx: ShardedQuantManagedCollisionContext, input: ListOfKJTList + ) -> Tuple[KeyedTensor, Optional[ListOfKJTList]]: return self.output_dist(ctx, self.compute(ctx, input)) # pyre-ignore - def forward(self, *input, **kwargs) -> KeyedTensor: - ctx = self.create_context() + def forward(self, *input, **kwargs) -> Tuple[KeyedTensor, Optional[ListOfKJTList]]: + ctx: ShardedQuantManagedCollisionContext = self.create_context() dist_input = self.input_dist(ctx, *input, **kwargs) return self.compute_and_output_dist(ctx, dist_input) @@ -386,7 +398,7 @@ def shardings( # pyre-ignore [7] return self._sharding_type_device_group_to_sharding - def create_context(self) -> NullShardedModuleContext: + def create_context(self) -> ShardedQuantManagedCollisionContext: if is_torchdynamo_compiling(): # Context creation is not supported by dynamo yet. # Context is not needed for TW sharding => @@ -394,7 +406,7 @@ def create_context(self) -> NullShardedModuleContext: # pyre-ignore return None - return NullShardedModuleContext() + return ShardedQuantManagedCollisionContext() class QuantEmbeddingBagCollectionSharder( @@ -843,10 +855,10 @@ def _create_mcebc_lookups(self) -> None: def input_dist( self, - ctx: NullShardedModuleContext, + ctx: ShardedQuantManagedCollisionContext, features: KeyedJaggedTensor, ) -> ListOfKJTList: - # TODO: resolve incompatiblity with different contexts + # TODO: resolve incompatiblity with different contexts. until then, override context if self._has_uninitialized_output_dist: self._create_output_dist(features.device()) self._has_uninitialized_output_dist = False @@ -858,12 +870,14 @@ def input_dist( is_sequence_embedding=False, ) + # pyre-ignore def compute( self, - ctx: NullShardedModuleContext, + ctx: ShardedQuantManagedCollisionContext, dist_input: ListOfKJTList, ) -> List[List[torch.Tensor]]: ret: List[List[torch.Tensor]] = [] + ctx.remapped_kjt = dist_input for i in range(len(self._managed_collision_collection._embedding_shardings)): dist_input_i = dist_input[i] lookups = self._mcebc_lookup[i] @@ -879,16 +893,16 @@ def compute( # pyre-ignore def output_dist( self, - ctx: NullShardedModuleContext, + ctx: ShardedQuantManagedCollisionContext, output: List[List[torch.Tensor]], - ) -> Tuple[ - Union[KeyedTensor, Dict[str, JaggedTensor]], Optional[KeyedJaggedTensor] - ]: + ) -> Tuple[Union[KeyedTensor, Dict[str, JaggedTensor]], Optional[ListOfKJTList]]: # pyre-ignore [6] ebc_out = super().output_dist(ctx, output) - kjt_out: Optional[KeyedJaggedTensor] = None + kjt_out: Optional[ListOfKJTList] = None + if self._return_remapped_features: + kjt_out = ctx.remapped_kjt return ebc_out, kjt_out