2024-08-11 17:13:55 -04:00

159 lines
4.1 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"id": "134e7f9d",
"metadata": {},
"source": [
"# API 9: Videos\n",
"\n",
"We have shown one can visualize KAN with the plot() method. If one wants to save the training dynamics of KAN plots, one only needs to pass argument save_video = True to train() method (and set some video related parameters)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "2075ef56",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"cuda\n",
"checkpoint directory created: ./model\n",
"saving model version 0.0\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"| train_loss: 2.89e-01 | test_loss: 2.96e-01 | reg: 1.31e+01 | : 100%|█| 5/5 [00:09<00:00, 1.94s/it"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"saving model version 0.1\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"from kan import *\n",
"import torch\n",
"\n",
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"print(device)\n",
"\n",
"# create a KAN: 2D inputs, 1D output, and 5 hidden neurons. cubic spline (k=3), 5 grid intervals (grid=5).\n",
"model = KAN(width=[4,2,1,1], grid=3, k=3, seed=1, device=device)\n",
"f = lambda x: torch.exp((torch.sin(torch.pi*(x[:,[0]]**2+x[:,[1]]**2))+torch.sin(torch.pi*(x[:,[2]]**2+x[:,[3]]**2)))/2)\n",
"dataset = create_dataset(f, n_var=4, train_num=3000, device=device)\n",
"\n",
"image_folder = 'video_img'\n",
"\n",
"# train the model\n",
"#model.train(dataset, opt=\"LBFGS\", steps=20, lamb=1e-3, lamb_entropy=2.);\n",
"model.fit(dataset, opt=\"LBFGS\", steps=5, lamb=0.001, lamb_entropy=2., save_fig=True, beta=10, \n",
" in_vars=[r'$x_1$', r'$x_2$', r'$x_3$', r'$x_4$'],\n",
" out_vars=[r'${\\rm exp}({\\rm sin}(x_1^2+x_2^2)+{\\rm sin}(x_3^2+x_4^2))$'],\n",
" img_folder=image_folder);\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "c18245a3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Moviepy - Building video video.mp4.\n",
"Moviepy - Writing video video.mp4\n",
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Moviepy - Done !\n",
"Moviepy - video ready video.mp4\n"
]
}
],
"source": [
"import os\n",
"import numpy as np\n",
"import moviepy.video.io.ImageSequenceClip # moviepy == 1.0.3\n",
"\n",
"video_name='video'\n",
"fps=5\n",
"\n",
"fps = fps\n",
"files = os.listdir(image_folder)\n",
"train_index = []\n",
"for file in files:\n",
" if file[0].isdigit() and file.endswith('.jpg'):\n",
" train_index.append(int(file[:-4]))\n",
"\n",
"train_index = np.sort(train_index)\n",
"\n",
"image_files = [image_folder+'/'+str(train_index[index])+'.jpg' for index in train_index]\n",
"\n",
"clip = moviepy.video.io.ImageSequenceClip.ImageSequenceClip(image_files, fps=fps)\n",
"clip.write_videofile(video_name+'.mp4')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "88d0d737",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}