how to use find_usable_cuda_devices
in lightning cli config.yaml?
#18622
Answered
by
mauvilsa
Jackiexiao
asked this question in
DDP / multi-GPU / multi-node
-
for example
I want to use find_usable_cuda_devices(1) instead of choosing one avaliable gpu manually |
Beta Was this translation helpful? Give feedback.
Answered by
mauvilsa
Sep 26, 2023
Replies: 1 comment
-
Unfortunately that is not supported. Though, you could do the following workaround. First subclass the trainer: import re
from lightning.fabric.accelerators import find_usable_cuda_devices
class CustomTrainer(Trainer):
def __init__(self, *args, **kwargs):
devices = kwargs.get("devices")
if (
isinstance(devices, str)
and re.match("^find_usable_cuda_devices\([0-9]*\)$", devices)
):
kwargs["devices"] = eval(devices)
super().__init__(*args, **kwargs) Then provide |
Beta Was this translation helpful? Give feedback.
0 replies
Answer selected by
Jackiexiao
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Unfortunately that is not supported. Though, you could do the following workaround. First subclass the trainer:
Then provide
trainer_class=CustomTrainer
when instantiating theLightningCLI
class. In the config.yaml then it could be written asdevices: find_usable_cuda_devices(1)
.