Skip to content

Commit 173dadd

Browse files
authored
Merge pull request #426 from ironjr/master
Fixed device mismatch error of Symbolic_KANLayer
2 parents c0d9981 + e29c384 commit 173dadd

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

kan/Symbolic_KANLayer.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -216,9 +216,9 @@ def fix_symbolic(self, i, j, fun_name, x=None, y=None, random=False, a_range=(-1
216216
self.funs[j][i] = fun
217217
self.funs_avoid_singularity[j][i] = fun_avoid_singularity
218218
if random == False:
219-
self.affine.data[j][i] = torch.tensor([1.,0.,1.,0.])
219+
self.affine.data[j][i] = torch.tensor([1.,0.,1.,0.], device=self.device)
220220
else:
221-
self.affine.data[j][i] = torch.rand(4,) * 2 - 1
221+
self.affine.data[j][i] = torch.rand(4, device=self.device) * 2 - 1
222222
return None
223223
else:
224224
#initialize from x & y and fun
@@ -237,9 +237,9 @@ def fix_symbolic(self, i, j, fun_name, x=None, y=None, random=False, a_range=(-1
237237
self.funs[j][i] = fun
238238
self.funs_avoid_singularity[j][i] = fun
239239
if random == False:
240-
self.affine.data[j][i] = torch.tensor([1.,0.,1.,0.])
240+
self.affine.data[j][i] = torch.tensor([1.,0.,1.,0.], device=self.device)
241241
else:
242-
self.affine.data[j][i] = torch.rand(4,) * 2 - 1
242+
self.affine.data[j][i] = torch.rand(4, device=self.device) * 2 - 1
243243
return None
244244

245245
def swap(self, i1, i2, mode='in'):

0 commit comments

Comments
 (0)