Merge pull request #426 from ironjr/master

Fixed device mismatch error of Symbolic_KANLayer
This commit is contained in:
Ziming Liu 2024-08-27 23:01:01 -04:00 committed by GitHub
commit 173dadda82
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -216,9 +216,9 @@ class Symbolic_KANLayer(nn.Module):
self.funs[j][i] = fun
self.funs_avoid_singularity[j][i] = fun_avoid_singularity
if random == False:
self.affine.data[j][i] = torch.tensor([1.,0.,1.,0.])
self.affine.data[j][i] = torch.tensor([1.,0.,1.,0.], device=self.device)
else:
self.affine.data[j][i] = torch.rand(4,) * 2 - 1
self.affine.data[j][i] = torch.rand(4, device=self.device) * 2 - 1
return None
else:
#initialize from x & y and fun
@ -237,9 +237,9 @@ class Symbolic_KANLayer(nn.Module):
self.funs[j][i] = fun
self.funs_avoid_singularity[j][i] = fun
if random == False:
self.affine.data[j][i] = torch.tensor([1.,0.,1.,0.])
self.affine.data[j][i] = torch.tensor([1.,0.,1.,0.], device=self.device)
else:
self.affine.data[j][i] = torch.rand(4,) * 2 - 1
self.affine.data[j][i] = torch.rand(4, device=self.device) * 2 - 1
return None
def swap(self, i1, i2, mode='in'):