Skip to content

Commit c7557d4

Browse files
iamzainhudafacebook-github-bot
authored andcommitted
rowwise for feature processors
Summary: In this diff we introduce row based sharding (TWRW, RW, GRID) type support for feature processors. Previously, feature processors did not support row based sharding since feature processors are data parallel. This means by splitting up the input for row based shards the accessed feature processor weights were in correct. In column/data sharding based approaches, the data is duplicated ensuring the correct weight is accessed across ranks. The indices/buckets are calculated post input split/distribution, to make it compatible with row based sharding we calculate this pre input split/distribution. This couples the train pipeline and feature processors. For each feature, we preprocess the input and place the calculated indices in KJT.weights, this propagates the indices correctly and indexs into the right weight to use for the final step in the feature processing. This applies in both pipelined and non pipelined situations - the input modification is done either at the pipelined forward call or in the input dist of the FPEBC. This is determined by the pipelining flag set through rewrite_model in train pipeline. **Previous versions of this diff were reverted as this change applied to all feature processors regardless of row wise sharding applied which surfaced errors that are not captured in usual E2E and unit tests. We now gate the change in two ways: 1) row based shardings must be specified by users to be applied for FP sharding and 2) pre processing input in pipeline will ONLY happen when row based sharding is present. This way FP sharding without row based shardings applied will go through the original forward path.** Differential Revision: D88093763
1 parent 5cf0f0d commit c7557d4

File tree

8 files changed

+429
-19
lines changed

8 files changed

+429
-19
lines changed

torchrec/distributed/fp_embeddingbag.py

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,18 @@
88
# pyre-strict
99

1010
from functools import partial
11-
from typing import Any, Dict, Iterator, List, Optional, Type, Union
11+
from typing import (
12+
Any,
13+
Dict,
14+
Iterator,
15+
List,
16+
Mapping,
17+
Optional,
18+
Tuple,
19+
Type,
20+
TypeVar,
21+
Union,
22+
)
1223

1324
import torch
1425
from torch import nn
@@ -31,14 +42,20 @@
3142
ShardingEnv,
3243
ShardingType,
3344
)
34-
from torchrec.distributed.utils import append_prefix, init_parameters
45+
from torchrec.distributed.utils import (
46+
append_prefix,
47+
init_parameters,
48+
modify_input_for_feature_processor,
49+
)
3550
from torchrec.modules.feature_processor_ import FeatureProcessorsCollection
3651
from torchrec.modules.fp_embedding_modules import (
3752
apply_feature_processors_to_kjt,
3853
FeatureProcessedEmbeddingBagCollection,
3954
)
4055
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
4156

57+
_T = TypeVar("_T")
58+
4259

4360
def param_dp_sync(kt: KeyedTensor, no_op_tensor: torch.Tensor) -> KeyedTensor:
4461
kt._values.add_(no_op_tensor)
@@ -74,6 +91,16 @@ def __init__(
7491
)
7592
)
7693

94+
self._row_wise_sharded: bool = False
95+
for param_sharding in table_name_to_parameter_sharding.values():
96+
if param_sharding.sharding_type in [
97+
ShardingType.ROW_WISE.value,
98+
ShardingType.TABLE_ROW_WISE.value,
99+
ShardingType.GRID_SHARD.value,
100+
]:
101+
self._row_wise_sharded = True
102+
break
103+
77104
self._lookups: List[nn.Module] = self._embedding_bag_collection._lookups
78105

79106
self._is_collection: bool = False
@@ -96,6 +123,11 @@ def __init__(
96123
def input_dist(
97124
self, ctx: EmbeddingBagCollectionContext, features: KeyedJaggedTensor
98125
) -> Awaitable[Awaitable[KJTList]]:
126+
if not self.is_pipelined and self._row_wise_sharded:
127+
# transform input to support row based sharding when not pipelined
128+
modify_input_for_feature_processor(
129+
features, self._feature_processors, self._is_collection
130+
)
99131
return self._embedding_bag_collection.input_dist(ctx, features)
100132

101133
def apply_feature_processors_to_kjt_list(self, dist_input: KJTList) -> KJTList:
@@ -166,6 +198,21 @@ def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]:
166198
def _initialize_torch_state(self, skip_registering: bool = False) -> None: # noqa
167199
self._embedding_bag_collection._initialize_torch_state(skip_registering)
168200

201+
def preprocess_input(
202+
self, args: List[_T], kwargs: Mapping[str, _T]
203+
) -> Tuple[List[_T], Mapping[str, _T]]:
204+
if self._row_wise_sharded:
205+
# only if we are row wise sharded we will pre process the KJT
206+
# and call the associated preprocess methods in feature processor
207+
for x in args + list(kwargs.values()):
208+
if isinstance(x, KeyedJaggedTensor):
209+
modify_input_for_feature_processor(
210+
features=x,
211+
feature_processors=self._feature_processors,
212+
is_collection=self._is_collection,
213+
)
214+
return args, kwargs
215+
169216

170217
class FeatureProcessedEmbeddingBagCollectionSharder(
171218
BaseEmbeddingSharder[FeatureProcessedEmbeddingBagCollection]
@@ -234,6 +281,9 @@ def sharding_types(self, compute_device_type: str) -> List[str]:
234281
ShardingType.TABLE_WISE.value,
235282
ShardingType.COLUMN_WISE.value,
236283
ShardingType.TABLE_COLUMN_WISE.value,
284+
ShardingType.TABLE_ROW_WISE.value,
285+
ShardingType.ROW_WISE.value,
286+
ShardingType.GRID_SHARD.value,
237287
]
238288

239289
return types

torchrec/distributed/planner/enumerators.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,15 @@
5454
EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE,
5555
}
5656

57+
# sharding types that require explicit user specification for feature-processed modules
58+
# row wise sharding uses a different pipelined configuration for feature processing
59+
# to guard against areas that aren't well tested, user must specify row wise sharding
60+
GUARDED_SHARDING_TYPES_FOR_FP_MODULES: Set[str] = {
61+
ShardingType.ROW_WISE.value,
62+
ShardingType.TABLE_ROW_WISE.value,
63+
ShardingType.GRID_SHARD.value,
64+
}
65+
5766

5867
class EmbeddingEnumerator(Enumerator):
5968
"""
@@ -186,7 +195,7 @@ def enumerate(
186195
sharding_options_per_table: List[ShardingOption] = []
187196

188197
for sharding_type in self._filter_sharding_types(
189-
name, sharder.sharding_types(self._compute_device)
198+
name, sharder.sharding_types(self._compute_device), sharder_key
190199
):
191200
for compute_kernel in self._filter_compute_kernels(
192201
name,
@@ -308,18 +317,37 @@ def populate_estimates(self, sharding_options: List[ShardingOption]) -> None:
308317
estimator.estimate(sharding_options, self._sharder_map)
309318

310319
def _filter_sharding_types(
311-
self, name: str, allowed_sharding_types: List[str]
320+
self, name: str, allowed_sharding_types: List[str], sharder_key: str = ""
312321
) -> List[str]:
313322
# GRID_SHARD is only supported if specified by user in parameter constraints
323+
# ROW based shardings only supported for FP modules if specified by user in parameter constraints
314324
if not self._constraints or not self._constraints.get(name):
315-
return [
325+
filtered = [
316326
t for t in allowed_sharding_types if t != ShardingType.GRID_SHARD.value
317327
]
328+
# For feature-processed modules, row-wise sharding types require explicit
329+
# user specification due to potential issues with position weighted FPs
330+
if "FeatureProcessedEmbeddingBagCollection" in sharder_key:
331+
filtered = [
332+
t
333+
for t in filtered
334+
if t not in GUARDED_SHARDING_TYPES_FOR_FP_MODULES
335+
]
336+
return filtered
318337
constraints: ParameterConstraints = self._constraints[name]
319338
if not constraints.sharding_types:
320-
return [
339+
filtered = [
321340
t for t in allowed_sharding_types if t != ShardingType.GRID_SHARD.value
322341
]
342+
# For feature-processed modules, row-wise sharding types require explicit
343+
# user specification due to potential issues with position weighted FPs
344+
if "FeatureProcessedEmbeddingBagCollection" in sharder_key:
345+
filtered = [
346+
t
347+
for t in filtered
348+
if t not in GUARDED_SHARDING_TYPES_FOR_FP_MODULES
349+
]
350+
return filtered
323351
constrained_sharding_types: List[str] = constraints.sharding_types
324352

325353
filtered_sharding_types = list(

torchrec/distributed/planner/tests/test_enumerators.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1115,3 +1115,76 @@ def test_throw_ex_no_sharding_option_for_table(self) -> None:
11151115
"Module: torchrec.modules.embedding_modules.EmbeddingBagCollection, sharder: CWSharder, compute device: cuda. "
11161116
"To debug, search above for warning logs about no available sharding types/compute kernels for table: table_1",
11171117
)
1118+
1119+
def test_filter_sharding_types_fp_ebc_no_constraints(self) -> None:
1120+
"""Test that row-wise sharding types are filtered out for FP modules without constraints."""
1121+
enumerator = EmbeddingEnumerator(
1122+
topology=MagicMock(),
1123+
batch_size=MagicMock(),
1124+
constraints=None,
1125+
)
1126+
1127+
# Test with FeatureProcessedEmbeddingBagCollection sharder key
1128+
# Without constraints, row-wise types should be filtered out
1129+
all_sharding_types = [
1130+
ShardingType.DATA_PARALLEL.value,
1131+
ShardingType.TABLE_WISE.value,
1132+
ShardingType.ROW_WISE.value,
1133+
ShardingType.TABLE_ROW_WISE.value,
1134+
ShardingType.COLUMN_WISE.value,
1135+
ShardingType.GRID_SHARD.value,
1136+
]
1137+
1138+
allowed_sharding_types = enumerator._filter_sharding_types(
1139+
"table_0",
1140+
all_sharding_types,
1141+
"torchrec.modules.fp_embedding_modules.FeatureProcessedEmbeddingBagCollection",
1142+
)
1143+
1144+
# ROW_WISE, TABLE_ROW_WISE, and GRID_SHARD should be filtered out
1145+
self.assertEqual(
1146+
set(allowed_sharding_types),
1147+
{
1148+
ShardingType.DATA_PARALLEL.value,
1149+
ShardingType.TABLE_WISE.value,
1150+
ShardingType.COLUMN_WISE.value,
1151+
},
1152+
)
1153+
1154+
def test_filter_sharding_types_fp_ebc_with_row_wise_constraint(self) -> None:
1155+
"""Test that row-wise sharding types are allowed for FP modules with explicit constraints."""
1156+
constraint = ParameterConstraints(
1157+
sharding_types=[
1158+
ShardingType.ROW_WISE.value,
1159+
ShardingType.TABLE_WISE.value,
1160+
],
1161+
)
1162+
constraints = {"table_0": constraint}
1163+
enumerator = EmbeddingEnumerator(
1164+
topology=MagicMock(),
1165+
batch_size=MagicMock(),
1166+
constraints=constraints,
1167+
)
1168+
1169+
all_sharding_types = [
1170+
ShardingType.DATA_PARALLEL.value,
1171+
ShardingType.TABLE_WISE.value,
1172+
ShardingType.ROW_WISE.value,
1173+
ShardingType.TABLE_ROW_WISE.value,
1174+
ShardingType.COLUMN_WISE.value,
1175+
]
1176+
1177+
# With explicit constraint specifying ROW_WISE, it should be allowed
1178+
allowed_sharding_types = enumerator._filter_sharding_types(
1179+
"table_0",
1180+
all_sharding_types,
1181+
"torchrec.modules.fp_embedding_modules.FeatureProcessedEmbeddingBagCollection",
1182+
)
1183+
1184+
self.assertEqual(
1185+
set(allowed_sharding_types),
1186+
{
1187+
ShardingType.ROW_WISE.value,
1188+
ShardingType.TABLE_WISE.value,
1189+
},
1190+
)

0 commit comments

Comments
 (0)