@@ -1075,10 +1075,12 @@ def _check_wrapped(cls, begin_module, check_fn, err_fn):
1075
1075
raise ValueError (err_fn (mod ))
1076
1076
1077
1077
@property
1078
- def module (self ) -> FlattenParamsWrapper :
1079
- """make model.module accessible, just like DDP."""
1078
+ def module (self ) -> nn .Module :
1079
+ """Make model.module accessible, just like DDP. Return the
1080
+ underlying module without the flatten_params_wrapper
1081
+ """
1080
1082
assert isinstance (self ._fsdp_wrapped_module , FlattenParamsWrapper )
1081
- return self ._fsdp_wrapped_module
1083
+ return self ._fsdp_wrapped_module . module
1082
1084
1083
1085
def check_is_root (self ) -> bool :
1084
1086
self ._lazy_init ()
@@ -1433,11 +1435,11 @@ def __getattr__(self, name: str) -> Any:
1433
1435
try :
1434
1436
return super ().__getattr__ (name ) # defer to nn.Module's logic
1435
1437
except AttributeError :
1436
- return getattr (self .module , name )
1438
+ return getattr (self ._fsdp_wrapped_module , name )
1437
1439
1438
1440
def __getitem__ (self , key : int ) -> Any :
1439
1441
"""Forward indexing calls in case the module is a nn.Sequential."""
1440
- return self .module .__getitem__ (key ) # type: ignore[operator]
1442
+ return self ._fsdp_wrapped_module .__getitem__ (key ) # type: ignore[operator]
1441
1443
1442
1444
def _reset_lazy_init (self ) -> None :
1443
1445
"""
@@ -1824,14 +1826,14 @@ def _local_post_state_dict_hook(
1824
1826
will happen. The underlying storage is the same.
1825
1827
"""
1826
1828
_replace_by_prefix (state_dict , f"{ prefix } { FSDP_WRAPPED_MODULE } ." , prefix )
1827
- if self .module .no_params :
1829
+ if self ._fsdp_wrapped_module .no_params :
1828
1830
return state_dict
1829
1831
1830
1832
# state_dict[f"{prefix}{FLAT_PARAM}"] exists and has the same tensor
1831
1833
# value as the flat_param but it is a pure Tensor because
1832
1834
# nn.Module.state_dict() will detach the parameter. Therefore, we need
1833
1835
# to get flat_param from the FlattenParamsWrapper to get the metadata.
1834
- flat_param = getattr (self .module , FLAT_PARAM , None )
1836
+ flat_param = getattr (self ._fsdp_wrapped_module , FLAT_PARAM , None )
1835
1837
# Construct a ShardedTensor from the flat_param.
1836
1838
full_numel = flat_param .full_numel
1837
1839
shard_offset = flat_param .numel () * self .rank
@@ -1858,10 +1860,10 @@ def _sharded_post_state_dict_hook(
1858
1860
with a unflattened, sharded parameter (a ShardedTensor).
1859
1861
"""
1860
1862
_replace_by_prefix (state_dict , f"{ prefix } { FSDP_WRAPPED_MODULE } ." , prefix )
1861
- if self .module .no_params :
1863
+ if self ._fsdp_wrapped_module .no_params :
1862
1864
return state_dict
1863
1865
1864
- for module_name , _ , param_name in self .module .orig_flat_param [0 ].param_info :
1866
+ for module_name , _ , param_name in self ._fsdp_wrapped_module .orig_flat_param [0 ].param_info :
1865
1867
module_name = module_name .replace (f"{ FPW_MODULE } ." , "" )
1866
1868
module_name = module_name .replace (f"{ FPW_MODULE } " , "" )
1867
1869
if module_name :
@@ -1989,8 +1991,8 @@ def state_dict(self, *args, **kwargs):
1989
1991
1990
1992
elif self ._state_dict_type == StateDictType .LOCAL_STATE_DICT :
1991
1993
if (
1992
- self .module .flat_param is not None and
1993
- not self .module .flat_param ._is_sharded
1994
+ self ._fsdp_wrapped_module .flat_param is not None and
1995
+ not self ._fsdp_wrapped_module .flat_param ._is_sharded
1994
1996
):
1995
1997
raise RuntimeError (
1996
1998
"local_state_dict can only be called "
@@ -2065,8 +2067,8 @@ def _local_pre_load_state_dict_hook(
2065
2067
_replace_by_prefix (state_dict , prefix , f"{ prefix } { FSDP_WRAPPED_MODULE } ." )
2066
2068
fqn = f"{ prefix } { FSDP_WRAPPED_MODULE } .{ FLAT_PARAM } "
2067
2069
if fqn not in state_dict :
2068
- assert getattr (self .module , FLAT_PARAM , None ) is None , (
2069
- "No flat parameter in state_dict but self.module .flat_param is not None"
2070
+ assert getattr (self ._fsdp_wrapped_module , FLAT_PARAM , None ) is None , (
2071
+ "No flat parameter in state_dict but self._fsdp_wrapped_module .flat_param is not None"
2070
2072
)
2071
2073
return
2072
2074
load_tensor = state_dict [fqn ]
@@ -2081,7 +2083,7 @@ def _local_pre_load_state_dict_hook(
2081
2083
2082
2084
# Get the metada of the flat_param to decide whether to pad the loaded
2083
2085
# tensor.
2084
- flat_param = self .module .flat_param
2086
+ flat_param = self ._fsdp_wrapped_module .flat_param
2085
2087
assert flat_param is not None
2086
2088
if flat_param .num_padded not in (0 , flat_param .numel ()):
2087
2089
assert load_tensor .numel () < flat_param .numel (), (
@@ -2104,10 +2106,10 @@ def _sharded_pre_load_state_dict_hook(
2104
2106
a new FlatParameter and shards the new FlatParameter to the local chunk.
2105
2107
"""
2106
2108
_replace_by_prefix (state_dict , prefix , prefix + f"{ FSDP_WRAPPED_MODULE } ." )
2107
- if self .module .no_params :
2109
+ if self ._fsdp_wrapped_module .no_params :
2108
2110
return
2109
2111
2110
- if not self .module .flat_param ._is_sharded :
2112
+ if not self ._fsdp_wrapped_module .flat_param ._is_sharded :
2111
2113
raise RuntimeError (
2112
2114
"load_sharded_state_dict can only be called when parameters "
2113
2115
"are flatten and sharded."
@@ -2118,7 +2120,7 @@ def _sharded_pre_load_state_dict_hook(
2118
2120
# gather all the parameters in this layer. This can be achieved by
2119
2121
# concatenated all the local shards and then append the padding.
2120
2122
# https://github.com/pytorch/pytorch/issues/77461
2121
- for module_name , _ , param_name in self .module .flat_param ._param_infos :
2123
+ for module_name , _ , param_name in self ._fsdp_wrapped_module .flat_param ._param_infos :
2122
2124
module_name = module_name .replace (f"{ FPW_MODULE } ." , "" )
2123
2125
module_name = module_name .replace (f"{ FPW_MODULE } " , "" )
2124
2126
if module_name :
@@ -2145,7 +2147,7 @@ def _sharded_pre_load_state_dict_hook(
2145
2147
nonsharded_tensors .append (tensor )
2146
2148
2147
2149
# Create a new flat_param from the loaded, non-sharded tensors.
2148
- flat_param = self .module .flat_param
2150
+ flat_param = self ._fsdp_wrapped_module .flat_param
2149
2151
loaded_flat_param = FlatParameter (nonsharded_tensors , requires_grad = False )
2150
2152
2151
2153
# Get the chunk from the loaded flat_param for the local rank.
@@ -2293,7 +2295,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
2293
2295
# These need to be re-registered every forward pass in some cases where grad_fn
2294
2296
# is mutated.
2295
2297
self ._register_post_backward_hooks ()
2296
- outputs = self .module (* args , ** kwargs )
2298
+ outputs = self ._fsdp_wrapped_module (* args , ** kwargs )
2297
2299
2298
2300
if self not in self ._fsdp_graph_order :
2299
2301
self ._my_fsdp_idx_in_graph = len (self ._fsdp_graph_order )
@@ -2438,7 +2440,7 @@ def _free_full_params_and_use_local_shard(params_to_free):
2438
2440
# full parameters.
2439
2441
with contextlib .ExitStack () as stack :
2440
2442
# Invariant: rank == 0 or !rank0_only
2441
- stack .enter_context (self .module .unflatten_params ())
2443
+ stack .enter_context (self ._fsdp_wrapped_module .unflatten_params ())
2442
2444
try :
2443
2445
yield
2444
2446
finally :
0 commit comments