Skip to content

Commit c291c61

Browse files
committed
[Refactor] Use default device instead of CPU in losses
ghstack-source-id: d52131545a36592f1da500cba1a663052094b028 Pull Request resolved: #2687
1 parent e05b160 commit c291c61

File tree

7 files changed

+8
-8
lines changed

7 files changed

+8
-8
lines changed

torchrl/objectives/cql.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ def __init__(
324324
try:
325325
device = next(self.parameters()).device
326326
except AttributeError:
327-
device = torch.device("cpu")
327+
device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))()
328328
self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device))
329329
if bool(min_alpha) ^ bool(max_alpha):
330330
min_alpha = min_alpha if min_alpha else 0.0

torchrl/objectives/crossq.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ def __init__(
307307
try:
308308
device = next(self.parameters()).device
309309
except AttributeError:
310-
device = torch.device("cpu")
310+
device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))()
311311
self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device))
312312
if bool(min_alpha) ^ bool(max_alpha):
313313
min_alpha = min_alpha if min_alpha else 0.0

torchrl/objectives/decision_transformer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def __init__(
104104
try:
105105
device = next(self.parameters()).device
106106
except AttributeError:
107-
device = torch.device("cpu")
107+
device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))()
108108

109109
self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device))
110110
if bool(min_alpha) ^ bool(max_alpha):

torchrl/objectives/deprecated.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def __init__(
203203
try:
204204
device = next(self.parameters()).device
205205
except AttributeError:
206-
device = torch.device("cpu")
206+
device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))()
207207

208208
self.register_buffer("alpha_init", torch.as_tensor(alpha_init, device=device))
209209
self.register_buffer(

torchrl/objectives/ppo.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,7 @@ def __init__(
393393
try:
394394
device = next(self.parameters()).device
395395
except (AttributeError, StopIteration):
396-
device = torch.device("cpu")
396+
device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))()
397397

398398
self.register_buffer("entropy_coef", torch.tensor(entropy_coef, device=device))
399399
if critic_coef is not None:

torchrl/objectives/redq.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ def __init__(
319319
try:
320320
device = next(self.parameters()).device
321321
except AttributeError:
322-
device = torch.device("cpu")
322+
device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))()
323323

324324
self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device))
325325
self.register_buffer(

torchrl/objectives/sac.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,7 @@ def __init__(
394394
try:
395395
device = next(self.parameters()).device
396396
except AttributeError:
397-
device = torch.device("cpu")
397+
device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))()
398398
self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device))
399399
if bool(min_alpha) ^ bool(max_alpha):
400400
min_alpha = min_alpha if min_alpha else 0.0
@@ -1121,7 +1121,7 @@ def __init__(
11211121
try:
11221122
device = next(self.parameters()).device
11231123
except AttributeError:
1124-
device = torch.device("cpu")
1124+
device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))()
11251125

11261126
self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device))
11271127
if bool(min_alpha) ^ bool(max_alpha):

0 commit comments

Comments
 (0)