Compare commits

...

4 Commits

Author SHA1 Message Date
Ziming Liu
ecde4ec327
Merge pull request from taddyb/master
Added support for saving checkpoints with CUDA device numbers
2025-01-19 13:49:39 -05:00
Tadd Bindas
162598654b added change to fix device num checkpoint 2025-01-19 11:47:28 -06:00
Ziming Liu
0a452a0739
Merge pull request from Timoniche/patch-1
train -> fit update in the new KAN version
2025-01-12 17:00:03 -05:00
Dmitrii Dulaev
406d8904f2
train -> fit update in the new KAN version 2025-01-13 00:23:39 +03:00
2 changed files with 6 additions and 3 deletions

@ -208,7 +208,7 @@
],
"source": [
"# train the model\n",
"model.train(dataset, opt=\"LBFGS\", steps=20, lamb=0.01, lamb_entropy=10.);"
"model.fit(dataset, opt=\"LBFGS\", steps=20, lamb=0.01, lamb_entropy=10.);"
]
},
{
@ -324,7 +324,7 @@
}
],
"source": [
"model.train(dataset, opt=\"LBFGS\", steps=50);"
"model.fit(dataset, opt=\"LBFGS\", steps=50);"
]
},
{
@ -409,7 +409,7 @@
}
],
"source": [
"model.train(dataset, opt=\"LBFGS\", steps=50);"
"model.fit(dataset, opt=\"LBFGS\", steps=50);"
]
},
{

@ -534,6 +534,9 @@ class MultKAN(nn.Module):
round = model.round,
device = str(model.device)
)
if dic["device"].isdigit():
dic["device"] = int(model.device)
for i in range (model.depth):
dic[f'symbolic.funs_name.{i}'] = model.symbolic_fun[i].funs_name