Skip to content

Commit 89df447

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
use module object id to cache the sharded modules (#3591)
Summary: # context * TorchRec relies on `nn.Module.named_children()` to traverse the model to find the sparse modules to shard. * In a normal case, every sparse module only appears once in the model hiearchy, i.e., it has **only one** parent module. * However, in some corner cases, a sparse module might have multiple parent modules. This might confuse the TorchRec sharder due to its traversing logic: the same sparse module has multiple FQNs, and hence being sharded multiple times (create multiple sharded modules according to the FQNs). {F1983896519} # solution * cache the sharded module with the original sparse module's object id * when the sparse module has multiple FQNs, only the first time in the `named_children` traversing, a sharded module will be created. # changes * the change is protected by a KillSwitch: [enable_module_id_cache_for_dmp_shard_modules](https://www.internalfb.com/intern/justknobs/?name=pytorch%2Ftorchrec#enable_module_id_cache_for_dmp_shard_modules) https://fb.workplace.com/groups/429376538334034/permalink/1343336826937996/ Reviewed By: malaybag, iamzainhuda Differential Revision: D88218200
1 parent 5cf0f0d commit 89df447

File tree

2 files changed

+244
-6
lines changed

2 files changed

+244
-6
lines changed

torchrec/distributed/model_parallel.py

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,13 @@ def copy(
385385
return copy_dmp
386386

387387
def _init_dmp(self, module: nn.Module) -> nn.Module:
388-
return self._shard_modules_impl(module)
388+
if torch._utils_internal.justknobs_check(
389+
"pytorch/torchrec:enable_module_id_cache_for_dmp_shard_modules"
390+
):
391+
module_id_cache: Dict[int, ShardedModule] = {}
392+
else:
393+
module_id_cache = None
394+
return self._shard_modules_impl(module, module_id_cache=module_id_cache)
389395

390396
def _init_delta_tracker(
391397
self, delta_tracker_config: DeltaTrackerConfig, module: nn.Module
@@ -435,28 +441,48 @@ def _shard_modules_impl(
435441
self,
436442
module: nn.Module,
437443
path: str = "",
444+
module_id_cache: Optional[Dict[str, ShardedModule]] = None,
438445
) -> nn.Module:
439446
# pre-sharded module
440447
if isinstance(module, ShardedModule):
441448
return module
442449

450+
if module_id_cache is not None:
451+
module_id = id(module)
452+
if module_id in module_id_cache:
453+
"""
454+
This is likely due to a single sparse module being used in multiple places in the model,
455+
which results in multiple FQNs for the same sparse module. The dedup logic is applied on
456+
the sharded module, i.e., multiple FQNs will refer to the same sharded module, as it is in
457+
eager-mode sparse module. However, there could be potential issues in other places where
458+
model is travesed via `named_children()`, the same sparse module will be visited multiple
459+
times again.
460+
"""
461+
logger.error(
462+
f"Module {path} is already in cache (replaced by sharded module already)"
463+
)
464+
return module_id_cache[module_id]
465+
443466
# shardable module
444467
module_sharding_plan = self._plan.get_plan_for_module(path)
445468
if module_sharding_plan:
446469
sharder_key = type(module)
447-
module = self._sharder_map[sharder_key].shard(
470+
sharded_module = self._sharder_map[sharder_key].shard(
448471
module,
449472
module_sharding_plan,
450473
self._env,
451474
self.device,
452475
path,
453476
)
454-
return module
477+
if module_id_cache is not None:
478+
module_id_cache[module_id] = sharded_module
479+
return sharded_module
455480

456481
for name, child in module.named_children():
457482
child = self._shard_modules_impl(
458483
child,
459484
path + "." + name if path else name,
485+
module_id_cache,
460486
)
461487
setattr(module, name, child)
462488

@@ -1002,12 +1028,18 @@ def _shard_modules_impl(
10021028
self,
10031029
module: nn.Module,
10041030
path: str = "",
1031+
module_id_cache: Optional[Dict[int, ShardedModule]] = None,
10051032
) -> nn.Module:
10061033

10071034
# pre-sharded module
10081035
if isinstance(module, ShardedModule):
10091036
return module
10101037

1038+
if module_id_cache is not None:
1039+
module_id = id(module)
1040+
if module_id in module_id_cache:
1041+
return module_id_cache[module_id]
1042+
10111043
# shardable module
10121044
module_sharding_plan = self._plan.get_plan_for_module(path)
10131045
if module_sharding_plan:
@@ -1027,19 +1059,22 @@ def _shard_modules_impl(
10271059
)
10281060
break
10291061

1030-
module = self._sharder_map[sharder_key].shard(
1062+
sharded_module = self._sharder_map[sharder_key].shard(
10311063
module,
10321064
module_sharding_plan,
10331065
env,
10341066
self.device,
10351067
path,
10361068
)
1037-
return module
1069+
if module_id_cache is not None:
1070+
module_id_cache[module_id] = sharded_module
1071+
return sharded_module
10381072

10391073
for name, child in module.named_children():
10401074
child = self._shard_modules_impl(
10411075
child,
10421076
path + "." + name if path else name,
1077+
module_id_cache,
10431078
)
10441079
setattr(module, name, child)
10451080

torchrec/distributed/tests/test_model_parallel_nccl_single_rank.py

Lines changed: 204 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,218 @@
77

88
# pyre-strict
99

10+
from unittest.mock import patch
11+
12+
import torch
13+
import torch.nn as nn
14+
from torchrec.distributed.model_parallel import DistributedModelParallel
1015
from torchrec.distributed.test_utils.test_model_parallel_base import (
1116
ModelParallelSparseOnlyBase,
1217
ModelParallelStateDictBase,
1318
)
19+
from torchrec.distributed.types import ShardedModule
20+
from torchrec.modules.embedding_configs import EmbeddingBagConfig, EmbeddingConfig
21+
from torchrec.modules.embedding_modules import (
22+
EmbeddingBagCollection,
23+
EmbeddingCollection,
24+
)
25+
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
1426

1527

1628
class ModelParallelStateDictTestNccl(ModelParallelStateDictBase):
1729
pass
1830

1931

32+
class SparseArch(nn.Module):
33+
def __init__(
34+
self,
35+
ebc: EmbeddingBagCollection,
36+
ec: EmbeddingCollection,
37+
) -> None:
38+
super().__init__()
39+
self.ebc = ebc
40+
self.ec = ec
41+
42+
def forward(self, features: KeyedJaggedTensor) -> tuple[torch.Tensor, torch.Tensor]:
43+
ebc_out = self.ebc(features)
44+
ec_out = self.ec(features)
45+
return ebc_out.values(), ec_out.values()
46+
47+
48+
# Create a model with two sparse architectures sharing the same modules
49+
class TwoSparseArchModel(nn.Module):
50+
def __init__(
51+
self,
52+
sparse1: SparseArch,
53+
sparse2: SparseArch,
54+
) -> None:
55+
super().__init__()
56+
# Both architectures share the same EBC and EC instances
57+
self.sparse1 = sparse1
58+
self.sparse2 = sparse2
59+
60+
def forward(
61+
self, features: KeyedJaggedTensor
62+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
63+
ebc1_out, ec1_out = self.sparse1(features)
64+
ebc2_out, ec2_out = self.sparse2(features)
65+
66+
return ebc1_out, ec1_out, ebc2_out, ec2_out
67+
68+
2069
class ModelParallelSparseOnlyTestNccl(ModelParallelSparseOnlyBase):
21-
pass
70+
def test_shared_sparse_module_in_multiple_parents(self) -> None:
71+
"""
72+
Test that the module ID cache correctly handles the same sparse module
73+
being used in multiple parent modules. This tests the caching behavior
74+
when a single EmbeddingBagCollection and EmbeddingCollection are shared
75+
across two different parent sparse architectures.
76+
"""
77+
78+
# Setup: Create shared embedding modules that will be reused
79+
ebc = EmbeddingBagCollection(
80+
device=torch.device("meta"),
81+
tables=[
82+
EmbeddingBagConfig(
83+
name="ebc_table",
84+
embedding_dim=64,
85+
num_embeddings=100,
86+
feature_names=["ebc_feature"],
87+
),
88+
],
89+
)
90+
ec = EmbeddingCollection(
91+
device=torch.device("meta"),
92+
tables=[
93+
EmbeddingConfig(
94+
name="ec_table",
95+
embedding_dim=32,
96+
num_embeddings=50,
97+
feature_names=["ec_feature"],
98+
),
99+
],
100+
)
101+
102+
# Create the model with shared modules
103+
sparse1 = SparseArch(ebc, ec)
104+
sparse2 = SparseArch(ebc, ec)
105+
model = TwoSparseArchModel(sparse1, sparse2)
106+
107+
# Execute: Shard the model with DistributedModelParallel
108+
dmp = DistributedModelParallel(model, device=self.device)
109+
110+
# Assert: Verify that the shared modules are properly handled
111+
self.assertIsNotNone(dmp.module)
112+
113+
# Verify that the same module instances are reused (cached behavior)
114+
wrapped_module = dmp.module
115+
self.assertIs(
116+
wrapped_module.sparse1.ebc,
117+
wrapped_module.sparse2.ebc,
118+
"ebc1 and ebc2 should be the same sharded instance",
119+
)
120+
self.assertIs(
121+
wrapped_module.sparse1.ec,
122+
wrapped_module.sparse2.ec,
123+
"ec1 and ec2 should be the same sharded instance",
124+
)
125+
self.assertIsInstance(
126+
wrapped_module.sparse1.ebc,
127+
ShardedModule,
128+
"ebc1 should be sharded",
129+
)
130+
self.assertIsInstance(
131+
wrapped_module.sparse1.ec,
132+
ShardedModule,
133+
"ec1 should be sharded",
134+
)
135+
136+
def test_shared_sparse_module_in_multiple_parents_negative(self) -> None:
137+
"""
138+
Test that when module ID caching is disabled (module_id_cache=None),
139+
the same module instance gets sharded multiple times, resulting in
140+
different sharded instances. This validates the behavior without caching.
141+
"""
142+
143+
def mock_init_dmp(
144+
self_dmp: DistributedModelParallel, module: nn.Module
145+
) -> nn.Module:
146+
"""Override _init_dmp to always set module_id_cache to None"""
147+
# Call _shard_modules_impl with module_id_cache=None (caching disabled)
148+
return self_dmp._shard_modules_impl(module, module_id_cache=None)
149+
150+
# Setup: Create shared embedding modules that will be reused
151+
ebc = EmbeddingBagCollection(
152+
device=torch.device("meta"),
153+
tables=[
154+
EmbeddingBagConfig(
155+
name="ebc_table",
156+
embedding_dim=64,
157+
num_embeddings=100,
158+
feature_names=["ebc_feature"],
159+
),
160+
],
161+
)
162+
ec = EmbeddingCollection(
163+
device=torch.device("meta"),
164+
tables=[
165+
EmbeddingConfig(
166+
name="ec_table",
167+
embedding_dim=32,
168+
num_embeddings=50,
169+
feature_names=["ec_feature"],
170+
),
171+
],
172+
)
173+
174+
# Create the model with shared modules
175+
sparse1 = SparseArch(ebc, ec)
176+
sparse2 = SparseArch(ebc, ec)
177+
model = TwoSparseArchModel(sparse1, sparse2)
178+
179+
# Execute: Mock _init_dmp to disable caching, then shard the model
180+
with patch.object(
181+
DistributedModelParallel,
182+
"_init_dmp",
183+
mock_init_dmp,
184+
):
185+
dmp = DistributedModelParallel(model, device=self.device)
186+
187+
# Assert: Verify that modules are NOT cached (different instances)
188+
self.assertIsNotNone(dmp.module)
189+
wrapped_module = dmp.module
190+
191+
# Without caching, the same module should be sharded twice,
192+
# resulting in different sharded instances
193+
self.assertIsNot(
194+
wrapped_module.sparse1.ebc,
195+
wrapped_module.sparse2.ebc,
196+
"Without caching, ebc1 and ebc2 should be different sharded instances",
197+
)
198+
self.assertIsNot(
199+
wrapped_module.sparse1.ec,
200+
wrapped_module.sparse2.ec,
201+
"Without caching, ec1 and ec2 should be different sharded instances",
202+
)
203+
204+
# Both should still be properly sharded, just not cached
205+
self.assertIsInstance(
206+
wrapped_module.sparse1.ebc,
207+
ShardedModule,
208+
"ebc1 should be sharded",
209+
)
210+
self.assertIsInstance(
211+
wrapped_module.sparse1.ec,
212+
ShardedModule,
213+
"ec1 should be sharded",
214+
)
215+
self.assertIsInstance(
216+
wrapped_module.sparse2.ebc,
217+
ShardedModule,
218+
"ebc2 should be sharded",
219+
)
220+
self.assertIsInstance(
221+
wrapped_module.sparse2.ec,
222+
ShardedModule,
223+
"ec2 should be sharded",
224+
)

0 commit comments

Comments
 (0)