2024-04-29 12:35:18 -04:00
{
"cells": [
{
"cell_type": "markdown",
"id": "5d904dee",
"metadata": {},
"source": [
"# Example 8: KANs' Scaling Laws"
]
},
{
"cell_type": "markdown",
"id": "6465ec94",
"metadata": {},
"source": [
"In this example, we show KAN's scaling laws (wrt model params and data size)"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "a1c25e8a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"data_size=100\n",
"grid_size=5\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-07-13 22:17:48 -04:00
"train loss: 5.45e-03 | test loss: 7.44e-03 | reg: 0.00e+00 : 100%|██| 50/50 [00:07<00:00, 6.28it/s]\n"
2024-04-29 12:35:18 -04:00
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"grid_size=10\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-07-13 22:17:48 -04:00
"train loss: 2.33e-04 | test loss: 1.38e-03 | reg: 0.00e+00 : 100%|██| 50/50 [00:04<00:00, 11.46it/s]\n"
2024-04-29 12:35:18 -04:00
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"grid_size=20\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-07-13 22:17:48 -04:00
"train loss: 3.80e-05 | test loss: 7.60e-03 | reg: 0.00e+00 : 100%|██| 50/50 [00:03<00:00, 15.35it/s]\n"
2024-04-29 12:35:18 -04:00
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"grid_size=50\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-07-13 22:17:48 -04:00
"train loss: 5.61e-01 | test loss: 1.51e+00 | reg: 0.00e+00 : 100%|██| 50/50 [00:07<00:00, 6.91it/s]\n"
2024-04-29 12:35:18 -04:00
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"grid_size=100\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-07-13 22:17:48 -04:00
"train loss: 7.86e-02 | test loss: 1.19e+00 | reg: 0.00e+00 : 100%|██| 50/50 [00:07<00:00, 6.87it/s]\n"
2024-04-29 12:35:18 -04:00
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"data_size=300\n",
"grid_size=5\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-07-13 22:17:48 -04:00
"train loss: 5.68e-03 | test loss: 6.18e-03 | reg: 0.00e+00 : 100%|██| 50/50 [00:04<00:00, 12.46it/s]\n"
2024-04-29 12:35:18 -04:00
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"grid_size=10\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-07-13 22:17:48 -04:00
"train loss: 2.85e-04 | test loss: 3.33e-04 | reg: 0.00e+00 : 100%|██| 50/50 [00:04<00:00, 11.53it/s]\n"
2024-04-29 12:35:18 -04:00
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"grid_size=20\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-07-13 22:17:48 -04:00
"train loss: 2.33e-05 | test loss: 3.69e-04 | reg: 0.00e+00 : 100%|██| 50/50 [00:04<00:00, 12.46it/s]\n"
2024-04-29 12:35:18 -04:00
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"grid_size=50\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-07-13 22:17:48 -04:00
"train loss: 3.59e-06 | test loss: 4.51e-03 | reg: 0.00e+00 : 100%|██| 50/50 [00:07<00:00, 6.43it/s]\n"
2024-04-29 12:35:18 -04:00
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"grid_size=100\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-07-13 22:17:48 -04:00
"train loss: 3.19e-06 | test loss: 3.36e-02 | reg: 0.00e+00 : 100%|██| 50/50 [00:08<00:00, 6.25it/s]\n"
2024-04-29 12:35:18 -04:00
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"data_size=1000\n",
"grid_size=5\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-07-13 22:17:48 -04:00
"train loss: 6.09e-03 | test loss: 6.24e-03 | reg: 0.00e+00 : 100%|██| 50/50 [00:06<00:00, 8.19it/s]\n"
2024-04-29 12:35:18 -04:00
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"grid_size=10\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-07-13 22:17:48 -04:00
"train loss: 2.75e-04 | test loss: 3.09e-04 | reg: 0.00e+00 : 100%|██| 50/50 [00:05<00:00, 9.25it/s]\n"
2024-04-29 12:35:18 -04:00
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"grid_size=20\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-07-13 22:17:48 -04:00
"train loss: 1.70e-05 | test loss: 2.00e-05 | reg: 0.00e+00 : 100%|██| 50/50 [00:07<00:00, 6.64it/s]\n"
2024-04-29 12:35:18 -04:00
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"grid_size=50\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-07-13 22:17:48 -04:00
"train loss: 1.42e-06 | test loss: 1.63e-06 | reg: 0.00e+00 : 100%|██| 50/50 [00:12<00:00, 4.10it/s]\n"
2024-04-29 12:35:18 -04:00
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"grid_size=100\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-07-13 22:17:48 -04:00
"train loss: 9.83e-07 | test loss: 1.61e-06 | reg: 0.00e+00 : 100%|██| 50/50 [00:10<00:00, 4.91it/s]\n"
2024-04-29 12:35:18 -04:00
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"data_size=3000\n",
"grid_size=5\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-07-13 22:17:48 -04:00
"train loss: 6.09e-03 | test loss: 6.01e-03 | reg: 0.00e+00 : 100%|██| 50/50 [00:13<00:00, 3.62it/s]\n"
2024-04-29 12:35:18 -04:00
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"grid_size=10\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-07-13 22:17:48 -04:00
"train loss: 3.09e-04 | test loss: 3.20e-04 | reg: 0.00e+00 : 100%|██| 50/50 [00:22<00:00, 2.20it/s]\n"
2024-04-29 12:35:18 -04:00
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"grid_size=20\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-07-13 22:17:48 -04:00
"train loss: 1.76e-05 | test loss: 1.92e-05 | reg: 0.00e+00 : 100%|██| 50/50 [00:22<00:00, 2.20it/s]\n"
2024-04-29 12:35:18 -04:00
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"grid_size=50\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-07-13 22:17:48 -04:00
"train loss: 1.41e-06 | test loss: 8.81e-05 | reg: 0.00e+00 : 100%|██| 50/50 [00:40<00:00, 1.22it/s]\n"
2024-04-29 12:35:18 -04:00
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"grid_size=100\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-07-13 22:17:48 -04:00
"train loss: 1.29e-06 | test loss: 6.64e-04 | reg: 0.00e+00 : 100%|██| 50/50 [00:31<00:00, 1.56it/s]\n"
2024-04-29 12:35:18 -04:00
]
}
],
"source": [
"from kan import *\n",
"\n",
"# initialize KAN with G=3\n",
"model = KAN(width=[2,1,1], grid=3, k=3)\n",
"\n",
"data_sizes = np.array([100,300,1000,3000])\n",
"grids = np.array([5,10,20,50,100])\n",
"\n",
"train_losses = np.zeros((data_sizes.shape[0], grids.shape[0]))\n",
"test_losses = np.zeros((data_sizes.shape[0], grids.shape[0]))\n",
"steps = 50\n",
"k = 3\n",
"\n",
"for j in range(data_sizes.shape[0]):\n",
" data_size = data_sizes[j]\n",
" print(f'data_size={data_size}')\n",
" # create dataset\n",
" f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)\n",
" dataset = create_dataset(f, n_var=2, train_num=data_size)\n",
" \n",
" for i in range(grids.shape[0]):\n",
" print(f'grid_size={grids[i]}')\n",
" if i == 0:\n",
" model = KAN(width=[2,1,1], grid=grids[i], k=k)\n",
2024-07-13 22:17:48 -04:00
" model.speed()\n",
2024-04-29 12:35:18 -04:00
" if i != 0:\n",
2024-07-13 22:17:48 -04:00
" model.save_plot_data = True\n",
" model.get_act(dataset)\n",
" model = model.refine(grids[i])\n",
" model.speed()\n",
" results = model.fit(dataset, opt=\"LBFGS\", steps=steps, stop_grid_update_step = 30)\n",
2024-04-29 12:35:18 -04:00
" train_losses[j][i] = results['train_loss'][-1]\n",
" test_losses[j][i] = results['test_loss'][-1]\n"
]
},
{
"cell_type": "markdown",
"id": "6be8ba55",
"metadata": {},
"source": [
"Fix data size, study model (grid) size scaling. Roughly display $N^{-4}$ scaling."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "e05289dd",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Text(0.5, 0, 'grid size')"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
2024-07-13 22:17:48 -04:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAkIAAAG2CAYAAACTTOmSAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAACWGElEQVR4nOzdd1yV9d/H8dcZ7A2yVEDcAk4c4bYUzZwNbbm1rH63mmVlauYos2nDkaaVWebPn2FmaqiJEwe4xYkoDoY42Oucc91/IEcRB+rBw/g87wf3D67re67zOQacN9/rO1SKoigIIYQQQlRCanMXIIQQQghhLhKEhBBCCFFpSRASQgghRKUlQUgIIYQQlZYEISGEEEJUWhKEhBBCCFFpSRASQgghRKUlQUgIIYQQlZYEISGEEEJUWlpzF1DWGQwGLl68iIODAyqVytzlCCGEEKIEFEUhPT2dqlWrolbfud9HgtA9XLx4ER8fH3OXIYQQQogHcO7cOapXr37H85UiCK1evZq33noLg8HAu+++y/Dhw0v8WAcHB6DgH9LR0bG0ShRCCCGECaWlpeHj42N8H78TVUXfdFWn0xEQEMCmTZtwdHSkWbNm7Nq1C1dX1xI9Pi0tDScnJ1JTUyUICSGEEOVESd+/K/xg6d27dxMYGEi1atVwcHCge/fu/PPPP+YuSwghhBBlQJkPQlu2bKFnz55UrVoVlUrFypUri7WZM2cO/v7+WFtbExwczNatW43nLl68SLVq1YxfV69enQsXLjyK0oUQQghRxpX5IJSZmUnjxo357rvvbnt+2bJljBkzhgkTJrBv3z7atWvHk08+SXx8PFAwavxWMvtLCCGEEFAOBks/+eSTPPnkk3c8/+WXXzJs2DDjAOhZs2bxzz//MHfuXGbMmEG1atWK9ACdP3+eVq1a3fF6ubm55ObmGr9OS0szwasQQgghRFlU5nuE7iYvL4/o6GhCQ0OLHA8NDWXHjh0AtGzZksOHD3PhwgXS09NZs2YNXbt2veM1Z8yYgZOTk/FDps4LIYQQFVe5DkIpKSno9Xo8PT2LHPf09CQxMREArVbLF198QadOnWjatCnjxo3Dzc3tjtccP348qampxo9z586V6msQQgghhPmU+VtjJXHrmB9FUYoc69WrF7169SrRtaysrLCysjJpfUIIIYQom8p1j1CVKlXQaDTG3p9CycnJxXqJhBBCCCFuVa6DkKWlJcHBwaxfv77I8fXr19O6dWszVSWEEKIs0BsUImMv8+f+C0TGXkZvqNDrB4sHVOZvjWVkZHDq1Cnj13Fxcezfvx9XV1d8fX0ZO3YsAwYMoHnz5oSEhDB//nzi4+MZOXKkGasWQghhTusOJzDlrxgSUnOMx7ydrJncM4BuQd5mrEyUNWV+i42IiAg6depU7PigQYP46aefgIIFFT/99FMSEhIICgriq6++on379iZ5ftliQwghypd1hxN4bclebn1zKxw5OvflZhKGKoGSvn+X+SBkbhKEhBCi/NAbFNrO/LdIT9DNVICXkzXb3n0cjVoW163IZK8xIYQQlc7uuCt3DEEACpCQmsPuuCuPrihRpkkQEkIIUWEkp985BN3s3NWsUq5ElBdlfrC0EEIIUVIeDtYlavf+H4dYeyiBroFePNHAE3cHWT+uspIgJIQQosJo6e+Ks40F17Lz79hGo1ahMyhsOn6JTccvoVIdormfC6EBXnQN9MLXzfYRVizMTYKQEEKICmNf/FUycnW3PVc4NPq7F5pS28Oe8Jgk/jmSyMHzqew5c5U9Z67y0Zqj1PdyIDTQi9AATwKrOhbbvUBULDJr7B5k1pgQQpQPZ1Iy6TtnO1ez8mlU3YnktFwS0+69jtDFa9msj0kiPCaRnaevFFl4sZqzDaGBnoQGeNGihgtajQytLS9k+ryJSBASQoiy72pmHs/M3cHplEwaVXfi91cew0qrYXfcFZLTc/BwsKalv+s9p8xfy8rj32PJhB9JYvOJS2Tn643nXGwteKKBJ6EBnrSv6461haa0X5Z4CBKETESCkBBClG25Oj0DFu5md9wVqjnbEPZG6xIPmr6b7Dw9206l8M+RRDYeTeJq1o1xRzYWGtrXrULXQC8er++Bs63lQz+fMC0JQiYiQUgIIcouRVEY+98DhO27gIOVlv+91pp6Xg4mfx6d3kDU2av8cySR8CNJXLiWbTynUato5e9K10AvugR4UtXZxuTPL+6fBCETkSAkhBBl15frT/DNxpNo1Cp+GtKCdnXcS/05FUXhyMU0wmOSCD+SyLHE9CLnG1V3IjTAk66BXtT2sJfB1mYiQchEJAgJIUTZtCL6PG8tPwDAJ0835PmWvmap4+zlTMKPFAy2jjp7lZvfVf2r2BEa4ElooBdNfZxRy7Yej4wEIRORICSEEGVPZOxlBi7aRb5e4bWOtXi3W31zlwTApfRcNh5NIjwmiW0nU8jTG4zn3B2s6BJQMNi6da0qWGplBlppkiBkIhKEhBCibDmVnMHTc7aTlqPjqUbefPt80zLZ05KRq2Pz8Uv8cySRTceSSb9pfSMHKy0d63vQNdCTDnXdcbC2MGOlFZMEIRORICSEEGVHSkYufeds59yVbJr5OvPbiMfKxTT2PJ2ByNOXCT+SyPqYJJLTc43nLDVqWtd2o2ugF51luw+TkSBkIhKEhBCibMjJ1/PCgp3si7+Gr6stYa+3xs2+/IUGg0Fh//lrBeOKjiRyOiXTeE6lgmBfF+MijjWq2Jmx0vJNgpCJSBASQgjzMxgU/rN0L2sOJeJkY8Efr7emlru9uct6aIqiEHspg3+uh6ID51OLnK/n6UBoYMEMNNnu4/5IEDIRCUJCCGF+n6w9xrzNsVhoVPwyrBWP1XQzd0mlIiH1+nYfR5LYefoyulu2++gS4ElooCcta7jKdh/3IEHIRCQICSGEeS3dHc/4Pw4B8FX/xvRtWt3MFT0aqVn5/Hs8iX8OF9/uw9nWgifqF4Si9nXcsbEs++OkHjUJQiYiQUgIIcxny4lLDPlpD3qDwpjOdRjTua65SzKLnHw9W0+mEH4kkQ23bPdhbaGmfR13ugZ68UQD2e6jkAQhE5EgJIQQ5nEsMY1n50aSkavj6abV+KJfYxkjw43tPsKPJPHPkcTbbvcRGuBJl0AvqlXi7T4kCJmIBCEhhHj0ktNy6DN7OxdTc2jl78riYS2x0srtn1spikJMQpoxFN263UfDate3+wjyok4l2+5DgpCJSBASQohHKytPR//vd3LoQio13e3447XWcrunhOIvZxEeU7Ax7J6zV4ps91HDzZbQQC+6BnrS1MelTC5CaUoShExEgpAQQjw6eoPCq79EseFoMq52loS93ho/N1lL50GkZBRs9/HPkSS2nUohT3dju48q9lbGGWita7lVyN42CUImIkFICCEenSl/HeHH7Wew1KpZOqIVwX6u5i6pQijc7iM8JpF/jyWTnnNjuw97Ky0d6xUMtu5Yr+Js9yFByEQkCAkhxKPx0/Y4PvwrBoDvXmxKj0ZVzVxRxZSnM7Dz9GXjLbSbt/uw0KhoXatKwXYfAR54OFibsdKHI0HIRCQICSFE6dsQk8Qrv0RhUOCdbvV4vWNtc5dUKRgMCgfOXyM8pmCw9elLRbf7aObrQmiAJ6GBXviXs+0+JAiZiAQhIYQoXYcvpPLcvEiy8/U838KHGU83rFSzm8qSU8kZ/HMkkfCYJA6cu1bkXF1Pe0IDvOga6EVQtbK/3YcEIRORICSEEKXn4rVs+szeTnJ6Lu3qVGHR4BZYyNYRZUJCajYbYgoGW9+63UdVJ2tCA70IDfCkhb9rmfxvJkHIRCQICSFE6UjPyee5eZEcS0ynrqc9/3utNY4VZKBuRZOalc+m48n8cySRiONFt/twsrHgiQYehAZ40aFu2dnuQ4KQiUgQEkII09PpDQz9OYotJy7h7mBF2Outqe5ia+6yRAnk5OvZdjKF8JhENhxN5kpmnvGctYWadoXbfdT3wMXuzus/6Q0Ku+OukJyeg4eDNS39XdGYcG0jCUImIkFICCFMS1EUJqw8zG+74rGx0LDs1cdoVN3Z3GW
2024-04-29 12:35:18 -04:00
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"for i in range(data_sizes.shape[0]):\n",
" plt.plot(grids, train_losses[i,:], marker=\"o\")\n",
"plt.xscale('log')\n",
"plt.yscale('log')\n",
"plt.plot(np.array([5,100]), 0.1*np.array([3,100])**(-4.), ls=\"--\", color=\"black\")\n",
"plt.legend([f'data={data_sizes[i]}' for i in range(data_sizes.shape[0])]+[r'$N^{-4}$'])\n",
"plt.ylabel('train RMSE')\n",
"plt.xlabel('grid size')"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "6d15cc9e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Text(0.5, 0, 'grid size')"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
2024-07-13 22:17:48 -04:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAkIAAAG1CAYAAAAV2Js8AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAACVNUlEQVR4nOzdeVxU9f7H8dcw7LIJyqKyuKKAmru4r7iUpVm2uqRp1r0/t7LSUnMpLc1suWqaZV1bTL2W5V65ixuKG64I4sIiiOzrzPn9MTA4goo6MAN8nvfB4zrnfOfMZxKZN9/zXVSKoigIIYQQQlRBFqYuQAghhBDCVCQICSGEEKLKkiAkhBBCiCpLgpAQQgghqiwJQkIIIYSosiQICSGEEKLKkiAkhBBCiCpLgpAQQgghqixLUxdg7rRaLdevX8fR0RGVSmXqcoQQQghRCoqikJaWRq1atbCwuHu/jwSh+7h+/Tre3t6mLkMIIYQQD+HKlSvUqVPnruclCN2Ho6MjoPsP6eTkZOJqhBBCCFEaqampeHt76z/H70aC0H0U3g5zcnKSICSEEEJUMPcb1iKDpYUQQghRZUkQEkIIIUSVJUFICCGEEFWWBCEhhBBCVFkShIQQQghRZUkQEkIIIUSVJUFICCGEEFWWBCEhhBBCVFmyoKIQQohKSaNVOBR1k4S0bNwdbWlb1xW1hewZKQxJEBJCCFHpbDkVy8w/IohNydYf83K2ZcaAAPoGeZmwMmFu5NaYEEKISmXLqVheX3XUIAQBxKVk8/qqo2w5FWuiyoQ5kiAkhBCi0tBoFWb+EYFSwrnCYzP/iECjLamFqIrk1pgQQohK41DUzWI9QbdTgNiUbN78NZwG7g7YWqmxsVJjZ6XG1soCW0s1dtYFf7ZS679uP28h44yMwlzGcEkQEkIIUWkkpN09BN3ut/DrD/0a1pYW2FpaFAQmNbaWamyt1dhaWhiEJjtrNTaWdwSpgj/bWFkUHCt+3rbw+VZqLNWV88aNOY3hqhJB6M8//+TNN99Eq9Xyzjvv8Oqrr5q6JCGEEGUgNTu/VO36BHpQ3d6a7DwNWXkasvO0ZOdpCr60BceKHudqtPrn5uZryc3Xlvq1HoWlhcogHJUcmtTY3fHY1iBoWRSEL7VBj5f+vKUaW2sLrNUWqFRl3yNTOIbrzpuThWO4lrzcslzDkEpRlEp9ozQ/P5+AgAB27NiBk5MTLVu25ODBg7i6upbq+ampqTg7O5OSkoKTk1MZVyuEEOJh3MrM5ZOt5/jpYMw926kAT2db9r7T44Fuw2i0Cjn5GrJyNWTna3X/n6cpOFYQom47n11wPrvwfL5hsNK1uzOAFYUyU1CpKApKlhYFvVxFvVu2lncPWoY9W3eEr9vaWqsteOLLPcSl5pRcAw/391OS0n5+V/oeoUOHDhEYGEjt2rUB6N+/P1u3buWFF14wcWVCCCEelaIorDt6jbmbzpCUkQtAu7quHIy6iQoMeh0KP1ZnDAh44A9ZtYUKe2tL7K3L/mNTURRy8rUl9k5l5WnIMThW9Oec24LU7ef1QaswqN12LCtPQ+G4cUWBrIJjplI4hutQ1E2C67uVy2uafRDavXs38+fPJywsjNjYWNavX8/AgQMN2ixevJj58+cTGxtLYGAgixYtonPnzgBcv35dH4IA6tSpw7Vr18rzLQghhCgD5+LSmPbbKQ5F3wSgobsDswcG0b6eW4ljUDwryDpCKlXR7bCypigKeRqlqMeqoPcqS9+jpeu9yik4b9Ajll8QyvS9Wxqy7ujhMgxqhrcY76W0Y72MweyDUEZGBs2bN+eVV15h8ODBxc6vXr2aCRMmsHjxYjp27MjXX39Nv379iIiIwMfHh5Lu/JXHPVAhhBBlIyMnn8//vsCKvVFotAp2VmrG92rIyI51sbbUDS7uG+RF7wBPs5iVZM5UKhXWliqsLS1wsrUq89fbdzGRl745eN927o62ZV5LIbMPQv369aNfv353Pb9w4UJGjRqlHwC9aNEitm7dypIlS5g7dy61a9c26AG6evUq7dq1u+v1cnJyyMkpuneZmppqhHchhBDiUSmKwtbTcQY9PX0CPZg+IJDaLnbF2qstVOV2e0WUTvt6bng52xKXkl3iWk+FY4Ta1i3dOF5jqNDz8nJzcwkLCyMkJMTgeEhICPv37wegbdu2nDp1imvXrpGWlsamTZvo06fPXa85d+5cnJ2d9V/e3t5l+h6EEELc3+WkDF5ZeZixBStGe7va8e2I1nw9tHWJIUiYJ7WFihkDAoCiMVuFHmUM16Oo0EEoMTERjUaDh4eHwXEPDw/i4uIAsLS05NNPP6V79+60aNGCyZMn4+Z2998QpkyZQkpKiv7rypUrZfoehBBC3F12noYv/r5AyGe72XnuBlZqFf/XowHbJnSlR2OP+19AmJ2+QV4sebklns6Gt788nW3Lfeo8VIBbY6Vx55gfRVEMjj355JM8+eSTpbqWjY0NNjY2Rq1PCCHEg9tz4QbTfz9NVGIGAB0buDHrqSDq13QwcWXiUZnTGK4KHYRq1KiBWq3W9/4USkhIKNZLJIQQomKIS8lm9sYINp7QbY5a09GGaU8EMKCZl0x2qUTMZQxXhQ5C1tbWtGrViu3btzNo0CD98e3bt/PUU0+ZsDIhhBAPKl+jZeX+aD7bfp6MXA0WKhjewY+JvRuVy4wmUTWZfRBKT0/n4sWL+sdRUVGEh4fj6uqKj48PkyZNYujQobRu3Zrg4GCWLVtGTEwMY8eONWHVQgghHkTY5Zu8t/4UZ+PSAGjh48KcgUEE1nI2cWWisjP7IHTkyBG6d++ufzxp0iQAhg8fzsqVK3nuuedISkpi1qxZxMbGEhQUxKZNm/D19TVVyUIIIUopOSOXeZvPsvqIbmKKs50V7/ZrzHOtvWWXd1EuKv1eY49K9hoTQgjj02oV1oRdYd7msyRn5gEwpHUd3unbGDcHmbAiHp3sNSaEEMIsRVxP5f3fTnI05hYAjT0dmTMwiNZ+5beInhCFJAgJIYQoF2nZeXy2/QIr90ehVaCatZqJvRsxvIMfVuoKvaydqMAkCAkhhChTiqKw8WQss/+MID5Vt4XR4029eP+JJng5y6rQwrQkCAkhhCgzl26kM2PDafZcSATAz82emU8F0bVRTRNXJoSOBCEhhBBGl52nYfGOiyzddYlcjRZrSwve6FafsV3rY2ulNnV5QuhJEBJCCGFUO84lMOP308TczASgS6OazHoyEL8a1UxcmRDFSRASQghhFNdvZTHrjwi2nNZte+TpZMuMAQH0DfKUrTGE2ZIgJIQQ4pHkabR8uzeKz/++QGauBrWFipEd/RjfqxEONvIxI8ybfIcKIYR4aIeibvL+byc5H58OQGvf6swZFERjT1mAVlQMEoSEEEI8sMT0HOZuOsu6o1cBcK1mzZR+jRncso5sjSEqFAlCQgghSk2jVfj5UAyfbDlLanY+AC+09eHtPv5Ur2Zt4uqEeHAShIQQQpTKqWspvPfbKY5fuQVAgJcTcwYF0dKnumkLE+IRSBASQghxTylZeSzcdo7/HriMVgFHG0veDGnEy+19sZStMUQFJ0FICCFEiRRF4ffw68zZeIbEdN3WGE82r8X7jzfB3cnWxNUJYRwShIQQQhRzMSGNab+dJvRSEgD1alZj9lNBdGxQw8SVCWFcEoSEEELoZeVq+PKfCyzfc4k8jYKNpQX/16MBo7vUw8ZStsYQlY8EISGEEAD8FRHPjA2nuXYrC4Aejd2Z+WQg3q72Jq5MiLIjQUgIIaq4KzczmflHBH+diQegtosdMwYE0DvAQ7bGEJWeBCEhhKiicvO1LN9ziS//uUB2nhZLCxWju9Tj/3o0wN5aPh5E1SDf6UIIUQXtv5jItN9PEXkjA4B2dV2ZMzCIhh6OJq5MiPIlQUgIIaqQhLRsPtp4ht/CrwNQw8Ga9x5vwsDHasttMFElSRASQogqQKNVWHXgMgu2niMtJx+VCl5u58tbffxxtrMydXlCmIwEISGEqOTCr9zi/d9OcupaKgDN6jgzZ2AQzeq4mLYwIcyABCEhhKikUjLz+GTrWX4
2024-04-29 12:35:18 -04:00
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"for i in range(data_sizes.shape[0]):\n",
" plt.plot(grids, test_losses[i,:], marker=\"o\")\n",
"plt.xscale('log')\n",
"plt.yscale('log')\n",
"plt.plot(np.array([5,100]), 0.1*np.array([3,100])**(-4.), ls=\"--\", color=\"black\")\n",
"plt.legend([f'data={data_sizes[i]}' for i in range(data_sizes.shape[0])]+[r'$N^{-4}$'])\n",
"plt.ylabel('test RMSE')\n",
"plt.xlabel('grid size')"
]
},
{
"cell_type": "markdown",
"id": "18bcedfe",
"metadata": {},
"source": [
"Fix model (grid) size, study data size scaling. No clear power law scaling. But we observe that: (1) increasing data size has no harm to performance. (2) powerful model (larger grid size) can benefit more from data size increase. Ideally one would want to increase data size and model size together so that their complexity always match."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "0dd85c41",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Text(0.5, 0, 'data size')"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
2024-07-13 22:17:48 -04:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAkIAAAG1CAYAAAAV2Js8AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAACi00lEQVR4nOzdd1xTVxsH8N/NYm9kCjhRKcoSxT1L1bq1dVfraB0dSm2rtVpXtdbdiqP6utqqtFpXtSpaB2pVZFXFKiIKMkQUCIQRSO77RyASZoCEG+D59pMP5Obk3CdIyZPnnHsOw7IsC0IIIYSQRojHdQCEEEIIIVyhRIgQQgghjRYlQoQQQghptCgRIoQQQkijRYkQIYQQQhotSoQIIYQQ0mhRIkQIIYSQRosSIUIIIYQ0WgKuA9B1crkcSUlJMDExAcMwXIdDCCGEEDWwLIusrCw4ODiAx6u47kOJUBWSkpLg5OTEdRiEEEIIqYGEhAQ0bdq0wscpEaqCiYkJAMUP0tTUlONoCCGEEKIOsVgMJycn5ft4RSgRqkLxcJipqSklQoQQQkg9U9W0FposTQghhJBGixIhQgghhDRaNDRGCCGkUZLJZCgoKOA6DFJDQqEQfD6/1v1QIkQIIaRRYVkWKSkpyMjI4DoUUkvm5uaws7Or1fI2lAgRQghpVIqTIBsbGxgaGtIacfUQy7LIyclBamoqAMDe3r7GfVEiRAghpNGQyWTKJMjKyorrcEgtGBgYAABSU1NhY2NT42EymixNCCGk0SieE2RoaMhxJEQTiv8dazPXixIhQgghjQ4NhzUMmvh3bBSJ0J9//ok2bdqgdevW2LVrF9fhEEIIIURHNPg5QoWFhQgICMDFixdhamoKb29vjBw5EpaWllyHRgghhBCONfiK0K1bt/DGG2/A0dERJiYmGDRoEM6ePct1WIQQQuo5mZzFP7EvcTwyEf/EvoRMznIdUhlPnjwBwzCIjIyssM2lS5fAMEyjXU5A5xOhK1euYMiQIXBwcADDMDh27FiZNlu3bkXz5s2hr68PHx8fhISEKB9LSkqCo6Oj8n7Tpk2RmJhYF6ETQghpoM7cTUb3NX9j3M4b+PRQJMbtvIHua/7GmbvJXIemwsnJCcnJyXB3d9faOfbu3QuGYcrc8vLytHZOTdL5REgikcDDwwNbtmwp9/GgoCDMnTsXixYtQkREBHr06IGBAwciPj4egGKtgdIqm1yVn58PsVisctOWp0+f4t69e1rrnxBCiOaduZuMWb+EIzlT9Y0+JTMPs34J15lkSCqVgs/nw87ODgKBdmfCmJqaIjk5WeWmr6+v1XNqis4nQgMHDsTKlSsxcuTIch/fsGEDpk2bhunTp6Ndu3bYtGkTnJycsG3bNgCAo6OjSgXo2bNnlS68tHr1apiZmSlvTk5Omn1BRQoKCjBu3Dj4+vpi165d5SZshBBCtI9lWeRIC9W6ZeUV4JsT91DeX+ziY0tPRCMrr6DKvqr7dz8rKwsTJkyAkZER7O3tsXHjRvTu3Rtz584FADRr1gwrV67ElClTYGZmhhkzZpQ7NHb69Gm4urrCwMAAffr0wZMnT2ryY1PBMAzs7OxUbvVFvZ4sLZVKERYWhgULFqgc9/f3x/Xr1wEAnTp1wt27d5GYmAhTU1OcPn0aS5YsqbDPhQsXIiAgQHlfLBZrJRmSSCQwMTFBbm4uZsyYgeDgYPz0008wMzPT+LkIIYRULLdABrclmpk7ygJIEeeh/dJzVbaNXv4WDEXqvw0HBATg2rVrOHHiBGxtbbFkyRKEh4fD09NT2Wbt2rVYvHgxvv7663L7SEhIwMiRIzFz5kzMmjULt2/fxmeffabSJj4+Hm5ubpXGMnHiRGzfvl15Pzs7Gy4uLpDJZPD09MSKFSvg5eWl9mvjUr1OhNLS0iCTyWBra6ty3NbWFikpKQAAgUCA9evXo0+fPpDL5fjiiy8qXU1UT08Penp6Wo0bUOyP8tdff2HdunVYtGgRfvvtN4SGhuLgwYPo3Lmz1s9PCCGk/sjKysK+fftw4MAB9OvXDwCwZ88eODg4qLTr27cv5s+fr7xfutqzbds2tGjRAhs3bgTDMGjTpg3u3LmDNWvWKNs4ODhUOrkaUAyFFWvbti327t2L9u3bQywWY/PmzejWrRuioqLQunXrGr7iulOvE6Fipef8sCyrcmzo0KEYOnRotfoMDAxEYGAgZDKZRmIsD4/HwxdffIFevXph7NixiIuLQ/fu3bFq1Sp89tln4PF0fuSSEELqPQMhH9HL31Kr7a24V5iyJ7TKdnvf90Wn5pUv02IgVH9LiMePH6OgoACdOnVSHjMzM0ObNm1U2nXs2LHSfu7fvw8/Pz+V98guXbqotBEIBGjVqpXasfn5+cHPz095v1u3bvD29saPP/6IH374Qe1+uFKv32mtra3B5/OV1Z9iqampZapE1TVnzhxER0cjNLTqX/ja6ty5MyIiIvDOO++gsLAQBw4cqNVy4YQQQtTHMAwMRQK1bj1aN4G9mT4quuSGAWBvpo8erZtU2Vd1VkUunk9U3gf/koyMjNTqpzLx8fEwNjau9DZz5swKn8/j8eDr64uYmJgqz6UL6nVFSCQSwcfHB8HBwRgxYoTyeHBwMIYNG8ZhZNVnbm6OoKAg+Pv7o0ePHnUyPEcIIaR6+DwG3wxxw6xfwsEAKpOmi1OUb4a4gc/T7BYeLVu2hFAoxK1bt5TzVsViMWJiYtCrVy+1+3FzcyuzDM2NGzdU7ld3aKw0lmURGRmJ9u3bqx0Xl3Q+EcrOzsajR4+U9+Pi4hAZGQlLS0s4OzsjICAAkyZNQseOHdGlSxf89NNPiI+PrzRbVUddDI2VxjAMpk+frnJsxYoVkEgkWLFiBYRCYZ3FQgghpHwD3O2xbaI3lp2MVrmE3s5MH98MccMA94qvTK4pExMTTJ48GZ9//jksLS1hY2ODb775Bjwer1qVpZkzZ2L9+vUICAjAhx9+iLCwMOzdu1elTXWHxpYtWwY/Pz+0bt0aYrEYP/zwAyIjIxEYGKh2H5xiddzFixdZKJJuldvkyZOVbQIDA1kXFxdWJBKx3t7e7OXLlzV2/szMTBYAm5mZqbE+1fXw4UOWx+OxAFg/Pz/28ePHdR4DIYQ0JLm5uWx0dDSbm5tb674KZXL2+qM09ljEM/b6ozS2UCbXQIQVE4vF7Pjx41lDQ0PWzs6O3bBhA9upUyd2wYIFLMuyrIuLC7tx40aV58TFxbEA2IiICOWxkydPsq1atWL19PTYHj16sLt372YBsOnp6TWKa+7cuayzszMrEonYJk2asP7+/uz169dr+Cqrp7J/T3XfvxmWpQVsKiMWi2FmZobMzMxKS4Ha8vvvv2PGjBnIzMyEmZkZdu7ciXfeeafO4yCEkIYgLy8PcXFxyt0I6jOJRAJHR0esX78e06ZN4zocTlT276nu+3e9niytTYGBgXBzc4Ovr6/G+2ZlMkhu3kLmn6cguXkLbCXDb++88w4iIyPh5+eHzMxMvPvuu5g5cyZyc3M1HhchhBDdFRERgYMHDyI2Nhbh4eGYMGECANS7ObG6hhKhCmjrqjHxuXN41K8/4idPRtL8+YifPBmP+vWH+FzFi281a9YMV65cwcKFC8EwDHbs2IFevXpBLpdrNDZCCCG6bd26dfDw8ED//v0hkUgQEhICa2trrsOq13R+snRDIj53DomfzgVKjUYWPn+uOL55E0z9/ct9rlAoxKpVq9C3b19MnDgR06dPp3WGCCGkEfHy8kJYWBjXYTQ4lAjVEVYmw/NVq8skQYoHWYBh8HzVapj06weGX/EiW/3798f9+/dhbm6uPHbv3j04OjqqHCOEEEJI1aikUAFNzxHKuR2GwlILP6pgWRSmpCDndtXZvoWFhfJyyfT0dLz99tvw9PQssxYEIYQQQipHiVAFND1HqPDFC422K5aSkgI+n4+nT5+ie/fuWLNmDc0dIoQQQtREiVAdETRpotF2xdq1a4fw8HCMGTMGMpkMCxYswIABA8psO0IIIYSQsigRqiOGHX0gsLMDKloBlGEgsLODYUefavdtZmaGgwc
2024-04-29 12:35:18 -04:00
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"for i in range(grids.shape[0]):\n",
" plt.plot(data_sizes, train_losses[:,i], marker=\"o\")\n",
"plt.xscale('log')\n",
"plt.yscale('log')\n",
"plt.plot(np.array([100,3000]), 1e8*np.array([100,3000])**(-4.), ls=\"--\", color=\"black\")\n",
"plt.legend([f'grid={grids[i]}' for i in range(grids.shape[0])]+[r'$N^{-4}$'])\n",
"plt.ylabel('train RMSE')\n",
"plt.xlabel('data size')"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "107801f6",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Text(0.5, 0, 'data size')"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
2024-07-13 22:17:48 -04:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAkIAAAG1CAYAAAAV2Js8AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAACiwElEQVR4nOzdd3xT1f/H8VeS7klLdykte8soe8kSUEFQVFQEFBz4UxQBlY0MQTYucCLgQHDxBUGGA8sUKEP2LqV07z2S3N8foaGlgwJpk7afp488au49uffTQfPuueeeo1IURUEIIYQQohpSm7sAIYQQQghzkSAkhBBCiGpLgpAQQgghqi0JQkIIIYSotiQICSGEEKLakiAkhBBCiGpLgpAQQgghqi0JQkIIIYSotqzMXYCl0+v1REZG4uzsjEqlMnc5QgghhCgDRVFIS0vDz88Ptbrkfh8JQrcRGRlJQECAucsQQgghxF24du0atWrVKnG/BKHbcHZ2BgxfSBcXFzNXI4QQQoiySE1NJSAgwPg+XhIJQreRfznMxcVFgpAQQghRydxuWIsMlhZCCCFEtSVBSAghhBDVlgQhIYQQQlRbEoSEEEIIUW1JEBJCCCFEtSVBSAghhBDVlgQhIYQQQlRbEoSEEEIIUW3JhIpmoOh0ZB4ORRsXh5WnJw5tg1FpNOYuSwghhKh2JAhVsNQdO4iZNx9tdLRxm5WPD95TJuPSt68ZKxNCCCGqH7k0VoFSd+zg+hvjCoUgAG1MDNffGEfqjh1mqkwIIYSoniQIVRBFpyNm3nxQlGJ2GrbFzJuPotNVcGVCCCFE9SVBqIJkHg419gQpqEiq0YBor2CSajRAQQWKgjY6mszDoWauVAghhKg+ZIxQBdHGxQEQ69GSC/WfIMfOzbjPNjuJBhd/xCv+uLGdEEIIIcpftegR+u2332jUqBENGjTgyy+/NEsNVp6exHq05GSzF8mxrVFoX45tDU42e5FYj5ZYeXqapT4hhBCiOqryQUir1TJ+/Hj++usvjhw5woIFC0hMTKzwOuzatOFCo6GGJypV4Z03nl+o/zhJGzeiS0mp4OqEEEKI6qnKB6GDBw/SrFkz/P39cXZ25qGHHmL79u0VXkf05TRyrF2LhqB8KhU5du6E/32SSw8PIHXbdpTiBlYLIYQQwmQsPgiFhIQwcOBA/Pz8UKlUbNy4sUibFStWUKdOHezs7AgODmb37t3GfZGRkfj7+xuf16pVi+vXr1dE6YVkpOaUqZ22VgN08fFcHzeOiNfGkhcTU86VCSGEENWXxQehjIwMWrZsyccff1zs/vXr1zNu3DimTp3K0aNH6datGw8++CDh4eEAxfaqqErqlSlHji62ZWp3vvYgoh5/lywnH9L//JPLDw8g6YcfUPT6cq5QCCGEqH4sPgg9+OCDzJ07l8cee6zY/UuXLmX06NG88MILNGnShOXLlxMQEMDKlSsB8Pf3L9QDFBERga+vb4nny8nJITU1tdDDFHwb1MCxxm3CkApyMrWcifdkf9vpHO8ymSj7hkTOeo+rI0aQc/mKSWoRQgghhIHFB6HS5ObmEhoaSt9blqbo27cv+/btA6B9+/acPHmS69evk5aWxtatW+nXr1+Jx5w/fz6urq7GR0BAgElqVatVdBvaoNQ2fUc3o9+LzQlo6g4qSLCuxalmo9nbeR7/JQVy4qmXiP/0U5TcXJPUJIQQQlR3lXoeofj4eHQ6Hd7e3oW2e3t7E31j8kIrKyuWLFlCz5490ev1vP3229SsWbPEY06ePJnx48cbn6empposDNVr7UX/l5uze/0FMpJvjhlycrOl65MNqNfaC4D6wV6kxmdxZl8UZ/ZFkZEM1wJ6cy2gN2f/uUTtv2fQ6u2ncQluaZK6hBBCiOqqUgehfLeO+VEUpdC2Rx55hEceeaRMx7K1tcXWtmzjee5GvdZe1GnpSdSFZDJSc3B0scW3QQ3U6sKfg4uHPR0eqUu7h4MIP53I6T2RhP0XR4prPU5Qj9MrIwiscZo2L/bGu5FXudUrhBBCVGWVOgh5eHig0WiMvT/5YmNji/QSWRK1WoV/I7fbNwTUGjVBLTwIauFBRnIOp/++zMkdl8i0cuByuj2Xl53E3R1a9GtIg/Y+2NpX6m+pEEIIUaEq9RghGxsbgoOD2blzZ6HtO3fupHPnzmaqqvw41rCl3aNNeO6Th+nbW4NP6ilU+jwSE+GfdedZ/fYe/lxzmqiLyTIHkRBCCFEGFt99kJ6ezsWLF43Pr1y5wrFjx3B3d6d27dqMHz+e4cOH07ZtWzp16sTnn39OeHg4Y8aMMWPV5UulVtHgifup91BbIpat4OyuK0T6diHD0Zez+6M5uz8aN19HmnbxpVFHH+ydbMxdshBCCGGRVIqFdx3s2rWLnj17Ftk+cuRIVq9eDRgmVFy4cCFRUVE0b96cZcuW0b17d5OcPzU1FVdXV1JSUnBxcTHJMU0t6/hxIqdNJy5GS6RvF2J92qFTGTKu2kpF3VaeNO3iR61GbqjUFT+HkhBCCFHRyvr+bfFByNwqQxACUHJzSVi1ivhPVpCn1xBTqzNxzQeSmHGzN8jFw44mXfxo0sn39nMaCSGEEJWYBCETqSxBKF/O5ctETZ9BVmio4XmbniR0G8nlc9nkZusAw6W1wOY1adbVj9rN3FFrKvVQMSGEEKIICUImUtmCEICi15O8YQOxixajz8gAa2tqjH6RpODBnP03lqiLN1e3d6xhS5POvjTp7IuLh70ZqxZCCCFMR4KQiVTGIJQvLyaG6FmzSf/rLwBs6tbFd+4csn0bcXpvJOf2R5OdkWdorIKAxm407epPnZYeaKykl0gIIUTlJUHIRCpzEALD5JJp23cQPXcuuvh4ANyeeRrP8ePB1oHLx+M4vSeSiLNJxtfYOVnTuKMPTbv64ebjaK7ShRBCiLsmQchEKnsQyqdLSSFm0SJSfvoZACtvb3xmzsS5l+GOvNT4LE7vjeTsvigyUm6uZeZb35WmXf2o18YLaxuNWWoXQggh7pQEIROpKkEoX8aBA0TNmEleeDgAzg/2x2fqVKw8PADQ6/RcPWVY0uPqiXjyfzps7K1o2N6bpl398AxwNlf5QgghRJlIEDKRqhaEAPRZWcR/8gkJX68GnQ61qyveb7+N62OPFlqjLT0ph7P7ozi9N5K0hGzjdq9AZ5p08aNhO29sZEkPIYQQFkiCkIlUxSCUL+vUKaKmTyfn9BkAHDp1xHfWLGxq1y7UTtErRJxL4vSeSC4fi0OvM/zIWNmoadDW0EvkXcelyOK3QgghhLlIEDKRqhyEABStlsQ1a4j78COUnBxUdnZ4jh2L+8gRqKyK9vZkpeVy7t9oTu+JJCk607jd3c+Rpl38aNTBBzsn64r8FIQQQogiJAiZSFUPQvlyr14laua7ZB44AIBd06b4vjcXuyZNim2vKApRl1I4syeSi6GxaPP0AGis1NRt7UnTrn74N6ghS3oIIYQwCwlCJlJdghAYwk3KL78Qs2Ah+tRU0GioOep5PF59FbWdXYmvy8nM4/zBGE7vjST+Wrpxu4unPU27+NK4ky+OrrKkhxBCiIojQchEqlMQyqeNiyP6vXmkbdsGgHVgbXxnzcaxY4fbvjb2aiqn90Ry/lAMeQWW9AhqUZOmXf2o3awmauklEkIIUc4kCJlIdQxC+dL+/JPoWbPRxsYCUOOJx/GaOBGNq+ttX5uXo+NiaAyn90QSfTnVuN3JzZbG+Ut61JQlPYQQQpQPCUImUp2DEIAuLY3YpUtJXvcDABpPD3ymTce57wNlvkssITKdM3uiOPtvFDkZWsNGFdRu4k7Trn4E3SdLegghhDAtCUImUt2DUL7Mw4eJmj6D3CtXAHDq0xuf6dOx9vYu8zF0eXouH4vj1J5Irp+7uaSHvbM1jTv60rSrHzW8HUxeuxBCiOpHgpCJSBC6SZ+TQ/ynn5LwxZeg1aJ2csJr4kRqPPkEKvWd9eikxGVyem8UZ/dFkZl6c0kPvwY1DEt6tPbESpb0EEIIcZckCJmIBKGiss+dJ2r6dLL/+w8Ah7Zt8Zk9G9u6de74WDqdnqs
2024-04-29 12:35:18 -04:00
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"for i in range(grids.shape[0]):\n",
" plt.plot(data_sizes, test_losses[:,i], marker=\"o\")\n",
"plt.xscale('log')\n",
"plt.yscale('log')\n",
"plt.plot(np.array([100,3000]), 1e5*np.array([100,3000])**(-4.), ls=\"--\", color=\"black\")\n",
"plt.legend([f'grid={grids[i]}' for i in range(grids.shape[0])]+[r'$N^{-4}$'])\n",
"plt.ylabel('test RMSE')\n",
"plt.xlabel('data size')"
]
2024-07-13 22:17:48 -04:00
},
{
"cell_type": "code",
"execution_count": null,
"id": "47bdf5af",
"metadata": {},
"outputs": [],
"source": []
2024-04-29 12:35:18 -04:00
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}