@@ -34,13 +34,13 @@ def avg_ch_var_residual(model, input, output):
3434
3535class ActivationStatsHook :
3636 """Iterates through each of `model`'s modules and matches modules using unix pattern
37- matching based on `layer_name ` and `layer_type`. If there is match, this class adds
38- creates a hook using `hook_fn` and adds it to the module.
37+ matching based on `hook_fn_locs ` and registers `hook_fn` to the module if there is
38+ a match.
3939
4040 Arguments:
4141 model (nn.Module): model from which we will extract the activation stats
42- layer_names ( str): The layer name to look for to register forward
43- hook. Example, 'stem', 'stages'
42+ hook_fn_locs (List[ str] ): List of `hook_fn` locations based on Unix type string
43+ matching with the name of model's modules.
4444 hook_fns (List[Callable]): List of hook functions to be registered at every
4545 module in `layer_names`.
4646
@@ -51,6 +51,9 @@ def __init__(self, model, hook_fn_locs, hook_fns):
5151 self .model = model
5252 self .hook_fn_locs = hook_fn_locs
5353 self .hook_fns = hook_fns
54+ if len (hook_fn_locs ) != len (hook_fns ):
55+ raise ValueError ("Please provide `hook_fns` for each `hook_fn_locs`, \
56+ their lengths are different." )
5457 self .stats = dict ((hook_fn .__name__ , []) for hook_fn in hook_fns )
5558 for hook_fn_loc , hook_fn in zip (hook_fn_locs , hook_fns ):
5659 self .register_hook (hook_fn_loc , hook_fn )
0 commit comments