533 lines
243 KiB
Plaintext
Raw Normal View History

2024-04-29 12:35:18 -04:00
{
"cells": [
{
"cell_type": "markdown",
"id": "5d904dee",
"metadata": {},
"source": [
"# Example 8: KANs' Scaling Laws"
]
},
{
"cell_type": "markdown",
"id": "6465ec94",
"metadata": {},
"source": [
"In this example, we show KAN's scaling laws (wrt model params and data size)"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "a1c25e8a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"data_size=100\n",
"grid_size=5\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"train loss: 5.22e-03 | test loss: 7.32e-03 | reg: 2.91e+00 : 100%|██| 50/50 [00:07<00:00, 7.10it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"grid_size=10\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"train loss: 2.20e-04 | test loss: 8.06e-04 | reg: 2.90e+00 : 100%|██| 50/50 [00:06<00:00, 7.48it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"grid_size=20\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"train loss: 7.01e-06 | test loss: 3.07e-04 | reg: 2.90e+00 : 100%|██| 50/50 [00:07<00:00, 6.73it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"grid_size=50\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"train loss: 1.97e-04 | test loss: 3.15e-02 | reg: 2.90e+00 : 100%|██| 50/50 [00:07<00:00, 6.45it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"grid_size=100\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"train loss: 2.00e-03 | test loss: 1.65e+00 | reg: 3.05e+00 : 100%|██| 50/50 [00:07<00:00, 6.35it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"data_size=300\n",
"grid_size=5\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"train loss: 5.80e-03 | test loss: 6.71e-03 | reg: 2.88e+00 : 100%|██| 50/50 [00:06<00:00, 7.23it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"grid_size=10\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"train loss: 3.19e-04 | test loss: 3.15e-04 | reg: 2.89e+00 : 100%|██| 50/50 [00:06<00:00, 7.31it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"grid_size=20\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"train loss: 1.96e-05 | test loss: 2.34e-05 | reg: 2.89e+00 : 100%|██| 50/50 [00:07<00:00, 6.67it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"grid_size=50\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"train loss: 2.08e-06 | test loss: 5.00e-06 | reg: 2.89e+00 : 100%|██| 50/50 [00:11<00:00, 4.37it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"grid_size=100\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"train loss: 2.81e-07 | test loss: 3.41e-02 | reg: 2.89e+00 : 100%|██| 50/50 [00:17<00:00, 2.83it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"data_size=1000\n",
"grid_size=5\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"train loss: 6.45e-03 | test loss: 6.44e-03 | reg: 2.91e+00 : 100%|██| 50/50 [00:07<00:00, 6.72it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"grid_size=10\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"train loss: 4.14e-04 | test loss: 3.76e-04 | reg: 2.94e+00 : 100%|██| 50/50 [00:07<00:00, 6.54it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"grid_size=20\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"train loss: 4.94e-05 | test loss: 4.69e-05 | reg: 2.93e+00 : 100%|██| 50/50 [00:14<00:00, 3.44it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"grid_size=50\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"train loss: 5.21e-06 | test loss: 1.30e-05 | reg: 2.93e+00 : 100%|██| 50/50 [00:49<00:00, 1.01it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"grid_size=100\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"train loss: 2.12e-06 | test loss: 1.05e-05 | reg: 2.93e+00 : 100%|██| 50/50 [01:15<00:00, 1.51s/it]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"data_size=3000\n",
"grid_size=5\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"train loss: 6.12e-03 | test loss: 6.77e-03 | reg: 2.79e+00 : 100%|██| 50/50 [00:16<00:00, 2.99it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"grid_size=10\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"train loss: 2.98e-04 | test loss: 3.35e-04 | reg: 2.78e+00 : 100%|██| 50/50 [00:34<00:00, 1.44it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"grid_size=20\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"train loss: 1.72e-05 | test loss: 1.86e-05 | reg: 2.78e+00 : 100%|██| 50/50 [00:38<00:00, 1.31it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"grid_size=50\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"train loss: 3.97e-07 | test loss: 4.93e-07 | reg: 2.78e+00 : 100%|██| 50/50 [00:51<00:00, 1.03s/it]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"grid_size=100\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"train loss: 2.61e-08 | test loss: 3.27e-08 | reg: 2.78e+00 : 100%|██| 50/50 [00:26<00:00, 1.85it/s]\n"
]
}
],
"source": [
"from kan import *\n",
"\n",
"# initialize KAN with G=3\n",
"model = KAN(width=[2,1,1], grid=3, k=3)\n",
"\n",
"data_sizes = np.array([100,300,1000,3000])\n",
"grids = np.array([5,10,20,50,100])\n",
"\n",
"train_losses = np.zeros((data_sizes.shape[0], grids.shape[0]))\n",
"test_losses = np.zeros((data_sizes.shape[0], grids.shape[0]))\n",
"steps = 50\n",
"k = 3\n",
"\n",
"for j in range(data_sizes.shape[0]):\n",
" data_size = data_sizes[j]\n",
" print(f'data_size={data_size}')\n",
" # create dataset\n",
" f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)\n",
" dataset = create_dataset(f, n_var=2, train_num=data_size)\n",
" \n",
" for i in range(grids.shape[0]):\n",
" print(f'grid_size={grids[i]}')\n",
" if i == 0:\n",
" model = KAN(width=[2,1,1], grid=grids[i], k=k)\n",
" if i != 0:\n",
" model = KAN(width=[2,1,1], grid=grids[i], k=k).initialize_from_another_model(model, dataset['train_input'])\n",
" results = model.train(dataset, opt=\"LBFGS\", steps=steps, stop_grid_update_step = 30)\n",
" train_losses[j][i] = results['train_loss'][-1]\n",
" test_losses[j][i] = results['test_loss'][-1]\n"
]
},
{
"cell_type": "markdown",
"id": "6be8ba55",
"metadata": {},
"source": [
"Fix data size, study model (grid) size scaling. Roughly display $N^{-4}$ scaling."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "e05289dd",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Text(0.5, 0, 'grid size')"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAkIAAAG1CAYAAAAV2Js8AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAC6f0lEQVR4nOzdd3hU1brH8e/MpFcS0ggkIQECCQktVOkdgiBNUQERPCiICiKKqOdcUemCiAYQAVFQRGkqhia9SW8hkJACBEjvPVP2/WPD0CGBSSYT1ud5uM9lz87MmhzM/LLW+66lkCRJQhAEQRAE4SmkNPYABEEQBEEQjEUEIUEQBEEQnloiCAmCIAiC8NQSQUgQBEEQhKeWCEKCIAiCIDy1RBASBEEQBOGpJYKQIAiCIAhPLRGEBEEQBEF4apkZewCVnU6n4/r169jb26NQKIw9HEEQBEEQSkGSJHJzc/H09ESpfPC8jwhCj3D9+nW8vLyMPQxBEARBEB5DQkICtWrVeuDjIgg9gr29PSB/Ix0cHIw8GkEQBEEQSiMnJwcvLy/95/iDVPkglJCQwPDhw0lJScHMzIz//ve/PP/886X++pvLYQ4ODiIICYIgCIKJeVRZS5UPQmZmZsyfP58mTZqQkpJCs2bNCA0NxdbW1thDEwRBEATByKp8EKpRowY1atQAwM3NDWdnZzIyMkQQEgRBEATB+O3ze/fupW/fvnh6eqJQKNi4ceM99yxcuBBfX1+srKwICQlh3759j/Vax44dQ6fTieJnQRAEQRCASjAjlJ+fT+PGjRk5ciSDBg265/E1a9YwYcIEFi5cSNu2bfnuu+/o3bs3kZGReHt7AxASEkJxcfE9X7tt2zY8PT0BSE9P55VXXmHp0qUPHU9xcfEdz5WTk/Mkb08QBEEQhEpMIUmSZOxB3KRQKNiwYQP9+/fXX2vVqhXNmjVj0aJF+msBAQH079+fGTNmlOp5i4uL6d69O6NHj2b48OEPvffTTz9l6tSp91zPzs4WxdKCIAiCYCJycnJwdHR85Oe30ZfGHqakpITjx4/To0ePO6736NGDgwcPluo5JEni1VdfpUuXLo8MQQBTpkwhOztb/ychIeGxxi4IgiAIQuVXqYNQWloaWq0Wd3f3O667u7uTlJRUquc4cOAAa9asYePGjTRp0oQmTZpw9uzZB95vaWmJg4MDK1eupHXr1nTt2vWJ3oMgCIIgCJWX0WuESuPuPQAkSSr1cRft2rVDp9OV+TXHjRvHuHHj9FNrgiAIgiBUPZU6CLm4uKBSqe6Z/UlJSblnlsiUaHVaTqScILUgFVcbV5q5NUOlVBl7WIIgCILw1KnUQcjCwoKQkBC2b9/OgAED9Ne3b9/Oc889V66vHRYWRlhYGFqt1qDP+8/lf5j17wyco5JwyoNMO8io78Hk1lPo5tPNoK8lCIIgCMLDGT0I5eXlERMTo/97fHw8p06dwtnZGW9vbyZOnMjw4cNp3rw5bdq0YcmSJVy5coUxY8aU67jKY2nsn8v/8Ot3E/jfdi0uubeup9lf58fuE+CN+SIMCYIgCEIFMnoQOnbsGJ07d9b/feLEiQCMGDGCFStWMGTIENLT0/nss89ITEwkKCiI8PBwfHx8jDXkx6LVadn8w/8xcf29M0zOuTBxvZZlFp/S+X+dxTKZIAiCIFSQSrWPUGVy+9JYdHT0E+8jdPTaYYr6v0r1XLhfmbcOyLAHq40raFGz1WO/jiAIgiAIVWQfIWMaN24ckZGRHD161CDPl3v0MC4PCEEg/w/hkivfJwiCIAhCxRBBqII45Zauhd8mu6icRyIIgiAIwk0iCD1AWFgYgYGBtGjRwiDPV9feplT3Lb36G79F/YZGpzHI6wqCIAiC8GAiCD2AoZfGbGo7orHR8ah5oUZn8pm36zOe/+t5Dlw7YJDXFgRBEATh/kQQqiAX8u3waZqFAom7q9Ml0F/tfhq+/l6Hx/5oxmx/gzH/jCE2K7bCxysIgiAITwMRhCpIjE0webVsqNk2E3PrO1vozW001GqbiU+XNCwc1DjkS7zzl47/rpGIPbufQX8O4ot/vyCjKMNIoxcEQRCEqkm0zz+AodvnD8Wms2LZAhaZz0fSQVGaBZoiFWZWWqxcSlAo4VtNf0Y6nKT4WDppkfZIWgUaMwXr2ij4o7UCK2t7Xm/0OkMDhmKhsjDguxUEQRCEqqW07fMiCD1Cab+Rj6LVSbSbtZPGuXv5n/lPeCpuze5cl6ozVT2crbqWVLOU+DHgKA1OLSP5X3Pyk6wASHW1YGF3Ded8lNS0q8m7Ie/Sw6dHqQ+fFQRBEISniQhCBmKoIASwJSKRsatOoERHC+UF3MgihWoc1TVAixKf6jZcTi8AoF9tiRl2a9Ds2EbySUe0RfJu04ebWLOkYwm5NgqauDbhgxYfEOwa/MTvUxAEQRCqEhGEDMSQQQjkMDT1r0gSb9svqIajFf/XN5BuAe4sPxDP3G3RFGt0WJurmNcim+4xc0nbmURmjA2gQG1rwY+dFGwP1iApFPTx68P4puOpYVfjiccnCIIgCFWBCEIGYuggBPIy2ZH4DFJyi3Czt6KlrzMq5a0lrktp+Uxed4bD8fLyWQsvOxb6H8cufD6JhywozjIHIKWuEzM75XDVVYGlypIRDUfwWtBr2JiXbs8iQRAEQaiqRBB6QoYuli4rnU5i9dErzAi/QF6xBguVkg/bVWNE7nKy1v9NaoQ9kkaJpFTwb2cPwpqmUGKuwMXahbebvs1zdZ4Th7cKgiAITy0RhAykPGaEyuJ6ViEfbzjLrqhUAAJqOBDWrgivf/5L0rYk8q5ZA6BxdWB5H2v+qZEOQH2n+rzf4n1a1RAHuAqCIAhPHxGEDMTYQQhAkiT+OHWdqX+dI7NAjUqp4PV23kystp+ilbNIOmyGpsAMgPTWdZneJpUEq3wAOtXqxMTmE/F19DXK2AVBEATBGEQQMpDKEIRuSssr5tM/z7HpTCIAfi62zO3jSeOzc0n9OZyMaFuQFCiszTk1qAmzap1Go9BhpjBjSIMhjGk0hmpW1Yz6HgRBEAShIoggZCCVKQjdtO1cEp9sjCAltxiAV9r48GFwHsqfJ5K4OZmiDHmzRUXdmvw6uCbrFCcAcLBwYEzjMbxY/0XMVeZGG78gCIIglDcRhAykMgYhgOxCNdP/Ps+aYwkA1KxmzfT+gXTI+ZvMhdNJPW6GTq0EBRT1a8/skBQiCuUzy7ztvZnYfCJdvLqIDRkFQRCEKkkEoSdk7K6x0joQk8aH68+QkFEIwMBmNfm/rh7YbJ1KyorN5FyRi6nNnGyJf6Mv0612kl58oy3fowXvN3+fgOoBRhu/IAiCIJQHEYQMpLLOCN2uoETDl1uj+eFgPJIELnYWfPZcEKHVk8gLe4ukLSmo8+ViauuWDdk2tCnfpaynRFeCAgX96vTjnWbv4GbjZuR3IgiCIAiGIYKQgZhCELrp+OVMJq87Q0xKHgC9GnrwWb8AXC6sIW3eDNLPKkGnQGGuxPK1IXzXKJe/E7YAYG1mzcigkbza8FWszayN+TYEQRAE4YmJIGQgphSEAIo1Wr7dGcOi3bFodBIOVmb899lABje0o2T1RyT+sI3CVLmY2rKWM3lT3mZ28SZOp54GwM3GjQnNJtDHrw9KhdKYb0UQBEGowh51ysKTEkHIQEwtCN0UeT2HyevOcPZaNgDt67kwfUAwtYpiyJ79Jik7U9GWyEGnWq+2nHmtD/OiFnM9/zoAgdUDeb/5+zT3aG609yAIgiBUTQ87d7NXkGHOzRRByEBMNQgBaLQ6lu6PZ972aEo0OmwsVEzu1YDhrbzRHfqRlFkzyb4ohyGVrRnVJ41nY0P4PmIp+Wp5Q8Zu3t2YGDIRLwcvY74VQRAEoYrYEpHI2FUnuDt83JwLWjSsmUHCkAhCBmLKQeimuNQ8Plx3liOX5G6x5j5OzBrciDr2OvKXTiLpp12U5MjF1LYBNbCcNpMl2VtZe3EtOkmHmdKMoQ2G8nrj13GwMM3vgSAIgmB8Wp1Eu1k775gJup0C8HC0Yv/kLk+8TFbaz29RBPIU8HO149fXW/P
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"for i in range(data_sizes.shape[0]):\n",
" plt.plot(grids, train_losses[i,:], marker=\"o\")\n",
"plt.xscale('log')\n",
"plt.yscale('log')\n",
"plt.plot(np.array([5,100]), 0.1*np.array([3,100])**(-4.), ls=\"--\", color=\"black\")\n",
"plt.legend([f'data={data_sizes[i]}' for i in range(data_sizes.shape[0])]+[r'$N^{-4}$'])\n",
"plt.ylabel('train RMSE')\n",
"plt.xlabel('grid size')"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "6d15cc9e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Text(0.5, 0, 'grid size')"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAkIAAAG1CAYAAAAV2Js8AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAClfklEQVR4nOzdd3xT5dvH8U+a7j2gC9qyy15lCMheAjIUBQHZOB4HAg6GgD/AgSzBgSiigMgeAsreyN6rlFlaRkv33sl5/kgbKLSlQNp0XG9ffdGenCRXase3577u+1YpiqIghBBCCFEKmRi7ACGEEEIIY5EgJIQQQohSS4KQEEIIIUotCUJCCCGEKLUkCAkhhBCi1JIgJIQQQohSS4KQEEIIIUotCUJCCCGEKLVMjV1AUafVarl37x52dnaoVCpjlyOEEEKIfFAUhfj4eDw9PTExyf26jwShJ7h37x5eXl7GLkMIIYQQz+D27duUL18+19slCD2BnZ0doPtE2tvbG7kaIYQQQuRHXFwcXl5e+t/juZEg9ARZw2H29vYShIQQQohi5kltLdIsLYQQQohSS4KQEEIIIUotCUJCCCGEKLUkCAkhhBCi1JIgJIQQQohSS4KQEEIIIUotCUJCCCGEKLUkCAkhhBCi1JIFFYUQQghR6DRaheOBUYTFp+BqZ0mTis6oTQp/T08JQkIIIYQoVNsuhjBlsz8hsSn6Yx4OlnzRvSYv1fYo1FpkaEwIIYQQhWbbxRD+b9npbCEIIDQ2hf9bdpptF0MKtR4JQkIIIYQoFBqtwpTN/ig53JZ1bMpmfzTanM4oGBKEhBBCCFEojgdGPXYl6GEKEBKbwvHAqEKrSYKQEEIIIQpFWHzuIehZzjMECUJCCCGEKBTm6vzFDlc7ywKu5AGZNSaEEEKIApWh0bLkSBBzdlzJ8zwV4O6gm0pfWErFFaF//vkHX19fqlatym+//WbscoQQQohS49jNSLp9/x/T/vEnMU2Dj4s1oAs9D8v6+IvuNQt1PaESf0UoIyODMWPGsHfvXuzt7WnYsCGvvvoqzs6FlzaFEEKI0iYsLoVvtgaw4cxdAJyszRj7UnX6NPJih3/oY+sIuRtpHaESH4SOHz9OrVq1KFeuHABdu3Zl+/bt9OvXz8iVCSGEECVP1jDYdzuvkpCagUoF/Zp482knX5xszAF4qbYHHWu6F4mVpYv80NiBAwfo3r07np6eqFQq/v7778fOmT9/PhUrVsTS0hI/Pz8OHjyov+3evXv6EARQvnx57t69WxilCyGEEKXK8cAoXv5BNwyWkJpBvfIObHy/BV+/UkcfgrKoTVQ0q+xCz/rlaFbZxSghCIrBFaHExETq1avH0KFD6d2792O3r1q1ilGjRjF//nxatGjBL7/8QpcuXfD398fb2xtFeXxRJpUq9092amoqqamp+o/j4uIM80KEEEKIEiosPoXpWwJYnzkM5pg5DNa3kRcmRgo4+VXkg1CXLl3o0qVLrrfPmTOH4cOHM2LECADmzp3L9u3b+fnnn/nmm28oV65ctitAd+7coWnTprk+3jfffMOUKVMM9wKEEEKIEipDo2Vp5jBYfC7DYEVdkR8ay0taWhqnTp2iU6dO2Y536tSJw4cPA9CkSRMuXrzI3bt3iY+PZ8uWLXTu3DnXxxw/fjyxsbH6t9u3bxfoaxBCCCGKo6xhsKn/+BOfmkHd8g78/V7Ow2BFWZG/IpSXiIgINBoNbm5u2Y67ubkRGhoKgKmpKbNnz6Zt27ZotVo+++wzXFxccn1MCwsLLCwsCrRuIYQQorjKaRjss87V6dvYy2h9Ps+jWAehLI/2/CiKku1Yjx496NGjR2GXJYQQQpQYGRotfx4NYs6OB8NgbzT25rPOxWcYLCfFOgiVKVMGtVqtv/qTJSws7LGrREIIIYR4NiduRTHp74sEhMYDULe8A1N71qa+l6NxCzOAYh2EzM3N8fPzY+fOnbzyyiv64zt37qRnz55GrEwIIYQo/sLjU/lm62XWny4Zw2A5KfJBKCEhgevXr+s/DgwM5OzZszg7O+Pt7c2YMWMYOHAgjRo1olmzZvz6668EBwfz7rvvGrFqIYQQovjK0GhZdjSI2dmGwbz4tHN1nIvxMFhOinwQOnnyJG3bttV/PGbMGAAGDx7M4sWL6du3L5GRkUydOpWQkBBq167Nli1b8PHxMVbJQgghRLF18lYUEx8aBqtTzoGpPWvRwNvJyJUVDJWS04qDQi8uLg4HBwdiY2Oxt7c3djlCCCFEgQiPT2X61gDWnb4DgIOVGZ929qVfE+9iOQyW39/fRf6KkBBCCCEKToZGy1/Hgpm14wrxKRmAbhjss5dK3jBYTiQICSGEEKXUyVtRTNp4icshuu2kSvowWE4kCAkhhBClTESCbhhs7amSMQz2PCQICSGEEKVETsNgfRt58dlLvrjYls5dFSQICSGEEKXAqaAoJv19Cf/MYbBanvZM61WbhqVoGCwnEoSEEEKIEiwiIZVvtwawJnMYzN7SlE9fqk7/UjgMlhMJQkIIIUQJpNEq/HUsiFnbrxAnw2C5kiAkhBBClDCngqKZvPEil+49GAab2rM2fj6lexgsJxKEhBBCiBIiMiGVb7cFsPrkQ8NgnX3p39RHhsFyIUFICCGEKOY0WoXlx4KY+dAwWJ9G5fnspeqUkWGwPEkQEkIIIYqx08G6YbCLd3XDYDU9dLPBZBgsfyQICSGEEMVQTsNgn3T2ZYAMgz0VCUJCCCFEMaLRKiw/HszMbQH6YbDX/coztosMgz0LCUJCCCFEMXEmOJpJjw2D1cLPx9nIlRVfEoSEEEKIIi4yIZUZ266w6uRtAOwsTfmkky8DmnpjqjYxcnXFmwQhIYQQoojSaBVWHA9m5vYrxCanA/CaX3nGyTCYwUgQEkIIIYqgR4fBanjYM61nLRpVkGEwQ5IgJIQQQhQhUYlpzNgWwMoTMgxWGCQICSGEEEWARquw8kQwM7Y9GAbr3VA3DFbWTobBCooEISGEEMLIzt6OYfLGi5y/EwvIMFhhkiAkhBBCGElUYhozt+uGwRQF7CxM+bhTNd58wUeGwQqJBCEhhBCikGUNg83cfoWYJN0w2KsNyzG+Sw0ZBitkEoSEEEKIQnTudgyTHhoGq+5ux7RetWksw2BGIUFICCGEKATRiWnM2H6FlSeC9cNgYzpVY6AMgxmVBCEhhBCiAGm1CitP3GbG9oBsw2DjulTH1c7SyNUJCUJCCCFEATmXORvs3EPDYFN71qZJRRkGKyokCAkhhBAGltMw2OiO1RjUTIbBihoJQkIIIYSBaLUKq07e5tttDw2DNSjHuK4yDFZUSRASQgghDOD8nRgmbbzEudsxgG4YbEqPWjSt5GLcwkSeJAgJIYQQzyEmKY2Z26+w/LhuGMzWwpQxMgxWbEgQEkIIIZ6BVquwOnMYLDpzGOyVBuUY36U6rvYyDFZcSBASQgghntKFO7FM3HhRPwzm62bH1J4yDFYcSRASQggh8imnYbBRHaoyuHkFzGQYrFiSICSEEEI8gVarsObUbaZvfTAM1qu+JxO61pBhsGJOgpAQQgiRh4t3Y5n490XOZg6DVXOzZWrP2rwgw2AlggQhIYQQIgcxSWnM2nGFv47JMFhJJkFICCGEeEjWMNi3264QlZgGQM/MYTA3GQYrcSQICSGEEJlyGgab0qM2zSrLMFhJJUFICCFEqReblM6sHVdYdiwIRQEbczWjO1aTYbBSQIKQEEKIUkurVVh7+g7TtwbIMFgpJUFICCFEqXTxbiyTNl7kTHAMAFVddbPBZBiskGg1EHQYEu6DrRv4NAcTdaGXIUFICCFEiaTRKhwPjCIsPgVXO0uaVHRGbaIiNimd2TuvsOxoENrMYbBRHaoxpIUMgxUa/02wbSzE3XtwzN4TXvoWavYo1FIkCAkhhChxtl0MYcpmf0JiU/TH3O0t6VjLjS3nQ4jMHAbrUU83DObuIMNghcZ/E6weBCjZj8eF6I73WVqoYUiCkBBCiBJl28UQ/m/Z6Ud/zRIal8KfR4IAqOJqy9SetWheuUzhF1iaaTW6K0GP/d8h85g
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"for i in range(data_sizes.shape[0]):\n",
" plt.plot(grids, test_losses[i,:], marker=\"o\")\n",
"plt.xscale('log')\n",
"plt.yscale('log')\n",
"plt.plot(np.array([5,100]), 0.1*np.array([3,100])**(-4.), ls=\"--\", color=\"black\")\n",
"plt.legend([f'data={data_sizes[i]}' for i in range(data_sizes.shape[0])]+[r'$N^{-4}$'])\n",
"plt.ylabel('test RMSE')\n",
"plt.xlabel('grid size')"
]
},
{
"cell_type": "markdown",
"id": "18bcedfe",
"metadata": {},
"source": [
"Fix model (grid) size, study data size scaling. No clear power law scaling. But we observe that: (1) increasing data size has no harm to performance. (2) powerful model (larger grid size) can benefit more from data size increase. Ideally one would want to increase data size and model size together so that their complexity always match."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "0dd85c41",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Text(0.5, 0, 'data size')"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAkIAAAG1CAYAAAAV2Js8AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAACpMUlEQVR4nOzdd3xT9frA8c9J0r2gQCctu0iZpWWLskSKKNPrFURQpqKCZYvAZcseWqZX8afCxSvCdbBRtgoUUKDKLLSUDgq06R7J+f0RGprOlCZNx/d9X71tTr455ylIzpPnuyRZlmUEQRAEQRCqIYWlAxAEQRAEQbAUkQgJgiAIglBtiURIEARBEIRqSyRCgiAIgiBUWyIREgRBEASh2hKJkCAIgiAI1ZZIhARBEARBqLZEIiQIgiAIQrWlsnQAFZ1Wq+Xu3bs4OTkhSZKlwxEEQRAEwQiyLJOcnIyXlxcKRdF1H5EIleDu3bv4+PhYOgxBEARBEJ5AVFQUdevWLfJ5kQiVwMnJCdD9QTo7O1s4GkEQBEEQjKFWq/Hx8dHfx4siEqES5HaHOTs7i0RIEARBECqZkoa1iMHSRQgNDcXf35927dpZOhRBEARBEMxEErvPF0+tVuPi4kJSUpKoCAmCIAhCJWHs/Vt0jQmCIAjVkkajITs729JhCE/IysoKpVJZ5vOIREgQBEGoVmRZJjY2lsTEREuHIpRRjRo18PDwKNPyNiIREgRBEKqV3CTIzc0Ne3t7sUZcJSTLMmlpacTHxwPg6en5xOcSiZAgCIJQbWg0Gn0SVKtWLUuHI5SBnZ0dAPHx8bi5uT1xN5mYNSYIgiBUG7ljguzt7S0ciWAKuX+PZRnrJRIhQRAEodoR3WFVgyn+HqtFIvTjjz/StGlTmjRpwqeffmrpcARBEARBqCCq/BihnJwcQkJC+OWXX3B2dqZt27YMGjQIV1dXS4cmCIIgCIKFVfmK0OnTp2nevDne3t44OTnRt29f9u/fb+mwBEEQhEpOo5X59cZ9/nchml9v3EejrXjrE9+6dQtJkrhw4UKRbY4cOYIkSdV2OYEKnwgdO3aMF198ES8vLyRJYvfu3QXarF+/ngYNGmBra0tgYCDHjx/XP3f37l28vb31j+vWrUt0dHR5hC4IgiBUUfsuxfD00p95dctvTPzPBV7d8htPL/2ZfZdiLB2aAR8fH2JiYmjRooXZrrF161YkSSrwlZGRYbZrmlKFT4RSU1Np3bo1n3zySaHP79ixg0mTJjFr1izOnz9P165dCQ4OJjIyEtCtNZBfcYOrMjMzUavVBl/mcunSJX777TeznV8QBEEwvX2XYnjrq3PEJBne6GOTMnjrq3MVJhnKyspCqVTi4eGBSmXekTDOzs7ExMQYfNna2pr1mqZS4ROh4OBgFi5cyKBBgwp9ftWqVYwaNYrRo0fTrFkz1qxZg4+PDxs2bADA29vboAJ0586dYhdeWrJkCS4uLvovHx8f0/5Cj8iyzFtvvUWnTp14/fXXuXv3rlmuIwiCIBRPlmXSsnKM+krOyGbu95cprBMs99i/vg8nOSO7xHOVdqvP5ORkhg0bhoODA56enqxevZpu3boxadIkAOrXr8/ChQsZOXIkLi4ujBkzptCusT179uDn54ednR3du3fn1q1bT/LHZkCSJDw8PAy+KotKPVg6KyuLsLAwZsyYYXC8d+/enDp1CoD27dtz6dIloqOjcXZ2Zs+ePcyZM6fIc86cOZOQkBD9Y7VabZZkKCMjgyZNmnDixAm+/PJLvvvuOz788EMmTZpUabJoQRCEqiA9W4P/HNOMHZWBWHUGLf91oMS24fOfx97a+NtwSEgIJ0+e5Pvvv8fd3Z05c+Zw7tw52rRpo2+zfPlyZs+ezYcffljoOaKiohg0aBDjx4/nrbfe4uzZs0yePNmgTWRkJP7+/sXG8tprr7Fx40b945SUFOrVq4dGo6FNmzYsWLCAgIAAo383S6rUiVBCQgIajQZ3d3eD4+7u7sTGxgKgUqlYuXIl3bt3R6vVMm3atGJXE7WxscHGxobQ0FBCQ0PRaDRmid3Ozo7PPvuMt956i/fee4/ffvuNmTNnsmXLFlatWsVLL70k1rkQBEEQAF016IsvvmDbtm307NkTgM8//xwvLy+Ddj169GDKlCn6x/mrPRs2bKBhw4asXr0aSZJo2rQpFy9eZOnSpfo2Xl5exQ6uBgx2c3/qqafYunUrLVu2RK1Ws3btWrp06cIff/xBkyZNnvA3Lj+VOhHKlT9hkGXZ4NhLL73ESy+9VKpzTpgwgQkTJqBWq3FxcTFJnIVp164dJ0+eZNu2bUyfPp2bN28yYMAAvvvuOwYOHGi26wqCIAg6dlZKwuc/b1Tb0xEPGPn5mRLbbX2jHe0bFL9Mi52V8VtC3Lx5k+zsbNq3b68/5uLiQtOmTQ3aBQUFFXuev/76i44dOxrcIzt16mTQRqVS0bhxY6Nj69ixIx07dtQ/7tKlC23btuXjjz9m3bp1Rp/HUip1IlS7dm2USqW++pMrPj6+QJWoIlMoFLz22msMGDCAJUuWcPDgQV588UX98/kTO0EQBMF0JEkyuouqa5M6eLrYEpuUUeg4IQnwcLGla5M6KBWme9/OHU9U2Af/vBwcHIw6T3GepGssL4VCQbt27bh27VqJ16oIKvxg6eJYW1sTGBjIwYMHDY4fPHiQzp07l+ncoaGh+Pv7065duzKdpzQcHR1ZtGgRv/76q36Ef0ZGBh06dGDDhg1m66YTBEEQjKNUSMx9UZck5E9zch/PfdHfpEkQQKNGjbCysuL06dP6Y2q1utTJhr+/f4HZyvkf53aNFfc1f/78Iq8hyzIXLlwo047w5anCV4RSUlK4fv26/nFERAQXLlzA1dUVX19fQkJCGD58OEFBQXTq1InNmzcTGRnJ+PHjy3Td8uoaK0zeHXQ/++wzzpw5w5kzZ9i4cSPr1q3j2WefLdd4BEEQhMf6tPBkw2ttmfdDuMEUeg8XW+a+6E+fFqZPAJycnBgxYgRTp07F1dUVNzc35s6di0KhKFWPwfjx41m5ciUhISGMGzeOsLAwtm7datCmtF1j8+bNo2PHjjRp0gS1Ws26deu4cOECoaGhRp/DouQK7pdffpHRDcQ3+BoxYoS+TWhoqFyvXj3Z2tpabtu2rXz06NEyX/eTTz6RmzVrJvv5+cmAnJSUVOZzPons7Gz5448/lmvWrKn/3V9++WX51q1bFolHEAShMktPT5fDw8Pl9PT0Mp8rR6OVT11PkHefvyOfup4g52i0JoiwaGq1Wh46dKhsb28ve3h4yKtWrZLbt28vz5gxQ5ZlWa5Xr568evVqg9dERETIgHz+/Hn9sR9++EFu3LixbGNjI3ft2lX+7LPPZEB++PDhE8U1adIk2dfXV7a2tpbr1Kkj9+7dWz516tQT/palU9zfZ1JSklH3b0mWS7mQQTWTWxFKSkoyGCVf3hISEpgzZw6bNm1Cq9Via2vLtGnT9J8IBEEQhJJlZGQQERGh342gMktNTcXb25uVK1cyatQoS4djEcX9fRp7/xZ30Eqidu3arF+/nvPnz9OtWzcyMjI4f/68SIIEQRCqifPnz7N9+3Zu3LjBuXPnGDZsGAD9+/e3cGSVW4UfI2Qp5l5H6Em1atWKn3/+me+++47WrVvrj8fFxXH37t1Ks4CVIAiCUHorVqzgypUr+slCx48fp3bt2pYOq1ITXWMlqChdYyUZNWoUn3/+OWPGjGHhwoXUqVPH0iEJgiBUOFWpa0wQXWPCIxqNhszMTGRZZvPmzTRp0oQ1a9aQnZ1t6dAEQRAEoUITiVAVoFQq+eqrrzh+/DgBAQEkJSXx/vvv06pVK/bt22fp8ARBEAShwhKJUBEssaBiWT399NOcOXOGLVu2UKdOHf7++2+Cg4PZvHmzpUMTBEEQhApJJEJFmDBhAuHh4Zw5U/KeMhWJUqlk9OjRXL16lZCQEDw9PXn55ZctHZYgCIIgVEgiEaqiatSowcqVK7l+/To1a9YEdMuev/rqq2zduhWtVmvhCAVBEATB8kQiVMXZ29vrf/7hhx/4z3/+wxt
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"for i in range(grids.shape[0]):\n",
" plt.plot(data_sizes, train_losses[:,i], marker=\"o\")\n",
"plt.xscale('log')\n",
"plt.yscale('log')\n",
"plt.plot(np.array([100,3000]), 1e8*np.array([100,3000])**(-4.), ls=\"--\", color=\"black\")\n",
"plt.legend([f'grid={grids[i]}' for i in range(grids.shape[0])]+[r'$N^{-4}$'])\n",
"plt.ylabel('train RMSE')\n",
"plt.xlabel('data size')"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "107801f6",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Text(0.5, 0, 'data size')"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAkIAAAG1CAYAAAAV2Js8AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAACnF0lEQVR4nOzdd3gUVdvH8e/sbnqDkF7pJXRC770oYHutiFIV1OcREQtil6KCgA1QVMCCYhcUEXik9957SUJ6QnrP7rx/LFmyEJIFNtmU++OVy+zM7My9Scj+cs6ZcxRVVVWEEEIIIWogja0LEEIIIYSwFQlCQgghhKixJAgJIYQQosaSICSEEEKIGkuCkBBCCCFqLAlCQgghhKixJAgJIYQQosaSICSEEEKIGktn6wIqO4PBQExMDG5ubiiKYutyhBBCCGEBVVXJyMggICAAjebG7T4ShMoQExNDcHCwrcsQQgghxC2IiooiKCjohvslCJXBzc0NMH4h3d3dbVyNEEIIISyRnp5OcHCw6X38RiQIlaGoO8zd3V2CkBBCCFHFlDWsRQZLCyGEEKLGkiAkhBBCiBpLgpAQQgghaiwJQkIIIYSosSQICSGEEKLGkiAkhBBCiBpLgpAQQgghaiwJQkIIIYSosWRCRRswGFRiz6SSlZ6Hi7sD/o1qodHIOmZCCCFERZMgVMHOHUhgy4ozZKXmmba51HKgx4ONaNDWx4aVCSGEEDWPdI1VoHMHEljz2VGzEASQlZrHms+Ocu5Ago0qE0IIIWomCUIVxGBQ2bLiTKnHbP3xDAaDWkEVCSGEEEKCUAWJPZN6XUvQtTJT8og9k1oxBQkhhBBCglBFyUovPQTd7HFCCCGEuH0ShCqIi7uDRced3BlHYmRGOVcjhBBCCJC7xiqMf6NauNRyKLN7LOrYZaKOXca/gQct+wRRv603Wq3kVSGEEKI81Ih32D///JMmTZrQqFEjvvjiC5vUoNEo9HiwUanHdL67Po06+KLRKMSeS2PtF8f45pXt7PnrAtnp+RVUqRBCCFFzKKqqVuvblAoLCwkLC2PDhg24u7vTrl07du3ahaenp0XPT09Px8PDg7S0NNzd3W+7npLmEXKt7UD3B67OI5SVlsexzdEc2xJjCkAanULDcB9a9Q7Gt97t1yGEEEJUZ5a+f1f7rrHdu3fTvHlzAgMDAbjjjjv4559/ePjhh21ST4O2PtRr7V3qzNIuHg50HFaf8CF1Obc/gcMbLhF/IZ3Tu+I5vSsen7rutOoTRMN2PmjtakSjnhBCCFEuKv276ObNmxk2bBgBAQEoisLvv/9+3TELFiygXr16ODo6Eh4ezpYtW0z7YmJiTCEIICgoiOjo6Ioo/YY0GoXAJrVp3MGPwCa1b7i8hlanoXFHP/7vpfbcP7U9TTr7odEpJFxMZ/2S4yx7ZRu7Vp4vc9yREEIIIUpW6YNQVlYWrVu35pNPPilx/4oVK5g0aRLTpk3jwIED9OjRgyFDhhAZGQlAST1/inLjdb3y8vJIT083+6gMfELd6T8qjMdndqPT8Pq41HIgJ6OAvasv8vUr2/nni6PEnk0t8fUKIYQQomSVvmtsyJAhDBky5Ib7586dy9ixYxk3bhwA8+fP559//mHhwoXMmjWLwMBAsxagS5cu0alTpxueb9asWbz11lvWewFW5uxuT/s76tJ2UAgXDiZxeEMUsWfTOLs3gbN7E/AKdqVVnyAatfdFZ6+1dblCCCFEpValBksrisJvv/3G3XffDUB+fj7Ozs789NNP3HPPPabjnn32WQ4ePMimTZsoLCykWbNmbNy40TRYeufOndSpU6fEa+Tl5ZGXd7WrKT09neDgYKsNli4PSZcyOLzhEqd3x6MvMADg6GJHWPcAWvQKxM3T0cYVCiGEEBWrRgyWTkpKQq/X4+vra7bd19eXuLg4AHQ6HR988AF9+vTBYDDw4osv3jAEATg4OODgYNnkh5WFV5AbfUc2o+s9DTm+LYajm6LJuJzL/n8iOLA2gnptvGnVO4iAxrVK7RYUQgghapoqHYSKXPvmrqqq2bbhw4czfPjwii6rwjm62tFuUChtBoRw8XAShzdcIvpUCucPJHL+QCKeAS607B1Ek05+2DlIt5kQQghRpYOQl5cXWq3W1PpTJCEh4bpWoppEo1Go38ab+m28SY7J5OjGaE7uiuNyTBablp9i5+/naNrVn5a9gvDwdrJ1uUIIIYTNVPq7xkpjb29PeHg469atM9u+bt06unbtaqOqKpc6Aa70eqQJo2Z1pfv9jXD3diIvu5BD66P49vUd/PXpISKPJ8vdZkIIIWqkSt8ilJmZydmzZ02PL1y4wMGDB/H09CQkJITJkyczcuRI2rdvT5cuXfj888+JjIxkwoQJNqy68nFwtqN1v2Ba9Qki4lgyRzZcIvL4ZS4eSebikWRq+TrTsncQTbv4Ye9Y6X8shBBCCKuo9HeNbdy4kT59+ly3/fHHH2fp0qWAcULF999/n9jYWFq0aMG8efPo2bOnVa5v7SU2KpOUuCyObormxI5YCnL1ANg5amnaxZ9WvYOo5ets4wqFEEKIW2Pp+3elD0K2Vp2DUJH83EJO7YzjyMZLpMRlm7aHhHnSsk8Qoc3roNxg9mshhBCiMpIgZCU1IQgVUVWVSydSOLzxEhePJMGVnwx3byda9gqkWVd/HJztbFukEEIIYQEJQlZSk4JQcWmJORzddIkT22PJyy4EQOegpUknP1r1DsIzwMXGFQohhBA3JkHISmpqECpSkKfn9O44Dm+4xOWYLNP2wCa1adUniLqtvG64aKwQQghhKxKErKSmB6EiqqoSczqVwxsvceFgIkU/NW6ejrToFUhYtwAcXaXbTAghROUgQchKJAhdL+NyLkc3RXN8awy5WQUAaO00NO7oS6s+QXgFudm4QiGEEDWdBCErkSB0Y4X5es7sjefwhkskRWWatvs39KBVn2DqtfFCq63Sc3YKIYSooiQIWYkEobKpqkrcuTQOb7zE+f2JGAzGHymXWg606BlIWPcAnN3tbVylEEKImkSCkJVIELo5Wal5HN0czbEt0eRkGLvNNDqFRu2N3WY+ofI1FEIIUf4kCFmJBKFboy8wcHZ/Aoc3XCLhYrppu289d1r1CaJBOx+0Ouk2E0IIUT4kCFmJBKHbF38hncMbozi7NwGD3vjj5uxuT/MeATTvGYiLh4ONKxRCCFHdSBCyEglC1pOdns+xLdEc2xxNVlo+ABqNQoNwH1r1CcK3njuKInMSCSGEuH0ShKxEgpD16fUGzh9I5MiGS8SeSzNt9w5xo2XvIBp18EFnp7VhhUIIIao6CUJWIkGofCVGZnB44yXO7I5HX2gAwNHVjrDuAbToGYibp6ONKxRCCFEVSRCyEglCFSMnM5/jW2M4uimazJQ8ABSNQv3WXrTsE0RAo1rSbSaEEMJiEoSsRIJQxTLoDVw4nMSRjZeIPpVq2l4n0JWWvQNp3MkPO3vpNhNCCFE6CUJWIkHIdpKjMzmy8RKndsVRmG/sNnNw1tGsWwAtewXi7uVk4wqFEEJUVhKErESCkO3lZhVwckcsRzZeIj0p17hRgbotvWjVJ4igprWl20wIIYQZCUJWIkGo8jAYVCKPJnN44yWijl82ba/t50zL3kE06eyHvaPOhhUKIYSoLCQIWYkEocopJS6LIxujObkjloI8PQD2jlqadvWnZa8gavk627hCIYQQtiRByEokCFVu+TmFnNwZy5GN0aTGZ5u2hzSvQ6s+QYSEeaJopNtMCCFqGglCViJBqGpQDSpRJy5zeOMlIo4mw5Wfag8fJ1r2CqJpV38cnKTbTAghagoJQlYiQajqSU3I5uimaE5sjyU/pxAAnYOWpp39aNk7CE9/FxtXKIQQorxJELISCUJVV35uIad3x3N4wyVSYrNM24Oa1qZVnyBCW3qhkW4zIYSoliQIWYkEoapPVVWiT6VweMMlLh5Oougn3q2OIy17BdGsmz+OLna2LVIIIYRVSRCyEglC1Ut6Ug5HN0dzfFsMeVlXus3sNDTu5EerPkHUCXS1cYVCCCGsQYKQlUgQqp4K8/W
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"for i in range(grids.shape[0]):\n",
" plt.plot(data_sizes, test_losses[:,i], marker=\"o\")\n",
"plt.xscale('log')\n",
"plt.yscale('log')\n",
"plt.plot(np.array([100,3000]), 1e5*np.array([100,3000])**(-4.), ls=\"--\", color=\"black\")\n",
"plt.legend([f'grid={grids[i]}' for i in range(grids.shape[0])]+[r'$N^{-4}$'])\n",
"plt.ylabel('test RMSE')\n",
"plt.xlabel('data size')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}