176 lines
4.2 KiB
Plaintext
176 lines
4.2 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "134e7f9d",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Demo 10: Device\n",
|
|
"\n",
|
|
"All other demos have by default used device = 'cpu'. In case we want to use cuda, we should pass the device argument to model and dataset."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "7a4ac1e1-84ba-4bc3-91b6-a776a5e7711c",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"cpu\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from kan import KAN, create_dataset\n",
|
|
"import torch\n",
|
|
"\n",
|
|
"\n",
|
|
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
|
"print(device)\n",
|
|
"\n",
|
|
"#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
|
"device = 'cpu'\n",
|
|
"print(device)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "2075ef56",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"checkpoint directory created: ./model\n",
|
|
"saving model version 0.0\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"| train_loss: 6.83e-01 | test_loss: 7.21e-01 | reg: 1.04e+03 | : 100%|█| 50/50 [00:19<00:00, 2.56it\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"saving model version 0.1\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"model = KAN(width=[4,100,100,100,1], grid=3, k=3, seed=0).to(device)\n",
|
|
"f = lambda x: torch.exp((torch.sin(torch.pi*(x[:,[0]]**2+x[:,[1]]**2))+torch.sin(torch.pi*(x[:,[2]]**2+x[:,[3]]**2)))/2)\n",
|
|
"dataset = create_dataset(f, n_var=4, train_num=1000, device=device)\n",
|
|
"\n",
|
|
"# train the model\n",
|
|
"#model.train(dataset, opt=\"LBFGS\", steps=20, lamb=1e-3, lamb_entropy=2.);\n",
|
|
"model.fit(dataset, opt=\"Adam\", lr=1e-3, steps=50, lamb=1e-3, lamb_entropy=5., update_grid=False);"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "2f182cc1-51bf-4151-a253-a52fe854919e",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "f6f8125e-d26d-4c97-9e5f-988099bb4737",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"cuda\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"device = 'cuda'\n",
|
|
"print(device)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "95017dfa-3a2a-43e0-8b68-fb220ca5abc9",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"checkpoint directory created: ./model\n",
|
|
"saving model version 0.0\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"| train_loss: 6.83e-01 | test_loss: 7.21e-01 | reg: 1.04e+03 | : 100%|█| 50/50 [00:01<00:00, 26.45it\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"saving model version 0.1\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"model = KAN(width=[4,100,100,100,1], grid=3, k=3, seed=0).to(device)\n",
|
|
"f = lambda x: torch.exp((torch.sin(torch.pi*(x[:,[0]]**2+x[:,[1]]**2))+torch.sin(torch.pi*(x[:,[2]]**2+x[:,[3]]**2)))/2)\n",
|
|
"dataset = create_dataset(f, n_var=4, train_num=1000, device=device)\n",
|
|
"\n",
|
|
"# train the model\n",
|
|
"#model.train(dataset, opt=\"LBFGS\", steps=20, lamb=1e-3, lamb_entropy=2.);\n",
|
|
"model.fit(dataset, opt=\"Adam\", lr=1e-3, steps=50, lamb=1e-3, lamb_entropy=5., update_grid=False);"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "8230d562-2635-4adc-b566-06ac679b166a",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.9.16"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|