File tree 7 files changed +8
-8
lines changed
7 files changed +8
-8
lines changed Original file line number Diff line number Diff line change @@ -324,7 +324,7 @@ def __init__(
324
324
try :
325
325
device = next (self .parameters ()).device
326
326
except AttributeError :
327
- device = torch .device ("cpu" )
327
+ device = getattr ( torch , "get_default_device" , lambda : torch .device ("cpu" ))( )
328
328
self .register_buffer ("alpha_init" , torch .tensor (alpha_init , device = device ))
329
329
if bool (min_alpha ) ^ bool (max_alpha ):
330
330
min_alpha = min_alpha if min_alpha else 0.0
Original file line number Diff line number Diff line change @@ -307,7 +307,7 @@ def __init__(
307
307
try :
308
308
device = next (self .parameters ()).device
309
309
except AttributeError :
310
- device = torch .device ("cpu" )
310
+ device = getattr ( torch , "get_default_device" , lambda : torch .device ("cpu" ))( )
311
311
self .register_buffer ("alpha_init" , torch .tensor (alpha_init , device = device ))
312
312
if bool (min_alpha ) ^ bool (max_alpha ):
313
313
min_alpha = min_alpha if min_alpha else 0.0
Original file line number Diff line number Diff line change @@ -104,7 +104,7 @@ def __init__(
104
104
try :
105
105
device = next (self .parameters ()).device
106
106
except AttributeError :
107
- device = torch .device ("cpu" )
107
+ device = getattr ( torch , "get_default_device" , lambda : torch .device ("cpu" ))( )
108
108
109
109
self .register_buffer ("alpha_init" , torch .tensor (alpha_init , device = device ))
110
110
if bool (min_alpha ) ^ bool (max_alpha ):
Original file line number Diff line number Diff line change @@ -203,7 +203,7 @@ def __init__(
203
203
try :
204
204
device = next (self .parameters ()).device
205
205
except AttributeError :
206
- device = torch .device ("cpu" )
206
+ device = getattr ( torch , "get_default_device" , lambda : torch .device ("cpu" ))( )
207
207
208
208
self .register_buffer ("alpha_init" , torch .as_tensor (alpha_init , device = device ))
209
209
self .register_buffer (
Original file line number Diff line number Diff line change @@ -393,7 +393,7 @@ def __init__(
393
393
try :
394
394
device = next (self .parameters ()).device
395
395
except (AttributeError , StopIteration ):
396
- device = torch .device ("cpu" )
396
+ device = getattr ( torch , "get_default_device" , lambda : torch .device ("cpu" ))( )
397
397
398
398
self .register_buffer ("entropy_coef" , torch .tensor (entropy_coef , device = device ))
399
399
if critic_coef is not None :
Original file line number Diff line number Diff line change @@ -319,7 +319,7 @@ def __init__(
319
319
try :
320
320
device = next (self .parameters ()).device
321
321
except AttributeError :
322
- device = torch .device ("cpu" )
322
+ device = getattr ( torch , "get_default_device" , lambda : torch .device ("cpu" ))( )
323
323
324
324
self .register_buffer ("alpha_init" , torch .tensor (alpha_init , device = device ))
325
325
self .register_buffer (
Original file line number Diff line number Diff line change @@ -394,7 +394,7 @@ def __init__(
394
394
try :
395
395
device = next (self .parameters ()).device
396
396
except AttributeError :
397
- device = torch .device ("cpu" )
397
+ device = getattr ( torch , "get_default_device" , lambda : torch .device ("cpu" ))( )
398
398
self .register_buffer ("alpha_init" , torch .tensor (alpha_init , device = device ))
399
399
if bool (min_alpha ) ^ bool (max_alpha ):
400
400
min_alpha = min_alpha if min_alpha else 0.0
@@ -1121,7 +1121,7 @@ def __init__(
1121
1121
try :
1122
1122
device = next (self .parameters ()).device
1123
1123
except AttributeError :
1124
- device = torch .device ("cpu" )
1124
+ device = getattr ( torch , "get_default_device" , lambda : torch .device ("cpu" ))( )
1125
1125
1126
1126
self .register_buffer ("alpha_init" , torch .tensor (alpha_init , device = device ))
1127
1127
if bool (min_alpha ) ^ bool (max_alpha ):
You can’t perform that action at this time.
0 commit comments