Update MultKAN.py remove conflict

This commit is contained in:
Ziming Liu 2024-07-14 08:22:16 -04:00 committed by GitHub
parent 827eab4f90
commit 985c5eb71d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -803,13 +803,8 @@ class MultKAN(nn.Module):
train_id = np.random.choice(dataset['train_input'].shape[0], batch_size, replace=False)
test_id = np.random.choice(dataset['test_input'].shape[0], batch_size_test, replace=False)
<<<<<<< HEAD
if _ % grid_update_freq == 0 and _ < stop_grid_update_step and update_grid and _ >= start_grid_update_step:
self.update_grid_from_samples(dataset['train_input'][train_id].to(device))
=======
if _ % grid_update_freq == 0 and _ < stop_grid_update_step and update_grid and _ > start_grid_update_step:
self.update_grid_from_samples(dataset['train_input'][train_id].to(self.device))
>>>>>>> e5abcb74c6cdc6af70d9665ec5a7b4ccb94ee564
if opt == "LBFGS":
optimizer.step(closure)
@ -1512,4 +1507,4 @@ class MultKAN(nn.Module):
plt.plot(inputs, outputs, marker="o")
return inputs, outputs
KAN = MultKAN
KAN = MultKAN