221 lines
41 KiB
Plaintext
221 lines
41 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "5d904dee",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Example 7: Solving Partial Differential Equation (PDE)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "7d568912",
|
|
"metadata": {},
|
|
"source": [
|
|
"We aim to solve a 2D poisson equation $\\nabla^2 f(x,y) = -2\\pi^2{\\rm sin}(\\pi x){\\rm sin}(\\pi y)$, with boundary condition $f(-1,y)=f(1,y)=f(x,-1)=f(x,1)=0$. The ground truth solution is $f(x,y)={\\rm sin}(\\pi x){\\rm sin}(\\pi y)$."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "0e2bc449",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"checkpoint directory created: ./model\n",
|
|
"saving model version 0.0\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"pde loss: 5.89e+00 | bc loss: 5.27e-02 | l2: 1.59e-02 : 100%|█████████| 1/1 [00:07<00:00, 7.07s/it]\n",
|
|
"pde loss: 1.42e+00 | bc loss: 1.99e-02 | l2: 5.40e-03 : 100%|█████████| 1/1 [00:07<00:00, 7.20s/it]\n",
|
|
"pde loss: 3.76e-01 | bc loss: 1.80e-02 | l2: 4.81e-03 : 100%|█████████| 1/1 [00:07<00:00, 7.52s/it]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from kan import KAN, LBFGS\n",
|
|
"import torch\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"from torch import autograd\n",
|
|
"from tqdm import tqdm\n",
|
|
"\n",
|
|
"dim = 2\n",
|
|
"np_i = 51 # number of interior points (along each dimension)\n",
|
|
"np_b = 51 # number of boundary points (along each dimension)\n",
|
|
"ranges = [-1, 1]\n",
|
|
"\n",
|
|
"\n",
|
|
"def batch_jacobian(func, x, create_graph=False):\n",
|
|
" # x in shape (Batch, Length)\n",
|
|
" def _func_sum(x):\n",
|
|
" return func(x).sum(dim=0)\n",
|
|
" return autograd.functional.jacobian(_func_sum, x, create_graph=create_graph).permute(1,0,2)\n",
|
|
"\n",
|
|
"# define solution\n",
|
|
"sol_fun = lambda x: torch.sin(torch.pi*x[:,[0]])*torch.sin(torch.pi*x[:,[1]])\n",
|
|
"source_fun = lambda x: -2*torch.pi**2 * torch.sin(torch.pi*x[:,[0]])*torch.sin(torch.pi*x[:,[1]])\n",
|
|
"\n",
|
|
"# interior\n",
|
|
"sampling_mode = 'mesh' # 'radnom' or 'mesh'\n",
|
|
"\n",
|
|
"x_mesh = torch.linspace(ranges[0],ranges[1],steps=np_i)\n",
|
|
"y_mesh = torch.linspace(ranges[0],ranges[1],steps=np_i)\n",
|
|
"X, Y = torch.meshgrid(x_mesh, y_mesh, indexing=\"ij\")\n",
|
|
"if sampling_mode == 'mesh':\n",
|
|
" #mesh\n",
|
|
" x_i = torch.stack([X.reshape(-1,), Y.reshape(-1,)]).permute(1,0)\n",
|
|
"else:\n",
|
|
" #random\n",
|
|
" x_i = torch.rand((np_i**2,2))*2-1\n",
|
|
"\n",
|
|
"# boundary, 4 sides\n",
|
|
"helper = lambda X, Y: torch.stack([X.reshape(-1,), Y.reshape(-1,)]).permute(1,0)\n",
|
|
"xb1 = helper(X[0], Y[0])\n",
|
|
"xb2 = helper(X[-1], Y[0])\n",
|
|
"xb3 = helper(X[:,0], Y[:,0])\n",
|
|
"xb4 = helper(X[:,0], Y[:,-1])\n",
|
|
"x_b = torch.cat([xb1, xb2, xb3, xb4], dim=0)\n",
|
|
"\n",
|
|
"alpha = 0.01\n",
|
|
"log = 1\n",
|
|
"\n",
|
|
"\n",
|
|
"grids = [5,10,20]\n",
|
|
"steps = 1\n",
|
|
"\n",
|
|
"pde_losses = []\n",
|
|
"bc_losses = []\n",
|
|
"l2_losses = []\n",
|
|
"\n",
|
|
"for grid in grids:\n",
|
|
" if grid == grids[0]:\n",
|
|
" model = KAN(width=[2,2,1], grid=grid, k=3, seed=3)\n",
|
|
" model = model.speed()\n",
|
|
" else:\n",
|
|
" model.save_act = True\n",
|
|
" model.get_act(x_i)\n",
|
|
" model = model.refine(grid)\n",
|
|
" model = model.speed()\n",
|
|
"\n",
|
|
" def train():\n",
|
|
" optimizer = LBFGS(model.parameters(), lr=1, history_size=10, line_search_fn=\"strong_wolfe\", tolerance_grad=1e-32, tolerance_change=1e-32, tolerance_ys=1e-32)\n",
|
|
"\n",
|
|
" pbar = tqdm(range(steps), desc='description', ncols=100)\n",
|
|
"\n",
|
|
" for _ in pbar:\n",
|
|
" def closure():\n",
|
|
" global pde_loss, bc_loss\n",
|
|
" optimizer.zero_grad()\n",
|
|
" # interior loss\n",
|
|
" sol = sol_fun(x_i)\n",
|
|
" sol_D1_fun = lambda x: batch_jacobian(model, x, create_graph=True)[:,0,:]\n",
|
|
" sol_D1 = sol_D1_fun(x_i)\n",
|
|
" sol_D2 = batch_jacobian(sol_D1_fun, x_i, create_graph=True)[:,:,:]\n",
|
|
" lap = torch.sum(torch.diagonal(sol_D2, dim1=1, dim2=2), dim=1, keepdim=True)\n",
|
|
" source = source_fun(x_i)\n",
|
|
" pde_loss = torch.mean((lap - source)**2)\n",
|
|
"\n",
|
|
" # boundary loss\n",
|
|
" bc_true = sol_fun(x_b)\n",
|
|
" bc_pred = model(x_b)\n",
|
|
" bc_loss = torch.mean((bc_pred-bc_true)**2)\n",
|
|
"\n",
|
|
" loss = alpha * pde_loss + bc_loss\n",
|
|
" loss.backward()\n",
|
|
" return loss\n",
|
|
"\n",
|
|
" if _ % 5 == 0 and _ < 20:\n",
|
|
" model.update_grid_from_samples(x_i)\n",
|
|
"\n",
|
|
" optimizer.step(closure)\n",
|
|
" sol = sol_fun(x_i)\n",
|
|
" loss = alpha * pde_loss + bc_loss\n",
|
|
" l2 = torch.mean((model(x_i) - sol)**2)\n",
|
|
"\n",
|
|
" if _ % log == 0:\n",
|
|
" pbar.set_description(\"pde loss: %.2e | bc loss: %.2e | l2: %.2e \" % (pde_loss.cpu().detach().numpy(), bc_loss.cpu().detach().numpy(), l2.detach().numpy()))\n",
|
|
"\n",
|
|
" pde_losses.append(pde_loss.detach().numpy())\n",
|
|
" bc_losses.append(bc_loss.detach().numpy())\n",
|
|
" l2_losses.append(l2.detach().numpy())\n",
|
|
" \n",
|
|
" \n",
|
|
" train()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "dcbfa677",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"<matplotlib.legend.Legend at 0x7faaf9228070>"
|
|
]
|
|
},
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
},
|
|
{
|
|
"data": {
|
|
"image/png": "\n",
|
|
"text/plain": [
|
|
"<Figure size 640x480 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"plt.plot(pde_losses, marker='o')\n",
|
|
"plt.plot(bc_losses, marker='o')\n",
|
|
"plt.plot(l2_losses, marker='o')\n",
|
|
"plt.yscale('log')\n",
|
|
"plt.xlabel('steps')\n",
|
|
"plt.legend(['PDE loss', 'BC loss', 'L2 squared'])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "bce40477",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.9.7"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|