2020from ._manipulate import checkpoint
2121
2222__all__ = [
23- 'FeatureInfo' , 'FeatureHooks' , 'FeatureDictNet ' , 'FeatureListNet ' , 'FeatureHookNet ' , 'FeatureGetterNet ' ,
24- 'feature_take_indices'
23+ 'FeatureInfo' , 'FeatureHooks' , 'FeatureBase ' , 'FeatureDictNet ' , 'FeatureListNet ' , 'FeatureHookNet ' ,
24+ 'FeatureGetterNet' , ' feature_take_indices'
2525]
2626
2727
@@ -227,7 +227,59 @@ def _get_return_layers(feature_info, out_map):
227227 return return_layers
228228
229229
230- class FeatureDictNet (nn .ModuleDict ):
230+ class FeatureBase (nn .Module ):
231+ """ Base class for feature extraction wrappers
232+
233+ Provides dict-like interface without inheriting from nn.ModuleDict to avoid FSDP2 issues.
234+ FSDP2's fully_shard has isinstance checks for (ModuleDict, ModuleList) that cause problems.
235+
236+ This class delegates dict operations to the underlying _modules OrderedDict.
237+ """
238+
239+ def __init__ (self ):
240+ super ().__init__ ()
241+ self .feature_info : Optional [FeatureInfo ] = None
242+ self .output_fmt : Optional [Format ] = None
243+ self .grad_checkpointing = False
244+
245+ def set_grad_checkpointing (self , enable : bool = True ):
246+ self .grad_checkpointing = enable
247+
248+ # Dict-like interface methods
249+ def __getitem__ (self , key : str ) -> nn .Module :
250+ return self ._modules [key ]
251+
252+ def __setitem__ (self , key : str , module : nn .Module ) -> None :
253+ self .add_module (key , module )
254+
255+ def __delitem__ (self , key : str ) -> None :
256+ del self ._modules [key ]
257+
258+ def __len__ (self ) -> int :
259+ return len (self ._modules )
260+
261+ def __iter__ (self ):
262+ return iter (self ._modules )
263+
264+ def __contains__ (self , key : str ) -> bool :
265+ return key in self ._modules
266+
267+ def keys (self ):
268+ return self ._modules .keys ()
269+
270+ def values (self ):
271+ return self ._modules .values ()
272+
273+ def items (self ):
274+ return self ._modules .items ()
275+
276+ def update (self , modules : Dict [str , nn .Module ]) -> None :
277+ """Update _modules with new modules."""
278+ for key , module in modules .items ():
279+ self .add_module (key , module )
280+
281+
282+ class FeatureDictNet (FeatureBase ):
231283 """ Feature extractor with OrderedDict return
232284
233285 Wrap a model and extract features as specified by the out indices, the network is
@@ -264,7 +316,6 @@ def __init__(
264316 self .feature_info = _get_feature_info (model , out_indices )
265317 self .output_fmt = Format (output_fmt )
266318 self .concat = feature_concat
267- self .grad_checkpointing = False
268319 self .return_layers = {}
269320
270321 return_layers = _get_return_layers (self .feature_info , out_map )
@@ -283,9 +334,6 @@ def __init__(
283334 f'Return layers ({ remaining } ) are not present in model'
284335 self .update (layers )
285336
286- def set_grad_checkpointing (self , enable : bool = True ):
287- self .grad_checkpointing = enable
288-
289337 def _collect (self , x ) -> (Dict [str , torch .Tensor ]):
290338 out = OrderedDict ()
291339 for i , (name , module ) in enumerate (self .items ()):
@@ -345,7 +393,7 @@ def forward(self, x) -> (List[torch.Tensor]):
345393 return list (self ._collect (x ).values ())
346394
347395
348- class FeatureHookNet (nn . ModuleDict ):
396+ class FeatureHookNet (FeatureBase ):
349397 """ FeatureHookNet
350398
351399 Wrap a model and extract features specified by the out indices using forward/forward-pre hooks.
@@ -386,7 +434,6 @@ def __init__(
386434 self .feature_info = _get_feature_info (model , out_indices )
387435 self .return_dict = return_dict
388436 self .output_fmt = Format (output_fmt )
389- self .grad_checkpointing = False
390437 if no_rewrite is None :
391438 no_rewrite = not flatten_sequential
392439 layers = OrderedDict ()
@@ -415,9 +462,6 @@ def __init__(
415462 self .update (layers )
416463 self .hooks = FeatureHooks (hooks , model .named_modules (), out_map = out_map )
417464
418- def set_grad_checkpointing (self , enable : bool = True ):
419- self .grad_checkpointing = enable
420-
421465 def forward (self , x ):
422466 for i , (name , module ) in enumerate (self .items ()):
423467 if self .grad_checkpointing and not torch .jit .is_scripting ():
@@ -432,7 +476,7 @@ def forward(self, x):
432476 return out if self .return_dict else list (out .values ())
433477
434478
435- class FeatureGetterNet (nn . ModuleDict ):
479+ class FeatureGetterNet (FeatureBase ):
436480 """ FeatureGetterNet
437481
438482 Wrap models with a feature getter method, like 'get_intermediate_layers'
@@ -472,6 +516,10 @@ def __init__(
472516 self .output_fmt = Format (output_fmt )
473517 self .norm = norm
474518
519+ def set_grad_checkpointing (self , enable : bool = True ):
520+ self .grad_checkpointing = enable
521+ self .model .set_grad_checkpointing (enable = enable )
522+
475523 def forward (self , x ):
476524 features = self .model .forward_intermediates (
477525 x ,
0 commit comments