diff --git a/torchrec/distributed/fp_embeddingbag.py b/torchrec/distributed/fp_embeddingbag.py index 4b069437f..40c13db85 100644 --- a/torchrec/distributed/fp_embeddingbag.py +++ b/torchrec/distributed/fp_embeddingbag.py @@ -8,11 +8,21 @@ # pyre-strict from functools import partial -from typing import Any, Dict, Iterator, List, Optional, Type, Union +from typing import ( + Any, + Dict, + Iterator, + List, + Mapping, + Optional, + Tuple, + Type, + TypeVar, + Union, +) import torch from torch import nn - from torchrec.distributed.embedding_types import ( BaseEmbeddingSharder, KJTList, @@ -31,7 +41,11 @@ ShardingEnv, ShardingType, ) -from torchrec.distributed.utils import append_prefix, init_parameters +from torchrec.distributed.utils import ( + append_prefix, + init_parameters, + modify_input_for_feature_processor, +) from torchrec.modules.feature_processor_ import FeatureProcessorsCollection from torchrec.modules.fp_embedding_modules import ( apply_feature_processors_to_kjt, @@ -39,6 +53,8 @@ ) from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor +_T = TypeVar("_T") + def param_dp_sync(kt: KeyedTensor, no_op_tensor: torch.Tensor) -> KeyedTensor: kt._values.add_(no_op_tensor) @@ -74,6 +90,16 @@ def __init__( ) ) + self._row_wise_sharded: bool = False + for param_sharding in table_name_to_parameter_sharding.values(): + if param_sharding.sharding_type in [ + ShardingType.ROW_WISE.value, + ShardingType.TABLE_ROW_WISE.value, + ShardingType.GRID_SHARD.value, + ]: + self._row_wise_sharded = True + break + self._lookups: List[nn.Module] = self._embedding_bag_collection._lookups self._is_collection: bool = False @@ -96,6 +122,14 @@ def __init__( def input_dist( self, ctx: EmbeddingBagCollectionContext, features: KeyedJaggedTensor ) -> Awaitable[Awaitable[KJTList]]: + if torch._utils_internal.justknobs_check( + "pytorch/torchrec:enable_rw_feature_processor" + ): + if not self.is_pipelined and self._row_wise_sharded: + # transform input to support row based sharding when not pipelined + modify_input_for_feature_processor( + features, self._feature_processors, self._is_collection + ) return self._embedding_bag_collection.input_dist(ctx, features) def apply_feature_processors_to_kjt_list(self, dist_input: KJTList) -> KJTList: @@ -166,6 +200,21 @@ def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]: def _initialize_torch_state(self, skip_registering: bool = False) -> None: # noqa self._embedding_bag_collection._initialize_torch_state(skip_registering) + def preprocess_input( + self, args: List[_T], kwargs: Mapping[str, _T] + ) -> Tuple[List[_T], Mapping[str, _T]]: + if self._row_wise_sharded: + # only if we are row wise sharded we will pre process the KJT + # and call the associated preprocess methods in feature processor + for x in args + list(kwargs.values()): + if isinstance(x, KeyedJaggedTensor): + modify_input_for_feature_processor( + features=x, + feature_processors=self._feature_processors, + is_collection=self._is_collection, + ) + return args, kwargs + class FeatureProcessedEmbeddingBagCollectionSharder( BaseEmbeddingSharder[FeatureProcessedEmbeddingBagCollection] @@ -236,4 +285,15 @@ def sharding_types(self, compute_device_type: str) -> List[str]: ShardingType.TABLE_COLUMN_WISE.value, ] + if torch._utils_internal.justknobs_check( + "pytorch/torchrec:enable_rw_feature_processor" + ): + types.extend( + [ + ShardingType.TABLE_ROW_WISE.value, + ShardingType.ROW_WISE.value, + ShardingType.GRID_SHARD.value, + ] + ) + return types diff --git a/torchrec/distributed/planner/enumerators.py b/torchrec/distributed/planner/enumerators.py index 202be6b71..b24fe2cfa 100644 --- a/torchrec/distributed/planner/enumerators.py +++ b/torchrec/distributed/planner/enumerators.py @@ -54,6 +54,15 @@ EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE, } +# sharding types that require explicit user specification for feature-processed modules +# row wise sharding uses a different pipelined configuration for feature processing +# to guard against areas that aren't well tested, user must specify row wise sharding +GUARDED_SHARDING_TYPES_FOR_FP_MODULES: Set[str] = { + ShardingType.ROW_WISE.value, + ShardingType.TABLE_ROW_WISE.value, + ShardingType.GRID_SHARD.value, +} + class EmbeddingEnumerator(Enumerator): """ @@ -186,7 +195,7 @@ def enumerate( sharding_options_per_table: List[ShardingOption] = [] for sharding_type in self._filter_sharding_types( - name, sharder.sharding_types(self._compute_device) + name, sharder.sharding_types(self._compute_device), sharder_key ): for compute_kernel in self._filter_compute_kernels( name, @@ -308,18 +317,37 @@ def populate_estimates(self, sharding_options: List[ShardingOption]) -> None: estimator.estimate(sharding_options, self._sharder_map) def _filter_sharding_types( - self, name: str, allowed_sharding_types: List[str] + self, name: str, allowed_sharding_types: List[str], sharder_key: str = "" ) -> List[str]: # GRID_SHARD is only supported if specified by user in parameter constraints + # ROW based shardings only supported for FP modules if specified by user in parameter constraints if not self._constraints or not self._constraints.get(name): - return [ + filtered = [ t for t in allowed_sharding_types if t != ShardingType.GRID_SHARD.value ] + # For feature-processed modules, row-wise sharding types require explicit + # user specification due to potential issues with position weighted FPs + if "FeatureProcessedEmbeddingBagCollection" in sharder_key: + filtered = [ + t + for t in filtered + if t not in GUARDED_SHARDING_TYPES_FOR_FP_MODULES + ] + return filtered constraints: ParameterConstraints = self._constraints[name] if not constraints.sharding_types: - return [ + filtered = [ t for t in allowed_sharding_types if t != ShardingType.GRID_SHARD.value ] + # For feature-processed modules, row-wise sharding types require explicit + # user specification due to potential issues with position weighted FPs + if "FeatureProcessedEmbeddingBagCollection" in sharder_key: + filtered = [ + t + for t in filtered + if t not in GUARDED_SHARDING_TYPES_FOR_FP_MODULES + ] + return filtered constrained_sharding_types: List[str] = constraints.sharding_types filtered_sharding_types = list( diff --git a/torchrec/distributed/planner/tests/test_enumerators.py b/torchrec/distributed/planner/tests/test_enumerators.py index 39a39d9f0..c625d0111 100644 --- a/torchrec/distributed/planner/tests/test_enumerators.py +++ b/torchrec/distributed/planner/tests/test_enumerators.py @@ -1115,3 +1115,76 @@ def test_throw_ex_no_sharding_option_for_table(self) -> None: "Module: torchrec.modules.embedding_modules.EmbeddingBagCollection, sharder: CWSharder, compute device: cuda. " "To debug, search above for warning logs about no available sharding types/compute kernels for table: table_1", ) + + def test_filter_sharding_types_fp_ebc_no_constraints(self) -> None: + """Test that row-wise sharding types are filtered out for FP modules without constraints.""" + enumerator = EmbeddingEnumerator( + topology=MagicMock(), + batch_size=MagicMock(), + constraints=None, + ) + + # Test with FeatureProcessedEmbeddingBagCollection sharder key + # Without constraints, row-wise types should be filtered out + all_sharding_types = [ + ShardingType.DATA_PARALLEL.value, + ShardingType.TABLE_WISE.value, + ShardingType.ROW_WISE.value, + ShardingType.TABLE_ROW_WISE.value, + ShardingType.COLUMN_WISE.value, + ShardingType.GRID_SHARD.value, + ] + + allowed_sharding_types = enumerator._filter_sharding_types( + "table_0", + all_sharding_types, + "torchrec.modules.fp_embedding_modules.FeatureProcessedEmbeddingBagCollection", + ) + + # ROW_WISE, TABLE_ROW_WISE, and GRID_SHARD should be filtered out + self.assertEqual( + set(allowed_sharding_types), + { + ShardingType.DATA_PARALLEL.value, + ShardingType.TABLE_WISE.value, + ShardingType.COLUMN_WISE.value, + }, + ) + + def test_filter_sharding_types_fp_ebc_with_row_wise_constraint(self) -> None: + """Test that row-wise sharding types are allowed for FP modules with explicit constraints.""" + constraint = ParameterConstraints( + sharding_types=[ + ShardingType.ROW_WISE.value, + ShardingType.TABLE_WISE.value, + ], + ) + constraints = {"table_0": constraint} + enumerator = EmbeddingEnumerator( + topology=MagicMock(), + batch_size=MagicMock(), + constraints=constraints, + ) + + all_sharding_types = [ + ShardingType.DATA_PARALLEL.value, + ShardingType.TABLE_WISE.value, + ShardingType.ROW_WISE.value, + ShardingType.TABLE_ROW_WISE.value, + ShardingType.COLUMN_WISE.value, + ] + + # With explicit constraint specifying ROW_WISE, it should be allowed + allowed_sharding_types = enumerator._filter_sharding_types( + "table_0", + all_sharding_types, + "torchrec.modules.fp_embedding_modules.FeatureProcessedEmbeddingBagCollection", + ) + + self.assertEqual( + set(allowed_sharding_types), + { + ShardingType.ROW_WISE.value, + ShardingType.TABLE_WISE.value, + }, + ) diff --git a/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py b/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py index e5c5d5d7f..71577b1fe 100644 --- a/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py +++ b/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py @@ -8,7 +8,6 @@ # pyre-strict import copy - import unittest from contextlib import contextmanager, ExitStack from dataclasses import dataclass @@ -22,7 +21,10 @@ from torch._dynamo.testing import reduce_to_scalar_loss from torch._dynamo.utils import counters from torchrec.distributed import DistributedModelParallel -from torchrec.distributed.embedding_types import EmbeddingComputeKernel +from torchrec.distributed.embedding_types import ( + EmbeddingComputeKernel, + EmbeddingTableConfig, +) from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder from torchrec.distributed.fp_embeddingbag import ( FeatureProcessedEmbeddingBagCollectionSharder, @@ -31,8 +33,13 @@ from torchrec.distributed.model_parallel import DMPCollection from torchrec.distributed.sharding_plan import ( construct_module_sharding_plan, + row_wise, table_wise, ) +from torchrec.distributed.test_utils.multi_process import ( + MultiProcessContext, + MultiProcessTestBase, +) from torchrec.distributed.test_utils.test_model import ( ModelInput, TestNegSamplingModule, @@ -78,7 +85,6 @@ ) from torchrec.modules.embedding_configs import DataType, EmbeddingBagConfig from torchrec.modules.embedding_modules import EmbeddingBagCollection - from torchrec.optim.keyed import KeyedOptimizerWrapper from torchrec.optim.optimizers import in_backward_optimizer_filter from torchrec.pt2.utils import kjt_for_pt2_tracing @@ -695,6 +701,161 @@ def custom_model_fwd( self.assertEqual(pred_pipeline.size(0), 64) +def fp_ebc_rw_sharding_test_runner( + rank: int, + world_size: int, + tables: List[EmbeddingTableConfig], + weighted_tables: List[EmbeddingTableConfig], + data: List[Tuple[ModelInput, List[ModelInput]]], + backend: str = "nccl", + local_size: Optional[int] = None, +) -> None: + with MultiProcessContext(rank, world_size, backend, local_size) as ctx: + assert ctx.pg is not None + sharder = cast( + ModuleSharder[nn.Module], + FeatureProcessedEmbeddingBagCollectionSharder(), + ) + + class DummyWrapper(nn.Module): + def __init__(self, sparse_arch): + super().__init__() + self.m = sparse_arch + + def forward(self, model_input) -> Tuple[torch.Tensor, torch.Tensor]: + return self.m(model_input.idlist_features) + + max_feature_lengths = [10, 10, 12, 12] + sparse_arch = DummyWrapper( + create_module_and_freeze( + tables=tables, # pyre-ignore[6] + device=ctx.device, + use_fp_collection=False, + max_feature_lengths=max_feature_lengths, + ) + ) + + # compute_kernel = EmbeddingComputeKernel.FUSED.value + module_sharding_plan = construct_module_sharding_plan( + sparse_arch.m._fp_ebc, + per_param_sharding={ + "table_0": row_wise(), + "table_1": row_wise(), + "table_2": row_wise(), + "table_3": row_wise(), + }, + world_size=2, + device_type=ctx.device.type, + sharder=sharder, + ) + sharded_sparse_arch_pipeline = DistributedModelParallel( + module=copy.deepcopy(sparse_arch), + plan=ShardingPlan({"m._fp_ebc": module_sharding_plan}), + env=ShardingEnv.from_process_group(ctx.pg), # pyre-ignore[6] + sharders=[sharder], + device=ctx.device, + ) + sharded_sparse_arch_no_pipeline = DistributedModelParallel( + module=copy.deepcopy(sparse_arch), + plan=ShardingPlan({"m._fp_ebc": module_sharding_plan}), + env=ShardingEnv.from_process_group(ctx.pg), # pyre-ignore[6] + sharders=[sharder], + device=ctx.device, + ) + + batches = [] + for d in data: + batches.append(d[1][ctx.rank].to(ctx.device)) + dataloader = iter(batches) + + optimizer_no_pipeline = optim.SGD( + sharded_sparse_arch_no_pipeline.parameters(), lr=0.1 + ) + optimizer_pipeline = optim.SGD( + sharded_sparse_arch_pipeline.parameters(), lr=0.1 + ) + + pipeline = TrainPipelineSparseDist( + sharded_sparse_arch_pipeline, + optimizer_pipeline, + ctx.device, + ) + + for batch in batches[:-2]: + batch = batch.to(ctx.device) + optimizer_no_pipeline.zero_grad() + loss, pred = sharded_sparse_arch_no_pipeline(batch) + loss.backward() + optimizer_no_pipeline.step() + + pred_pipeline = pipeline.progress(dataloader) + torch.testing.assert_close(pred_pipeline.cpu(), pred.cpu()) + + +class TrainPipelineGPUTest(MultiProcessTestBase): + def setUp(self, backend: str = "nccl") -> None: + super().setUp() + + self.pipeline_class = TrainPipelineSparseDist + num_features = 4 + num_weighted_features = 4 + self.tables = [ + EmbeddingBagConfig( + num_embeddings=(i + 1) * 100, + embedding_dim=(i + 1) * 4, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + ) + for i in range(num_features) + ] + self.weighted_tables = [ + EmbeddingBagConfig( + num_embeddings=(i + 1) * 100, + embedding_dim=(i + 1) * 4, + name="weighted_table_" + str(i), + feature_names=["weighted_feature_" + str(i)], + ) + for i in range(num_weighted_features) + ] + + self.backend = backend + if torch.cuda.is_available(): + self.device = torch.device("cuda") + else: + self.device = torch.device("cpu") + + if self.backend == "nccl" and self.device == torch.device("cpu"): + self.skipTest("NCCL not supported on CPUs.") + + def _generate_data( + self, + num_batches: int = 5, + batch_size: int = 1, + max_feature_lengths: Optional[List[int]] = None, + ) -> List[Tuple[ModelInput, List[ModelInput]]]: + return [ + ModelInput.generate( + tables=self.tables, + weighted_tables=self.weighted_tables, + batch_size=batch_size, + world_size=2, + num_float_features=10, + max_feature_lengths=max_feature_lengths, + ) + for i in range(num_batches) + ] + + def test_fp_ebc_rw(self) -> None: + data = self._generate_data(max_feature_lengths=[10, 10, 12, 12]) + self._run_multi_process_test( + callable=fp_ebc_rw_sharding_test_runner, + world_size=2, + tables=self.tables, + weighted_tables=self.weighted_tables, + data=data, + ) + + class TrainPipelineSparseDist2DShardingTest(unittest.TestCase): @contextmanager def _mocked_pipeline(self, obj: T) -> Generator[T, None, None]: diff --git a/torchrec/distributed/train_pipeline/utils.py b/torchrec/distributed/train_pipeline/utils.py index 8bea1ff37..662a14550 100644 --- a/torchrec/distributed/train_pipeline/utils.py +++ b/torchrec/distributed/train_pipeline/utils.py @@ -11,7 +11,6 @@ import logging from collections import defaultdict from contextlib import AbstractContextManager - from threading import Event, Thread from typing import ( Any, @@ -29,7 +28,6 @@ import torch from torch.profiler import record_function - from torchrec.distributed.dist_data import KJTAllToAll, KJTAllToAllTensorsAwaitable from torchrec.distributed.embedding_sharding import ( FusedKJTListSplitsAwaitable, @@ -169,6 +167,10 @@ def _start_data_dist( # and this info was done in the _rewrite_model by tracing the # entire model to get the arg_info_list args, kwargs = forward.args.build_args_kwargs(batch) + if torch._utils_internal.justknobs_check( + "pytorch/torchrec:enable_rw_feature_processor" + ): + args, kwargs = module.preprocess_input(args, kwargs) # Start input distribution. module_ctx = module.create_context() @@ -404,6 +406,8 @@ def _rewrite_model( # noqa C901 logger.info(f"Module '{node.target}' will be pipelined") child = sharded_modules[node.target] original_forwards.append(child.forward) + # Set pipelining flag on the child module + child.is_pipelined = True # pyre-ignore[8] Incompatible attribute type child.forward = pipelined_forward( node.target, diff --git a/torchrec/distributed/types.py b/torchrec/distributed/types.py index e74aab933..228c22789 100644 --- a/torchrec/distributed/types.py +++ b/torchrec/distributed/types.py @@ -19,7 +19,10 @@ Generic, Iterator, List, + Mapping, Optional, + ParamSpec, + Sequence, Tuple, Type, TypeVar, @@ -79,6 +82,8 @@ class GenericMeta(type): ) from torchrec.streamable import Multistreamable +_T = TypeVar("_T") + def _tabulate( table: List[List[Union[str, int]]], headers: Optional[List[str]] = None @@ -1098,6 +1103,8 @@ def __init__( if qcomm_codecs_registry is None: qcomm_codecs_registry = {} self._qcomm_codecs_registry = qcomm_codecs_registry + # In pipelining, this flag is flipped in rewrite_model when the forward is replaced with the pipelined forward + self.is_pipelined = False @abc.abstractmethod def create_context(self) -> ShrdCtx: @@ -1200,6 +1207,19 @@ def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]: for key, _ in self.named_parameters(prefix): yield key + def preprocess_input( + self, + args: List[_T], + kwargs: Mapping[str, _T], + ) -> Tuple[List[_T], Mapping[str, _T]]: + """ + This function can be used to preprocess the input arguments prior to module forward call. + + For example, it is used in ShardedFeatureProcessorEmbeddingBagCollection to transform the input data + prior to the forward call. + """ + return args, kwargs + @property @abc.abstractmethod def unsharded_module_type(self) -> Type[nn.Module]: diff --git a/torchrec/distributed/utils.py b/torchrec/distributed/utils.py index b12660e97..bfeea20d4 100644 --- a/torchrec/distributed/utils.py +++ b/torchrec/distributed/utils.py @@ -26,8 +26,10 @@ from torch import nn from torch.autograd.profiler import record_function from torchrec import optim as trec_optim -from torchrec.distributed.embedding_types import EmbeddingComputeKernel - +from torchrec.distributed.embedding_types import ( + EmbeddingComputeKernel, + KeyedJaggedTensor, +) from torchrec.distributed.types import ( DataType, EmbeddingEvent, @@ -38,6 +40,7 @@ ShardMetadata, ) from torchrec.modules.embedding_configs import data_type_to_sparse_type +from torchrec.modules.feature_processor_ import FeatureProcessorsCollection from torchrec.types import CopyMixIn logger: logging.Logger = logging.getLogger(__name__) @@ -758,3 +761,47 @@ def _recalculate_torch_state_helper( _recalculate_torch_state_helper(child) _recalculate_torch_state_helper(module) + + +def modify_input_for_feature_processor( + features: KeyedJaggedTensor, + feature_processors: Union[nn.ModuleDict, FeatureProcessorsCollection], + is_collection: bool, +) -> None: + """ + This function applies the feature processor pre input dist. This way we + can support row wise based sharding mechanisms. + + This is an inplace modifcation of the input KJT. + """ + with torch.no_grad(): + if features.weights_or_none() is None: + # force creation of weights, this way the feature jagged tensor weights are tied to the original KJT + features._weights = torch.zeros_like(features.values(), dtype=torch.float32) + + if is_collection: + if hasattr(feature_processors, "pre_process_input"): + feature_processors.pre_process_input(features) # pyre-ignore[29] + else: + logging.info( + f"[Feature Processor Pipeline] Skipping pre_process_input for feature processor {feature_processors=}" + ) + else: + # per feature process + for feature in features.keys(): + if feature in feature_processors: # pyre-ignore[58] + feature_processor = feature_processors[feature] # pyre-ignore[29] + if hasattr(feature_processor, "pre_process_input"): + feature_processor.pre_process_input(features[feature]) + else: + logging.info( + f"[Feature Processor Pipeline] Skipping pre_process_input for feature processor {feature_processor=}" + ) + else: + features[feature].weights().copy_( + torch.ones( + features[feature].values().shape[0], + dtype=torch.float32, + device=features[feature].values().device, + ) + ) diff --git a/torchrec/modules/feature_processor_.py b/torchrec/modules/feature_processor_.py index 707f5bd2b..e612047af 100644 --- a/torchrec/modules/feature_processor_.py +++ b/torchrec/modules/feature_processor_.py @@ -14,7 +14,7 @@ import torch -from torch import nn +from torch import distributed as dist, nn from torch.nn.modules.module import _IncompatibleKeys from torchrec.pt2.checks import is_non_strict_exporting @@ -72,6 +72,7 @@ def __init__( torch.empty([max_feature_length], device=device), requires_grad=True, ) + self._pre_processed = False self.reset_parameters() @@ -90,10 +91,13 @@ def forward( Returns: JaggedTensor: same as input features with `weights` field being populated. """ - - seq = torch.ops.fbgemm.offsets_range( - features.offsets().long(), torch.numel(features.values()) - ) + if self._pre_processed: + # position is embedded as weights + seq = features.weights().clone().to(torch.int64) + else: + seq = torch.ops.fbgemm.offsets_range( + features.offsets().long(), torch.numel(features.values()) + ) weighted_features = JaggedTensor( values=features.values(), lengths=features.lengths(), @@ -102,6 +106,20 @@ def forward( ) return weighted_features + def pre_process_input(self, features: JaggedTensor) -> None: + """ + Args: + features (JaggedTensor]): feature representation + + Returns: + torch.Tensor: position weights + """ + self._pre_processed = True + cat_seq = torch.ops.fbgemm.offsets_range( + features.offsets().long(), torch.numel(features.values()) + ) + features.weights().copy_(cat_seq.to(torch.float32)) + class FeatureProcessorsCollection(nn.Module): """ @@ -169,6 +187,7 @@ def __init__( for length in self.max_feature_lengths.values(): if length <= 0: raise + self._pre_processed = False self.position_weights: nn.ParameterDict = nn.ParameterDict() # needed since nn.ParameterDict isn't torchscriptable (get_items) @@ -203,9 +222,12 @@ def forward(self, features: KeyedJaggedTensor) -> KeyedJaggedTensor: features.offsets().long(), torch.numel(features.values()) ) else: - cat_seq = torch.ops.fbgemm.offsets_range( - features.offsets().long(), torch.numel(features.values()) - ) + if self._pre_processed: + cat_seq = features.weights().clone().to(torch.int64) + else: + cat_seq = torch.ops.fbgemm.offsets_range( + features.offsets().long(), torch.numel(features.values()) + ) return KeyedJaggedTensor( keys=features.keys(), @@ -245,3 +267,10 @@ def load_state_dict( for k, param in self.position_weights.items(): self.position_weights_dict[k] = param return result + + def pre_process_input(self, features: KeyedJaggedTensor) -> None: + self._pre_processed = True + cat_seq = torch.ops.fbgemm.offsets_range( + features.offsets().long(), torch.numel(features.values()) + ) + features.weights().copy_(cat_seq.to(torch.float32))