Merge pull request #426 from ironjr/master
Fixed device mismatch error of Symbolic_KANLayer
This commit is contained in:
commit
173dadda82
@ -216,9 +216,9 @@ class Symbolic_KANLayer(nn.Module):
|
||||
self.funs[j][i] = fun
|
||||
self.funs_avoid_singularity[j][i] = fun_avoid_singularity
|
||||
if random == False:
|
||||
self.affine.data[j][i] = torch.tensor([1.,0.,1.,0.])
|
||||
self.affine.data[j][i] = torch.tensor([1.,0.,1.,0.], device=self.device)
|
||||
else:
|
||||
self.affine.data[j][i] = torch.rand(4,) * 2 - 1
|
||||
self.affine.data[j][i] = torch.rand(4, device=self.device) * 2 - 1
|
||||
return None
|
||||
else:
|
||||
#initialize from x & y and fun
|
||||
@ -237,9 +237,9 @@ class Symbolic_KANLayer(nn.Module):
|
||||
self.funs[j][i] = fun
|
||||
self.funs_avoid_singularity[j][i] = fun
|
||||
if random == False:
|
||||
self.affine.data[j][i] = torch.tensor([1.,0.,1.,0.])
|
||||
self.affine.data[j][i] = torch.tensor([1.,0.,1.,0.], device=self.device)
|
||||
else:
|
||||
self.affine.data[j][i] = torch.rand(4,) * 2 - 1
|
||||
self.affine.data[j][i] = torch.rand(4, device=self.device) * 2 - 1
|
||||
return None
|
||||
|
||||
def swap(self, i1, i2, mode='in'):
|
||||
|
Loading…
x
Reference in New Issue
Block a user