Session 9. More on Neural Networks#
Image classification of fashion items 📷#
In this notebook, we will summarize all the components for training a neural network over a new image classification task.
%matplotlib inline
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda, Compose
import matplotlib.pyplot as plt
/Users/victorgallego/miniforge3/lib/python3.9/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
_torch_pytree._register_pytree_node(
PyTorch offers domain-specific libraries such as TorchText
, TorchVision
, and TorchAudio
, all of which include datasets. For this class, we’ll be using a TorchVision dataset.
The torchvision.datasets
module contains Dataset
objects for many real-world vision datasets, such as CIFAR and COCO. In this class, we’ll use the FashionMNIST dataset.
FashionMNIST was created by Zalando, and contains 70,000 grayscale images in 10 categories. The images show individual articles of clothing at low resolution (28 by 28 pixels), as seen here:
Each training and test example is assigned to one of the following labels:
Label |
Description |
---|---|
0 |
T-shirt/top |
1 |
Trouser |
2 |
Pullover |
3 |
Dress |
4 |
Coat |
5 |
Sandal |
6 |
Shirt |
7 |
Sneaker |
8 |
Bag |
9 |
Ankle boot |
1. Set up the DataLoaders#
We’ll use the torchvision
module to load and transform the FashionMNIST dataset. The code below will load the FashionMNIST dataset. Instead of spliting the data into a training and validation set, this dataset is already split. We’ll use the train
set to train the model and the test
set to evaluate it.
In the following cell, we wrap it in a DataLoader
that will batch and shuffle the data.
# Download training data from open datasets.
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor(),
)
# Download test data from open datasets.
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor(),
)
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz
100%|██████████| 26421880/26421880 [00:02<00:00, 11284071.54it/s]
Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz
100%|██████████| 29515/29515 [00:00<00:00, 937108.79it/s]
Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
100%|██████████| 4422102/4422102 [00:00<00:00, 8142683.28it/s]
Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
100%|██████████| 5148/5148 [00:00<00:00, 10336178.55it/s]
Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw
batch_size = 64
# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)
for X, y in test_dataloader:
print("Shape of X [N, C, H, W]: ", X.shape)
print("Shape of y: ", y.shape, y.dtype)
break
Shape of X [N, C, H, W]: torch.Size([64, 1, 28, 28])
Shape of y: torch.Size([64]) torch.int64
# Display sample data
figure = plt.figure(figsize=(10, 8))
cols, rows = 5, 5
for i in range(1, cols * rows + 1):
idx = torch.randint(len(test_data), size=(1,)).item()
img, label = test_data[idx]
figure.add_subplot(rows, cols, i)
plt.title(label)
plt.axis("off")
plt.imshow(img.squeeze(), cmap="gray")
plt.show()

2. Create the model#
To define a neural network in PyTorch, we create a class that inherits from nn.Module. We define the layers of the network in the init function and specify how data will pass through the network in the forward function.
⚠️ Note here we are using a new layer:
nn.Flatten
, which reshapes the 2D image data into an array (so it can be multiplied as a vector in the following layers). https://pytorch.org/docs/stable/generated/torch.nn.Flatten.html
# Define model
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
# The Flatten layer converts each 2D 28x28 image into a contiguous array of 784 pixel values (the minibatch dimension (at dim=0) is maintained).
self.flatten = nn.Flatten()
self.linear_1 = nn.Linear(28*28, 128)
self.linear_2 = nn.Linear(128, 64)
self.linear_3 = nn.Linear(64, 10)
def forward(self, x):
x = self.flatten(x)
x = self.linear_1(x)
x = nn.ReLU()(x)
x = self.linear_2(x)
x = nn.ReLU()(x)
x = self.linear_3(x)
# We don't need to apply the softmax activation here. When it is not included, the loss function will automatically apply it. In the previous class we made it explicit for the sake of clarity.
return x
model = NeuralNetwork()
print(model)
NeuralNetwork(
(flatten): Flatten(start_dim=1, end_dim=-1)
(linear_1): Linear(in_features=784, out_features=128, bias=True)
(linear_2): Linear(in_features=128, out_features=64, bias=True)
(linear_3): Linear(in_features=64, out_features=10, bias=True)
)
3. Train the model and then evaluate#
To train a model, we need a loss function and an optimizer. We’ll be using nn.CrossEntropyLoss for loss and Stochastic Gradient Descent for optimization.
loss_fn = nn.CrossEntropyLoss() # applies the softmax at the end to get probabilities of each class
learning_rate = 1e-2
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
In a single training loop, the model makes predictions on the training dataset (fed to it in batches), and back-propagates the prediction error to adjust the model’s parameters.
def train(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
for batch, (X, y) in enumerate(dataloader):
# Compute prediction error
pred = model(X)
loss = loss_fn(pred, y)
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch % 100 == 0:
loss, current = loss.item(), batch * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
We can also check the model’s performance against the test dataset to ensure it is learning.
def test(dataloader, model):
size = len(dataloader.dataset)
model.eval()
test_loss, correct = 0, 0
with torch.no_grad():
for X, y in dataloader:
pred = model(X)
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= size
correct /= size
print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
The training process is conducted over several iterations (epochs). During each epoch, the model learns parameters to make better predictions. We print the model’s accuracy and loss at each epoch; we’d like to see the accuracy increase and the loss decrease with every epoch.
epochs = 15
for t in range(epochs):
print(f"Epoch {t+1}\n-------------------------------")
train(train_dataloader, model, loss_fn, optimizer)
test(test_dataloader, model)
print("Done!")
Epoch 1
-------------------------------
loss: 2.289097 [ 0/60000]
loss: 2.204153 [ 6400/60000]
loss: 2.007140 [12800/60000]
loss: 1.802666 [19200/60000]
loss: 1.406623 [25600/60000]
loss: 1.252974 [32000/60000]
loss: 1.139174 [38400/60000]
loss: 1.009746 [44800/60000]
loss: 0.990271 [51200/60000]
loss: 0.889658 [57600/60000]
Test Error:
Accuracy: 67.9%, Avg loss: 0.013709
Epoch 2
-------------------------------
loss: 0.927257 [ 0/60000]
loss: 0.924171 [ 6400/60000]
loss: 0.681874 [12800/60000]
loss: 0.846820 [19200/60000]
loss: 0.691251 [25600/60000]
loss: 0.678255 [32000/60000]
loss: 0.771630 [38400/60000]
loss: 0.759616 [44800/60000]
loss: 0.723761 [51200/60000]
loss: 0.680475 [57600/60000]
Test Error:
Accuracy: 76.2%, Avg loss: 0.010562
Epoch 3
-------------------------------
loss: 0.618100 [ 0/60000]
loss: 0.711628 [ 6400/60000]
loss: 0.482961 [12800/60000]
loss: 0.697245 [19200/60000]
loss: 0.599393 [25600/60000]
loss: 0.565808 [32000/60000]
loss: 0.636101 [38400/60000]
loss: 0.662630 [44800/60000]
loss: 0.658170 [51200/60000]
loss: 0.592210 [57600/60000]
Test Error:
Accuracy: 79.2%, Avg loss: 0.009311
Epoch 4
-------------------------------
loss: 0.499205 [ 0/60000]
loss: 0.607933 [ 6400/60000]
loss: 0.407313 [12800/60000]
loss: 0.617116 [19200/60000]
loss: 0.550129 [25600/60000]
loss: 0.520347 [32000/60000]
loss: 0.568434 [38400/60000]
loss: 0.638106 [44800/60000]
loss: 0.633640 [51200/60000]
loss: 0.529101 [57600/60000]
Test Error:
Accuracy: 80.1%, Avg loss: 0.008692
Epoch 5
-------------------------------
loss: 0.432519 [ 0/60000]
loss: 0.554164 [ 6400/60000]
loss: 0.369696 [12800/60000]
loss: 0.566110 [19200/60000]
loss: 0.507076 [25600/60000]
loss: 0.491527 [32000/60000]
loss: 0.530983 [38400/60000]
loss: 0.634127 [44800/60000]
loss: 0.613094 [51200/60000]
loss: 0.486707 [57600/60000]
Test Error:
Accuracy: 80.8%, Avg loss: 0.008325
Epoch 6
-------------------------------
loss: 0.388891 [ 0/60000]
loss: 0.522734 [ 6400/60000]
loss: 0.346615 [12800/60000]
loss: 0.533556 [19200/60000]
loss: 0.476073 [25600/60000]
loss: 0.472929 [32000/60000]
loss: 0.504632 [38400/60000]
loss: 0.626713 [44800/60000]
loss: 0.593248 [51200/60000]
loss: 0.462189 [57600/60000]
Test Error:
Accuracy: 81.5%, Avg loss: 0.008075
Epoch 7
-------------------------------
loss: 0.358328 [ 0/60000]
loss: 0.502070 [ 6400/60000]
loss: 0.328464 [12800/60000]
loss: 0.513722 [19200/60000]
loss: 0.455949 [25600/60000]
loss: 0.460615 [32000/60000]
loss: 0.483159 [38400/60000]
loss: 0.615375 [44800/60000]
loss: 0.576040 [51200/60000]
loss: 0.448871 [57600/60000]
Test Error:
Accuracy: 81.7%, Avg loss: 0.007917
Epoch 8
-------------------------------
loss: 0.336749 [ 0/60000]
loss: 0.484792 [ 6400/60000]
loss: 0.314682 [12800/60000]
loss: 0.499265 [19200/60000]
loss: 0.437410 [25600/60000]
loss: 0.452041 [32000/60000]
loss: 0.465575 [38400/60000]
loss: 0.603855 [44800/60000]
loss: 0.562545 [51200/60000]
loss: 0.438371 [57600/60000]
Test Error:
Accuracy: 81.8%, Avg loss: 0.007762
Epoch 9
-------------------------------
loss: 0.320475 [ 0/60000]
loss: 0.470900 [ 6400/60000]
loss: 0.303407 [12800/60000]
loss: 0.489299 [19200/60000]
loss: 0.419739 [25600/60000]
loss: 0.447229 [32000/60000]
loss: 0.451173 [38400/60000]
loss: 0.591696 [44800/60000]
loss: 0.547607 [51200/60000]
loss: 0.430787 [57600/60000]
Test Error:
Accuracy: 82.2%, Avg loss: 0.007628
Epoch 10
-------------------------------
loss: 0.306368 [ 0/60000]
loss: 0.458059 [ 6400/60000]
loss: 0.292609 [12800/60000]
loss: 0.480203 [19200/60000]
loss: 0.404263 [25600/60000]
loss: 0.440233 [32000/60000]
loss: 0.437055 [38400/60000]
loss: 0.579413 [44800/60000]
loss: 0.535025 [51200/60000]
loss: 0.425971 [57600/60000]
Test Error:
Accuracy: 82.8%, Avg loss: 0.007468
Epoch 11
-------------------------------
loss: 0.294402 [ 0/60000]
loss: 0.445635 [ 6400/60000]
loss: 0.283190 [12800/60000]
loss: 0.470698 [19200/60000]
loss: 0.390130 [25600/60000]
loss: 0.434453 [32000/60000]
---------------------------------------------------------------------------
KeyboardInterrupt Traceback (most recent call last)
/var/folders/l_/k13w4mhd5hv4bddxwqz8qdfw0000gn/T/ipykernel_10774/2998487589.py in <module>
2 for t in range(epochs):
3 print(f"Epoch {t+1}\n-------------------------------")
----> 4 train(train_dataloader, model, loss_fn, optimizer)
5 test(test_dataloader, model)
6 print("Done!")
/var/folders/l_/k13w4mhd5hv4bddxwqz8qdfw0000gn/T/ipykernel_10774/568462196.py in train(dataloader, model, loss_fn, optimizer)
1 def train(dataloader, model, loss_fn, optimizer):
2 size = len(dataloader.dataset)
----> 3 for batch, (X, y) in enumerate(dataloader):
4
5 # Compute prediction error
~/miniforge3/lib/python3.9/site-packages/torch/utils/data/dataloader.py in __next__(self)
629 # TODO(https://github.com/pytorch/pytorch/issues/76750)
630 self._reset() # type: ignore[call-arg]
--> 631 data = self._next_data()
632 self._num_yielded += 1
633 if self._dataset_kind == _DatasetKind.Iterable and \
~/miniforge3/lib/python3.9/site-packages/torch/utils/data/dataloader.py in _next_data(self)
673 def _next_data(self):
674 index = self._next_index() # may raise StopIteration
--> 675 data = self._dataset_fetcher.fetch(index) # may raise StopIteration
676 if self._pin_memory:
677 data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)
~/miniforge3/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py in fetch(self, possibly_batched_index)
49 data = self.dataset.__getitems__(possibly_batched_index)
50 else:
---> 51 data = [self.dataset[idx] for idx in possibly_batched_index]
52 else:
53 data = self.dataset[possibly_batched_index]
~/miniforge3/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py in <listcomp>(.0)
49 data = self.dataset.__getitems__(possibly_batched_index)
50 else:
---> 51 data = [self.dataset[idx] for idx in possibly_batched_index]
52 else:
53 data = self.dataset[possibly_batched_index]
~/miniforge3/lib/python3.9/site-packages/torchvision/datasets/mnist.py in __getitem__(self, index)
140 # doing this so that it is consistent with all other datasets
141 # to return a PIL Image
--> 142 img = Image.fromarray(img.numpy(), mode="L")
143
144 if self.transform is not None:
~/miniforge3/lib/python3.9/site-packages/PIL/Image.py in fromarray(obj, mode)
3067 .. versionadded:: 1.1.6
3068 """
-> 3069 arr = obj.__array_interface__
3070 shape = arr["shape"]
3071 ndim = len(shape)
KeyboardInterrupt:
Exercise What is the best test accuracy you can achieve?
You can play with the architecture and/or the hyperparameters.
4. Making predictions#
Once the model is trained, we can use it to make predictions. We’ll use the argmax
function to get the index of the max value in the output of the model. This index corresponds to the predicted class of clothing.
classes = [
"T-shirt/top",
"Trouser",
"Pullover",
"Dress",
"Coat",
"Sandal",
"Shirt",
"Sneaker",
"Bag",
"Ankle boot",
]
model.eval()
example_idx = 13
x, y = test_data[example_idx][0], test_data[example_idx][1]
with torch.no_grad():
pred = model(x)
predicted, actual = classes[pred[0].argmax(0)], classes[y]
print(f'Predicted: "{predicted}", Actual: "{actual}"')
# plot the image
plt.imshow(x.squeeze(), cmap="gray")
Predicted: "Sandal", Actual: "Sneaker"

Exercise Can you look for three examples of misclassified items?
Are the mistakes reasonable?