Skip to content

Commit c29df68

Browse files
LinjianMapytorchmergebot
authored andcommitted
[FSDP] Return original module when fsdp wrapped model call .module (pytorch#78671)
Fixes pytorch#78607 Pull Request resolved: pytorch#78671 Approved by: https://github.com/awgu, https://github.com/rohan-varma
1 parent 1884d7f commit c29df68

File tree

4 files changed

+27
-29
lines changed

4 files changed

+27
-29
lines changed

test/distributed/fsdp/test_fsdp_freezing_weights.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -134,12 +134,8 @@ def _dist_train(
134134
optimizer.zero_grad()
135135
fake_loss.backward()
136136
if freezing_method == FreezingMethod.GradToNone:
137-
if with_fsdp:
138-
for param in model.module.module.trunk.parameters():
139-
param.grad = None
140-
else:
141-
for param in model.module.trunk.parameters():
142-
param.grad = None
137+
for param in model.module.trunk.parameters():
138+
param.grad = None
143139
optimizer.step()
144140

145141
if with_fsdp:

test/distributed/fsdp/test_fsdp_mixed_precision.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -560,7 +560,7 @@ def test_mixed_precision_resnet(self):
560560
# in original resnet model.
561561
fsdp_bn = 0
562562
for module in fsdp.fsdp_modules(fsdp):
563-
wrapped_module = module.module.module
563+
wrapped_module = module.module
564564
if isinstance(wrapped_module, _BatchNorm):
565565
fsdp_bn += 1
566566

test/distributed/fsdp/test_wrap.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -375,9 +375,9 @@ def test_transformer_auto_wrap_policy(self):
375375
auto_wrap_policy=my_auto_wrap_policy
376376
)
377377
self.assertTrue(isinstance(fsdp_model, FSDP))
378-
for layer in fsdp_model.module.module.transformer.encoder.layers:
378+
for layer in fsdp_model.module.transformer.encoder.layers:
379379
self.assertTrue(isinstance(layer, FSDP))
380-
for layer in fsdp_model.module.module.transformer.decoder.layers:
380+
for layer in fsdp_model.module.transformer.decoder.layers:
381381
self.assertTrue(isinstance(layer, FSDP))
382382

383383
@unittest.skipIf(torch.cuda.device_count() < 2, "Requires at least 2 GPUs")

torch/distributed/fsdp/fully_sharded_data_parallel.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1075,10 +1075,12 @@ def _check_wrapped(cls, begin_module, check_fn, err_fn):
10751075
raise ValueError(err_fn(mod))
10761076

10771077
@property
1078-
def module(self) -> FlattenParamsWrapper:
1079-
"""make model.module accessible, just like DDP."""
1078+
def module(self) -> nn.Module:
1079+
"""Make model.module accessible, just like DDP. Return the
1080+
underlying module without the flatten_params_wrapper
1081+
"""
10801082
assert isinstance(self._fsdp_wrapped_module, FlattenParamsWrapper)
1081-
return self._fsdp_wrapped_module
1083+
return self._fsdp_wrapped_module.module
10821084

10831085
def check_is_root(self) -> bool:
10841086
self._lazy_init()
@@ -1433,11 +1435,11 @@ def __getattr__(self, name: str) -> Any:
14331435
try:
14341436
return super().__getattr__(name) # defer to nn.Module's logic
14351437
except AttributeError:
1436-
return getattr(self.module, name)
1438+
return getattr(self._fsdp_wrapped_module, name)
14371439

14381440
def __getitem__(self, key: int) -> Any:
14391441
"""Forward indexing calls in case the module is a nn.Sequential."""
1440-
return self.module.__getitem__(key) # type: ignore[operator]
1442+
return self._fsdp_wrapped_module.__getitem__(key) # type: ignore[operator]
14411443

14421444
def _reset_lazy_init(self) -> None:
14431445
"""
@@ -1824,14 +1826,14 @@ def _local_post_state_dict_hook(
18241826
will happen. The underlying storage is the same.
18251827
"""
18261828
_replace_by_prefix(state_dict, f"{prefix}{FSDP_WRAPPED_MODULE}.", prefix)
1827-
if self.module.no_params:
1829+
if self._fsdp_wrapped_module.no_params:
18281830
return state_dict
18291831

18301832
# state_dict[f"{prefix}{FLAT_PARAM}"] exists and has the same tensor
18311833
# value as the flat_param but it is a pure Tensor because
18321834
# nn.Module.state_dict() will detach the parameter. Therefore, we need
18331835
# to get flat_param from the FlattenParamsWrapper to get the metadata.
1834-
flat_param = getattr(self.module, FLAT_PARAM, None)
1836+
flat_param = getattr(self._fsdp_wrapped_module, FLAT_PARAM, None)
18351837
# Construct a ShardedTensor from the flat_param.
18361838
full_numel = flat_param.full_numel
18371839
shard_offset = flat_param.numel() * self.rank
@@ -1858,10 +1860,10 @@ def _sharded_post_state_dict_hook(
18581860
with a unflattened, sharded parameter (a ShardedTensor).
18591861
"""
18601862
_replace_by_prefix(state_dict, f"{prefix}{FSDP_WRAPPED_MODULE}.", prefix)
1861-
if self.module.no_params:
1863+
if self._fsdp_wrapped_module.no_params:
18621864
return state_dict
18631865

1864-
for module_name, _, param_name in self.module.orig_flat_param[0].param_info:
1866+
for module_name, _, param_name in self._fsdp_wrapped_module.orig_flat_param[0].param_info:
18651867
module_name = module_name.replace(f"{FPW_MODULE}.", "")
18661868
module_name = module_name.replace(f"{FPW_MODULE}", "")
18671869
if module_name:
@@ -1989,8 +1991,8 @@ def state_dict(self, *args, **kwargs):
19891991

19901992
elif self._state_dict_type == StateDictType.LOCAL_STATE_DICT:
19911993
if (
1992-
self.module.flat_param is not None and
1993-
not self.module.flat_param._is_sharded
1994+
self._fsdp_wrapped_module.flat_param is not None and
1995+
not self._fsdp_wrapped_module.flat_param._is_sharded
19941996
):
19951997
raise RuntimeError(
19961998
"local_state_dict can only be called "
@@ -2065,8 +2067,8 @@ def _local_pre_load_state_dict_hook(
20652067
_replace_by_prefix(state_dict, prefix, f"{prefix}{FSDP_WRAPPED_MODULE}.")
20662068
fqn = f"{prefix}{FSDP_WRAPPED_MODULE}.{FLAT_PARAM}"
20672069
if fqn not in state_dict:
2068-
assert getattr(self.module, FLAT_PARAM, None) is None, (
2069-
"No flat parameter in state_dict but self.module.flat_param is not None"
2070+
assert getattr(self._fsdp_wrapped_module, FLAT_PARAM, None) is None, (
2071+
"No flat parameter in state_dict but self._fsdp_wrapped_module.flat_param is not None"
20702072
)
20712073
return
20722074
load_tensor = state_dict[fqn]
@@ -2081,7 +2083,7 @@ def _local_pre_load_state_dict_hook(
20812083

20822084
# Get the metada of the flat_param to decide whether to pad the loaded
20832085
# tensor.
2084-
flat_param = self.module.flat_param
2086+
flat_param = self._fsdp_wrapped_module.flat_param
20852087
assert flat_param is not None
20862088
if flat_param.num_padded not in (0, flat_param.numel()):
20872089
assert load_tensor.numel() < flat_param.numel(), (
@@ -2104,10 +2106,10 @@ def _sharded_pre_load_state_dict_hook(
21042106
a new FlatParameter and shards the new FlatParameter to the local chunk.
21052107
"""
21062108
_replace_by_prefix(state_dict, prefix, prefix + f"{FSDP_WRAPPED_MODULE}.")
2107-
if self.module.no_params:
2109+
if self._fsdp_wrapped_module.no_params:
21082110
return
21092111

2110-
if not self.module.flat_param._is_sharded:
2112+
if not self._fsdp_wrapped_module.flat_param._is_sharded:
21112113
raise RuntimeError(
21122114
"load_sharded_state_dict can only be called when parameters "
21132115
"are flatten and sharded."
@@ -2118,7 +2120,7 @@ def _sharded_pre_load_state_dict_hook(
21182120
# gather all the parameters in this layer. This can be achieved by
21192121
# concatenated all the local shards and then append the padding.
21202122
# https://github.com/pytorch/pytorch/issues/77461
2121-
for module_name, _, param_name in self.module.flat_param._param_infos:
2123+
for module_name, _, param_name in self._fsdp_wrapped_module.flat_param._param_infos:
21222124
module_name = module_name.replace(f"{FPW_MODULE}.", "")
21232125
module_name = module_name.replace(f"{FPW_MODULE}", "")
21242126
if module_name:
@@ -2145,7 +2147,7 @@ def _sharded_pre_load_state_dict_hook(
21452147
nonsharded_tensors.append(tensor)
21462148

21472149
# Create a new flat_param from the loaded, non-sharded tensors.
2148-
flat_param = self.module.flat_param
2150+
flat_param = self._fsdp_wrapped_module.flat_param
21492151
loaded_flat_param = FlatParameter(nonsharded_tensors, requires_grad=False)
21502152

21512153
# Get the chunk from the loaded flat_param for the local rank.
@@ -2293,7 +2295,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
22932295
# These need to be re-registered every forward pass in some cases where grad_fn
22942296
# is mutated.
22952297
self._register_post_backward_hooks()
2296-
outputs = self.module(*args, **kwargs)
2298+
outputs = self._fsdp_wrapped_module(*args, **kwargs)
22972299

22982300
if self not in self._fsdp_graph_order:
22992301
self._my_fsdp_idx_in_graph = len(self._fsdp_graph_order)
@@ -2438,7 +2440,7 @@ def _free_full_params_and_use_local_shard(params_to_free):
24382440
# full parameters.
24392441
with contextlib.ExitStack() as stack:
24402442
# Invariant: rank == 0 or !rank0_only
2441-
stack.enter_context(self.module.unflatten_params())
2443+
stack.enter_context(self._fsdp_wrapped_module.unflatten_params())
24422444
try:
24432445
yield
24442446
finally:

0 commit comments

Comments
 (0)