224 lines
37 KiB
Plaintext
224 lines
37 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "095b0666",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Example 11: Encouraging linearity\n",
|
|
"\n",
|
|
"In cases where we don't know how deep we should set KANs to be, one strategy is to try from small models, grudually making models wider/deeper until we find the minimal model that performs the task quite well. Another strategy is to start from a big enough model and prune it down. This jupyter notebook demonstrates cases where we go for the second strategy. Besides sparsity along width, we also want activation functions to be linear ('shortcut' along depth)."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "ef047a0f",
|
|
"metadata": {},
|
|
"source": [
|
|
"There are two relevant tricks: \n",
|
|
"\n",
|
|
"(1) set the base function 'base_fun' to be linear; \n",
|
|
"\n",
|
|
"(2) penalize spline coefficients. When spline coefficients are zero, the activation function is linear."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "91301ca0",
|
|
"metadata": {},
|
|
"source": [
|
|
"$f(x)={\\rm sin}(\\pi x)$. Although we know a [1,1] KAN suffices, we suppose we don't know that and use a [1,1,1,1] KAN instead."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "77f9e16d",
|
|
"metadata": {},
|
|
"source": [
|
|
"without trick"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "c881665b",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"cuda\n",
|
|
"checkpoint directory created: ./model\n",
|
|
"saving model version 0.0\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"| train_loss: 3.74e-04 | test_loss: 3.84e-04 | reg: 8.88e+00 | : 100%|█| 20/20 [00:05<00:00, 3.79it"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"saving model version 0.1\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from kan import *\n",
|
|
"\n",
|
|
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
|
"print(device)\n",
|
|
"\n",
|
|
"# create dataset f(x,y) = sin(pi*x). This task can be achieved by a [1,1] KAN\n",
|
|
"f = lambda x: torch.sin(torch.pi*x[:,[0]])\n",
|
|
"dataset = create_dataset(f, n_var=1, device=device)\n",
|
|
"\n",
|
|
"model = KAN(width=[1,1,1,1], grid=5, k=3, seed=0, noise_scale=0.1, device=device)\n",
|
|
"\n",
|
|
"model.fit(dataset, opt=\"LBFGS\", steps=20);"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "201ceacf",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"image/png": "",
|
|
"text/plain": [
|
|
"<Figure size 500x600 with 7 Axes>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"model.plot()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "13c725a5",
|
|
"metadata": {},
|
|
"source": [
|
|
"with tricks"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "a22ffff3",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"checkpoint directory created: ./model\n",
|
|
"saving model version 0.0\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"| train_loss: 8.89e-03 | test_loss: 8.40e-03 | reg: 1.83e+01 | : 100%|█| 20/20 [00:04<00:00, 4.20it"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"saving model version 0.1\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from kan import *\n",
|
|
"\n",
|
|
"# create dataset f(x,y) = sin(pi*x). This task can be achieved by a [1,1] KAN\n",
|
|
"f = lambda x: torch.sin(torch.pi*x[:,[0]])\n",
|
|
"dataset = create_dataset(f, n_var=1, device=device)\n",
|
|
"\n",
|
|
"# set base_fun to be linear\n",
|
|
"model = KAN(width=[1,1,1,1], grid=5, k=3, seed=0, base_fun='identity', noise_scale=0.1, device=device)\n",
|
|
"\n",
|
|
"# penality spline coefficients\n",
|
|
"model.fit(dataset, opt=\"LBFGS\", steps=20, lamb=1e-4, lamb_coef=10.0);"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "c82c8db5",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"image/png": "",
|
|
"text/plain": [
|
|
"<Figure size 500x600 with 7 Axes>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"model.plot(beta=10)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "e3c92b0d",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.9.16"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|