{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "import numpy as np\n", "import torch\n", "torch.set_printoptions(edgeitems=2)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "t_c = [0.5, 14.0, 15.0, 28.0, 11.0, 8.0, 3.0, -4.0, 6.0, 13.0, 21.0]\n", "t_u = [35.7, 55.9, 58.2, 81.9, 56.3, 48.9, 33.9, 21.8, 48.4, 60.4, 68.4]\n", "t_c = torch.tensor(t_c)\n", "t_u = torch.tensor(t_u)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "def model(t_u, w, b):\n", " return w * t_u + b" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "def loss_fn(t_p, t_c):\n", " squared_diffs = (t_p - t_c)**2\n", " return squared_diffs.mean()" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([35.7000, 55.9000, 58.2000, 81.9000, 56.3000, 48.9000, 33.9000, 21.8000,\n", " 48.4000, 60.4000, 68.4000])" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "w = torch.ones(1)\n", "b = torch.zeros(1)\n", "\n", "t_p = model(t_u, w, b)\n", "t_p" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(1763.8846)" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "loss = loss_fn(t_p, t_c)\n", "loss" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "delta = 0.1\n", "\n", "loss_rate_of_change_w = \\\n", " (loss_fn(model(t_u, w + delta, b), t_c) - \n", " loss_fn(model(t_u, w - delta, b), t_c)) / (2.0 * delta)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "learning_rate = 1e-2\n", "\n", "w = w - learning_rate * loss_rate_of_change_w" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "loss_rate_of_change_b = \\\n", " (loss_fn(model(t_u, w, b + delta), t_c) - \n", " loss_fn(model(t_u, w, b - delta), t_c)) / (2.0 * delta)\n", "\n", "b = b - learning_rate * loss_rate_of_change_b" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "def dloss_fn(t_p, t_c):\n", " dsq_diffs = 2 * (t_p - t_c)\n", " return dsq_diffs" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "def dmodel_dw(t_u, w, b):\n", " return t_u" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "def dmodel_db(t_u, w, b):\n", " return 1.0" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "def grad_fn(t_u, t_c, t_p, w, b):\n", " dloss_dw = dloss_fn(t_p, t_c) * dmodel_dw(t_u, w, b)\n", " dloss_db = dloss_fn(t_p, t_c) * dmodel_db(t_u, w, b)\n", " return torch.stack([dloss_dw.mean(), dloss_db.mean()])" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "def training_loop(n_epochs, learning_rate, params, t_u, t_c):\n", " for epoch in range(1, n_epochs + 1):\n", " w, b = params\n", "\n", " t_p = model(t_u, w, b) # <1>\n", " loss = loss_fn(t_p, t_c)\n", " grad = grad_fn(t_u, t_c, t_p, w, b) # <2>\n", "\n", " params = params - learning_rate * grad\n", "\n", " print('Epoch %d, Loss %f' % (epoch, float(loss))) # <3>\n", " \n", " return params" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "def training_loop(n_epochs, learning_rate, params, t_u, t_c, print_params=True):\n", " for epoch in range(1, n_epochs + 1):\n", " w, b = params\n", "\n", " t_p = model(t_u, w, b) # <1>\n", " loss = loss_fn(t_p, t_c)\n", " grad = grad_fn(t_u, t_c, t_p, w, b) # <2>\n", "\n", " params = params - learning_rate * grad\n", "\n", " if epoch in {1, 2, 3, 10, 11, 99, 100, 4000, 5000}: # <3>\n", " print('Epoch %d, Loss %f' % (epoch, float(loss)))\n", " if print_params:\n", " print(' Params:', params)\n", " print(' Grad: ', grad)\n", " if epoch in {4, 12, 101}:\n", " print('...')\n", "\n", " if not torch.isfinite(loss).all():\n", " break # <3>\n", " \n", " return params" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1, Loss 1763.884644\n", " Params: tensor([-44.1730, -0.8260])\n", " Grad: tensor([4517.2964, 82.6000])\n", "Epoch 2, Loss 5802484.500000\n", " Params: tensor([2568.4011, 45.1637])\n", " Grad: tensor([-261257.4062, -4598.9707])\n", "Epoch 3, Loss 19408031744.000000\n", " Params: tensor([-148527.7344, -2616.3933])\n", " Grad: tensor([15109615.0000, 266155.7188])\n", "...\n", "Epoch 10, Loss 90901075478458130961171361977860096.000000\n", " Params: tensor([3.2144e+17, 5.6621e+15])\n", " Grad: tensor([-3.2700e+19, -5.7600e+17])\n", "Epoch 11, Loss inf\n", " Params: tensor([-1.8590e+19, -3.2746e+17])\n", " Grad: tensor([1.8912e+21, 3.3313e+19])\n" ] }, { "data": { "text/plain": [ "tensor([-1.8590e+19, -3.2746e+17])" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "training_loop(\n", " n_epochs = 100, \n", " learning_rate = 1e-2, \n", " params = torch.tensor([1.0, 0.0]), \n", " t_u = t_u, \n", " t_c = t_c)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1, Loss 1763.884644\n", " Params: tensor([ 0.5483, -0.0083])\n", " Grad: tensor([4517.2964, 82.6000])\n", "Epoch 2, Loss 323.090546\n", " Params: tensor([ 0.3623, -0.0118])\n", " Grad: tensor([1859.5493, 35.7843])\n", "Epoch 3, Loss 78.929634\n", " Params: tensor([ 0.2858, -0.0135])\n", " Grad: tensor([765.4666, 16.5122])\n", "...\n", "Epoch 10, Loss 29.105242\n", " Params: tensor([ 0.2324, -0.0166])\n", " Grad: tensor([1.4803, 3.0544])\n", "Epoch 11, Loss 29.104168\n", " Params: tensor([ 0.2323, -0.0169])\n", " Grad: tensor([0.5780, 3.0384])\n", "...\n", "Epoch 99, Loss 29.023582\n", " Params: tensor([ 0.2327, -0.0435])\n", " Grad: tensor([-0.0533, 3.0226])\n", "Epoch 100, Loss 29.022669\n", " Params: tensor([ 0.2327, -0.0438])\n", " Grad: tensor([-0.0532, 3.0226])\n" ] }, { "data": { "text/plain": [ "tensor([ 0.2327, -0.0438])" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "training_loop(\n", " n_epochs = 100, \n", " learning_rate = 1e-4, \n", " params = torch.tensor([1.0, 0.0]), \n", " t_u = t_u, \n", " t_c = t_c)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "t_un = 0.1 * t_u" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1, Loss 80.364342\n", " Params: tensor([1.7761, 0.1064])\n", " Grad: tensor([-77.6140, -10.6400])\n", "Epoch 2, Loss 37.574917\n", " Params: tensor([2.0848, 0.1303])\n", " Grad: tensor([-30.8623, -2.3864])\n", "Epoch 3, Loss 30.871077\n", " Params: tensor([2.2094, 0.1217])\n", " Grad: tensor([-12.4631, 0.8587])\n", "...\n", "Epoch 10, Loss 29.030487\n", " Params: tensor([ 2.3232, -0.0710])\n", " Grad: tensor([-0.5355, 2.9295])\n", "Epoch 11, Loss 28.941875\n", " Params: tensor([ 2.3284, -0.1003])\n", " Grad: tensor([-0.5240, 2.9264])\n", "...\n", "Epoch 99, Loss 22.214186\n", " Params: tensor([ 2.7508, -2.4910])\n", " Grad: tensor([-0.4453, 2.5208])\n", "Epoch 100, Loss 22.148710\n", " Params: tensor([ 2.7553, -2.5162])\n", " Grad: tensor([-0.4445, 2.5165])\n" ] }, { "data": { "text/plain": [ "tensor([ 2.7553, -2.5162])" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "training_loop(\n", " n_epochs = 100, \n", " learning_rate = 1e-2, \n", " params = torch.tensor([1.0, 0.0]), \n", " t_u = t_un, # <1>\n", " t_c = t_c)" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1, Loss 80.364342\n", "Epoch 2, Loss 37.574917\n", "Epoch 3, Loss 30.871077\n", "...\n", "Epoch 10, Loss 29.030487\n", "Epoch 11, Loss 28.941875\n", "...\n", "Epoch 99, Loss 22.214186\n", "Epoch 100, Loss 22.148710\n", "...\n", "Epoch 4000, Loss 2.927680\n", "Epoch 5000, Loss 2.927648\n" ] }, { "data": { "text/plain": [ "tensor([ 5.3671, -17.3012])" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "params = training_loop(\n", " n_epochs = 5000, \n", " learning_rate = 1e-2, \n", " params = torch.tensor([1.0, 0.0]), \n", " t_u = t_un, \n", " t_c = t_c,\n", " print_params = False)\n", "\n", "params" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXYAAAD8CAYAAABjAo9vAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAHyFJREFUeJzt3Xl8FdX9//HXh30HEUQU0oiCogiCEUUUwbCJUqutVGvVqhVbpW7VqqCIiMjXDevWNu62bqnKT1FUFkEEKhIQA7IoIrKILIphl5Cc3x/3MhAIJPfeuZl7J+/n4+Ej+ZzMnfnc0rwZzpk7Y845REQkPKoE3YCIiPhLwS4iEjIKdhGRkFGwi4iEjIJdRCRkFOwiIiGjYBcRCRkFu4hIyCjYRURCploQB23SpInLzMwM4tAiImlr9uzZ651zTcvaLpBgz8zMJC8vL4hDi4ikLTP7tjzbaSpGRCRkFOwiIiGjYBcRCRkFu4hIyCjYRURCRsEuIlIR8nNhdDsY1ijyNT83aYcK5HJHEZFKJT8Xxl4HhdsidcGKSA3QfoDvh9MZu4hIsk0avjvUdyncFhlPAgW7iEiyFayMbTxBCnYRkWRr2CK28QQp2EVEki17KFSvXXKseu3IeBIo2EVEkq39AOj/KDRsCVjka/9Hk7JwCroqRkSkYrQfkLQg35vO2EVEQkbBLiISMgp2EZEKMOPr9dz6ej7bC4uSfizNsYuIJNG2HUV0GTWJn7YWAnBDr9Y0b1i7jFclRsEuIpIkT3+8lBHvLvTqN685NemhDgp2ERHfrfhxK6ffP9mrB2S14P7fdKiw4yvYRUR84pzjqhdnM3HhGm/s0yHZHFK/VoX2oWAXEfHBx1+t45JnPvXq//v18fz2pIxAelGwi4gkYOuOnWSNmMjWHZGrXVo1rcv713ejRrXgLjosd7CbWUvgReBQoBjIcc793cyGAVcB66KbDnbOjfO7URGRVPPklCXc//5ir37r2q50aNkowI4iYjlj3wn81Tk3x8zqA7PNbEL0Z6Odcw/6356ISOpZtn4L3R+c4tUXn5zBvecdH1xDeyl3sDvnVgOro99vMrOFwOHJakxEJNU457jsuVlM/XKdN5Z3R0+a1KsZYFf7imuO3cwygY7ATKArMMjMLgXyiJzVb/CrQRGRVDB58Vouf26WVz88oAPnd0rO/dQTFXOwm1k94A3gBufcRjP7B3AP4KJfHwKuKOV1A4GBABkZwawUi4jEavPPO+k4fDyFRQ6AYw6tz9i/nEb1qql7R5aYgt3MqhMJ9Zecc28COOfW7PHzp4B3Snutcy4HyAHIyspy8TYsIlJR/j7xK0ZP/NKr3/nLabQ7vGGAHZVPLFfFGPAMsNA59/Ae482j8+8A5wHz/W1RRKRifb1uM9kPfeTVl3fN5K7+xwXYUWxiOWPvClwCzDOzudGxwcBFZnYCkamYZcDVvnYoIlJBiosdFz89k/8t/cEbm3NnLxrXrRFgV7GL5aqYaYCV8iNdsy4iaW/CgjVc9WKeVz96UUd+2eGwADuKnz55KiKV2sbthbQfNt6r27doyJt/PpVqKbw4WhYFu4hUWg+NX8xjHy7x6veuP522zRsE2JE/FOwiUul8tWYTvUZP9eqru7Xi9n5tA+zIXwp2Eak0ioodA/71P2Z/u/szlHOH9qJRnfRaHC2Lgl1EKoX35q3mzy/N8eonL+5Ev+ObB9hR8ijYRSTUCrYW0mH47sXRE39xELlXd6FqldIu8gsHBbuIhNZ97y3kXx8t9erxN3ajTbP6AXZUMRTsIhI6i77fSN9HPvbqQT2O4uY+RwfYUcVSsItIaBQVO857cjr5Kwu8sc/v6k3D2tUD7KriKdhFJBTGfv4df3nlM6/OueREeh93aIAdBUfBLiJpbcOWHXS8Z4JXn9KqMS//8RSqhHhxtCwKdhFJW8PHLuDZ6d949cSbzuCoQ+oF2FFqULCLSNqZv6qAcx6b5tU39GzNDT3bBNhRalGwi0ja2FlUzDmPTWPR95sAqFbF+GxoL+rXqlyLo2VRsItIWhjz2UpufO1zr372D1mceUyzADtKXQp2EfFffi5MGg4FK6FhC8geCu0HxLWrHzb/zIkjJnp1tzZNeeHyk4g81E1Ko2AXEX/l58LY66BwW6QuWBGpIeZwv/P/zeffn3zr1VNu7k5mk7p+dRpaCnYR8dek4btDfZfCbZHxcgZ7/sqf+OXj0736lj5Hc22Po/zsMtQU7CLir4KVsY3vobComL6PTOXrdVsAqFOjKrOG9KRuTUVVLPS/loj4q2GLyPRLaeMHkJu3gr+9nu/VL17RmW5tmvrdXaWgYBcRf2UPLTnHDlC9dmS8FGs3bafzvZO8umfbQ3jq0iwtjiag3MFuZi2BF4FDgWIgxzn3dzNrDLwGZALLgAHOuQ3724+IhNyuefRyXBVz2xv5vDpr99n9x3/rQcvGdSqq09Ay51z5NjRrDjR3zs0xs/rAbOBXwB+AH51zo8zsNuAg59ytB9pXVlaWy8vLS6xzEUlbc5Zv4PwnZ3j14H7HMLDbkQF2lB7MbLZzLqus7cp9xu6cWw2sjn6/ycwWAocD5wLdo5u9AEwBDhjsIlI57dhZTPbDU1jxY2SapmHt6nxyeza1a1QNuLNwiWuO3cwygY7ATKBZNPRxzq02s0P285qBwECAjIyMeA4rImns5ZnLGTxm3u76jydz6lFNAuwovGIOdjOrB7wB3OCc21jeBQ7nXA6QA5GpmFiPKyLpac3G7Zw8cvfi6FntDuXJiztpcTSJYgp2M6tOJNRfcs69GR1eY2bNo2frzYG1fjcpIunHOccVz89i8uJ13ti0W3vQ4iAtjiZbLFfFGPAMsNA59/AeP3obuAwYFf36lq8dikjaeWvuKq5/da5X39X/WC7vekSAHVUusZyxdwUuAeaZ2a4/scFEAj3XzK4ElgMX+NuiiKSLrTt2cuzQD0qMzb+7D/X0ydEKFctVMdOA/U2KZfvTjoikq5ty5/LmnFVe/fCADpzf6cCfNpXk0F+jIpKQL9dsovfoqV5ds1oVFt3TV4ujAVKwi0hcnHMcOXgcxXtc4zbhxm60blY/uKYEULCLSBz+m7eCW/a4YdeArBbc/5sOAXYke1Kwi0i5bf55J+3uKrk4umB4H+rUUJSkEv1piEi5XPvyHN7NX+3Vj13Ukf4dDguwI9kfBbuIHNCC7zbS79GPvbpBrWrkD+sTYEdSFgW7iJTKOccRt48rMfbhX8+gVdN6AXUk5aVgF5F97H3Drt+fksGIXx0fYEcSCwW7iHgKthXS4e7xJcYW3dOXWtV1W910omAXEQCuejGPCQvWePU/f9+Jvu2aB9iRxEvBLlLJ5a/8iV8+Pt2rD6lfk0+H9AywI0mUgl2kkiptcXTqLT3IOFi31U13CnaRSuiFGcu46+0vvPqKrkcwtP+xAXYkflKwi1QiP23dwQnDJ5QYWzyiLzWraXE0TBTsIpXEpc9+ytQvdz/N6JnLsshu2yzAjiRZFOwiITdn+QbOf3KGV2c0rsPUv/UIsCNJNgW7SEgVFztaDS65ODr9tjM5vFHtgDqSiqJgFwmhpz9eyoh3F3r1n7sfya19jwmwI6lICnaREPlh88+cOGJiibEvR5xFjWpVAupIgqBgFwmJ3/7rf8z85kevfuGKzpzRpmmAHUlQyh3sZvYscA6w1jnXLjo2DLgK2LXUPtg5N670PYhIMsxa9iMX/PN/Xt2mWT3G33hGgB1J0GI5Y38eeBx4ca/x0c65B33rSKSyys+FScOhYCU0bAHZQ6H9gP1uXlQceebonj65PZtDG9ZKdqeS4sod7M65qWaWmbxWRCqx/FwYex0UbovUBSsiNZQa7k9MXsIDHyz26uuyW3NTrzYV0amkAT/m2AeZ2aVAHvBX59wGH/YpUrlMGr471Hcp3BYZ3yPY127aTud7J5XYbMm9Z1GtqhZHZbdEg/0fwD2Ai359CLiitA3NbCAwECAjIyPBw4qETMHKMsfPfXwan68s8OqX/3gypx7VJNmdSRpKKNidc97Nm83sKeCdA2ybA+QAZGVluUSOKxI6DVtEpl9KGZ/x9Xp+99RMb6hDi4a8Nei0CmxO0k1C/34zsz3vwn8eMD+xdkQqqeyhUL3kJ0Jd9dpct65/iVD/dEi2Ql3KFMvljq8A3YEmZrYSuAvobmYnEJmKWQZcnYQeRdJTLFe57BqPbr+xZjPu2HQ+bxdHQvyWPkdzbY+jKqhxSXfmXMXPimRlZbm8vLwKP65Ihdn7KheInJH3f/SAlzB+X7CdU+4ruTj69ch+VK1iyepU0oiZzXbOZZW1nT55KpIM5bzKZU99H5nKou83eXXu1V3ofETjZHYpIaVgF0mGclzlssvUL9dx6bOfenXnzMbk/qlLsjqTSkDBLpIMB7jKZZfComJaD3mvxI9n39GTg+vVTHZ3EnL6VINIMpRylQvVa0fGgfvfX1Qi1If0a8uyUWeXP9Tzc2F0OxjWKPI1P9evziUEdMYukgx7XeWy66qYVRn96XrbuyU2XTqyH1ViWRyN8fYDUvnoqhiRCtL9gcks+2GrV795zal0yjgo9h2NbrefaZ6WcKM+ShJmuipGJEV8uGgNVzy/+0Tm9NZN+PeVJ8e/wxgWZqVyUrCLJMnPO4s4+o73S4x9dmcvDqpbI7Edl2NhVio3LZ6KJMGIdxaUCPVh/Y9l2aizEw91KHNhVkRn7CI+Wv7DVro9MLnEWMyLo2XZz8KsFk5lFwW7iE9OGTmJ7zdu9+q3B3WlfYtGyTlY+wEKctkvBbtIgt6f/z1/+s9sr+7ZthlPX1bmhQsiSaNgF4nT9sIijrmz5OLo53f1pmHt6gF1JBKhYBeJw9C35vPi/7716pHnHc/vTtaTwSQ1KNhFYvDN+i30eHBKybH7+mGm2+pK6lCwi5TTCcPH89PWQq9+97rTOO6whmW/MJYHboj4QMEuUoZ38r9j0MufefXZxzfniYs7le/Fuq+LBEDBLrIf23YU0XZoycXRecN6U79WDIujcTxwQyRRCnaRUtz6ej6v5e3+2P4Dv2nPBVktY9+R7usiAVCwi+xhydpN9Hx4qldXrWIsufes+BdHdV8XCUC5g93MngXOAdY659pFxxoDrwGZwDJggHNug/9tiiSXc462Q99ne2GxN/bBDd04+tD6ie04e2jpD7XWfV0kiWK5CdjzQN+9xm4DJjnnWgOTorVIWhnz2UqOuH2cF+rndzycZaPOTjzUITKP3v/RyL3SscjX/o9qfl2Sqtxn7M65qWaWudfwuUD36PcvAFOAW33oSyTptvy8k+Pu+qDE2Bd396FuTZ9nKHVfF6lgif4/uJlzbjWAc261mR3iQ08iSXfja3MZ89kqr37ktyfwq46HB9iRiH8qbPHUzAYCAwEyMvTRawnGou830veRj726bo2qzL+7jz45KqGSaLCvMbPm0bP15sDa/W3onMsBciDyzNMEjysSE+ccR9w+rsTYxJvO4KhD6gXUkUjyJPoEpbeBy6LfXwa8leD+RHyXm7eiRKhfeFJLlo06W6EuoRXL5Y6vEFkobWJmK4G7gFFArpldCSwHLkhGkyLx2LS9kOOHjS8xtnB4X2rXqBpQRyIVI5arYi7az4+yfepFxDfXvDSbcfO+9+onfteJs9s3D7AjkYqjT55KqMxfVcA5j03z6sZ1azDnzl4BdiRS8RTsEgqlLY5Oubk7mU3qBtSRSHAU7JL2/vPJt9zx/+Z79WVdfsHd57YLsCORYCnYJW0VbC2kw/CSi6OL7ulLrepaHJXKTcEuaenK52cxadHuj03kXHIivYumwuMd9KQiqfQU7JJWPl/xE+c+Md2rD2tYixm3Z+tJRSJ7ULBLWigudrQaXHJx9OO/9aBl4zqRQk8qEvEo2CXlPTf9G+4eu8Crrzr9CIacfWzJjfSkIhGPgl1S1oYtO+h4z4QSY4tH9KVmtVIWR/WkIhGPgl1S0u+fnsm0Jeu9+tk/ZHHmMc32/wI9qUjEo2CXlDL72w38+h8zvLpVk7p8eHP3sl+4ax590nBdFSOVnoJdUkJRsePIvRZHZ9x2Joc1ql3+nehJRSKAgl1SQM7Urxk5bpFXX9vjSG7pc0yAHYmkNwW7BGb95p/JGjGxxNhX955F9aqJPiZApHJTsEsgLvjnDGYt2+DV/76yM6e3bhpgRyLhoWCXCjVz6Q/8NucTr27bvAHvXX96gB2JhI+CXSpEaYujMwdn06xBrcR3np+rq2FE9qBgl6R7/MOveHD8l159Y882XN+ztT871z1iRPahYJekWbtxO51HTioxtuTes6jm5+Ko7hEjsg8FuyRF/8emMW9VgVe/ctUpdDnyYP8PpHvEiOxDwS6+mr5kPRc/PdOrO2Y0Ysw1XZN3QN0jRmQfvgS7mS0DNgFFwE7nXJYf+5X0UVhUTOsh75UYmzWkJ03r10zugXWPGJF9+HnG3sM5t77szSRsHh6/mEc/XOLVt/Y9hj93P7JiDq57xIjsQ1MxErfVBdvoct+HJca+HtmPqlWsYhvRPWJESvAr2B0w3swc8C/nXI5P+5UU1fPhj1iydrNX//dPXTgps3GAHYnILn4Fe1fn3HdmdggwwcwWOeem7rmBmQ0EBgJkZGT4dFipaJMXr+Xy52Z59SmtGvPqwC4BdiQie/Ml2J1z30W/rjWzMUBnYOpe2+QAOQBZWVnOj+NKxdmxs5g2d5RcHJ19R08OrpfkxVERiVnCwW5mdYEqzrlN0e97A8MT7kxSxn3vLeRfHy316jvObssfT28VYEciciB+nLE3A8aY2a79veyce9+H/UrAVm7Yymn/N7nE2NKR/ahS0YujIhKThIPdObcU6OBDL5JCTr//Q1b8uPva8DHXnErHjIMC7EhEykuXO0oJExas4aoX87z6jDZNeeGKzgF2JCKxUrALAD/vLOLoO0rOoM0d2otGdWoE1JGIxEvBLtw99guem77Mq+859zgu6ZIZWD8ikhgFeypL8gMkvv1hC2c8MKXE2Df39SO6EC4iaUrBnqqS/ACJk+6dyLpNP3v12EGncXyLhgnvV0SCp8fBp6oDPUAiAe/NW03mbe96od772GYsG3W2Ql0kRHTGnqp8foDE9sIijrmz5OJo/rDeNKhVPa79iUjqUrCnKh8fIDF4zDxenrncq0edfzwXdtb9ekTCSsGeqnx4gMTX6zaT/dBHJca0OCoSfgr2VJXgAySOH/YBm7bv9Or3rj+dts0bJKNTEUkxCvZUFscDJN6au4rrX53r1f07HMZjF3X0uzMRSWEK9pDYumMnxw79oMTY/Lv7UK+m/ohFKhv91ofAsLe/4PkZy7z6oQs68OsTY19kFZFwULCnse9+2sapo3Y/c7RG1SosHtFXi6MilZyCPQ0557ju1bmM/fw7b2z6bWdyeKPaAXYlIqlCwZ5mPln6AxfmfOLVw889jkt1wy4R2YOCPU1sLyyi66gP+WHLDgCaN6zF5Ju7U6t61YA7E5FUo2BPA89O+4bh7yzw6v/+qQsnZTYOsCMRSWUK9hS29zNHf92pBQ8NKOUphEm+va+IpBcFewpyzvGn/8zmgy/WeGMzB2fTrEGtfTdO8u19RST9+HLbXjPra2aLzWyJmd3mxz4rq+lL1nPE7eO8UB953vEsG3V26aEOSbu9r4ikr4TP2M2sKvAE0AtYCcwys7edcwsO/ErZ07YdRXQeOdG7v0tG4zpMvOkMalQr4+9en2/vKyLpz4+pmM7AEufcUgAzexU4F1Cwl1PO1K8ZOW6RV795zal0yjiofC/28fa+IhIOfgT74cCeybISONmH/Ybe8h+20u2B3YujF57UklG/bh/bTny4va+IhIsfwV7a59fdPhuZDQQGAmRkVO6HPDjnuPKFPD5ctNYb+3RINofU3888+oEkeHtfEQkfP4J9JdByj7oF8N3eGznncoAcgKysrH2Cv7L46Mt1XPbsp159/2/aMyCr5QFeUQ5x3N5XRMLLj2CfBbQ2syOAVcCFwO982G+obPl5JyeOmMD2wmIAjmxal/dv6Eb1qnqeuIj4K+Fgd87tNLNBwAdAVeBZ59wXCXcWIk9MXsIDHyz26rcHdaV9i0YBdiQiYebLB5Scc+OAcX7sK0y+Wb+FHg9O8epLTvkF9/yqXXANiUiloE+eJkFxseOy5z7l46/We2Oz7+jJwfVqBtiViFQWCnafTVq4hitfyPPq0b/twHkddU25iFQcBbtPNm0v5IThEygqjlzw07Z5A8YO6ko1LY6KSAVTsPtg9IQv+fukr7z63etO47jDGgbYkYhUZgr2BCxZu5meD3/k1Vd0PYKh/Y8NsCMREQV7XIqLHRc99Qkzv/nRG/vszl4cVLdGgF2JiEQo2GP0wRffc/W/Z3v1Yxd1pH+HwwLsSESkJAV7ORVsK6TD3eO9ukOLhrx5TVeqVintVjkiIsFRsJfDAx8s4onJX3v1+zeczjGHNgiwIxGR/VOwH8CXazbRe/RUr776jFbcflbbADsSESmbgr0URcWOC/45gznLf/LGPh/am4Z1qgfYlYhI+SjY9zJu3mqueWmOV//z953o2655gB2JiMRGwR7109YdnDB8gldn/eIgXru6ixZHRSTtKNiBkeMWkjN1qVdPuLEbrZvVD7AjEZH4VepgX/DdRvo9+rFXD+pxFDf3OTrAjkREEpc+wZ6f69tzPXcWFXPekzOYt6pg9+6H9aZBLS2Oikj6S49gz8+FsddB4bZIXbAiUkPM4f7W3FVc/+pcr37q0ix6HdvMr05FRAKXHsE+afjuUN+lcFtkvJzB/uOWHXS6Z/fiaJdWB/PSH0+mihZHRSRk0iPYC1bGNr6XYW9/wfMzlnn1pL+ewZFN6/nQmIhI6kmPYG/YIjL9Utr4AcxfVcA5j03z6ht7tuH6nq397k5EJKWkR7BnDy05xw5QvXZkvBQ7i4o557FpLPp+EwA1qlZhztBe1KuZHm9XRCQRCSWdmQ0DrgLWRYcGO+fGJdrUPnbNo5fjqpg356zkptzPvfq5P5xEj2MO8b0lEZFU5ccp7Gjn3IM+7OfA2g844ELp+s0/kzViolef0aYpz19+EmZaHBWRyiUUcxNDxszjpZnLvXrKzd3JbFI3wI5ERILjR7APMrNLgTzgr865DaVtZGYDgYEAGRkZPhwWPl/xE+c+Md2rb+lzNNf2OMqXfYuIpCtzzh14A7OJwKGl/GgI8AmwHnDAPUBz59wVZR00KyvL5eXlxd5t1I6dxfR9ZCpL128BoG6Nqnw6pCd1tTgqIiFmZrOdc1llbVdmEjrnepbzgE8B75Rn20TkzlrB397I9+oXr+hMtzZNk31YEZG0kehVMc2dc6uj5XnA/MRb2r/cvN2h3rNtM5669EQtjoqI7CXRuYv7zewEIlMxy4CrE+7oANo0q88JLRvx2EUdadm4TjIPJSKStsqcY0+GROfYRUQqo/LOsVepiGZERKTiKNhFREJGwS4iEjIKdhGRkFGwi4iEjIJdRCRkFOwiIiGjYBcRCZlAPqBkZuuAb0v5URMiNxULg7C8l7C8D9B7SUVheR9QMe/lF865Mm+OFUiw74+Z5ZXnU1XpICzvJSzvA/ReUlFY3gek1nvRVIyISMgo2EVEQibVgj0n6AZ8FJb3Epb3AXovqSgs7wNS6L2k1By7iIgkLtXO2EVEJEEpEexm1tLMJpvZQjP7wsyuD7qneJhZLTP71Mw+j76Pu4PuKVFmVtXMPjOzpD/2MJnMbJmZzTOzuWaWtg8DMLNGZva6mS2K/r50CbqneJjZ0dE/i13/bTSzG4LuKx5mdmP0932+mb1iZrUC7ykVpmLMrDmRB2HPMbP6wGzgV865BQG3FhOLPKevrnNus5lVB6YB1zvnPgm4tbiZ2U1AFtDAOXdO0P3Ey8yWAVnOubS+ZtrMXgA+ds49bWY1gDrOuZ+C7isRZlYVWAWc7Jwr7fMtKcvMDifye36sc26bmeUC45xzzwfZV0qcsTvnVjvn5kS/3wQsBA4PtqvYuYjN0bJ69L/g/+aMk5m1AM4Gng66FwEzawB0A54BcM7tSPdQj8oGvk63UN9DNaC2mVUD6gDfBdxPagT7nswsE+gIzAy2k/hEpy7mAmuBCc65tHwfUY8AfwOKg27EBw4Yb2azzWxg0M3EqRWwDnguOj32tJnVDbopH1wIvBJ0E/Fwzq0CHgSWA6uBAufc+GC7SrFgN7N6wBvADc65jUH3Ew/nXJFz7gSgBdDZzNoF3VM8zOwcYK1zbnbQvfikq3OuE3AWcK2ZdQu6oThUAzoB/3DOdQS2ALcF21JiotNJvwT+G3Qv8TCzg4BzgSOAw4C6Zvb7YLtKoWCPzkm/AbzknHsz6H4SFf0n8hSgb8CtxKsr8Mvo3PSrwJlm9p9gW4qfc+676Ne1wBigc7AdxWUlsHKPfwW+TiTo09lZwBzn3JqgG4lTT+Ab59w651wh8CZwasA9pUawRxcdnwEWOuceDrqfeJlZUzNrFP2+NpE/9EXBdhUf59ztzrkWzrlMIv9U/tA5F/iZSDzMrG50UZ7o1EVvYH6wXcXOOfc9sMLMjo4OZQNpdYFBKS4iTadhopYDp5hZnWiOZRNZIwxUtaAbiOoKXALMi85PAwx2zo0LsKd4NAdeiK7yVwFynXNpfZlgSDQDxkR+76gGvOycez/YluL2F+Cl6BTGUuDygPuJm5nVAXoBVwfdS7ycczPN7HVgDrAT+IwU+ARqSlzuKCIi/kmJqRgREfGPgl1EJGQU7CIiIaNgFxEJGQW7iEjIKNhFREJGwS4iEjIKdhGRkPn/erOFYOwD2HsAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "%matplotlib inline\n", "from matplotlib import pyplot as plt\n", "\n", "t_p = model(t_un, *params) # <1>\n", "\n", "plt.plot(0.1 * t_u.numpy(), t_p.detach().numpy())\n", "plt.plot(0.1 * t_u.numpy(), t_c.numpy(), 'o')" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.6" } }, "nbformat": 4, "nbformat_minor": 2 }