33Hacked together by / Copyright 2020 Ross Wightman 
44""" 
55from  .model_ema  import  ModelEma 
6- 
6+ import  torch  
7+ import  fnmatch 
78
89def  unwrap_model (model ):
910    if  isinstance (model , ModelEma ):
@@ -14,3 +15,78 @@ def unwrap_model(model):
1415
1516def  get_state_dict (model , unwrap_fn = unwrap_model ):
1617    return  unwrap_fn (model ).state_dict ()
18+ 
19+ 
20+ def  avg_sq_ch_mean (model , input , output ): 
21+     "calculate average channel square mean of output activations" 
22+     return  torch .mean (output .mean (axis = [0 ,2 ,3 ])** 2 ).item ()
23+ 
24+ 
25+ def  avg_ch_var (model , input , output ): 
26+     "calculate average channel variance of output activations" 
27+     return  torch .mean (output .var (axis = [0 ,2 ,3 ])).item ()\
28+ 
29+ 
30+ def  avg_ch_var_residual (model , input , output ): 
31+     "calculate average channel variance of output activations" 
32+     return  torch .mean (output .var (axis = [0 ,2 ,3 ])).item ()
33+ 
34+ 
35+ class  ActivationStatsHook :
36+     """Iterates through each of `model`'s modules and matches modules using unix pattern  
37+     matching based on `hook_fn_locs` and registers `hook_fn` to the module if there is  
38+     a match.  
39+ 
40+     Arguments: 
41+         model (nn.Module): model from which we will extract the activation stats 
42+         hook_fn_locs (List[str]): List of `hook_fn` locations based on Unix type string  
43+             matching with the name of model's modules.  
44+         hook_fns (List[Callable]): List of hook functions to be registered at every 
45+             module in `layer_names`. 
46+      
47+     Inspiration from https://docs.fast.ai/callback.hook.html. 
48+ 
49+     Refer to https://gist.github.com/amaarora/6e56942fcb46e67ba203f3009b30d950 for an example  
50+     on how to plot Signal Propogation Plots using `ActivationStatsHook`. 
51+     """ 
52+ 
53+     def  __init__ (self , model , hook_fn_locs , hook_fns ):
54+         self .model  =  model 
55+         self .hook_fn_locs  =  hook_fn_locs 
56+         self .hook_fns  =  hook_fns 
57+         if  len (hook_fn_locs ) !=  len (hook_fns ):
58+             raise  ValueError ("Please provide `hook_fns` for each `hook_fn_locs`, \  
59+                  their lengths are different." )
60+         self .stats  =  dict ((hook_fn .__name__ , []) for  hook_fn  in  hook_fns )
61+         for  hook_fn_loc , hook_fn  in  zip (hook_fn_locs , hook_fns ): 
62+             self .register_hook (hook_fn_loc , hook_fn )
63+ 
64+     def  _create_hook (self , hook_fn ):
65+         def  append_activation_stats (module , input , output ):
66+             out  =  hook_fn (module , input , output )
67+             self .stats [hook_fn .__name__ ].append (out )
68+         return  append_activation_stats 
69+         
70+     def  register_hook (self , hook_fn_loc , hook_fn ):
71+         for  name , module  in  self .model .named_modules ():
72+             if  not  fnmatch .fnmatch (name , hook_fn_loc ):
73+                 continue 
74+             module .register_forward_hook (self ._create_hook (hook_fn ))
75+ 
76+ 
77+ def  extract_spp_stats (model , 
78+                       hook_fn_locs ,
79+                       hook_fns , 
80+                       input_shape = [8 , 3 , 224 , 224 ]):
81+     """Extract average square channel mean and variance of activations during  
82+     forward pass to plot Signal Propogation Plots (SPP). 
83+      
84+     Paper: https://arxiv.org/abs/2101.08692 
85+ 
86+     Example Usage: https://gist.github.com/amaarora/6e56942fcb46e67ba203f3009b30d950 
87+     """  
88+     x  =  torch .normal (0. , 1. , input_shape )
89+     hook  =  ActivationStatsHook (model , hook_fn_locs = hook_fn_locs , hook_fns = hook_fns )
90+     _  =  model (x )
91+     return  hook .stats 
92+     
0 commit comments