In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

In [2]:
torch.manual_seed(4242)



In [3]:
train_loader = torch.utils.data.DataLoader(
 datasets.MNIST('../data/p1ch2/mnist', train=True, download=True,
 transform=transforms.Compose([
 transforms.ToTensor(),
 transforms.Normalize((0.1307,), (0.3081,))
 ])),
 batch_size=64, shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Processing...
Done!


In [4]:
class Net(nn.Module):
 def __init__(self):
 super(Net, self).__init__()
 self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
 self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
 self.conv2_drop = nn.Dropout2d()
 self.fc1 = nn.Linear(320, 50)
 self.fc2 = nn.Linear(50, 10)

 def forward(self, x):
 x = F.relu(F.max_pool2d(self.conv1(x), 2))
 x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
 x = x.view(-1, 320)
 x = F.relu(self.fc1(x))
 x = F.dropout(x, training=self.training)
 x = self.fc2(x)
 return F.log_softmax(x, dim=1)

In [5]:
model = Net()

In [6]:
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

In [7]:
for epoch in range(10):
 for batch_idx, (data, target) in enumerate(train_loader):
 optimizer.zero_grad()
 output = model(data)
 loss = F.nll_loss(output, target)
 loss.backward()
 optimizer.step()
 print('Current loss', float(loss))

Current loss 0.4354310631752014
Current loss 0.23793256282806396
Current loss 0.382179856300354
Current loss 0.3900523781776428
Current loss 0.283257395029068
Current loss 0.1536979377269745
Current loss 0.10767409205436707
Current loss 0.14431846141815186
Current loss 0.30025267601013184
Current loss 0.18810895085334778


In [8]:
torch.save(model.state_dict(), '../data/p1ch2/mnist/mnist.pth')

In [9]:
pretrained_model = Net()
pretrained_model.load_state_dict(torch.load('../data/p1ch2/mnist/mnist.pth'))