Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Add device parameter to classes. #1647

Open
fnhirwa opened this issue Sep 4, 2024 · 3 comments
Open

[ENH] Add device parameter to classes. #1647

fnhirwa opened this issue Sep 4, 2024 · 3 comments
Labels
feature request New feature or request

Comments

@fnhirwa
Copy link
Member

fnhirwa commented Sep 4, 2024

  • PyTorch-Forecasting version: 1.0.0
  • PyTorch version: 2.4.0
  • Python version: 3.9.19
  • Operating System: MacOs(Darwin)

Expected behavior

The current implementation of different classes seems to be using devices in a fashion that I can say is complex and users have no control exactly on how they can easily switch between devices.

https://github.com/jdb78/pytorch-forecasting/blob/81aee6650ed3de0c3071c9ce1fce19eec7fc24a7/pytorch_forecasting/metrics/distributions.py#L470

Like here the device being used is the one of the input which is somehow complex for the interface being used.

I suggest the Introduction of the device parameter to the classes. This would give users control over the device control and switching, rather than relying entirely on input data devices.

@fkiraly fkiraly added the feature request New feature or request label Sep 4, 2024
@XinyuWuu
Copy link
Member

XinyuWuu commented Sep 5, 2024

The distribution is implemented by lightning. I am not sure if we can set it manually.

import lightning.pytorch as pl
trainer = pl.Trainer(
    accelerator="cpu",
)

@fnhirwa
Copy link
Member Author

fnhirwa commented Sep 5, 2024

I think we should find a better way to handle the accelerator case as I don't think it would be sensible to set these for users: https://github.com/jdb78/pytorch-forecasting/blob/bb6c8a2243c35ca35c2c0e14093d352430fee6d0/pytorch_forecasting/models/temporal_fusion_transformer/tuning.py#L177

We can pass it as some optional parameter.

@fkiraly
Copy link
Collaborator

fkiraly commented Sep 5, 2024

This is the tuning code, I think it's extraneous to the models - as long as the models themselves have it not hard-coded, I consider this less of a problem. I feel this one tuning routine does a lot of hard-coding and things extraneous to an otherwise very consistent architectural design.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature request New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants