Skip to content

[Feature Request] Allow CUDNN based modules in losses #2741

Open
@splatter96

Description

@splatter96

Motivation

Currently we cannot use CUDNN based modules in loss modules as they are incompatible with vmap used in most of the losses.
Particularly for RNN modules this leaves a lot of performance on the table as the CUDNN modules are highly optimized.
I have compared the performance of a CUDNN based LSTM module and Python based version in the DiscreteSACLoss and in my use case saw a performance increase of almost 4x. For this I replaced the the vmapped calls to the Q network in the loss with a loop over regular calls. For this I just simply subclassed the DisreteSACLoss like this:

class CustomDiscreteSACLoss(DiscreteSACLoss):
    # override the default vmapped function
    def _make_vmap(self):
        def customvmap(td, params):
            td_out = []

            for p in params.unbind(0):
                with p.to_module(self.qvalue_network):
                    td_out.append(self.qvalue_network(td))

            return torch.stack(td_out, 0)

        self._vmap_qnetworkN0 = customvmap

I also tried to compile the python based module to gain better performance as proposed in vmoens reply in #1717, but could never match the performance of using the CUDNN based module.

Solution

I would suggest to add a parameter to the constructor of the affected losses to select between vmapping or no vmapping. So the user can select which version performs best for him as the performance advantages of using CUDNN modules might outweight the performance lost from not using vmap as in my case.

I can also go for implementing this solution in a PR if this feature is desired.

Checklist

  • I have checked that there is no similar issue in the repo (required)

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions