Skip to content

Commit ecde4ec

Browse files
authored
Merge pull request #516 from taddyb/master
Added support for saving checkpoints with CUDA device numbers
2 parents 0a452a0 + 1625986 commit ecde4ec

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

kan/MultKAN.py

+3
Original file line numberDiff line numberDiff line change
@@ -534,6 +534,9 @@ def saveckpt(self, path='model'):
534534
round = model.round,
535535
device = str(model.device)
536536
)
537+
538+
if dic["device"].isdigit():
539+
dic["device"] = int(model.device)
537540

538541
for i in range (model.depth):
539542
dic[f'symbolic.funs_name.{i}'] = model.symbolic_fun[i].funs_name

0 commit comments

Comments
 (0)