Browse Source

Remove references to Variable. Use float(loss) instead of loss.item()

Luca Antiga 7 years ago
parent
commit
ec20acf706
1 changed files with 2 additions and 4 deletions
  1. 2 4
      p1ch2/4_mnist.ipynb

+ 2 - 4
p1ch2/4_mnist.ipynb

@@ -10,8 +10,7 @@
     "import torch.nn as nn\n",
     "import torch.nn.functional as F\n",
     "import torch.optim as optim\n",
-    "from torchvision import datasets, transforms\n",
-    "from torch.autograd import Variable"
+    "from torchvision import datasets, transforms"
    ]
   },
   {
@@ -89,11 +88,10 @@
    "source": [
     "for epoch in range(10):\n",
     "    for batch_idx, (data, target) in enumerate(train_loader):\n",
-    "        #data, target = Variable(data), Variable(target)\n",
     "        optimizer.zero_grad()\n",
     "        output = model(data)\n",
     "        loss = F.nll_loss(output, target)\n",
-    "        print('Current loss', loss.item())\n",
+    "        print('Current loss', float(loss))\n",
     "        loss.backward()\n",
     "        optimizer.step()"
    ]