91 lines
2.5 KiB
Plaintext
91 lines
2.5 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "134e7f9d",
|
|
"metadata": {},
|
|
"source": [
|
|
"# API Demo 9: Videos of KAN training\n",
|
|
"\n",
|
|
"### We have shown one can visualize KAN with the plot() method. If one wants to save the training dynamics of KAN plots, one only needs to pass argument save_video = True to train() method (and set some video related parameters)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "2075ef56",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"train loss: 6.39e-03 | test loss: 6.40e-03 | reg: 7.91e+00 : 100%|██| 50/50 [01:30<00:00, 1.81s/it]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Moviepy - Building video video.mp4.\n",
|
|
"Moviepy - Writing video video.mp4\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
" \r"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Moviepy - Done !\n",
|
|
"Moviepy - video ready video.mp4\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from kan import KAN, create_dataset\n",
|
|
"import torch\n",
|
|
"\n",
|
|
"# create a KAN: 2D inputs, 1D output, and 5 hidden neurons. cubic spline (k=3), 5 grid intervals (grid=5).\n",
|
|
"model = KAN(width=[4,2,1,1], grid=3, k=3, seed=0)\n",
|
|
"f = lambda x: torch.exp((torch.sin(torch.pi*(x[:,[0]]**2+x[:,[1]]**2))+torch.sin(torch.pi*(x[:,[2]]**2+x[:,[3]]**2)))/2)\n",
|
|
"dataset = create_dataset(f, n_var=4, train_num=3000)\n",
|
|
"\n",
|
|
"# train the model\n",
|
|
"#model.train(dataset, opt=\"LBFGS\", steps=20, lamb=1e-3, lamb_entropy=2.);\n",
|
|
"model.train(dataset, opt=\"LBFGS\", steps=50, lamb=5e-5, lamb_entropy=2., save_video=True, beta=10, \n",
|
|
" in_vars=[r'$x_1$', r'$x_2$', r'$x_3$', r'$x_4$'],\n",
|
|
" out_vars=[r'${\\rm exp}({\\rm sin}(x_1^2+x_2^2)+{\\rm sin}(x_3^2+x_4^2))$'],\n",
|
|
" video_name='video', fps=5);"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.9.7"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|