{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import torch\n", "torch.set_printoptions(edgeitems=2, precision=2, linewidth=75)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[ 7. , 0.27, 0.36, ..., 0.45, 8.8 , 6. ],\n", " [ 6.3 , 0.3 , 0.34, ..., 0.49, 9.5 , 6. ],\n", " [ 8.1 , 0.28, 0.4 , ..., 0.44, 10.1 , 6. ],\n", " ...,\n", " [ 6.5 , 0.24, 0.19, ..., 0.46, 9.4 , 6. ],\n", " [ 5.5 , 0.29, 0.3 , ..., 0.38, 12.8 , 7. ],\n", " [ 6. , 0.21, 0.38, ..., 0.32, 11.8 , 6. ]], dtype=float32)" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import csv\n", "wine_path = \"../data/p1ch4/tabular-wine/winequality-white.csv\"\n", "wineq_numpy = np.loadtxt(wine_path, dtype=np.float32, delimiter=\";\",\n", " skiprows=1)\n", "wineq_numpy" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "((4898, 12),\n", " ['fixed acidity',\n", " 'volatile acidity',\n", " 'citric acid',\n", " 'residual sugar',\n", " 'chlorides',\n", " 'free sulfur dioxide',\n", " 'total sulfur dioxide',\n", " 'density',\n", " 'pH',\n", " 'sulphates',\n", " 'alcohol',\n", " 'quality'])" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "col_list = next(csv.reader(open(wine_path), delimiter=';'))\n", "\n", "wineq_numpy.shape, col_list" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([4898, 12]), torch.float32)" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "wineq = torch.from_numpy(wineq_numpy)\n", "\n", "wineq.shape, wineq.dtype" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor([[ 7.00, 0.27, ..., 0.45, 8.80],\n", " [ 6.30, 0.30, ..., 0.49, 9.50],\n", " ...,\n", " [ 5.50, 0.29, ..., 0.38, 12.80],\n", " [ 6.00, 0.21, ..., 0.32, 11.80]]), torch.Size([4898, 11]))" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data = wineq[:, :-1] # <1>\n", "data, data.shape" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor([6., 6., ..., 7., 6.]), torch.Size([4898]))" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "target = wineq[:, -1] # <2>\n", "target, target.shape" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([6, 6, ..., 7, 6])" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "target = wineq[:, -1].long()\n", "target" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[0., 0., ..., 0., 0.],\n", " [0., 0., ..., 0., 0.],\n", " ...,\n", " [0., 0., ..., 0., 0.],\n", " [0., 0., ..., 0., 0.]])" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "target_onehot = torch.zeros(target.shape[0], 10)\n", "\n", "target_onehot.scatter_(1, target.unsqueeze(1), 1.0)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[6],\n", " [6],\n", " ...,\n", " [7],\n", " [6]])" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "target_unsqueezed = target.unsqueeze(1)\n", "target_unsqueezed" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([6.85e+00, 2.78e-01, 3.34e-01, 6.39e+00, 4.58e-02, 3.53e+01,\n", " 1.38e+02, 9.94e-01, 3.19e+00, 4.90e-01, 1.05e+01])" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data_mean = torch.mean(data, dim=0)\n", "data_mean" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([7.12e-01, 1.02e-02, 1.46e-02, 2.57e+01, 4.77e-04, 2.89e+02,\n", " 1.81e+03, 8.95e-06, 2.28e-02, 1.30e-02, 1.51e+00])" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data_var = torch.var(data, dim=0)\n", "data_var" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[ 1.72e-01, -8.18e-02, ..., -3.49e-01, -1.39e+00],\n", " [-6.57e-01, 2.16e-01, ..., 1.35e-03, -8.24e-01],\n", " ...,\n", " [-1.61e+00, 1.17e-01, ..., -9.63e-01, 1.86e+00],\n", " [-1.01e+00, -6.77e-01, ..., -1.49e+00, 1.04e+00]])" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data_normalized = (data - data_mean) / torch.sqrt(data_var)\n", "data_normalized" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([4898]), torch.bool, tensor(20))" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "bad_indexes = target <= 3 # <1>\n", "bad_indexes.shape, bad_indexes.dtype, bad_indexes.sum()" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([20, 11])" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "bad_data = data[bad_indexes]\n", "bad_data.shape" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " 0 fixed acidity 7.60 6.89 6.73\n", " 1 volatile acidity 0.33 0.28 0.27\n", " 2 citric acid 0.34 0.34 0.33\n", " 3 residual sugar 6.39 6.71 5.26\n", " 4 chlorides 0.05 0.05 0.04\n", " 5 free sulfur dioxide 53.33 35.42 34.55\n", " 6 total sulfur dioxide 170.60 141.83 125.25\n", " 7 density 0.99 0.99 0.99\n", " 8 pH 3.19 3.18 3.22\n", " 9 sulphates 0.47 0.49 0.50\n", "10 alcohol 10.34 10.26 11.42\n" ] } ], "source": [ "bad_data = data[target <= 3]\n", "mid_data = data[(target > 3) & (target < 7)] # <1>\n", "good_data = data[target >= 7]\n", "\n", "bad_mean = torch.mean(bad_data, dim=0)\n", "mid_mean = torch.mean(mid_data, dim=0)\n", "good_mean = torch.mean(good_data, dim=0)\n", "\n", "for i, args in enumerate(zip(col_list, bad_mean, mid_mean, good_mean)):\n", " print('{:2} {:20} {:6.2f} {:6.2f} {:6.2f}'.format(i, *args))" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([4898]), torch.bool, tensor(2727))" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "total_sulfur_threshold = 141.83\n", "total_sulfur_data = data[:,6]\n", "predicted_indexes = torch.lt(total_sulfur_data, total_sulfur_threshold)\n", "\n", "predicted_indexes.shape, predicted_indexes.dtype, predicted_indexes.sum()" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([4898]), torch.bool, tensor(3258))" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "actual_indexes = target > 5\n", "\n", "actual_indexes.shape, actual_indexes.dtype, actual_indexes.sum()" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(2018, 0.74000733406674, 0.6193984039287906)" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "n_matches = torch.sum(actual_indexes & predicted_indexes).item()\n", "n_predicted = torch.sum(predicted_indexes).item()\n", "n_actual = torch.sum(actual_indexes).item()\n", "\n", "n_matches, n_matches / n_predicted, n_matches / n_actual" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.6" } }, "nbformat": 4, "nbformat_minor": 2 }