Skip to content

Commit e5815fc

Browse files
iamzainhudafacebook-github-bot
authored andcommitted
rowwise for feature processors (#3606)
Summary: In this diff we introduce row based sharding (TWRW, RW, GRID) type support for feature processors. Previously, feature processors did not support row based sharding since feature processors are data parallel. This means by splitting up the input for row based shards the accessed feature processor weights were in correct. In column/data sharding based approaches, the data is duplicated ensuring the correct weight is accessed across ranks. The indices/buckets are calculated post input split/distribution, to make it compatible with row based sharding we calculate this pre input split/distribution. This couples the train pipeline and feature processors. For each feature, we preprocess the input and place the calculated indices in KJT.weights, this propagates the indices correctly and indexs into the right weight to use for the final step in the feature processing. This applies in both pipelined and non pipelined situations - the input modification is done either at the pipelined forward call or in the input dist of the FPEBC. This is determined by the pipelining flag set through rewrite_model in train pipeline. **Previous versions of this diff were reverted as this change applied to all feature processors regardless of row wise sharding applied which surfaced errors that are not captured in usual E2E and unit tests. We now gate the change in two ways: 1) row based shardings must be specified by users to be applied for FP sharding and 2) pre processing input in pipeline will ONLY happen when row based sharding is present. This way FP sharding without row based shardings applied will go through the original forward path.** Differential Revision: D88093763
1 parent 5cf0f0d commit e5815fc

File tree

8 files changed

+440
-22
lines changed

8 files changed

+440
-22
lines changed

torchrec/distributed/fp_embeddingbag.py

Lines changed: 60 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,22 @@
88
# pyre-strict
99

1010
from functools import partial
11-
from typing import Any, Dict, Iterator, List, Optional, Type, Union
11+
from typing import (
12+
Any,
13+
Dict,
14+
Iterator,
15+
List,
16+
Mapping,
17+
Optional,
18+
Tuple,
19+
Type,
20+
TypeVar,
21+
Union,
22+
)
1223

24+
import pyjk as justknobs
1325
import torch
1426
from torch import nn
15-
1627
from torchrec.distributed.embedding_types import (
1728
BaseEmbeddingSharder,
1829
KJTList,
@@ -31,14 +42,20 @@
3142
ShardingEnv,
3243
ShardingType,
3344
)
34-
from torchrec.distributed.utils import append_prefix, init_parameters
45+
from torchrec.distributed.utils import (
46+
append_prefix,
47+
init_parameters,
48+
modify_input_for_feature_processor,
49+
)
3550
from torchrec.modules.feature_processor_ import FeatureProcessorsCollection
3651
from torchrec.modules.fp_embedding_modules import (
3752
apply_feature_processors_to_kjt,
3853
FeatureProcessedEmbeddingBagCollection,
3954
)
4055
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
4156

57+
_T = TypeVar("_T")
58+
4259

4360
def param_dp_sync(kt: KeyedTensor, no_op_tensor: torch.Tensor) -> KeyedTensor:
4461
kt._values.add_(no_op_tensor)
@@ -74,6 +91,16 @@ def __init__(
7491
)
7592
)
7693

94+
self._row_wise_sharded: bool = False
95+
for param_sharding in table_name_to_parameter_sharding.values():
96+
if param_sharding.sharding_type in [
97+
ShardingType.ROW_WISE.value,
98+
ShardingType.TABLE_ROW_WISE.value,
99+
ShardingType.GRID_SHARD.value,
100+
]:
101+
self._row_wise_sharded = True
102+
break
103+
77104
self._lookups: List[nn.Module] = self._embedding_bag_collection._lookups
78105

79106
self._is_collection: bool = False
@@ -96,6 +123,12 @@ def __init__(
96123
def input_dist(
97124
self, ctx: EmbeddingBagCollectionContext, features: KeyedJaggedTensor
98125
) -> Awaitable[Awaitable[KJTList]]:
126+
if justknobs.check("pytorch/torchrec:enable_rw_feature_processor"):
127+
if not self.is_pipelined and self._row_wise_sharded:
128+
# transform input to support row based sharding when not pipelined
129+
modify_input_for_feature_processor(
130+
features, self._feature_processors, self._is_collection
131+
)
99132
return self._embedding_bag_collection.input_dist(ctx, features)
100133

101134
def apply_feature_processors_to_kjt_list(self, dist_input: KJTList) -> KJTList:
@@ -166,6 +199,21 @@ def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]:
166199
def _initialize_torch_state(self, skip_registering: bool = False) -> None: # noqa
167200
self._embedding_bag_collection._initialize_torch_state(skip_registering)
168201

202+
def preprocess_input(
203+
self, args: List[_T], kwargs: Mapping[str, _T]
204+
) -> Tuple[List[_T], Mapping[str, _T]]:
205+
if self._row_wise_sharded:
206+
# only if we are row wise sharded we will pre process the KJT
207+
# and call the associated preprocess methods in feature processor
208+
for x in args + list(kwargs.values()):
209+
if isinstance(x, KeyedJaggedTensor):
210+
modify_input_for_feature_processor(
211+
features=x,
212+
feature_processors=self._feature_processors,
213+
is_collection=self._is_collection,
214+
)
215+
return args, kwargs
216+
169217

170218
class FeatureProcessedEmbeddingBagCollectionSharder(
171219
BaseEmbeddingSharder[FeatureProcessedEmbeddingBagCollection]
@@ -236,4 +284,13 @@ def sharding_types(self, compute_device_type: str) -> List[str]:
236284
ShardingType.TABLE_COLUMN_WISE.value,
237285
]
238286

287+
if justknobs.check("pytorch/torchrec:enable_rw_feature_processor"):
288+
types.extend(
289+
[
290+
ShardingType.TABLE_ROW_WISE.value,
291+
ShardingType.ROW_WISE.value,
292+
ShardingType.GRID_SHARD.value,
293+
]
294+
)
295+
239296
return types

torchrec/distributed/planner/enumerators.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,15 @@
5454
EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE,
5555
}
5656

57+
# sharding types that require explicit user specification for feature-processed modules
58+
# row wise sharding uses a different pipelined configuration for feature processing
59+
# to guard against areas that aren't well tested, user must specify row wise sharding
60+
GUARDED_SHARDING_TYPES_FOR_FP_MODULES: Set[str] = {
61+
ShardingType.ROW_WISE.value,
62+
ShardingType.TABLE_ROW_WISE.value,
63+
ShardingType.GRID_SHARD.value,
64+
}
65+
5766

5867
class EmbeddingEnumerator(Enumerator):
5968
"""
@@ -186,7 +195,7 @@ def enumerate(
186195
sharding_options_per_table: List[ShardingOption] = []
187196

188197
for sharding_type in self._filter_sharding_types(
189-
name, sharder.sharding_types(self._compute_device)
198+
name, sharder.sharding_types(self._compute_device), sharder_key
190199
):
191200
for compute_kernel in self._filter_compute_kernels(
192201
name,
@@ -308,18 +317,37 @@ def populate_estimates(self, sharding_options: List[ShardingOption]) -> None:
308317
estimator.estimate(sharding_options, self._sharder_map)
309318

310319
def _filter_sharding_types(
311-
self, name: str, allowed_sharding_types: List[str]
320+
self, name: str, allowed_sharding_types: List[str], sharder_key: str = ""
312321
) -> List[str]:
313322
# GRID_SHARD is only supported if specified by user in parameter constraints
323+
# ROW based shardings only supported for FP modules if specified by user in parameter constraints
314324
if not self._constraints or not self._constraints.get(name):
315-
return [
325+
filtered = [
316326
t for t in allowed_sharding_types if t != ShardingType.GRID_SHARD.value
317327
]
328+
# For feature-processed modules, row-wise sharding types require explicit
329+
# user specification due to potential issues with position weighted FPs
330+
if "FeatureProcessedEmbeddingBagCollection" in sharder_key:
331+
filtered = [
332+
t
333+
for t in filtered
334+
if t not in GUARDED_SHARDING_TYPES_FOR_FP_MODULES
335+
]
336+
return filtered
318337
constraints: ParameterConstraints = self._constraints[name]
319338
if not constraints.sharding_types:
320-
return [
339+
filtered = [
321340
t for t in allowed_sharding_types if t != ShardingType.GRID_SHARD.value
322341
]
342+
# For feature-processed modules, row-wise sharding types require explicit
343+
# user specification due to potential issues with position weighted FPs
344+
if "FeatureProcessedEmbeddingBagCollection" in sharder_key:
345+
filtered = [
346+
t
347+
for t in filtered
348+
if t not in GUARDED_SHARDING_TYPES_FOR_FP_MODULES
349+
]
350+
return filtered
323351
constrained_sharding_types: List[str] = constraints.sharding_types
324352

325353
filtered_sharding_types = list(

torchrec/distributed/planner/tests/test_enumerators.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1115,3 +1115,76 @@ def test_throw_ex_no_sharding_option_for_table(self) -> None:
11151115
"Module: torchrec.modules.embedding_modules.EmbeddingBagCollection, sharder: CWSharder, compute device: cuda. "
11161116
"To debug, search above for warning logs about no available sharding types/compute kernels for table: table_1",
11171117
)
1118+
1119+
def test_filter_sharding_types_fp_ebc_no_constraints(self) -> None:
1120+
"""Test that row-wise sharding types are filtered out for FP modules without constraints."""
1121+
enumerator = EmbeddingEnumerator(
1122+
topology=MagicMock(),
1123+
batch_size=MagicMock(),
1124+
constraints=None,
1125+
)
1126+
1127+
# Test with FeatureProcessedEmbeddingBagCollection sharder key
1128+
# Without constraints, row-wise types should be filtered out
1129+
all_sharding_types = [
1130+
ShardingType.DATA_PARALLEL.value,
1131+
ShardingType.TABLE_WISE.value,
1132+
ShardingType.ROW_WISE.value,
1133+
ShardingType.TABLE_ROW_WISE.value,
1134+
ShardingType.COLUMN_WISE.value,
1135+
ShardingType.GRID_SHARD.value,
1136+
]
1137+
1138+
allowed_sharding_types = enumerator._filter_sharding_types(
1139+
"table_0",
1140+
all_sharding_types,
1141+
"torchrec.modules.fp_embedding_modules.FeatureProcessedEmbeddingBagCollection",
1142+
)
1143+
1144+
# ROW_WISE, TABLE_ROW_WISE, and GRID_SHARD should be filtered out
1145+
self.assertEqual(
1146+
set(allowed_sharding_types),
1147+
{
1148+
ShardingType.DATA_PARALLEL.value,
1149+
ShardingType.TABLE_WISE.value,
1150+
ShardingType.COLUMN_WISE.value,
1151+
},
1152+
)
1153+
1154+
def test_filter_sharding_types_fp_ebc_with_row_wise_constraint(self) -> None:
1155+
"""Test that row-wise sharding types are allowed for FP modules with explicit constraints."""
1156+
constraint = ParameterConstraints(
1157+
sharding_types=[
1158+
ShardingType.ROW_WISE.value,
1159+
ShardingType.TABLE_WISE.value,
1160+
],
1161+
)
1162+
constraints = {"table_0": constraint}
1163+
enumerator = EmbeddingEnumerator(
1164+
topology=MagicMock(),
1165+
batch_size=MagicMock(),
1166+
constraints=constraints,
1167+
)
1168+
1169+
all_sharding_types = [
1170+
ShardingType.DATA_PARALLEL.value,
1171+
ShardingType.TABLE_WISE.value,
1172+
ShardingType.ROW_WISE.value,
1173+
ShardingType.TABLE_ROW_WISE.value,
1174+
ShardingType.COLUMN_WISE.value,
1175+
]
1176+
1177+
# With explicit constraint specifying ROW_WISE, it should be allowed
1178+
allowed_sharding_types = enumerator._filter_sharding_types(
1179+
"table_0",
1180+
all_sharding_types,
1181+
"torchrec.modules.fp_embedding_modules.FeatureProcessedEmbeddingBagCollection",
1182+
)
1183+
1184+
self.assertEqual(
1185+
set(allowed_sharding_types),
1186+
{
1187+
ShardingType.ROW_WISE.value,
1188+
ShardingType.TABLE_WISE.value,
1189+
},
1190+
)

0 commit comments

Comments
 (0)