From d24da94aeeb38c87ab00bcba2fa25c94cbb7bbdb Mon Sep 17 00:00:00 2001 From: Lucas Nestler <39779310+ClashLuke@users.noreply.github.com> Date: Fri, 5 Aug 2022 11:46:41 +0200 Subject: [PATCH 1/3] feat: allow more input shapes and weight locations --- torchsummary/torchsummary.py | 35 +++++++++++++++++++---------------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/torchsummary/torchsummary.py b/torchsummary/torchsummary.py index 1ed065f..8034bae 100644 --- a/torchsummary/torchsummary.py +++ b/torchsummary/torchsummary.py @@ -1,9 +1,12 @@ +import typing +from collections import OrderedDict + +import numpy as np import torch import torch.nn as nn from torch.autograd import Variable +from torch.utils._pytree import tree_map -from collections import OrderedDict -import numpy as np def summary(model, input_size, batch_size=-1, device=torch.device('cuda:0'), dtypes=None): @@ -14,6 +17,12 @@ def summary(model, input_size, batch_size=-1, device=torch.device('cuda:0'), dty return params_info +def tensor_size(tensor: typing.Any) -> typing.Optional[typing.List[int]]: + if not isinstance(tensor, torch.Tensor): + return None + return list(tensor.size()) + + def summary_string(model, input_size, batch_size=-1, device=torch.device('cuda:0'), dtypes=None): if dtypes == None: dtypes = [torch.FloatTensor]*len(input_size) @@ -27,22 +36,16 @@ def hook(module, input, output): m_key = "%s-%i" % (class_name, module_idx + 1) summary[m_key] = OrderedDict() - summary[m_key]["input_shape"] = list(input[0].size()) - summary[m_key]["input_shape"][0] = batch_size - if isinstance(output, (list, tuple)): - summary[m_key]["output_shape"] = [ - [-1] + list(o.size())[1:] for o in output - ] - else: - summary[m_key]["output_shape"] = list(output.size()) - summary[m_key]["output_shape"][0] = batch_size + input_shape = tree_map(tensor_size, input) + if len(input_shape) == 1: # backwards compatibility + input_shape = input_shape[0] + summary[m_key]["input_shape"] = input_shape + summary[m_key]["output_shape"] = tree_map(tensor_size, input) params = 0 - if hasattr(module, "weight") and hasattr(module.weight, "size"): - params += torch.prod(torch.LongTensor(list(module.weight.size()))) - summary[m_key]["trainable"] = module.weight.requires_grad - if hasattr(module, "bias") and hasattr(module.bias, "size"): - params += torch.prod(torch.LongTensor(list(module.bias.size()))) + for p in module.parameters(recurse=False): + params += prod(list(p.size()))) + summary[m_key]["trainable"] |= p.requires_grad summary[m_key]["nb_params"] = params if ( From fc6667dfb84b26b5aa3c566be39778b043c4549d Mon Sep 17 00:00:00 2001 From: Lucas Nestler <39779310+ClashLuke@users.noreply.github.com> Date: Fri, 5 Aug 2022 11:52:50 +0200 Subject: [PATCH 2/3] fix: remove trailing bracket --- torchsummary/torchsummary.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchsummary/torchsummary.py b/torchsummary/torchsummary.py index 8034bae..d0670f8 100644 --- a/torchsummary/torchsummary.py +++ b/torchsummary/torchsummary.py @@ -44,7 +44,7 @@ def hook(module, input, output): params = 0 for p in module.parameters(recurse=False): - params += prod(list(p.size()))) + params += np.prod(list(p.size())) summary[m_key]["trainable"] |= p.requires_grad summary[m_key]["nb_params"] = params From 784f1e1b33c600fdbdb77931028629c7d0999798 Mon Sep 17 00:00:00 2001 From: Lucas Nestler <39779310+ClashLuke@users.noreply.github.com> Date: Fri, 5 Aug 2022 11:55:02 +0200 Subject: [PATCH 3/3] fix: initialize trainable --- torchsummary/torchsummary.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchsummary/torchsummary.py b/torchsummary/torchsummary.py index d0670f8..440702b 100644 --- a/torchsummary/torchsummary.py +++ b/torchsummary/torchsummary.py @@ -43,6 +43,7 @@ def hook(module, input, output): summary[m_key]["output_shape"] = tree_map(tensor_size, input) params = 0 + summary[m_key]["trainable"] = False for p in module.parameters(recurse=False): params += np.prod(list(p.size())) summary[m_key]["trainable"] |= p.requires_grad