Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 63 additions & 3 deletions torchrec/distributed/fp_embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,21 @@
# pyre-strict

from functools import partial
from typing import Any, Dict, Iterator, List, Optional, Type, Union
from typing import (
Any,
Dict,
Iterator,
List,
Mapping,
Optional,
Tuple,
Type,
TypeVar,
Union,
)

import torch
from torch import nn

from torchrec.distributed.embedding_types import (
BaseEmbeddingSharder,
KJTList,
Expand All @@ -31,14 +41,20 @@
ShardingEnv,
ShardingType,
)
from torchrec.distributed.utils import append_prefix, init_parameters
from torchrec.distributed.utils import (
append_prefix,
init_parameters,
modify_input_for_feature_processor,
)
from torchrec.modules.feature_processor_ import FeatureProcessorsCollection
from torchrec.modules.fp_embedding_modules import (
apply_feature_processors_to_kjt,
FeatureProcessedEmbeddingBagCollection,
)
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor

_T = TypeVar("_T")


def param_dp_sync(kt: KeyedTensor, no_op_tensor: torch.Tensor) -> KeyedTensor:
kt._values.add_(no_op_tensor)
Expand Down Expand Up @@ -74,6 +90,16 @@ def __init__(
)
)

self._row_wise_sharded: bool = False
for param_sharding in table_name_to_parameter_sharding.values():
if param_sharding.sharding_type in [
ShardingType.ROW_WISE.value,
ShardingType.TABLE_ROW_WISE.value,
ShardingType.GRID_SHARD.value,
]:
self._row_wise_sharded = True
break

self._lookups: List[nn.Module] = self._embedding_bag_collection._lookups

self._is_collection: bool = False
Expand All @@ -96,6 +122,14 @@ def __init__(
def input_dist(
self, ctx: EmbeddingBagCollectionContext, features: KeyedJaggedTensor
) -> Awaitable[Awaitable[KJTList]]:
if torch._utils_internal.justknobs_check(
"pytorch/torchrec:enable_rw_feature_processor"
):
if not self.is_pipelined and self._row_wise_sharded:
# transform input to support row based sharding when not pipelined
modify_input_for_feature_processor(
features, self._feature_processors, self._is_collection
)
return self._embedding_bag_collection.input_dist(ctx, features)

def apply_feature_processors_to_kjt_list(self, dist_input: KJTList) -> KJTList:
Expand Down Expand Up @@ -166,6 +200,21 @@ def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]:
def _initialize_torch_state(self, skip_registering: bool = False) -> None: # noqa
self._embedding_bag_collection._initialize_torch_state(skip_registering)

def preprocess_input(
self, args: List[_T], kwargs: Mapping[str, _T]
) -> Tuple[List[_T], Mapping[str, _T]]:
if self._row_wise_sharded:
# only if we are row wise sharded we will pre process the KJT
# and call the associated preprocess methods in feature processor
for x in args + list(kwargs.values()):
if isinstance(x, KeyedJaggedTensor):
modify_input_for_feature_processor(
features=x,
feature_processors=self._feature_processors,
is_collection=self._is_collection,
)
return args, kwargs


class FeatureProcessedEmbeddingBagCollectionSharder(
BaseEmbeddingSharder[FeatureProcessedEmbeddingBagCollection]
Expand Down Expand Up @@ -236,4 +285,15 @@ def sharding_types(self, compute_device_type: str) -> List[str]:
ShardingType.TABLE_COLUMN_WISE.value,
]

if torch._utils_internal.justknobs_check(
"pytorch/torchrec:enable_rw_feature_processor"
):
types.extend(
[
ShardingType.TABLE_ROW_WISE.value,
ShardingType.ROW_WISE.value,
ShardingType.GRID_SHARD.value,
]
)

return types
36 changes: 32 additions & 4 deletions torchrec/distributed/planner/enumerators.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,15 @@
EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE,
}

# sharding types that require explicit user specification for feature-processed modules
# row wise sharding uses a different pipelined configuration for feature processing
# to guard against areas that aren't well tested, user must specify row wise sharding
GUARDED_SHARDING_TYPES_FOR_FP_MODULES: Set[str] = {
ShardingType.ROW_WISE.value,
ShardingType.TABLE_ROW_WISE.value,
ShardingType.GRID_SHARD.value,
}


class EmbeddingEnumerator(Enumerator):
"""
Expand Down Expand Up @@ -186,7 +195,7 @@ def enumerate(
sharding_options_per_table: List[ShardingOption] = []

for sharding_type in self._filter_sharding_types(
name, sharder.sharding_types(self._compute_device)
name, sharder.sharding_types(self._compute_device), sharder_key
):
for compute_kernel in self._filter_compute_kernels(
name,
Expand Down Expand Up @@ -308,18 +317,37 @@ def populate_estimates(self, sharding_options: List[ShardingOption]) -> None:
estimator.estimate(sharding_options, self._sharder_map)

def _filter_sharding_types(
self, name: str, allowed_sharding_types: List[str]
self, name: str, allowed_sharding_types: List[str], sharder_key: str = ""
) -> List[str]:
# GRID_SHARD is only supported if specified by user in parameter constraints
# ROW based shardings only supported for FP modules if specified by user in parameter constraints
if not self._constraints or not self._constraints.get(name):
return [
filtered = [
t for t in allowed_sharding_types if t != ShardingType.GRID_SHARD.value
]
# For feature-processed modules, row-wise sharding types require explicit
# user specification due to potential issues with position weighted FPs
if "FeatureProcessedEmbeddingBagCollection" in sharder_key:
filtered = [
t
for t in filtered
if t not in GUARDED_SHARDING_TYPES_FOR_FP_MODULES
]
return filtered
constraints: ParameterConstraints = self._constraints[name]
if not constraints.sharding_types:
return [
filtered = [
t for t in allowed_sharding_types if t != ShardingType.GRID_SHARD.value
]
# For feature-processed modules, row-wise sharding types require explicit
# user specification due to potential issues with position weighted FPs
if "FeatureProcessedEmbeddingBagCollection" in sharder_key:
filtered = [
t
for t in filtered
if t not in GUARDED_SHARDING_TYPES_FOR_FP_MODULES
]
return filtered
constrained_sharding_types: List[str] = constraints.sharding_types

filtered_sharding_types = list(
Expand Down
73 changes: 73 additions & 0 deletions torchrec/distributed/planner/tests/test_enumerators.py
Original file line number Diff line number Diff line change
Expand Up @@ -1115,3 +1115,76 @@ def test_throw_ex_no_sharding_option_for_table(self) -> None:
"Module: torchrec.modules.embedding_modules.EmbeddingBagCollection, sharder: CWSharder, compute device: cuda. "
"To debug, search above for warning logs about no available sharding types/compute kernels for table: table_1",
)

def test_filter_sharding_types_fp_ebc_no_constraints(self) -> None:
"""Test that row-wise sharding types are filtered out for FP modules without constraints."""
enumerator = EmbeddingEnumerator(
topology=MagicMock(),
batch_size=MagicMock(),
constraints=None,
)

# Test with FeatureProcessedEmbeddingBagCollection sharder key
# Without constraints, row-wise types should be filtered out
all_sharding_types = [
ShardingType.DATA_PARALLEL.value,
ShardingType.TABLE_WISE.value,
ShardingType.ROW_WISE.value,
ShardingType.TABLE_ROW_WISE.value,
ShardingType.COLUMN_WISE.value,
ShardingType.GRID_SHARD.value,
]

allowed_sharding_types = enumerator._filter_sharding_types(
"table_0",
all_sharding_types,
"torchrec.modules.fp_embedding_modules.FeatureProcessedEmbeddingBagCollection",
)

# ROW_WISE, TABLE_ROW_WISE, and GRID_SHARD should be filtered out
self.assertEqual(
set(allowed_sharding_types),
{
ShardingType.DATA_PARALLEL.value,
ShardingType.TABLE_WISE.value,
ShardingType.COLUMN_WISE.value,
},
)

def test_filter_sharding_types_fp_ebc_with_row_wise_constraint(self) -> None:
"""Test that row-wise sharding types are allowed for FP modules with explicit constraints."""
constraint = ParameterConstraints(
sharding_types=[
ShardingType.ROW_WISE.value,
ShardingType.TABLE_WISE.value,
],
)
constraints = {"table_0": constraint}
enumerator = EmbeddingEnumerator(
topology=MagicMock(),
batch_size=MagicMock(),
constraints=constraints,
)

all_sharding_types = [
ShardingType.DATA_PARALLEL.value,
ShardingType.TABLE_WISE.value,
ShardingType.ROW_WISE.value,
ShardingType.TABLE_ROW_WISE.value,
ShardingType.COLUMN_WISE.value,
]

# With explicit constraint specifying ROW_WISE, it should be allowed
allowed_sharding_types = enumerator._filter_sharding_types(
"table_0",
all_sharding_types,
"torchrec.modules.fp_embedding_modules.FeatureProcessedEmbeddingBagCollection",
)

self.assertEqual(
set(allowed_sharding_types),
{
ShardingType.ROW_WISE.value,
ShardingType.TABLE_WISE.value,
},
)
Loading
Loading