diff --git a/torchsummary/torchsummary.py b/torchsummary/torchsummary.py index 1ed065f..440702b 100644 --- a/torchsummary/torchsummary.py +++ b/torchsummary/torchsummary.py @@ -1,9 +1,12 @@ +import typing +from collections import OrderedDict + +import numpy as np import torch import torch.nn as nn from torch.autograd import Variable +from torch.utils._pytree import tree_map -from collections import OrderedDict -import numpy as np def summary(model, input_size, batch_size=-1, device=torch.device('cuda:0'), dtypes=None): @@ -14,6 +17,12 @@ def summary(model, input_size, batch_size=-1, device=torch.device('cuda:0'), dty return params_info +def tensor_size(tensor: typing.Any) -> typing.Optional[typing.List[int]]: + if not isinstance(tensor, torch.Tensor): + return None + return list(tensor.size()) + + def summary_string(model, input_size, batch_size=-1, device=torch.device('cuda:0'), dtypes=None): if dtypes == None: dtypes = [torch.FloatTensor]*len(input_size) @@ -27,22 +36,17 @@ def hook(module, input, output): m_key = "%s-%i" % (class_name, module_idx + 1) summary[m_key] = OrderedDict() - summary[m_key]["input_shape"] = list(input[0].size()) - summary[m_key]["input_shape"][0] = batch_size - if isinstance(output, (list, tuple)): - summary[m_key]["output_shape"] = [ - [-1] + list(o.size())[1:] for o in output - ] - else: - summary[m_key]["output_shape"] = list(output.size()) - summary[m_key]["output_shape"][0] = batch_size + input_shape = tree_map(tensor_size, input) + if len(input_shape) == 1: # backwards compatibility + input_shape = input_shape[0] + summary[m_key]["input_shape"] = input_shape + summary[m_key]["output_shape"] = tree_map(tensor_size, input) params = 0 - if hasattr(module, "weight") and hasattr(module.weight, "size"): - params += torch.prod(torch.LongTensor(list(module.weight.size()))) - summary[m_key]["trainable"] = module.weight.requires_grad - if hasattr(module, "bias") and hasattr(module.bias, "size"): - params += torch.prod(torch.LongTensor(list(module.bias.size()))) + summary[m_key]["trainable"] = False + for p in module.parameters(recurse=False): + params += np.prod(list(p.size())) + summary[m_key]["trainable"] |= p.requires_grad summary[m_key]["nb_params"] = params if (