Skip to content

Not working with Double Precision Networks #44

@cSchubes

Description

@cSchubes

When trying to use the summary function on a network after calling net.double(), an error is returned:

RuntimeError: Expected object of type torch.FloatTensor but found type torch.DoubleTensor for argument #2 'mat2'

The calling code is:

net = FCNet(obs_space=checkpoint['env'].observation_space.shape[0], action_space=checkpoint['env'].action_space.n,
          shape=netConfig['shape'], activation=netConfig['activation'], dropout_rate=netConfig['dropout'],
          bias=netConfig['bias'])

if TENSOR_TYPE == torch.FloatTensor:
    net = net.float()
elif TENSOR_TYPE == torch.DoubleTensor:
    net = net.double()
netArch = buildNetArch(net)
summary(net, input_size=(1, checkpoint['env'].observation_space.shape[0]))

I've gotten this error outside of torchsummary in the past when trying to pass in tensors of the wrong type. It seems like torchsummary passes a FloatTensor through the network to get the metrics, but this doesn't work on a network specifically made to use DoubleTensors. Is this correct?

If so, I could maybe try and fix it myself.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions