Merge pull request #516 from taddyb/master
Some checks failed
docs / Docs (push) Has been cancelled

Added support for saving checkpoints with CUDA device numbers
This commit is contained in:
Ziming Liu 2025-01-19 13:49:39 -05:00 committed by GitHub
commit ecde4ec327
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -534,6 +534,9 @@ class MultKAN(nn.Module):
round = model.round,
device = str(model.device)
)
if dic["device"].isdigit():
dic["device"] = int(model.device)
for i in range (model.depth):
dic[f'symbolic.funs_name.{i}'] = model.symbolic_fun[i].funs_name