You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Currently we cannot use CUDNN based modules in loss modules as they are incompatible with vmap used in most of the losses.
Particularly for RNN modules this leaves a lot of performance on the table as the CUDNN modules are highly optimized.
I have compared the performance of a CUDNN based LSTM module and Python based version in the DiscreteSACLoss and in my use case saw a performance increase of almost 4x. For this I replaced the the vmapped calls to the Q network in the loss with a loop over regular calls. For this I just simply subclassed the DisreteSACLoss like this:
class CustomDiscreteSACLoss(DiscreteSACLoss):
# override the default vmapped function
def _make_vmap(self):
def customvmap(td, params):
td_out = []
for p in params.unbind(0):
with p.to_module(self.qvalue_network):
td_out.append(self.qvalue_network(td))
return torch.stack(td_out, 0)
self._vmap_qnetworkN0 = customvmap
I also tried to compile the python based module to gain better performance as proposed in vmoens reply in #1717, but could never match the performance of using the CUDNN based module.
Solution
I would suggest to add a parameter to the constructor of the affected losses to select between vmapping or no vmapping. So the user can select which version performs best for him as the performance advantages of using CUDNN modules might outweight the performance lost from not using vmap as in my case.
I can also go for implementing this solution in a PR if this feature is desired.
Checklist
I have checked that there is no similar issue in the repo (required)
The text was updated successfully, but these errors were encountered:
You can but it's quite an undertaking!
I'm thinking of this API:
@make_non_functional("_critic_loss_non_functional")defcritic_loss(self, tensordict):
... # do stuff with functional modulesdef_critic_loss_non_functional(self, tensordict):
... # do stuff with regular modules
That way we can clearly separate the functional and non-functional codes across losses.
Wdyt?
If you think it'd be a cool feature, i'm happy to review a PR.
I'm also happy to do it myself but IDK when I will have the bandwidth to do that, probably not before the end of the week.
Motivation
Currently we cannot use CUDNN based modules in loss modules as they are incompatible with vmap used in most of the losses.
Particularly for RNN modules this leaves a lot of performance on the table as the CUDNN modules are highly optimized.
I have compared the performance of a CUDNN based LSTM module and Python based version in the DiscreteSACLoss and in my use case saw a performance increase of almost 4x. For this I replaced the the vmapped calls to the Q network in the loss with a loop over regular calls. For this I just simply subclassed the DisreteSACLoss like this:
I also tried to compile the python based module to gain better performance as proposed in vmoens reply in #1717, but could never match the performance of using the CUDNN based module.
Solution
I would suggest to add a parameter to the constructor of the affected losses to select between vmapping or no vmapping. So the user can select which version performs best for him as the performance advantages of using CUDNN modules might outweight the performance lost from not using vmap as in my case.
I can also go for implementing this solution in a PR if this feature is desired.
Checklist
The text was updated successfully, but these errors were encountered: