Skip to content

Commit bd27208

Browse files
committed
Create own base class for Feature*Net extraction wrappers to avoid FSDP issues.
1 parent ae4d1bb commit bd27208

File tree

1 file changed

+61
-13
lines changed

1 file changed

+61
-13
lines changed

timm/models/_features.py

Lines changed: 61 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
from ._manipulate import checkpoint
2121

2222
__all__ = [
23-
'FeatureInfo', 'FeatureHooks', 'FeatureDictNet', 'FeatureListNet', 'FeatureHookNet', 'FeatureGetterNet',
24-
'feature_take_indices'
23+
'FeatureInfo', 'FeatureHooks', 'FeatureBase', 'FeatureDictNet', 'FeatureListNet', 'FeatureHookNet',
24+
'FeatureGetterNet', 'feature_take_indices'
2525
]
2626

2727

@@ -227,7 +227,59 @@ def _get_return_layers(feature_info, out_map):
227227
return return_layers
228228

229229

230-
class FeatureDictNet(nn.ModuleDict):
230+
class FeatureBase(nn.Module):
231+
""" Base class for feature extraction wrappers
232+
233+
Provides dict-like interface without inheriting from nn.ModuleDict to avoid FSDP2 issues.
234+
FSDP2's fully_shard has isinstance checks for (ModuleDict, ModuleList) that cause problems.
235+
236+
This class delegates dict operations to the underlying _modules OrderedDict.
237+
"""
238+
239+
def __init__(self):
240+
super().__init__()
241+
self.feature_info: Optional[FeatureInfo] = None
242+
self.output_fmt: Optional[Format] = None
243+
self.grad_checkpointing = False
244+
245+
def set_grad_checkpointing(self, enable: bool = True):
246+
self.grad_checkpointing = enable
247+
248+
# Dict-like interface methods
249+
def __getitem__(self, key: str) -> nn.Module:
250+
return self._modules[key]
251+
252+
def __setitem__(self, key: str, module: nn.Module) -> None:
253+
self.add_module(key, module)
254+
255+
def __delitem__(self, key: str) -> None:
256+
del self._modules[key]
257+
258+
def __len__(self) -> int:
259+
return len(self._modules)
260+
261+
def __iter__(self):
262+
return iter(self._modules)
263+
264+
def __contains__(self, key: str) -> bool:
265+
return key in self._modules
266+
267+
def keys(self):
268+
return self._modules.keys()
269+
270+
def values(self):
271+
return self._modules.values()
272+
273+
def items(self):
274+
return self._modules.items()
275+
276+
def update(self, modules: Dict[str, nn.Module]) -> None:
277+
"""Update _modules with new modules."""
278+
for key, module in modules.items():
279+
self.add_module(key, module)
280+
281+
282+
class FeatureDictNet(FeatureBase):
231283
""" Feature extractor with OrderedDict return
232284
233285
Wrap a model and extract features as specified by the out indices, the network is
@@ -264,7 +316,6 @@ def __init__(
264316
self.feature_info = _get_feature_info(model, out_indices)
265317
self.output_fmt = Format(output_fmt)
266318
self.concat = feature_concat
267-
self.grad_checkpointing = False
268319
self.return_layers = {}
269320

270321
return_layers = _get_return_layers(self.feature_info, out_map)
@@ -283,9 +334,6 @@ def __init__(
283334
f'Return layers ({remaining}) are not present in model'
284335
self.update(layers)
285336

286-
def set_grad_checkpointing(self, enable: bool = True):
287-
self.grad_checkpointing = enable
288-
289337
def _collect(self, x) -> (Dict[str, torch.Tensor]):
290338
out = OrderedDict()
291339
for i, (name, module) in enumerate(self.items()):
@@ -345,7 +393,7 @@ def forward(self, x) -> (List[torch.Tensor]):
345393
return list(self._collect(x).values())
346394

347395

348-
class FeatureHookNet(nn.ModuleDict):
396+
class FeatureHookNet(FeatureBase):
349397
""" FeatureHookNet
350398
351399
Wrap a model and extract features specified by the out indices using forward/forward-pre hooks.
@@ -386,7 +434,6 @@ def __init__(
386434
self.feature_info = _get_feature_info(model, out_indices)
387435
self.return_dict = return_dict
388436
self.output_fmt = Format(output_fmt)
389-
self.grad_checkpointing = False
390437
if no_rewrite is None:
391438
no_rewrite = not flatten_sequential
392439
layers = OrderedDict()
@@ -415,9 +462,6 @@ def __init__(
415462
self.update(layers)
416463
self.hooks = FeatureHooks(hooks, model.named_modules(), out_map=out_map)
417464

418-
def set_grad_checkpointing(self, enable: bool = True):
419-
self.grad_checkpointing = enable
420-
421465
def forward(self, x):
422466
for i, (name, module) in enumerate(self.items()):
423467
if self.grad_checkpointing and not torch.jit.is_scripting():
@@ -432,7 +476,7 @@ def forward(self, x):
432476
return out if self.return_dict else list(out.values())
433477

434478

435-
class FeatureGetterNet(nn.ModuleDict):
479+
class FeatureGetterNet(FeatureBase):
436480
""" FeatureGetterNet
437481
438482
Wrap models with a feature getter method, like 'get_intermediate_layers'
@@ -472,6 +516,10 @@ def __init__(
472516
self.output_fmt = Format(output_fmt)
473517
self.norm = norm
474518

519+
def set_grad_checkpointing(self, enable: bool = True):
520+
self.grad_checkpointing = enable
521+
self.model.set_grad_checkpointing(enable=enable)
522+
475523
def forward(self, x):
476524
features = self.model.forward_intermediates(
477525
x,

0 commit comments

Comments
 (0)