{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[33mWARNING: Ignoring invalid distribution -y-mini-racer (/Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages)\u001b[0m\u001b[33m\n", "\u001b[0mRequirement already satisfied: scikit-learn in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (1.3.2)\n", "Requirement already satisfied: matplotlib in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (3.7.3)\n", "Requirement already satisfied: torch in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (2.0.1)\n", "Requirement already satisfied: numpy<2.0,>=1.17.3 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from scikit-learn) (1.24.4)\n", "Requirement already satisfied: scipy>=1.5.0 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from scikit-learn) (1.5.2)\n", "Requirement already satisfied: joblib>=1.1.1 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from scikit-learn) (1.3.2)\n", "Requirement already satisfied: threadpoolctl>=2.0.0 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from scikit-learn) (2.1.0)\n", "Requirement already satisfied: contourpy>=1.0.1 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from matplotlib) (1.1.1)\n", "Requirement already satisfied: cycler>=0.10 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from matplotlib) (0.10.0)\n", "Requirement already satisfied: fonttools>=4.22.0 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from matplotlib) (4.43.1)\n", "Requirement already satisfied: kiwisolver>=1.0.1 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from matplotlib) (1.3.0)\n", "Requirement already satisfied: packaging>=20.0 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from matplotlib) (23.1)\n", "Requirement already satisfied: pillow>=6.2.0 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from matplotlib) (8.0.1)\n", "Requirement already satisfied: pyparsing>=2.3.1 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from matplotlib) (2.4.7)\n", "Requirement already satisfied: python-dateutil>=2.7 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from matplotlib) (2.8.2)\n", "Requirement already satisfied: importlib-resources>=3.2.0 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from matplotlib) (6.1.0)\n", "Requirement already satisfied: filelock in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from torch) (3.0.12)\n", "Requirement already satisfied: typing-extensions in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from torch) (4.8.0)\n", "Requirement already satisfied: sympy in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from torch) (1.6.2)\n", "Requirement already satisfied: networkx in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from torch) (2.5)\n", "Requirement already satisfied: jinja2 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from torch) (2.11.2)\n", "Requirement already satisfied: six in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from cycler>=0.10->matplotlib) (1.15.0)\n", "Requirement already satisfied: zipp>=3.1.0 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from importlib-resources>=3.2.0->matplotlib) (3.4.0)\n", "Requirement already satisfied: MarkupSafe>=0.23 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from jinja2->torch) (1.1.1)\n", "Requirement already satisfied: decorator>=4.3.0 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from networkx->torch) (4.4.2)\n", "Requirement already satisfied: mpmath>=0.19 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from sympy->torch) (1.1.0)\n", "\u001b[33mWARNING: Ignoring invalid distribution -y-mini-racer (/Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages)\u001b[0m\u001b[33m\n", "\u001b[0m\u001b[33mDEPRECATION: pyodbc 4.0.0-unsupported has a non-standard version number. pip 24.0 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of pyodbc or contact the author to suggest that they release a version with a conforming version number. Discussion can be found at https://github.com/pypa/pip/issues/12063\u001b[0m\u001b[33m\n", "\u001b[0m\n", "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.0\u001b[0m\n", "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n" ] } ], "source": [ "!pip install scikit-learn matplotlib torch" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import torch\n", "\n", "torch.manual_seed(1024)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "x = torch.linspace(100, 300, 200)\n", "x = (x - torch.mean(x)) / torch.std(x)\n", "epsilon = torch.randn(x.shape)\n", "y = 10 * x + 5 + epsilon" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXkAAAD4CAYAAAAJmJb0AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/OQEPoAAAACXBIWXMAAAsTAAALEwEAmpwYAAAjCklEQVR4nO3df5Bc1XUn8O+Z1gO18BYtgoKlNmNBipUWLCNFKkKirMuSvcgxMQiwg73ZBFe8UVIbVy0qampHMRUEiYtZKxVSu0kq0SauxRWCByMxloGssC0l1FIRycgz8iAjrbFBgobAOEwrxtNAz8zZP/q91ps3974f3e/1z++nSqWZ/vHeZVCdvnPuueeKqoKIiHrTQLsHQERE2WGQJyLqYQzyREQ9jEGeiKiHMcgTEfWwZe0egN+ll16qa9eubfcwiIi6yvHjx3+kqqtMz3VUkF+7di3Gx8fbPQwioq4iImdszzFdQ0TUwxjkiYh6GIM8EVEPY5AnIuphDPJERD2so6priIj6xdhECfsOn8ar5QrWFPIY2rEOOzcVU78PgzwRUYuNTZSw5+AUKtV5AECpXMGeg1MAkHqgZ7qGiKjF9h0+XQ/wnkp1HvsOn079XgzyREQt9mq5kujxZjDIExG12JpCPtHjzWCQJyJqsaEd65B3coseyzs5DO1Yl/q9uPBKRJQCW7VMWBXNvsOnUSpXkBNZlJNPc/GVQZ6IqEm2apnxM2/iwPFSaBVN1lU2DPJERE269xsnjdUyDz/7MuZVlzx+1yMnAIRX2TDIExF1gLGJEmZmq8bnggHe/7h/Bh+UZpUNgzwRkUHcHalhte05EWugr1Tnrc+nWWXD6hoiogAvx14qV6Co5cp3j07i7rGpJa8Nm3V/5ucuX1JF4zevmnmVDYM8EVGAKVeuAB46dhZjE6VFj9tm3YW8gz/YuQH337oBORHja4qFPO6/dQOKhTzE9z2ra4iIMjI2UULJMjtXYMmi6NCOdUvy6wKgXKli471PQaQ2Yxf3/R5vxr5zUzGTxmQezuSJiFxemiZMqVzB1pEj9Rn9zk3F+mwcwKJgXq5U64uywcz7cqc14ZdBnojIZUrTmJTKFdw5OolN9z2FsYkSdm4q4pnh7SgW8kuCuc3MbBV7Dk4tSf+kjUGeiMiVtHQxGKiTvj+rzpN+TQd5EblcRI6KyPdE5KSI/Ff38UtE5Jsi8n3375XND5eIKDuNlC76A3Uj78+i86RfGjP5OQB3qerVAK4H8DsicjWAYQDfVtWrAHzb/Z6IqGOZGoeZ62IW8wL10I51sV7vNyCSacqm6SCvqq+p6nfcr38M4HkARQA3A3jQfdmDAHY2ey8ioiz5F1G9ksY4OXYFsHXkCADgV68fNAb6Fc4AnNzSZ7zdr1kFelHLbqyGLiayFsDTAD4A4KyqFtzHBcCM933gPbsA7AKAwcHBzWfOnEltPEREzdo6csRaUhmUd3K4/9YNAGDtSHnXIyeMu1yLhTyeGd7e0BhF5LiqbjE+l1aQF5H3APh7AF9U1YMiUvYHdRGZUdXQvPyWLVt0fHw8lfEQEaUh2GESwJKad7+oYH3F8BPG9wqAF0dubGiMYUE+leoaEXEAHADwkKoedB9+XURWu8+vBvBGGvciIkrb2EQJW0eO4IrhJxbVwAPmFM4Dt2+0XitqIbWVp0IBKex4dVMxfwXgeVX9I99ThwDcAWDE/fvrzd6LiCgN/uZjF+cd/OTdOVTna/NrU0/34K7UsYmSdTYfFaxNO2SzOhUKSGcmvxXArwHYLiKT7p+Poxbc/4OIfB/AR93viYjaKth8rFyp1gO8J6p+fd/h09aUS1SwNv1mkHa/Gr+mZ/Kq+n9hrzL6SLPXJyJKU9xdrWFpF9tzingnOmXdr8aPDcqIqGcF0zLvzs1jtroQ671haZc1hbyx4qaYUV69GWxrQERdzbZoakrLxA3wUTly06apLPPqzeBMnoi6lu0AbSB+WgYAnAHBe5YvQ3m2GnoKlMd7Ls7JUe2W6maoZrFOnoiSsG1UKhbyeNWdwccxIMCC1g76EEHsYN8pwurkOZMnoq5lWwCNu0PVs+B+GpQr5w/kNpVSdiPm5Imoa2W1gcjTilbAWWOQJ6KutW39qthdHwt5BysaOI0p61bAWWO6hoi60thECQeOl2Ll3QXA5D031L9P0nQs698WssaZPBF1Ha+bY9zqmQGRRSWWphJIk04ti0yCQZ6IuopXNmlq12szrwrF4sVU/+HbOaklfQp5BytXOC1pN9AqTNcQUUfx71I1lTFG1b8X8g4uunAZXi1XMCCy5MPAW0x9Znh71wfwODiTJ6KOEdylWipXcOfoJDbd91Ssw7LzTg57b7oGQzvWYU0hb53td/tiahKcyRNRx7DN0mdmq/U0i61vjABY7gzgztHJ0EM9vGv0C87kiahjhM2wvTSLadHUGRAsywlmZmubmcICfC8spibBmTwRtYUp926bpXtK5Qp2j07i4ryD5c5Avf3A7Ltz9QAfpthFrQrSwpk8EbWcKfe+5+AUtq1fFVna6HWUnJmtorDCwdCOdSjHDPD9stjqxyBPRC1nyr1XqvM4emoa99+6AYW8E+s6Xq6+sCL89f2WovFjkCeilrPl3l8tV7BzUxGT99yAlRGB21OpzkMVS34D8Nod9Eq9e6NSycmLyJcB/DKAN1T1A+5jewH8JoBp92W/q6pPpnE/Iuo8UfXtfrbcu7/qJU4KxnOuUsUDt2/siv7urZbWwuv/BvAnAL4SePwBVf3DlO5BRB0q7PAOU6Ad2rFu0euB2sy7VK5g68iRWIuwfmsK+Zaem9pNUknXqOrTAN5M41pE1H1sOfa7Hjmx5Fg+oBb4/W0F/HXtSRZhgf7Ot8eR2slQIrIWwOOBdM1nAfwrgHEAd6nqjOF9uwDsAoDBwcHNZ86cSWU8RJSN4OHYIohVvph3csbceNjpTkM71uGuR05Yd672Y0mkSdjJUFkG+csA/Ai1D+jfB7BaVX8j7Bo8/o+oswXTMknlRLCguihnfsXwE8bNSwLgxZEbI5+nNh3/p6qv+wbwvwA8ntW9iKg1khyObeLNyP05e1vufUAEYxOlWIu0ZJdZCaWIrPZ9ewuA57K6FxG1RpqNvcLaFAC1DwRbbp55+PhSCfIi8jCAfwCwTkReEZHPAfiSiEyJyHcBbAOwO417EVH7xJk9e73Z4/Dq4u+/dYPxff4NUsVCvqf6vLdKKukaVf2M4eG/SuPaRNQ5TKWPfnknlyid431o7NxUxO7RSeNrvA8CBvXGcMcrEcW2c1MRt20uWg/PFigGYk7kgykX228JzL03h0GeiDA2UcLWkSPGmvago6emra18Z6sLWIhZsBdMuZhy88y9N4+thon6XNLdqmksvhbdHap+3vdsTZAuBnmiPmfbrbrv8GljgE3SbsAk2L7Afw/m3tPHIE/U58I6QpqajkUtvgZ5B2uXyhVj+wLA/BsDpYM5eaI+Z1vYvDjvGA/2ALCk74yNd7D2M8PbUSzkl+Tyvd8YKDuptTVIA9saELWeqVVB3slhuTNg7EnjnbDkf3+wl413LJ8/HWNrT+Bdk3n4xrWlrQERdaZgCmbb+lW4cNlAPcivXOHgnk9cgzstdeslN43jBeK4eXRbLt/L0XvXZgonXUzXEPUR09mqf33sLMqV8zP2t6sLGD/zZmgaZvfoJO4em0p0b1OJpD9H72EKJ10M8kR9JE6DsUp1Hg8/+7I1tQLUAvNDx86G1tMH+XvIe+0JbPdIs0dOv2O6hqiPxA2etv7tfgpYyyxtgqkdWy957nJND2fyRH0k7eDZ6Izb22HrlVX6cZdrujiTJ+picQ/P9l7XzCYmk0Y+NILVPIrzuXme9JQ+BnmiLhW3HUGzpznZNDrjNq0LeAHeX5pJ6WC6hqhLhbUjiHpdUgLgj2/fmEpP97AdtpQ+zuSJulRUsEwzRbPGbSiWRhqFx/m1FmfyRF0qrP/63WNT2D06mTjA552BzBdC2VK4tRjkibqULVhuW78KDx07G1rnbpJ3crj/1g/igZTSMjamenke55edVHrXiMiXAfwygDdU9QPuY5cAGAWwFsBLAH5FVWfCrsPeNUTJmPrGmPrNmBRC+sxQdwnrXZNWkP8QgLcAfMUX5L8E4E1VHRGRYQArVfW/hV2HQZ4ofllk8D1JKmhYydJbwoJ8KukaVX0awJuBh28G8KD79YMAdqZxL6JeZuots+fgVGT7gCQVNMx/95csc/KXqepr7tf/DOAy04tEZJeIjIvI+PT0dIbDIep8ccsig+KWHzL/3X9aUkKpqioixryQqu4HsB+opWtaMR6iThW3hjyY0imscKy5eO4i7W9ZBvnXRWS1qr4mIqsBvJHhvYh6Qpwa8rvHphZVz5TKFTgDAicnqM6fnyfVqmU4a+93WaZrDgG4w/36DgBfz/BeRD3B1nPdO/g6GOA91QXFRRcsY1kiLZHKTF5EHgbwYQCXisgrAO4BMALgERH5HIAzAH4ljXsR9TIvKPt3qvpn7H997Kz1vecqVUzec0PWQ6Quk0qQV9XPWJ76SBrXJ+pGUaWQtue91+wenUy0oUlR68/O/Dv5sXcNUQaiOkRGPb/v8OnEO1ZN1yFiWwOiDESVQkY930xHRp6RSn4M8kQZiCqFtD3v5eGb7cjItr3kYZAnykBhhWN83AvetiAuqKV6TFU2fnknV+/vbuLl55MctE29iUGeKGVjEyW89fbcksednNTbCQztWLekpS+w+HBsf6fGQt7ByhXOkvLIsA+DuC0RqLdx4ZUoZfsOn0Z1Yemy6UUXLKsvuoYtrHqpljiHdJhKLv28/DwXYfsXZ/JEKbPlw89VqosakNkkzcfv3FTEM8Pbjb8ZhI2H+gODPFHKwk5sitMtslSuYNN9TyVOs4Tdl/oXgzxRykx5cmdAMPvuXOzj+GZmqxh69ESiQM9j9ciEQZ4oZaZFUyQ4sclTnddE9e48Vo9MuPBKlAH/ounWkSMoV5IFeE/SfHqcxVrqLwzyRA0wna1qOys1LFCvdOvpbbN8fz69kWMBiZiuIUooeERfuVLFzGzVelxf2MLn29UF3PjB1XAGltbG+OvqGz0WkIhBniihqAqZYO+YsA1Lleo8jp6axr5PXVvL3btWrnCw75PXLqqDb+RYQCKma4giBNMkcSpkSuUKxiZK9SB94bIB6wfDq+VKZC497rGAREEM8kSw57tNLYEFiNUGeM/BKYyfeRMHjpdCZ/5x6tjjHAtIZMIgT30hbNEyrLe7KU0St897pTqPh599GfNqf0fcOvahHesWjTHJe6m/MchTz4tzQIct3x2WDvFm9GEz+7AAX0xQIePPzbO6hpLIPMiLyEsAfgxgHsCcqm7J+p5EfmFBfOemYmi+OywHHxXgASAnYgz0xUIezwxvj/lfUMMaeGpEq6prtqnqRgZ4aoeoRcuwni9Rfd3DArwAuP7KlWw1QG3FEkrqeVGNu8J6vnitAnJi6/FopwC+c/YcbttcZKsBaptW5OQVwFMiogD+QlX3+58UkV0AdgHA4OBgC4ZD/SZq0TIq3+39HbxGnCobrw4+aWqGKC2tCPK/qKolEflpAN8UkVOq+rT3pBv09wPAli1bGjmgniiULYgDtb4ycRYyTdfYtn5VZHkkwFp2aq/Mg7yqlty/3xCRxwBcB+Dp8HcRLdVM75bgomVUxU2cawDAlvdfUh/TgGWRlbXs1E6ZBnkRuQjAgKr+2P36BgD3ZXlP6k1hQRlIXlpoq7i565ETuHN0sl4VYypzjLtxCuAiK7Vf1jP5ywA8JrVFq2UA/kZV/0/G96QeZAvKew+dxDtzC4lm5IA9heLNxL2/g9eL8xsAa9mpk2Qa5FX1hwCuzfIe1B9sQdnUpz3O4dVxe9AErxdVc89aduo0LKGkrpA0rx222Dk2UcJP3plLdL1SuYKtI0esHwxcXKVOxSBPXcFWy+4duhFk+1C4e2wKu0cnGzqpyWtOluR+RO3G3jXUFcLKIOMudo5NlPDQsbOxG4yZmFoZcHGVOhmDPHUNUxmklyOPUw0TNwcfRVHbucrFVeoGDPLU8Uwli8DiGfy86qJWBP73Bmf6YfJOLvK1jTQXI2oX5uSpo5nONt09Ook7RydjHYcXdVSfn9dXphiSX2dqhroNZ/LU0ZIe2hGscomqehEAv3r9IP5g54ZFj5tm/ytXOLjnE9cwNUNdhUGeOlrS0kSvysVL8YR9INgO7eCmJuolDPLU0ZJsWvJSKVF5+LyTi2z3y01N1CuYk6eOFnVoh0cA3La5aN2V6lm5wmE/d+ornMlTW0V1lvSnTrzNSKYUjAI4emoaYxOl0Jn/iguWMcBTXxENOWi41bZs2aLj4+PtHga1iK1rY9hMe2yihDtHJ63XjCqBFAAvjtzY6JCJOpKIHLcdr8qZPGUmapYe1lky7JQm28amnEhkuSTbD1C/YU6eMmGqb99zcApjE6X6a8I6S4a9z9bHxnRgR/A1rHGnfsMgT5kIa8nriTurDr7PO1zbfzj2bZuL1uZhAA/Qpv7FdA1lwjZLL5Ur2HjvUxABZmarsQ7DNl0vWOK4deSI8ToC4IHbNzK4U9/iTJ4yETZLL1eqmJmttfr1ujpGUdQCuT9t42f7UFGEnxBF1OsY5CkTcevbATfQx4j0pvy8x/ahEtaHhqgfZB7kReRjInJaRF4QkeGs70edwZ83jyNuJa+pCRlgX4zlQiv1u0yDvIjkAPwpgF8CcDWAz4jI1VnekzrHzk1FPDO8PfXZtCk1Y1qM5UIrUfYLr9cBeME90Bsi8lUANwP4Xsb3pYxF1cD7De1Yl6inexRbaob9ZoiWyjrIFwG87Pv+FQA/53+BiOwCsAsABgcHMx4OpSG4U7VUrmDoaydw7zdOojxbxcV5ByJAebZa/wC4/9YNoTtV42IKhiiZtpdQqup+APuBWluDNg+HYjDVwFcXtF4x4z8k2/sAeM/yZP/UnAHBe5Yvw8xsNfRoPyIKl3WQLwG43Pf9+9zHqIsl7fHu/wCIg8GcKD1ZB/l/AnCViFyBWnD/NID/mPE9KSNxDuJIIrgRKk6fdyJKJtMgr6pzIvJ5AIcB5AB8WVVPZnlPaszYRAl7D52sp1qCR90lPRA7ircTlacvEWUr85y8qj4J4Mms70ONG5soYehrJ1BdOD+vnpmtYujREwAQeRBHI9YU8qyGIWoB7ngl7Dt8elGA91Tntb7xKGkePgwrZIhah0GeQgO495ytNr2Qd0LbF+SdHP7T9YPcpETUJm0voaTWC25kKqxwrNUvXnA3bWjKOznsvekaAKhfz1Qjz4BO1D4M8n3GtJHJGRAMCBDM2Dg5qadV/Get2k5sIqLOwzNe+8zWkSPGo/MKeQfA+Y1MK5wBXOjkQmfkSVobEFF2eMYr1dny7+cq1foB195s30vheC1+Pd4Zq/46d/9rGOiJOgeDfJ9ZU8gbZ/L+hdWwA7bfmVuoPxf8HdBrA8wgT9Q5WF3TZ+L0XQ87YDuqVj7NUksiah6DfJ+J03c97gHbJs28l4jSxyDfh7zDPB64fSMAYPfo5KLzU5Mc3efHTU5EnYc5+T5lKqUMLpx6C6xhvMVXdo4k6kwM8n3KtrjqLZx6wXr36KS16yQDO1HnY5DvU7YFUv/jtrbCXgdJBneizsecfJ+yLZCuKeQxNlGybpoCaukZBnii7sCZfA/y70T195IJfu3kBNX583P1vJPDtvWrIvvGF1lBQ9Q1OJPvMd6CaqlcgaJW2z4zWzV+Da0dDuIvpTx6ajo0wLOChqi7cCbfY5Ic7lFdUKy4YBkmfu+G+mO7Ryetr+dCK1H3YZDvEV6KJqrkMSi4AGtre1As5PHM8PamxkhErZdZkBeRvQB+E8C0+9DvukcBUsruHpvCQ8fONnTA9sV5B1tHjtQ7SW5bvwoHjpeW9I1nioaoO2U9k39AVf8w43v0vLCWvmMTpYYDPAD85N25envhUrmCA8dLuG1zEUdPTbOFMFEPYLqmw0XtTLXVsschgkXVNUBtQ9TRU9NMzRD1iKyD/OdF5NcBjAO4S1VnMr5fR2vkkI2wtr+N5OA9eSdnXaBlJ0mi3tHUyVAi8i0A7zU89QUAxwD8CLW9M78PYLWq/obhGrsA7AKAwcHBzWfOnGl4PJ0sOCMHaoH2ts1FPH7itXrKZOUKB/d84vy5qY0G8TBelYzt+lxkJeouYSdDteT4PxFZC+BxVf1A2Ot6+fi/sB2kQQMC5AZkSSqlWXknt6itsO2DJ9h6mIg6W1uO/xOR1ar6mvvtLQCey+pe3SBJCmRBgYWUA7ypxj3qcG4i6n5Z5uS/JCIbUUvXvATgtzK8V8ez1Z8nVSzkMfvuXP38Vb+cCOYNv5mFpV/8HSeJqPdk1tZAVX9NVTeo6gdV9SbfrL4vNXoQh583Gzdl2ATA9VeujDzaj4j6C3vXtIj/2L0oAwI4OVn0mL95mLdI66cAvnP2HG7bXAw92o+I+gvr5DNmKpsMq5oJVtcE3xfWl4Y17kQU1JLqmrh6rbomrGzS1DrANOv2f0jE+T8lAF4cuTGl/wIi6gZh1TVM12TItpHp6KnpeuomLK0SbBsch+0wECLqT0zXNCjO7tWwI/biVLUkaRsMcJGViJZikG9AVD8Zj61sMtj50VabHqe2XgSAgjXuRGTEIJ/A2EQJew+dNFa3VKrzuHN0EvsOn64H223rVxk7RJYr1UWdH00fEEDM2nplDp6I7BjkYxqbKGHoaydQXQjPjntBe/zMmzhwvBQrl16pzmPf4dMAFlfUmHq7BzEHT0RhuPAa077DpyMDvKdSncfDz76cKJ/ufTh4i6z+3u5ebb0E3sMcPBFF4Uw+pqTtd03tBcLkRKyVOF7deyOtiomovzHIx5S094wIjO0HTOL2dmefGSJKiumamIZ2rIMzEEyY1FoQGB6GYGlrguDzwPkaeVu7A+bciagZnMkH2FIi3gzaX13jtSC49xsnl3SF9NoFe50hC3kHIkB5tmpNtZh2xzLnTkTNYJD3iap/t6VLdo9OWq85r4q8k8Pem64JTbWwtzsRZYHpGh9bGwKvvNEmKqUS5xpALdA/M7wdD9y+EUDtw2PryBGMTZQi30tEZNKTM/lGq1BsFTSlcgWb7nvKmmoZ2rFuSaol7rVNY4+zm5aIKI6eC/JxgmRw56qXWw+roPFy7v7rAefTKxfnHSx3BownNgHxF1DDfptgkCeipHouXROVcvF2rvpbE8zMVjH06AlsW78q1ulNleo89h46uWjzUrlSxVtvzxkrbZycxF5ADWtqRkSUVM8F+aggadu5Wp3XegvgOMqV6pIPk+qCwrQp9qILlsWehdtm/CylJKJGNBXkReRTInJSRBZEZEvguT0i8oKInBaRHc0NMz5bMFQAP7PnydANTaVyBfsOn8bKFU6qYzpnaGhmYzoLlqWURNSoZmfyzwG4FcDT/gdF5GoAnwZwDYCPAfgzEWnuFOuYwg7MjtNqoFSu4K2350Jfk3dyiT4IkszC/WfB8pxWImpWUwuvqvo8AIgsSUTfDOCrqvoOgBdF5AUA1wH4h2buF4dp01JSUY3IvJROsKLGGRBAaqkfTyOzcLYvIKK0ZJWTLwJ42ff9K+5jS4jILhEZF5Hx6enpVG6+c1MRF12YTeFQsZCvB+HgjHvfp67Fvk9ey1k4EXWMyEgoIt8C8F7DU19Q1a83OwBV3Q9gP1A7yLvR6wRr4+M0E/P6xZheW8g7eGduIbTNgG3GzaBORJ0iMsir6kcbuG4JwOW+79/nPpYJU228eyqelT9gm3rG7L3pGgBsM0BE3S2rzVCHAPyNiPwRgDUArgLwjxndy1gbr4A10Bd9AXtsooQLlw3U3+9tjPKCOYM6EXWzpoK8iNwC4H8CWAXgCRGZVNUdqnpSRB4B8D0AcwB+R1XjH5OUkK02XlEL6LaZePA3AKDWJTJ4VisRUbdqtrrmMQCPWZ77IoAvNnP9uGw5+GIhXz9VycT2GwDAnjFE1Bt6YserrTZ+9t250A6OUa0C4naPJCLqVD0R5L1yxkJ+8Qalmdkq9hycsgb6OJuU2DOGiLpZTwR5wF4bH2xOtnXkCK4YfgJbR47EakjGnjFE1M16JsgD4c3JvEVWr2tkqVzBgeMl3La5WK+XD+7bZc8YIup2PdVP3rYAu6aQt7YgPnpqur442+hhI0REnaqngrzphCZvNm47h9U/+2fPGCLqNT2Vrgnr4Mg+7UTUj3pqJg/YZ+Nhs3wiol7Vc0Hexgv8zLkTUT/piSAfd8GUOXci6jddH+RNHSjZjoCIqKbrF15tpZFsR0BE1ANBPmwDFBFRv+v6IM/SSCIiu64P8qYOlCyNJCKq6fqFV5ZGEhHZdX2QB1gaSURk0/XpGiIismsqyIvIp0TkpIgsiMgW3+NrRaQiIpPunz9vfqhERJRUs+ma5wDcCuAvDM/9QFU3Nnl9IiJqQrMHeT8PACLB4zaIiKgTZJmTv0JEJkTk70Xk39teJCK7RGRcRManp6czHA4RUf+JnMmLyLcAvNfw1BdU9euWt70GYFBV/0VENgMYE5FrVPVfgy9U1f0A9rv3mhaRM/GHn6lLAfyo3YNIgOPNXreNmePNXqeM+f22JyKDvKp+NOndVPUdAO+4Xx8XkR8A+LcAxiPetyrpvbIiIuOquiX6lZ2B481et42Z481eN4w5k3SNiKwSkZz79ZUArgLwwyzuRUREds2WUN4iIq8A+HkAT4jIYfepDwH4rohMAngUwG+r6ptNjZSIiBJrtrrmMQCPGR4/AOBAM9fuAPvbPYCEON7sdduYOd7sdfyYRVXbPQYiIsoI2xoQEfUwBnkioh7GIO+y9eExvO4lEZlye/KEloRmKcF4PyYip0XkBREZbuUYA+O4RES+KSLfd/9eaXndvK/n0aE2jDP05yUiF4rIqPv8syKyttVjNIwpasyfdfegeD/X/9yOcbpj+bKIvCEiz1meFxH5H+5/y3dF5GdbPUbDmKLG/GEROef7+f5eq8cYSlX5p7Yu8e8ArAPwdwC2hLzuJQCXdsN4AeQA/ADAlQAuAHACwNVtGu+XAAy7Xw8D+O+W173Vxp9p5M8LwH8B8Ofu158GMNrmfwdxxvxZAH/SznH6xvIhAD8L4DnL8x8H8LcABMD1AJ7tgjF/GMDj7R6n7Q9n8i5VfV5Vu+b075jjvQ7AC6r6Q1V9F8BXAdyc/eiMbgbwoPv1gwB2tmkcYeL8vPz/HY8C+Ii0t3lTJ/0/jqSqTwMIK6e+GcBXtOYYgIKIrG7N6MxijLmjMcgnpwCeEpHjIrKr3YOJUATwsu/7V9zH2uEyVX3N/fqfAVxmed1yt5fRMRHZ2Zqh1cX5edVfo6pzAM4B+KmWjM4s7v/j29z0x6MicnlrhtaQTvo3m8TPi8gJEflbEbmm3YPx64mToeJqsA9P0C+qaklEfhrAN0XklPtJn7qUxtsyYeP1f6OqKiK22t33uz/fKwEcEZEpVf1B2mPtM98A8LCqviMiv4XabyLb2zymXvId1P7dviUiHwcwhtou/47QV0FeG+jDY7hGyf37DRF5DLVflzMJ8imMtwTAP2t7n/tYJsLGKyKvi8hqVX3N/fX7Dcs1vJ/vD0Xk7wBsQi3n3Apxfl7ea14RkWUALgbwL60ZnlHkmFXVP76/RG19pFO19N9sGtTXeFFVnxSRPxORS1W1ExqXMV2ThIhcJCL/xvsawA2oHZzSqf4JwFUicoWIXIDaQmHLK1ZchwDc4X59B4Alv4mIyEoRudD9+lIAWwF8r2UjjPfz8v93fBLAEXVX39okcsyBnPZNAJ5v4fiSOgTg190qm+sBnPOl+TqSiLzXW5cRketQi6vt/OBfrN0rv53yB8AtqOX/3gHwOoDD7uNrADzpfn0latULJwCcRC1t0rHjdb//OID/h9psuJ3j/SkA3wbwfQDfAnCJ+/gWAH/pfv0LAKbcn+8UgM+1YZxLfl4A7gNwk/v1cgBfA/ACgH8EcGU7/93GHPP97r/XEwCOAljfxrE+jFor8qr77/dzAH4btf5WQK2q5k/d/5YphFS6ddCYP+/7+R4D8AvtHrP/D9saEBH1MKZriIh6GIM8EVEPY5AnIuphDPJERD2MQZ6IqIcxyBMR9TAGeSKiHvb/AeEdc3jCYiI9AAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "\n", "plt.scatter(x, y)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "from sklearn import linear_model" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "m = linear_model.LinearRegression()" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(array([9.934817], dtype=float32), 5.093296)" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "m.fit(x.view(-1, 1), y)\n", "m.coef_, m.intercept_" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "#### 梯度下降法" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "import torch.nn as nn\n", "\n", "class Linear(nn.Module):\n", " \n", " def __init__(self):\n", " # 定义模型参数\n", " super().__init__()\n", " self.a = nn.Parameter(torch.zeros(()))\n", " self.b = nn.Parameter(torch.zeros(()))\n", " \n", " def forward(self, x):\n", " # 向前传播\n", " return self.a * x + self.b\n", " \n", " def string(self):\n", " return f'y = {self.a.item():.2f} * x + {self.b.item():.2f}'" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0.], grad_fn=)" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "m = Linear()\n", "m(x)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[Parameter containing:\n", " tensor(0., requires_grad=True),\n", " Parameter containing:\n", " tensor(0., requires_grad=True)]" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "list(m.parameters())" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "y = 1.98 * x + 1.02\n", "y = 3.56 * x + 1.83\n", "y = 4.83 * x + 2.49\n", "y = 5.85 * x + 3.01\n", "y = 6.66 * x + 3.42\n", "y = 7.31 * x + 3.76\n", "y = 7.83 * x + 4.03\n", "y = 8.25 * x + 4.24\n", "y = 8.59 * x + 4.41\n", "y = 8.85 * x + 4.55\n", "y = 9.07 * x + 4.66\n", "y = 9.24 * x + 4.74\n", "y = 9.38 * x + 4.81\n", "y = 9.49 * x + 4.87\n", "y = 9.58 * x + 4.91\n", "y = 9.65 * x + 4.95\n", "y = 9.71 * x + 4.98\n", "y = 9.75 * x + 5.00\n", "y = 9.79 * x + 5.02\n", "y = 9.82 * x + 5.03\n" ] } ], "source": [ "import torch.optim as optim\n", "\n", "learning_rate = 0.1\n", "model = Linear()\n", "optimizer = optim.SGD(model.parameters(), lr=learning_rate)\n", "\n", "for t in range(20):\n", " y_pred = model(x)\n", " # 定义损失\n", " loss = (y - y_pred).pow(2).mean()\n", " optimizer.zero_grad()\n", " # 计算梯度\n", " loss.backward()\n", " # 更新模型参数\n", " optimizer.step()\n", " print(model.string())" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "y = 1.98 * x + 1.02\n", "y = 3.56 * x + 1.83\n", "y = 4.83 * x + 2.49\n", "y = 5.85 * x + 3.01\n", "y = 6.66 * x + 3.42\n", "y = 7.31 * x + 3.76\n", "y = 7.83 * x + 4.03\n", "y = 8.25 * x + 4.24\n", "y = 8.59 * x + 4.41\n", "y = 8.85 * x + 4.55\n", "y = 9.07 * x + 4.66\n", "y = 9.24 * x + 4.74\n", "y = 9.38 * x + 4.81\n", "y = 9.49 * x + 4.87\n", "y = 9.58 * x + 4.91\n", "y = 9.65 * x + 4.95\n", "y = 9.71 * x + 4.98\n", "y = 9.75 * x + 5.00\n", "y = 9.79 * x + 5.02\n", "y = 9.82 * x + 5.03\n" ] } ], "source": [ "learning_rate = 0.1\n", "model = Linear()\n", "optimizer = optim.SGD(model.parameters(), lr=learning_rate)\n", "\n", "for t in range(20):\n", " y_pred = model(x)\n", " # 定义损失\n", " loss = (y - y_pred).pow(2).mean()\n", " # optimizer.zero_grad()\n", " # 计算梯度\n", " loss.backward()\n", " # 更新模型参数\n", " ## optimizer.step()\n", " with torch.no_grad():\n", " for param in model.parameters():\n", " param -= learning_rate * param.grad\n", " param.grad = None\n", " print(model.string())" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "y = 3.12 * x + -1.99\n", "y = 3.48 * x + -2.28\n", "y = 3.22 * x + -1.97\n", "y = 2.85 * x + -1.22\n", "y = 2.68 * x + -0.23\n", "y = 2.92 * x + 1.08\n", "y = 3.74 * x + 2.61\n", "y = 5.07 * x + 4.15\n", "y = 6.73 * x + 5.52\n", "y = 8.22 * x + 6.48\n", "y = 9.36 * x + 5.75\n", "y = 9.75 * x + 5.42\n", "y = 9.88 * x + 5.28\n", "y = 9.89 * x + 5.26\n", "y = 9.89 * x + 5.20\n", "y = 9.88 * x + 5.18\n", "y = 9.88 * x + 5.17\n", "y = 9.84 * x + 5.14\n", "y = 9.86 * x + 5.15\n", "y = 9.94 * x + 5.21\n" ] } ], "source": [ "### 随机梯度下降法\n", "learning_rate = 0.1\n", "batch_size = 20\n", "model = Linear()\n", "optimizer = optim.SGD(model.parameters(), lr=learning_rate)\n", "\n", "for t in range(20):\n", " ix = (t * batch_size) % len(x)\n", " xx = x[ix: ix + batch_size]\n", " yy = y[ix: ix + batch_size]\n", " yy_pred = model(xx)\n", " # 定义损失\n", " loss = (yy - yy_pred).pow(2).mean()\n", " optimizer.zero_grad()\n", " # 计算梯度\n", " loss.backward()\n", " # 更新模型参数\n", " optimizer.step()\n", " print(model.string())" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "y = 1.00 * x + -1.00\n", "y = 1.89 * x + -1.92\n", "y = 2.57 * x + -2.63\n", "y = 3.04 * x + -2.97\n", "y = 3.40 * x + -2.94\n", "y = 3.74 * x + -2.62\n", "y = 4.19 * x + -2.09\n", "y = 4.76 * x + -1.44\n", "y = 5.44 * x + -0.70\n", "y = 6.19 * x + 0.10\n", "y = 6.87 * x + 0.81\n", "y = 7.46 * x + 1.44\n", "y = 7.96 * x + 2.02\n", "y = 8.38 * x + 2.57\n", "y = 8.75 * x + 3.10\n", "y = 9.08 * x + 3.61\n", "y = 9.38 * x + 4.09\n", "y = 9.66 * x + 4.54\n", "y = 9.91 * x + 4.96\n", "y = 10.15 * x + 5.33\n" ] } ], "source": [ "### 随机梯度下降法\n", "learning_rate = 1\n", "batch_size = 20\n", "model = Linear()\n", "optimizer = optim.AdamW(model.parameters(), lr=learning_rate)\n", "\n", "for t in range(20):\n", " ix = (t * batch_size) % len(x)\n", " xx = x[ix: ix + batch_size]\n", " yy = y[ix: ix + batch_size]\n", " yy_pred = model(xx)\n", " # 定义损失\n", " loss = (yy - yy_pred).pow(2).mean()\n", " optimizer.zero_grad()\n", " # 计算梯度\n", " loss.backward()\n", " # 更新模型参数\n", " optimizer.step()\n", " print(model.string())" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "#### 张量的基本操作" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[0., 0., 0.],\n", " [0., 0., 0.]])" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "torch.zeros(2, 3)" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[ 1.9050, 0.1757],\n", " [ 1.2764, 0.6187],\n", " [ 3.2715, 0.6103],\n", " [-1.1903, 1.0333]])" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "torch.randn(4, 2)" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([2, 3, 4])" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a = torch.randn(2, 3, 4)\n", "a.shape" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([1, 2, 3, 4])" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a.unsqueeze(0).shape" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[[[ 0.6836, -0.8168, 0.0590, -1.3575],\n", " [ 2.8567, 1.0398, -0.6034, 0.3212],\n", " [-0.6649, 0.0157, -1.1210, 0.5838]],\n", "\n", " [[-0.3839, -0.6906, 1.4496, -0.3944],\n", " [ 0.7254, -1.0734, 0.9207, 0.6957],\n", " [ 0.0532, 1.7621, 0.2933, 2.3150]]]])" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "b = a.unsqueeze(0)\n", "b" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([1, 2, 3, 4])" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "b.shape" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([2, 3, 4])" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "b.squeeze(0).shape" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([1, 2, 3, 4])" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ "b.squeeze(1).shape" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([10])" ] }, "execution_count": 34, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data = torch.tensor(range(0, 10))\n", "data.shape" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([2, 5])" ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data.view(2, 5).shape" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([5, 2])" ] }, "execution_count": 37, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data.view(5, -1).shape" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [], "source": [ "d1 = data.view(2, 5)\n", "t1 = d1.T" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([1, 10])" ] }, "execution_count": 39, "metadata": {}, "output_type": "execute_result" } ], "source": [ "d1.view(1, 10).shape" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [ { "ename": "RuntimeError", "evalue": "view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mt1\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mview\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m10\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;31mRuntimeError\u001b[0m: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead." ] } ], "source": [ "t1.view(1, 10).shape" ] }, { "cell_type": "code", "execution_count": 42, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor([[0, 5, 1, 6, 2, 7, 3, 8, 4, 9]]),\n", " tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]))" ] }, "execution_count": 42, "metadata": {}, "output_type": "execute_result" } ], "source": [ "t1.reshape(1, 10), d1.view(1, 10)" ] }, { "cell_type": "code", "execution_count": 43, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([2, 3]), torch.Size([2, 3]))" ] }, "execution_count": 43, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a = torch.tensor(range(1, 7)).view(2, 3)\n", "b = torch.tensor(range(11, 17)).view(2, 3)\n", "a.shape, b.shape" ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor([[1, 2, 3],\n", " [4, 5, 6]]),\n", " tensor([[11, 12, 13],\n", " [14, 15, 16]]),\n", " tensor([[11, 24, 39],\n", " [56, 75, 96]]))" ] }, "execution_count": 44, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a, b, a * b" ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([3])" ] }, "execution_count": 45, "metadata": {}, "output_type": "execute_result" } ], "source": [ "b = torch.tensor(range(1, 4)).view(3)\n", "b.shape" ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[ 1, 4, 9],\n", " [ 4, 10, 18]])" ] }, "execution_count": 46, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a # (2, 3)\n", "b # ( 3)\n", "a * b" ] }, { "cell_type": "code", "execution_count": 47, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([3, 5])" ] }, "execution_count": 47, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mat1 = torch.randn(3, 4)\n", "mat2 = torch.randn(4, 5)\n", "(mat1 @ mat2).shape" ] }, { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([5, 8, 3, 5])" ] }, "execution_count": 48, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mat1 = torch.randn(5, 1, 3, 4) # (5, 1, 3, 4)\n", "mat2 = torch.randn(8, 4, 5) # ( 8, 4, 5)\n", "(mat1 @ mat2).shape # (5, 8, 3, 5)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" } }, "nbformat": 4, "nbformat_minor": 4 }