Skip to content

Commit 377fba1

Browse files
iamzainhudafacebook-github-bot
authored andcommitted
rowwise for feature processors (#3606)
Summary: In this diff we introduce row based sharding (TWRW, RW, GRID) type support for feature processors. Previously, feature processors did not support row based sharding since feature processors are data parallel. This means by splitting up the input for row based shards the accessed feature processor weights were in correct. In column/data sharding based approaches, the data is duplicated ensuring the correct weight is accessed across ranks. The indices/buckets are calculated post input split/distribution, to make it compatible with row based sharding we calculate this pre input split/distribution. This couples the train pipeline and feature processors. For each feature, we preprocess the input and place the calculated indices in KJT.weights, this propagates the indices correctly and indexs into the right weight to use for the final step in the feature processing. This applies in both pipelined and non pipelined situations - the input modification is done either at the pipelined forward call or in the input dist of the FPEBC. This is determined by the pipelining flag set through rewrite_model in train pipeline. **Previous versions of this diff were reverted as this change applied to all feature processors regardless of row wise sharding applied which surfaced errors that are not captured in usual E2E and unit tests. We now gate the change in two ways: 1) row based shardings must be specified by users to be applied for FP sharding and 2) pre processing input in pipeline will ONLY happen when row based sharding is present. This way FP sharding without row based shardings applied will go through the original forward path.** Differential Revision: D88093763
1 parent cc8df00 commit 377fba1

File tree

8 files changed

+444
-22
lines changed

8 files changed

+444
-22
lines changed

torchrec/distributed/fp_embeddingbag.py

Lines changed: 63 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,21 @@
88
# pyre-strict
99

1010
from functools import partial
11-
from typing import Any, Dict, Iterator, List, Optional, Type, Union
11+
from typing import (
12+
Any,
13+
Dict,
14+
Iterator,
15+
List,
16+
Mapping,
17+
Optional,
18+
Tuple,
19+
Type,
20+
TypeVar,
21+
Union,
22+
)
1223

1324
import torch
1425
from torch import nn
15-
1626
from torchrec.distributed.embedding_types import (
1727
BaseEmbeddingSharder,
1828
KJTList,
@@ -31,14 +41,20 @@
3141
ShardingEnv,
3242
ShardingType,
3343
)
34-
from torchrec.distributed.utils import append_prefix, init_parameters
44+
from torchrec.distributed.utils import (
45+
append_prefix,
46+
init_parameters,
47+
modify_input_for_feature_processor,
48+
)
3549
from torchrec.modules.feature_processor_ import FeatureProcessorsCollection
3650
from torchrec.modules.fp_embedding_modules import (
3751
apply_feature_processors_to_kjt,
3852
FeatureProcessedEmbeddingBagCollection,
3953
)
4054
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
4155

56+
_T = TypeVar("_T")
57+
4258

4359
def param_dp_sync(kt: KeyedTensor, no_op_tensor: torch.Tensor) -> KeyedTensor:
4460
kt._values.add_(no_op_tensor)
@@ -74,6 +90,16 @@ def __init__(
7490
)
7591
)
7692

93+
self._row_wise_sharded: bool = False
94+
for param_sharding in table_name_to_parameter_sharding.values():
95+
if param_sharding.sharding_type in [
96+
ShardingType.ROW_WISE.value,
97+
ShardingType.TABLE_ROW_WISE.value,
98+
ShardingType.GRID_SHARD.value,
99+
]:
100+
self._row_wise_sharded = True
101+
break
102+
77103
self._lookups: List[nn.Module] = self._embedding_bag_collection._lookups
78104

79105
self._is_collection: bool = False
@@ -96,6 +122,14 @@ def __init__(
96122
def input_dist(
97123
self, ctx: EmbeddingBagCollectionContext, features: KeyedJaggedTensor
98124
) -> Awaitable[Awaitable[KJTList]]:
125+
if torch._utils_internal.justknobs_check(
126+
"pytorch/torchrec:enable_rw_feature_processor"
127+
):
128+
if not self.is_pipelined and self._row_wise_sharded:
129+
# transform input to support row based sharding when not pipelined
130+
modify_input_for_feature_processor(
131+
features, self._feature_processors, self._is_collection
132+
)
99133
return self._embedding_bag_collection.input_dist(ctx, features)
100134

101135
def apply_feature_processors_to_kjt_list(self, dist_input: KJTList) -> KJTList:
@@ -166,6 +200,21 @@ def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]:
166200
def _initialize_torch_state(self, skip_registering: bool = False) -> None: # noqa
167201
self._embedding_bag_collection._initialize_torch_state(skip_registering)
168202

203+
def preprocess_input(
204+
self, args: List[_T], kwargs: Mapping[str, _T]
205+
) -> Tuple[List[_T], Mapping[str, _T]]:
206+
if self._row_wise_sharded:
207+
# only if we are row wise sharded we will pre process the KJT
208+
# and call the associated preprocess methods in feature processor
209+
for x in args + list(kwargs.values()):
210+
if isinstance(x, KeyedJaggedTensor):
211+
modify_input_for_feature_processor(
212+
features=x,
213+
feature_processors=self._feature_processors,
214+
is_collection=self._is_collection,
215+
)
216+
return args, kwargs
217+
169218

170219
class FeatureProcessedEmbeddingBagCollectionSharder(
171220
BaseEmbeddingSharder[FeatureProcessedEmbeddingBagCollection]
@@ -236,4 +285,15 @@ def sharding_types(self, compute_device_type: str) -> List[str]:
236285
ShardingType.TABLE_COLUMN_WISE.value,
237286
]
238287

288+
if torch._utils_internal.justknobs_check(
289+
"pytorch/torchrec:enable_rw_feature_processor"
290+
):
291+
types.extend(
292+
[
293+
ShardingType.TABLE_ROW_WISE.value,
294+
ShardingType.ROW_WISE.value,
295+
ShardingType.GRID_SHARD.value,
296+
]
297+
)
298+
239299
return types

torchrec/distributed/planner/enumerators.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,15 @@
5454
EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE,
5555
}
5656

57+
# sharding types that require explicit user specification for feature-processed modules
58+
# row wise sharding uses a different pipelined configuration for feature processing
59+
# to guard against areas that aren't well tested, user must specify row wise sharding
60+
GUARDED_SHARDING_TYPES_FOR_FP_MODULES: Set[str] = {
61+
ShardingType.ROW_WISE.value,
62+
ShardingType.TABLE_ROW_WISE.value,
63+
ShardingType.GRID_SHARD.value,
64+
}
65+
5766

5867
class EmbeddingEnumerator(Enumerator):
5968
"""
@@ -186,7 +195,7 @@ def enumerate(
186195
sharding_options_per_table: List[ShardingOption] = []
187196

188197
for sharding_type in self._filter_sharding_types(
189-
name, sharder.sharding_types(self._compute_device)
198+
name, sharder.sharding_types(self._compute_device), sharder_key
190199
):
191200
for compute_kernel in self._filter_compute_kernels(
192201
name,
@@ -308,18 +317,37 @@ def populate_estimates(self, sharding_options: List[ShardingOption]) -> None:
308317
estimator.estimate(sharding_options, self._sharder_map)
309318

310319
def _filter_sharding_types(
311-
self, name: str, allowed_sharding_types: List[str]
320+
self, name: str, allowed_sharding_types: List[str], sharder_key: str = ""
312321
) -> List[str]:
313322
# GRID_SHARD is only supported if specified by user in parameter constraints
323+
# ROW based shardings only supported for FP modules if specified by user in parameter constraints
314324
if not self._constraints or not self._constraints.get(name):
315-
return [
325+
filtered = [
316326
t for t in allowed_sharding_types if t != ShardingType.GRID_SHARD.value
317327
]
328+
# For feature-processed modules, row-wise sharding types require explicit
329+
# user specification due to potential issues with position weighted FPs
330+
if "FeatureProcessedEmbeddingBagCollection" in sharder_key:
331+
filtered = [
332+
t
333+
for t in filtered
334+
if t not in GUARDED_SHARDING_TYPES_FOR_FP_MODULES
335+
]
336+
return filtered
318337
constraints: ParameterConstraints = self._constraints[name]
319338
if not constraints.sharding_types:
320-
return [
339+
filtered = [
321340
t for t in allowed_sharding_types if t != ShardingType.GRID_SHARD.value
322341
]
342+
# For feature-processed modules, row-wise sharding types require explicit
343+
# user specification due to potential issues with position weighted FPs
344+
if "FeatureProcessedEmbeddingBagCollection" in sharder_key:
345+
filtered = [
346+
t
347+
for t in filtered
348+
if t not in GUARDED_SHARDING_TYPES_FOR_FP_MODULES
349+
]
350+
return filtered
323351
constrained_sharding_types: List[str] = constraints.sharding_types
324352

325353
filtered_sharding_types = list(

torchrec/distributed/planner/tests/test_enumerators.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1115,3 +1115,76 @@ def test_throw_ex_no_sharding_option_for_table(self) -> None:
11151115
"Module: torchrec.modules.embedding_modules.EmbeddingBagCollection, sharder: CWSharder, compute device: cuda. "
11161116
"To debug, search above for warning logs about no available sharding types/compute kernels for table: table_1",
11171117
)
1118+
1119+
def test_filter_sharding_types_fp_ebc_no_constraints(self) -> None:
1120+
"""Test that row-wise sharding types are filtered out for FP modules without constraints."""
1121+
enumerator = EmbeddingEnumerator(
1122+
topology=MagicMock(),
1123+
batch_size=MagicMock(),
1124+
constraints=None,
1125+
)
1126+
1127+
# Test with FeatureProcessedEmbeddingBagCollection sharder key
1128+
# Without constraints, row-wise types should be filtered out
1129+
all_sharding_types = [
1130+
ShardingType.DATA_PARALLEL.value,
1131+
ShardingType.TABLE_WISE.value,
1132+
ShardingType.ROW_WISE.value,
1133+
ShardingType.TABLE_ROW_WISE.value,
1134+
ShardingType.COLUMN_WISE.value,
1135+
ShardingType.GRID_SHARD.value,
1136+
]
1137+
1138+
allowed_sharding_types = enumerator._filter_sharding_types(
1139+
"table_0",
1140+
all_sharding_types,
1141+
"torchrec.modules.fp_embedding_modules.FeatureProcessedEmbeddingBagCollection",
1142+
)
1143+
1144+
# ROW_WISE, TABLE_ROW_WISE, and GRID_SHARD should be filtered out
1145+
self.assertEqual(
1146+
set(allowed_sharding_types),
1147+
{
1148+
ShardingType.DATA_PARALLEL.value,
1149+
ShardingType.TABLE_WISE.value,
1150+
ShardingType.COLUMN_WISE.value,
1151+
},
1152+
)
1153+
1154+
def test_filter_sharding_types_fp_ebc_with_row_wise_constraint(self) -> None:
1155+
"""Test that row-wise sharding types are allowed for FP modules with explicit constraints."""
1156+
constraint = ParameterConstraints(
1157+
sharding_types=[
1158+
ShardingType.ROW_WISE.value,
1159+
ShardingType.TABLE_WISE.value,
1160+
],
1161+
)
1162+
constraints = {"table_0": constraint}
1163+
enumerator = EmbeddingEnumerator(
1164+
topology=MagicMock(),
1165+
batch_size=MagicMock(),
1166+
constraints=constraints,
1167+
)
1168+
1169+
all_sharding_types = [
1170+
ShardingType.DATA_PARALLEL.value,
1171+
ShardingType.TABLE_WISE.value,
1172+
ShardingType.ROW_WISE.value,
1173+
ShardingType.TABLE_ROW_WISE.value,
1174+
ShardingType.COLUMN_WISE.value,
1175+
]
1176+
1177+
# With explicit constraint specifying ROW_WISE, it should be allowed
1178+
allowed_sharding_types = enumerator._filter_sharding_types(
1179+
"table_0",
1180+
all_sharding_types,
1181+
"torchrec.modules.fp_embedding_modules.FeatureProcessedEmbeddingBagCollection",
1182+
)
1183+
1184+
self.assertEqual(
1185+
set(allowed_sharding_types),
1186+
{
1187+
ShardingType.ROW_WISE.value,
1188+
ShardingType.TABLE_WISE.value,
1189+
},
1190+
)

0 commit comments

Comments
 (0)