Skip to content

Commit 2319cbb

Browse files
authored
Merge pull request #525 from amaarora/spp
Add `ActivationStatsHook` to allow extracting activation stats for Signal Propogation Plots
2 parents a2727c1 + 6b18061 commit 2319cbb

File tree

1 file changed

+77
-1
lines changed

1 file changed

+77
-1
lines changed

timm/utils/model.py

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
Hacked together by / Copyright 2020 Ross Wightman
44
"""
55
from .model_ema import ModelEma
6-
6+
import torch
7+
import fnmatch
78

89
def unwrap_model(model):
910
if isinstance(model, ModelEma):
@@ -14,3 +15,78 @@ def unwrap_model(model):
1415

1516
def get_state_dict(model, unwrap_fn=unwrap_model):
1617
return unwrap_fn(model).state_dict()
18+
19+
20+
def avg_sq_ch_mean(model, input, output):
21+
"calculate average channel square mean of output activations"
22+
return torch.mean(output.mean(axis=[0,2,3])**2).item()
23+
24+
25+
def avg_ch_var(model, input, output):
26+
"calculate average channel variance of output activations"
27+
return torch.mean(output.var(axis=[0,2,3])).item()\
28+
29+
30+
def avg_ch_var_residual(model, input, output):
31+
"calculate average channel variance of output activations"
32+
return torch.mean(output.var(axis=[0,2,3])).item()
33+
34+
35+
class ActivationStatsHook:
36+
"""Iterates through each of `model`'s modules and matches modules using unix pattern
37+
matching based on `hook_fn_locs` and registers `hook_fn` to the module if there is
38+
a match.
39+
40+
Arguments:
41+
model (nn.Module): model from which we will extract the activation stats
42+
hook_fn_locs (List[str]): List of `hook_fn` locations based on Unix type string
43+
matching with the name of model's modules.
44+
hook_fns (List[Callable]): List of hook functions to be registered at every
45+
module in `layer_names`.
46+
47+
Inspiration from https://docs.fast.ai/callback.hook.html.
48+
49+
Refer to https://gist.github.com/amaarora/6e56942fcb46e67ba203f3009b30d950 for an example
50+
on how to plot Signal Propogation Plots using `ActivationStatsHook`.
51+
"""
52+
53+
def __init__(self, model, hook_fn_locs, hook_fns):
54+
self.model = model
55+
self.hook_fn_locs = hook_fn_locs
56+
self.hook_fns = hook_fns
57+
if len(hook_fn_locs) != len(hook_fns):
58+
raise ValueError("Please provide `hook_fns` for each `hook_fn_locs`, \
59+
their lengths are different.")
60+
self.stats = dict((hook_fn.__name__, []) for hook_fn in hook_fns)
61+
for hook_fn_loc, hook_fn in zip(hook_fn_locs, hook_fns):
62+
self.register_hook(hook_fn_loc, hook_fn)
63+
64+
def _create_hook(self, hook_fn):
65+
def append_activation_stats(module, input, output):
66+
out = hook_fn(module, input, output)
67+
self.stats[hook_fn.__name__].append(out)
68+
return append_activation_stats
69+
70+
def register_hook(self, hook_fn_loc, hook_fn):
71+
for name, module in self.model.named_modules():
72+
if not fnmatch.fnmatch(name, hook_fn_loc):
73+
continue
74+
module.register_forward_hook(self._create_hook(hook_fn))
75+
76+
77+
def extract_spp_stats(model,
78+
hook_fn_locs,
79+
hook_fns,
80+
input_shape=[8, 3, 224, 224]):
81+
"""Extract average square channel mean and variance of activations during
82+
forward pass to plot Signal Propogation Plots (SPP).
83+
84+
Paper: https://arxiv.org/abs/2101.08692
85+
86+
Example Usage: https://gist.github.com/amaarora/6e56942fcb46e67ba203f3009b30d950
87+
"""
88+
x = torch.normal(0., 1., input_shape)
89+
hook = ActivationStatsHook(model, hook_fn_locs=hook_fn_locs, hook_fns=hook_fns)
90+
_ = model(x)
91+
return hook.stats
92+

0 commit comments

Comments
 (0)