{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import torch\n", "torch.set_printoptions(edgeitems=2)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[ 7. , 0.27, 0.36, ..., 0.45, 8.8 , 6. ],\n", " [ 6.3 , 0.3 , 0.34, ..., 0.49, 9.5 , 6. ],\n", " [ 8.1 , 0.28, 0.4 , ..., 0.44, 10.1 , 6. ],\n", " ...,\n", " [ 6.5 , 0.24, 0.19, ..., 0.46, 9.4 , 6. ],\n", " [ 5.5 , 0.29, 0.3 , ..., 0.38, 12.8 , 7. ],\n", " [ 6. , 0.21, 0.38, ..., 0.32, 11.8 , 6. ]], dtype=float32)" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import csv\n", "wine_path = \"../data/p1ch4/tabular-wine/winequality-white.csv\"\n", "wineq_numpy = np.loadtxt(wine_path, dtype=np.float32, delimiter=\";\", skiprows=1)\n", "wineq_numpy" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "((4898, 12),\n", " ['fixed acidity',\n", " 'volatile acidity',\n", " 'citric acid',\n", " 'residual sugar',\n", " 'chlorides',\n", " 'free sulfur dioxide',\n", " 'total sulfur dioxide',\n", " 'density',\n", " 'pH',\n", " 'sulphates',\n", " 'alcohol',\n", " 'quality'])" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "col_list = next(csv.reader(open(wine_path), delimiter=';'))\n", "\n", "wineq_numpy.shape, col_list" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([4898, 12]), 'torch.FloatTensor')" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "wineq = torch.from_numpy(wineq_numpy)\n", "\n", "wineq.shape, wineq.type()" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor([[ 7.0000, 0.2700, ..., 0.4500, 8.8000],\n", " [ 6.3000, 0.3000, ..., 0.4900, 9.5000],\n", " ...,\n", " [ 5.5000, 0.2900, ..., 0.3800, 12.8000],\n", " [ 6.0000, 0.2100, ..., 0.3200, 11.8000]]), torch.Size([4898, 11]))" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data = wineq[:, :-1] # <1>\n", "data, data.shape" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor([6., 6., ..., 7., 6.]), torch.Size([4898]))" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "target = wineq[:, -1] # <2>\n", "target, target.shape" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([6, 6, ..., 7, 6])" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "target = wineq[:, -1].long()\n", "target" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[0., 0., ..., 0., 0.],\n", " [0., 0., ..., 0., 0.],\n", " ...,\n", " [0., 0., ..., 0., 0.],\n", " [0., 0., ..., 0., 0.]])" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "target_onehot = torch.zeros(target.shape[0], 10)\n", "\n", "target_onehot.scatter_(1, target.unsqueeze(1), 1.0)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[6],\n", " [6],\n", " ...,\n", " [7],\n", " [6]])" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "target_unsqueezed = target.unsqueeze(1)\n", "target_unsqueezed" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([6.8548e+00, 2.7824e-01, 3.3419e-01, 6.3914e+00, 4.5772e-02, 3.5308e+01,\n", " 1.3836e+02, 9.9403e-01, 3.1883e+00, 4.8985e-01, 1.0514e+01])" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data_mean = torch.mean(data, dim=0)\n", "data_mean" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([7.1211e-01, 1.0160e-02, 1.4646e-02, 2.5726e+01, 4.7733e-04, 2.8924e+02,\n", " 1.8061e+03, 8.9455e-06, 2.2801e-02, 1.3025e-02, 1.5144e+00])" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data_var = torch.var(data, dim=0)\n", "data_var" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[ 1.7209e-01, -8.1764e-02, ..., -3.4914e-01, -1.3930e+00],\n", " [-6.5743e-01, 2.1587e-01, ..., 1.3467e-03, -8.2418e-01],\n", " ...,\n", " [-1.6054e+00, 1.1666e-01, ..., -9.6250e-01, 1.8574e+00],\n", " [-1.0129e+00, -6.7703e-01, ..., -1.4882e+00, 1.0448e+00]])" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data_normalized = (data - data_mean) / torch.sqrt(data_var)\n", "data_normalized" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([4898]), torch.uint8, tensor(20))" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "bad_indexes = torch.le(target, 3)\n", "bad_indexes.shape, bad_indexes.dtype, bad_indexes.sum()" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([20, 11])" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "bad_data = data[bad_indexes]\n", "bad_data.shape" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " 0 fixed acidity 7.60 6.89 6.73\n", " 1 volatile acidity 0.33 0.28 0.27\n", " 2 citric acid 0.34 0.34 0.33\n", " 3 residual sugar 6.39 6.71 5.26\n", " 4 chlorides 0.05 0.05 0.04\n", " 5 free sulfur dioxide 53.33 35.42 34.55\n", " 6 total sulfur dioxide 170.60 141.83 125.25\n", " 7 density 0.99 0.99 0.99\n", " 8 pH 3.19 3.18 3.22\n", " 9 sulphates 0.47 0.49 0.50\n", "10 alcohol 10.34 10.26 11.42\n" ] } ], "source": [ "bad_data = data[torch.le(target, 3)]\n", "mid_data = data[torch.gt(target, 3) & torch.lt(target, 7)] # <1>\n", "good_data = data[torch.ge(target, 7)]\n", "\n", "bad_mean = torch.mean(bad_data, dim=0)\n", "mid_mean = torch.mean(mid_data, dim=0)\n", "good_mean = torch.mean(good_data, dim=0)\n", "\n", "for i, args in enumerate(zip(col_list, bad_mean, mid_mean, good_mean)):\n", " print('{:2} {:20} {:6.2f} {:6.2f} {:6.2f}'.format(i, *args))" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([4898]), torch.uint8, tensor(2727))" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "total_sulfur_threshold = 141.83\n", "total_sulfur_data = data[:,6]\n", "predicted_indexes = torch.lt(total_sulfur_data, total_sulfur_threshold)\n", "\n", "predicted_indexes.shape, predicted_indexes.dtype, predicted_indexes.sum()" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([4898]), torch.uint8, tensor(3258))" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "actual_indexes = torch.gt(target, 5)\n", "\n", "actual_indexes.shape, actual_indexes.dtype, actual_indexes.sum()" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(2018, 0.74000733406674, 0.6193984039287906)" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "n_matches = torch.sum(actual_indexes & predicted_indexes).item()\n", "n_predicted = torch.sum(predicted_indexes).item()\n", "n_actual = torch.sum(actual_indexes).item()\n", "\n", "n_matches, n_matches / n_predicted, n_matches / n_actual" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.6" } }, "nbformat": 4, "nbformat_minor": 2 }