diff --git a/kan/MultKAN.py b/kan/MultKAN.py index fe76003..97d66ca 100644 --- a/kan/MultKAN.py +++ b/kan/MultKAN.py @@ -803,13 +803,8 @@ class MultKAN(nn.Module): train_id = np.random.choice(dataset['train_input'].shape[0], batch_size, replace=False) test_id = np.random.choice(dataset['test_input'].shape[0], batch_size_test, replace=False) -<<<<<<< HEAD if _ % grid_update_freq == 0 and _ < stop_grid_update_step and update_grid and _ >= start_grid_update_step: self.update_grid_from_samples(dataset['train_input'][train_id].to(device)) -======= - if _ % grid_update_freq == 0 and _ < stop_grid_update_step and update_grid and _ > start_grid_update_step: - self.update_grid_from_samples(dataset['train_input'][train_id].to(self.device)) ->>>>>>> e5abcb74c6cdc6af70d9665ec5a7b4ccb94ee564 if opt == "LBFGS": optimizer.step(closure) @@ -1512,4 +1507,4 @@ class MultKAN(nn.Module): plt.plot(inputs, outputs, marker="o") return inputs, outputs -KAN = MultKAN \ No newline at end of file +KAN = MultKAN