From e29c38446c8a663a317c613a054fb60ed973bcf3 Mon Sep 17 00:00:00 2001 From: Jaerin Lee Date: Tue, 27 Aug 2024 09:15:30 +0900 Subject: [PATCH] fixed device mismatch error of symbolic fitting --- kan/Symbolic_KANLayer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/kan/Symbolic_KANLayer.py b/kan/Symbolic_KANLayer.py index 51baf0a..3b19929 100644 --- a/kan/Symbolic_KANLayer.py +++ b/kan/Symbolic_KANLayer.py @@ -216,9 +216,9 @@ class Symbolic_KANLayer(nn.Module): self.funs[j][i] = fun self.funs_avoid_singularity[j][i] = fun_avoid_singularity if random == False: - self.affine.data[j][i] = torch.tensor([1.,0.,1.,0.]) + self.affine.data[j][i] = torch.tensor([1.,0.,1.,0.], device=self.device) else: - self.affine.data[j][i] = torch.rand(4,) * 2 - 1 + self.affine.data[j][i] = torch.rand(4, device=self.device) * 2 - 1 return None else: #initialize from x & y and fun @@ -237,9 +237,9 @@ class Symbolic_KANLayer(nn.Module): self.funs[j][i] = fun self.funs_avoid_singularity[j][i] = fun if random == False: - self.affine.data[j][i] = torch.tensor([1.,0.,1.,0.]) + self.affine.data[j][i] = torch.tensor([1.,0.,1.,0.], device=self.device) else: - self.affine.data[j][i] = torch.rand(4,) * 2 - 1 + self.affine.data[j][i] = torch.rand(4, device=self.device) * 2 - 1 return None def swap(self, i1, i2, mode='in'):