Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 40 additions & 26 deletions torchrec/distributed/quant_embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
IntNBitTableBatchedEmbeddingBagsCodegen,
)
from torch import nn

from torch.distributed._shard.sharding_spec import EnumerableShardingSpec
from torchrec.distributed.embedding_lookup import EmbeddingComputeKernel
from torchrec.distributed.embedding_sharding import (
Expand Down Expand Up @@ -53,6 +52,7 @@
from torchrec.distributed.sharding.rw_sharding import InferRwPooledEmbeddingSharding
from torchrec.distributed.sharding.tw_sharding import InferTwEmbeddingSharding
from torchrec.distributed.types import (
Multistreamable,
NullShardedModuleContext,
NullShardingContext,
ParameterSharding,
Expand Down Expand Up @@ -168,12 +168,21 @@ def create_infer_embedding_bag_sharding(
raise ValueError(f"Sharding type not supported {sharding_type}")


class ShardedQuantManagedCollisionContext(Multistreamable):
remapped_kjt: Optional[ListOfKJTList] = None

def record_stream(self, stream: torch.Stream) -> None:
super().record_stream(stream)
if self.remapped_kjt is not None:
self.remapped_kjt.record_stream(stream)


class ShardedQuantEmbeddingBagCollection(
ShardedQuantEmbeddingModuleState[
ListOfKJTList,
List[List[torch.Tensor]],
KeyedTensor,
NullShardedModuleContext,
ShardedQuantManagedCollisionContext,
],
):
"""
Expand Down Expand Up @@ -329,7 +338,7 @@ def _create_output_dist(self, device: Optional[torch.device] = None) -> None:
# pyre-ignore [14]
# pyre-ignore
def input_dist(
self, ctx: NullShardedModuleContext, features: KeyedJaggedTensor
self, ctx: ShardedQuantManagedCollisionContext, features: KeyedJaggedTensor
) -> ListOfKJTList:
input_dist_outputs = self._input_dist_module(features)

Expand All @@ -341,7 +350,7 @@ def input_dist(

def compute(
self,
ctx: NullShardedModuleContext,
ctx: ShardedQuantManagedCollisionContext,
dist_input: ListOfKJTList,
) -> List[List[torch.Tensor]]:
# syntax for torchscript
Expand All @@ -350,26 +359,29 @@ def compute(
# pyre-ignore
def output_dist(
self,
ctx: NullShardedModuleContext,
ctx: ShardedQuantManagedCollisionContext,
output: List[List[torch.Tensor]],
) -> KeyedTensor:
return construct_output_kt(
embeddings=[
dist.forward(output[i]) for i, dist in enumerate(self._output_dists)
],
embedding_dims=self._embedding_dims,
embedding_names=self._embedding_names,
) -> Tuple[KeyedTensor, Optional[ListOfKJTList]]:
return (
construct_output_kt(
embeddings=[
dist.forward(output[i]) for i, dist in enumerate(self._output_dists)
],
embedding_dims=self._embedding_dims,
embedding_names=self._embedding_names,
),
None,
)

# pyre-ignore
def compute_and_output_dist(
self, ctx: NullShardedModuleContext, input: ListOfKJTList
) -> KeyedTensor:
self, ctx: ShardedQuantManagedCollisionContext, input: ListOfKJTList
) -> Tuple[KeyedTensor, Optional[ListOfKJTList]]:
return self.output_dist(ctx, self.compute(ctx, input))

# pyre-ignore
def forward(self, *input, **kwargs) -> KeyedTensor:
ctx = self.create_context()
def forward(self, *input, **kwargs) -> Tuple[KeyedTensor, Optional[ListOfKJTList]]:
ctx: ShardedQuantManagedCollisionContext = self.create_context()
dist_input = self.input_dist(ctx, *input, **kwargs)
return self.compute_and_output_dist(ctx, dist_input)

Expand All @@ -386,15 +398,15 @@ def shardings(
# pyre-ignore [7]
return self._sharding_type_device_group_to_sharding

def create_context(self) -> NullShardedModuleContext:
def create_context(self) -> ShardedQuantManagedCollisionContext:
if is_torchdynamo_compiling():
# Context creation is not supported by dynamo yet.
# Context is not needed for TW sharding =>
# Unblocking dynamo TW with None.
# pyre-ignore
return None

return NullShardedModuleContext()
return ShardedQuantManagedCollisionContext()


class QuantEmbeddingBagCollectionSharder(
Expand Down Expand Up @@ -843,10 +855,10 @@ def _create_mcebc_lookups(self) -> None:

def input_dist(
self,
ctx: NullShardedModuleContext,
ctx: ShardedQuantManagedCollisionContext,
features: KeyedJaggedTensor,
) -> ListOfKJTList:
# TODO: resolve incompatiblity with different contexts
# TODO: resolve incompatiblity with different contexts. until then, override context
if self._has_uninitialized_output_dist:
self._create_output_dist(features.device())
self._has_uninitialized_output_dist = False
Expand All @@ -858,12 +870,14 @@ def input_dist(
is_sequence_embedding=False,
)

# pyre-ignore
def compute(
self,
ctx: NullShardedModuleContext,
ctx: ShardedQuantManagedCollisionContext,
dist_input: ListOfKJTList,
) -> List[List[torch.Tensor]]:
ret: List[List[torch.Tensor]] = []
ctx.remapped_kjt = dist_input
for i in range(len(self._managed_collision_collection._embedding_shardings)):
dist_input_i = dist_input[i]
lookups = self._mcebc_lookup[i]
Expand All @@ -879,16 +893,16 @@ def compute(
# pyre-ignore
def output_dist(
self,
ctx: NullShardedModuleContext,
ctx: ShardedQuantManagedCollisionContext,
output: List[List[torch.Tensor]],
) -> Tuple[
Union[KeyedTensor, Dict[str, JaggedTensor]], Optional[KeyedJaggedTensor]
]:
) -> Tuple[Union[KeyedTensor, Dict[str, JaggedTensor]], Optional[ListOfKJTList]]:

# pyre-ignore [6]
ebc_out = super().output_dist(ctx, output)

kjt_out: Optional[KeyedJaggedTensor] = None
kjt_out: Optional[ListOfKJTList] = None
if self._return_remapped_features:
kjt_out = ctx.remapped_kjt

return ebc_out, kjt_out

Expand Down