@@ -216,9 +216,9 @@ def fix_symbolic(self, i, j, fun_name, x=None, y=None, random=False, a_range=(-1
216
216
self .funs [j ][i ] = fun
217
217
self .funs_avoid_singularity [j ][i ] = fun_avoid_singularity
218
218
if random == False :
219
- self .affine .data [j ][i ] = torch .tensor ([1. ,0. ,1. ,0. ])
219
+ self .affine .data [j ][i ] = torch .tensor ([1. ,0. ,1. ,0. ], device = self . device )
220
220
else :
221
- self .affine .data [j ][i ] = torch .rand (4 ,) * 2 - 1
221
+ self .affine .data [j ][i ] = torch .rand (4 , device = self . device ) * 2 - 1
222
222
return None
223
223
else :
224
224
#initialize from x & y and fun
@@ -237,9 +237,9 @@ def fix_symbolic(self, i, j, fun_name, x=None, y=None, random=False, a_range=(-1
237
237
self .funs [j ][i ] = fun
238
238
self .funs_avoid_singularity [j ][i ] = fun
239
239
if random == False :
240
- self .affine .data [j ][i ] = torch .tensor ([1. ,0. ,1. ,0. ])
240
+ self .affine .data [j ][i ] = torch .tensor ([1. ,0. ,1. ,0. ], device = self . device )
241
241
else :
242
- self .affine .data [j ][i ] = torch .rand (4 ,) * 2 - 1
242
+ self .affine .data [j ][i ] = torch .rand (4 , device = self . device ) * 2 - 1
243
243
return None
244
244
245
245
def swap (self , i1 , i2 , mode = 'in' ):
0 commit comments