|
|
@@ -10,8 +10,7 @@
|
|
|
"import torch.nn as nn\n",
|
|
|
"import torch.nn.functional as F\n",
|
|
|
"import torch.optim as optim\n",
|
|
|
- "from torchvision import datasets, transforms\n",
|
|
|
- "from torch.autograd import Variable"
|
|
|
+ "from torchvision import datasets, transforms"
|
|
|
]
|
|
|
},
|
|
|
{
|
|
|
@@ -89,11 +88,10 @@
|
|
|
"source": [
|
|
|
"for epoch in range(10):\n",
|
|
|
" for batch_idx, (data, target) in enumerate(train_loader):\n",
|
|
|
- " #data, target = Variable(data), Variable(target)\n",
|
|
|
" optimizer.zero_grad()\n",
|
|
|
" output = model(data)\n",
|
|
|
" loss = F.nll_loss(output, target)\n",
|
|
|
- " print('Current loss', loss.item())\n",
|
|
|
+ " print('Current loss', float(loss))\n",
|
|
|
" loss.backward()\n",
|
|
|
" optimizer.step()"
|
|
|
]
|