Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request] Allow CUDNN based modules in losses #2741

Open
1 task done
splatter96 opened this issue Feb 3, 2025 · 3 comments
Open
1 task done

[Feature Request] Allow CUDNN based modules in losses #2741

splatter96 opened this issue Feb 3, 2025 · 3 comments
Assignees
Labels
enhancement New feature or request

Comments

@splatter96
Copy link

Motivation

Currently we cannot use CUDNN based modules in loss modules as they are incompatible with vmap used in most of the losses.
Particularly for RNN modules this leaves a lot of performance on the table as the CUDNN modules are highly optimized.
I have compared the performance of a CUDNN based LSTM module and Python based version in the DiscreteSACLoss and in my use case saw a performance increase of almost 4x. For this I replaced the the vmapped calls to the Q network in the loss with a loop over regular calls. For this I just simply subclassed the DisreteSACLoss like this:

class CustomDiscreteSACLoss(DiscreteSACLoss):
    # override the default vmapped function
    def _make_vmap(self):
        def customvmap(td, params):
            td_out = []

            for p in params.unbind(0):
                with p.to_module(self.qvalue_network):
                    td_out.append(self.qvalue_network(td))

            return torch.stack(td_out, 0)

        self._vmap_qnetworkN0 = customvmap

I also tried to compile the python based module to gain better performance as proposed in vmoens reply in #1717, but could never match the performance of using the CUDNN based module.

Solution

I would suggest to add a parameter to the constructor of the affected losses to select between vmapping or no vmapping. So the user can select which version performs best for him as the performance advantages of using CUDNN modules might outweight the performance lost from not using vmap as in my case.

I can also go for implementing this solution in a PR if this feature is desired.

Checklist

  • I have checked that there is no similar issue in the repo (required)
@splatter96 splatter96 added the enhancement New feature or request label Feb 3, 2025
@vmoens
Copy link
Contributor

vmoens commented Feb 3, 2025

Yeah we should make all the loss optionally non-functional
thanks for pinging!

@splatter96
Copy link
Author

Should I prepare a PR for this?

@vmoens
Copy link
Contributor

vmoens commented Feb 3, 2025

You can but it's quite an undertaking!
I'm thinking of this API:

@make_non_functional("_critic_loss_non_functional")
def critic_loss(self, tensordict):
    ... # do stuff with functional modules
def _critic_loss_non_functional(self, tensordict):
    ... # do stuff with regular modules

and the decorator would be defined by

def make_non_functional(alt_func):
    def decorator(func):
        def wrapper(self, *args, **kwargs):
            if self.functional:
                return func(self, *args, **kwargs)
            else:
                return getattr(self, alt_func)(*args, **kwargs)
        return wrapper
    return decorator

That way we can clearly separate the functional and non-functional codes across losses.

Wdyt?

If you think it'd be a cool feature, i'm happy to review a PR.
I'm also happy to do it myself but IDK when I will have the bandwidth to do that, probably not before the end of the week.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants