Session 9. More on Neural Networks#

Image classification of fashion items 📷#

In this notebook, we will summarize all the components for training a neural network over a new image classification task.

%matplotlib inline
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda, Compose
import matplotlib.pyplot as plt
/Users/victorgallego/miniforge3/lib/python3.9/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  _torch_pytree._register_pytree_node(

PyTorch offers domain-specific libraries such as TorchText, TorchVision, and TorchAudio, all of which include datasets. For this class, we’ll be using a TorchVision dataset.

The torchvision.datasets module contains Dataset objects for many real-world vision datasets, such as CIFAR and COCO. In this class, we’ll use the FashionMNIST dataset.

FashionMNIST was created by Zalando, and contains 70,000 grayscale images in 10 categories. The images show individual articles of clothing at low resolution (28 by 28 pixels), as seen here:

Fashion-MNIST samples

Each training and test example is assigned to one of the following labels:

Label

Description

0

T-shirt/top

1

Trouser

2

Pullover

3

Dress

4

Coat

5

Sandal

6

Shirt

7

Sneaker

8

Bag

9

Ankle boot

1. Set up the DataLoaders#

We’ll use the torchvision module to load and transform the FashionMNIST dataset. The code below will load the FashionMNIST dataset. Instead of spliting the data into a training and validation set, this dataset is already split. We’ll use the train set to train the model and the test set to evaluate it.

In the following cell, we wrap it in a DataLoader that will batch and shuffle the data.

# Download training data from open datasets.
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)

# Download test data from open datasets.
test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz
100%|██████████| 26421880/26421880 [00:02<00:00, 11284071.54it/s]
Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz
100%|██████████| 29515/29515 [00:00<00:00, 937108.79it/s]
Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
100%|██████████| 4422102/4422102 [00:00<00:00, 8142683.28it/s] 
Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
100%|██████████| 5148/5148 [00:00<00:00, 10336178.55it/s]
Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw

batch_size = 64

# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

for X, y in test_dataloader:
    print("Shape of X [N, C, H, W]: ", X.shape)
    print("Shape of y: ", y.shape, y.dtype)
    break
Shape of X [N, C, H, W]:  torch.Size([64, 1, 28, 28])
Shape of y:  torch.Size([64]) torch.int64
# Display sample data
figure = plt.figure(figsize=(10, 8))
cols, rows = 5, 5
for i in range(1, cols * rows + 1):
    idx = torch.randint(len(test_data), size=(1,)).item()
    img, label = test_data[idx]
    figure.add_subplot(rows, cols, i)
    plt.title(label)
    plt.axis("off")
    plt.imshow(img.squeeze(), cmap="gray")
plt.show()
../_images/802ab37577850e112a30ae6f6411fcbaabd65e169984a208fbf9a883a86926de.png

2. Create the model#

To define a neural network in PyTorch, we create a class that inherits from nn.Module. We define the layers of the network in the init function and specify how data will pass through the network in the forward function.

# Define model
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        # The Flatten layer converts each 2D 28x28 image into a contiguous array of 784 pixel values (the minibatch dimension (at dim=0) is maintained).
        self.flatten = nn.Flatten()
        self.linear_1 = nn.Linear(28*28, 128)
        self.linear_2 = nn.Linear(128, 64)
        self.linear_3 = nn.Linear(64, 10)
        
    def forward(self, x):
        x = self.flatten(x)
        x = self.linear_1(x)
        x = nn.ReLU()(x)
        x = self.linear_2(x)
        x = nn.ReLU()(x)
        x = self.linear_3(x)

        # We don't need to apply the softmax activation here. When it is not included, the loss function will automatically apply it. In the previous class we made it explicit for the sake of clarity.
        return x

model = NeuralNetwork()
print(model)
NeuralNetwork(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_1): Linear(in_features=784, out_features=128, bias=True)
  (linear_2): Linear(in_features=128, out_features=64, bias=True)
  (linear_3): Linear(in_features=64, out_features=10, bias=True)
)

3. Train the model and then evaluate#

To train a model, we need a loss function and an optimizer. We’ll be using nn.CrossEntropyLoss for loss and Stochastic Gradient Descent for optimization.

loss_fn = nn.CrossEntropyLoss()  # applies the softmax at the end to get probabilities of each class
learning_rate = 1e-2
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

In a single training loop, the model makes predictions on the training dataset (fed to it in batches), and back-propagates the prediction error to adjust the model’s parameters.

def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(dataloader):
        
        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)
        
        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

We can also check the model’s performance against the test dataset to ensure it is learning.

def test(dataloader, model):
    size = len(dataloader.dataset)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= size
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

The training process is conducted over several iterations (epochs). During each epoch, the model learns parameters to make better predictions. We print the model’s accuracy and loss at each epoch; we’d like to see the accuracy increase and the loss decrease with every epoch.

epochs = 15
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    test(test_dataloader, model)
print("Done!") 
Epoch 1
-------------------------------
loss: 2.289097  [    0/60000]
loss: 2.204153  [ 6400/60000]
loss: 2.007140  [12800/60000]
loss: 1.802666  [19200/60000]
loss: 1.406623  [25600/60000]
loss: 1.252974  [32000/60000]
loss: 1.139174  [38400/60000]
loss: 1.009746  [44800/60000]
loss: 0.990271  [51200/60000]
loss: 0.889658  [57600/60000]
Test Error: 
 Accuracy: 67.9%, Avg loss: 0.013709 

Epoch 2
-------------------------------
loss: 0.927257  [    0/60000]
loss: 0.924171  [ 6400/60000]
loss: 0.681874  [12800/60000]
loss: 0.846820  [19200/60000]
loss: 0.691251  [25600/60000]
loss: 0.678255  [32000/60000]
loss: 0.771630  [38400/60000]
loss: 0.759616  [44800/60000]
loss: 0.723761  [51200/60000]
loss: 0.680475  [57600/60000]
Test Error: 
 Accuracy: 76.2%, Avg loss: 0.010562 

Epoch 3
-------------------------------
loss: 0.618100  [    0/60000]
loss: 0.711628  [ 6400/60000]
loss: 0.482961  [12800/60000]
loss: 0.697245  [19200/60000]
loss: 0.599393  [25600/60000]
loss: 0.565808  [32000/60000]
loss: 0.636101  [38400/60000]
loss: 0.662630  [44800/60000]
loss: 0.658170  [51200/60000]
loss: 0.592210  [57600/60000]
Test Error: 
 Accuracy: 79.2%, Avg loss: 0.009311 

Epoch 4
-------------------------------
loss: 0.499205  [    0/60000]
loss: 0.607933  [ 6400/60000]
loss: 0.407313  [12800/60000]
loss: 0.617116  [19200/60000]
loss: 0.550129  [25600/60000]
loss: 0.520347  [32000/60000]
loss: 0.568434  [38400/60000]
loss: 0.638106  [44800/60000]
loss: 0.633640  [51200/60000]
loss: 0.529101  [57600/60000]
Test Error: 
 Accuracy: 80.1%, Avg loss: 0.008692 

Epoch 5
-------------------------------
loss: 0.432519  [    0/60000]
loss: 0.554164  [ 6400/60000]
loss: 0.369696  [12800/60000]
loss: 0.566110  [19200/60000]
loss: 0.507076  [25600/60000]
loss: 0.491527  [32000/60000]
loss: 0.530983  [38400/60000]
loss: 0.634127  [44800/60000]
loss: 0.613094  [51200/60000]
loss: 0.486707  [57600/60000]
Test Error: 
 Accuracy: 80.8%, Avg loss: 0.008325 

Epoch 6
-------------------------------
loss: 0.388891  [    0/60000]
loss: 0.522734  [ 6400/60000]
loss: 0.346615  [12800/60000]
loss: 0.533556  [19200/60000]
loss: 0.476073  [25600/60000]
loss: 0.472929  [32000/60000]
loss: 0.504632  [38400/60000]
loss: 0.626713  [44800/60000]
loss: 0.593248  [51200/60000]
loss: 0.462189  [57600/60000]
Test Error: 
 Accuracy: 81.5%, Avg loss: 0.008075 

Epoch 7
-------------------------------
loss: 0.358328  [    0/60000]
loss: 0.502070  [ 6400/60000]
loss: 0.328464  [12800/60000]
loss: 0.513722  [19200/60000]
loss: 0.455949  [25600/60000]
loss: 0.460615  [32000/60000]
loss: 0.483159  [38400/60000]
loss: 0.615375  [44800/60000]
loss: 0.576040  [51200/60000]
loss: 0.448871  [57600/60000]
Test Error: 
 Accuracy: 81.7%, Avg loss: 0.007917 

Epoch 8
-------------------------------
loss: 0.336749  [    0/60000]
loss: 0.484792  [ 6400/60000]
loss: 0.314682  [12800/60000]
loss: 0.499265  [19200/60000]
loss: 0.437410  [25600/60000]
loss: 0.452041  [32000/60000]
loss: 0.465575  [38400/60000]
loss: 0.603855  [44800/60000]
loss: 0.562545  [51200/60000]
loss: 0.438371  [57600/60000]
Test Error: 
 Accuracy: 81.8%, Avg loss: 0.007762 

Epoch 9
-------------------------------
loss: 0.320475  [    0/60000]
loss: 0.470900  [ 6400/60000]
loss: 0.303407  [12800/60000]
loss: 0.489299  [19200/60000]
loss: 0.419739  [25600/60000]
loss: 0.447229  [32000/60000]
loss: 0.451173  [38400/60000]
loss: 0.591696  [44800/60000]
loss: 0.547607  [51200/60000]
loss: 0.430787  [57600/60000]
Test Error: 
 Accuracy: 82.2%, Avg loss: 0.007628 

Epoch 10
-------------------------------
loss: 0.306368  [    0/60000]
loss: 0.458059  [ 6400/60000]
loss: 0.292609  [12800/60000]
loss: 0.480203  [19200/60000]
loss: 0.404263  [25600/60000]
loss: 0.440233  [32000/60000]
loss: 0.437055  [38400/60000]
loss: 0.579413  [44800/60000]
loss: 0.535025  [51200/60000]
loss: 0.425971  [57600/60000]
Test Error: 
 Accuracy: 82.8%, Avg loss: 0.007468 

Epoch 11
-------------------------------
loss: 0.294402  [    0/60000]
loss: 0.445635  [ 6400/60000]
loss: 0.283190  [12800/60000]
loss: 0.470698  [19200/60000]
loss: 0.390130  [25600/60000]
loss: 0.434453  [32000/60000]
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
/var/folders/l_/k13w4mhd5hv4bddxwqz8qdfw0000gn/T/ipykernel_10774/2998487589.py in <module>
      2 for t in range(epochs):
      3     print(f"Epoch {t+1}\n-------------------------------")
----> 4     train(train_dataloader, model, loss_fn, optimizer)
      5     test(test_dataloader, model)
      6 print("Done!")

/var/folders/l_/k13w4mhd5hv4bddxwqz8qdfw0000gn/T/ipykernel_10774/568462196.py in train(dataloader, model, loss_fn, optimizer)
      1 def train(dataloader, model, loss_fn, optimizer):
      2     size = len(dataloader.dataset)
----> 3     for batch, (X, y) in enumerate(dataloader):
      4 
      5         # Compute prediction error

~/miniforge3/lib/python3.9/site-packages/torch/utils/data/dataloader.py in __next__(self)
    629                 # TODO(https://github.com/pytorch/pytorch/issues/76750)
    630                 self._reset()  # type: ignore[call-arg]
--> 631             data = self._next_data()
    632             self._num_yielded += 1
    633             if self._dataset_kind == _DatasetKind.Iterable and \

~/miniforge3/lib/python3.9/site-packages/torch/utils/data/dataloader.py in _next_data(self)
    673     def _next_data(self):
    674         index = self._next_index()  # may raise StopIteration
--> 675         data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    676         if self._pin_memory:
    677             data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)

~/miniforge3/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py in fetch(self, possibly_batched_index)
     49                 data = self.dataset.__getitems__(possibly_batched_index)
     50             else:
---> 51                 data = [self.dataset[idx] for idx in possibly_batched_index]
     52         else:
     53             data = self.dataset[possibly_batched_index]

~/miniforge3/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py in <listcomp>(.0)
     49                 data = self.dataset.__getitems__(possibly_batched_index)
     50             else:
---> 51                 data = [self.dataset[idx] for idx in possibly_batched_index]
     52         else:
     53             data = self.dataset[possibly_batched_index]

~/miniforge3/lib/python3.9/site-packages/torchvision/datasets/mnist.py in __getitem__(self, index)
    140         # doing this so that it is consistent with all other datasets
    141         # to return a PIL Image
--> 142         img = Image.fromarray(img.numpy(), mode="L")
    143 
    144         if self.transform is not None:

~/miniforge3/lib/python3.9/site-packages/PIL/Image.py in fromarray(obj, mode)
   3067     .. versionadded:: 1.1.6
   3068     """
-> 3069     arr = obj.__array_interface__
   3070     shape = arr["shape"]
   3071     ndim = len(shape)

KeyboardInterrupt: 

Exercise What is the best test accuracy you can achieve?

You can play with the architecture and/or the hyperparameters.

4. Making predictions#

Once the model is trained, we can use it to make predictions. We’ll use the argmax function to get the index of the max value in the output of the model. This index corresponds to the predicted class of clothing.

classes = [
    "T-shirt/top",
    "Trouser",
    "Pullover",
    "Dress",
    "Coat",
    "Sandal",
    "Shirt",
    "Sneaker",
    "Bag",
    "Ankle boot",
]

model.eval()

example_idx = 13
x, y = test_data[example_idx][0], test_data[example_idx][1]
with torch.no_grad():
    pred = model(x)
    predicted, actual = classes[pred[0].argmax(0)], classes[y]
    print(f'Predicted: "{predicted}", Actual: "{actual}"')

    # plot the image
    plt.imshow(x.squeeze(), cmap="gray")
Predicted: "Sandal", Actual: "Sneaker"
../_images/84ae9967729acd28e3af3c947ad9c34dd5bb9cb81c041167c7d9d36d292520d5.png

Exercise Can you look for three examples of misclassified items?

Are the mistakes reasonable?