Description
Motivation
Currently we cannot use CUDNN based modules in loss modules as they are incompatible with vmap used in most of the losses.
Particularly for RNN modules this leaves a lot of performance on the table as the CUDNN modules are highly optimized.
I have compared the performance of a CUDNN based LSTM module and Python based version in the DiscreteSACLoss and in my use case saw a performance increase of almost 4x. For this I replaced the the vmapped calls to the Q network in the loss with a loop over regular calls. For this I just simply subclassed the DisreteSACLoss like this:
class CustomDiscreteSACLoss(DiscreteSACLoss):
# override the default vmapped function
def _make_vmap(self):
def customvmap(td, params):
td_out = []
for p in params.unbind(0):
with p.to_module(self.qvalue_network):
td_out.append(self.qvalue_network(td))
return torch.stack(td_out, 0)
self._vmap_qnetworkN0 = customvmap
I also tried to compile the python based module to gain better performance as proposed in vmoens reply in #1717, but could never match the performance of using the CUDNN based module.
Solution
I would suggest to add a parameter to the constructor of the affected losses to select between vmapping or no vmapping. So the user can select which version performs best for him as the performance advantages of using CUDNN modules might outweight the performance lost from not using vmap as in my case.
I can also go for implementing this solution in a PR if this feature is desired.
Checklist
- I have checked that there is no similar issue in the repo (required)