Update KANLayer.py

This commit is contained in:
Ziming Liu 2024-07-14 08:45:40 -04:00 committed by GitHub
parent d0ea020df6
commit 5e18f1b419
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -129,7 +129,7 @@ class KANLayer(nn.Module):
else:
mask = 1.
scale_base = scale_base.to(device)
#scale_base = scale_base.to(device)
self.scale_base = torch.nn.Parameter(torch.ones(in_dim, out_dim, device=device) * scale_base * mask).requires_grad_(sb_trainable) # make scale trainable
#else:
#self.scale_base = torch.nn.Parameter(scale_base.to(device)).requires_grad_(sb_trainable)