The classic “Hello World” when discussing image classification is using CNN on the MNIST dataset. The dataset contains 60,000-item where each MNIST image is a crude 28 x 28 pixel grayscale handwritten digit from “0” to “9.”
Regular NN on MNIST dataset
Now, can we create an image classifier using a regular neural network (without CNN) ? Yes, we can, actually we already did it back when we are studying Machine Learning/Deep Learning. Here’s the code:
Hey, it works, it also have 91% accuracy. There is no problem right ?
Well, on a simple image such as the MNIST dataset, which contains only black and white colors as well as simple shapes, that’s true. However, the images we encounter in the real world are far more complex and diverse in terms of colors, textures, and objects.
To tackle these challenges effectively, specialized neural network architectures like Convolutional Neural Networks (CNNs) have emerged as the preferred choice, as they are designed to capture spatial hierarchies, local features, and patterns, making them well-suited for the diverse and intricate nature of real-world images.
Note: - Spatial Hierarchies: The network can learn to recognize patterns, shapes, and structures in an image in a hierarchical manner. This involves identifying simple features (such as edges and corners) at lower levels and progressively combining them to recognize more complex and abstract objects or concepts at higher levels. - Local Features: The network can identify and focus on specific regions or elements within an image that are relevant for recognition or classification. These local features can be small patterns, textures, or details within an image that contribute to the overall understanding of what the image represents.
CNN on MNIST dataset
Let’s try converting the code above to it’s CNN version
import torchimport torch.nn as nnimport torch.optim as optimfrom torchvision import datasets, transformsfrom torch.utils.data import DataLoader# Define a CNN model for MNISTclass CNNModel(nn.Module):def__init__(self):super(CNNModel, self).__init__()self.conv1 = nn.Conv2d(1, 32, kernel_size=3)self.conv2 = nn.Conv2d(32, 64, kernel_size=3)self.fc1 = nn.Linear(64*5*5, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x): x = torch.relu(self.conv1(x)) x = torch.max_pool2d(x, 2) x = torch.relu(self.conv2(x)) x = torch.max_pool2d(x, 2) x = x.view(x.size(0), -1) x = torch.relu(self.fc1(x)) x =self.fc2(x)return x# Define data transformationstransform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])# Load MNIST datamnist_train = datasets.MNIST(root='./data', train=True, transform=transform, download=True)mnist_test = datasets.MNIST(root='./data', train=False, transform=transform, download=True)# Use Data Loadertrain_loader = DataLoader(mnist_train, batch_size=100, shuffle=True)test_loader = DataLoader(mnist_test, batch_size=100, shuffle=False)# Instantiate the CNN modelcnn_model = CNNModel()# Define loss function and optimizerloss_fn = nn.CrossEntropyLoss()learning_rate =0.01optimizer = optim.SGD(cnn_model.parameters(), lr=learning_rate)# Define accuracy functiondef accuracy(outputs, labels): _, preds = torch.max(outputs, dim=1)return torch.tensor(torch.sum(preds == labels).item() /len(preds))# Training looptotal_epochs =5for epoch inrange(total_epochs):for images, labels in train_loader: outputs = cnn_model(images) loss = loss_fn(outputs, labels) optimizer.zero_grad() loss.backward() optimizer.step()print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, total_epochs, loss.item()))# Evaluation#cnn_model.eval()with torch.no_grad(): accum_acc =0for images, labels in test_loader: outputs = cnn_model(images) loss = loss_fn(outputs, labels) acc = accuracy(outputs, labels) accum_acc += accprint('Test loss: {:.4f}, Test accuracy: {:.4f}'.format(loss.item(), accum_acc/len(test_loader)))
Our new code defines a CNN model with two convolutional layers followed by fully connected layers. It also normalizes the data to have a mean of 0.5 and a standard deviation of 0.5. Normalizing the data ensures that the pixel values have a consistent scale, usually between 0 and 1 or -1 and 1. This helps stabilize training, as large input values can lead to unstable gradients during backpropagation.
Still not convinced, well you can try modifying the code above to use the CIFAR10 which you can find on Huggingface. The CIFAR10 dataset presents a more complex challenge compared to MNIST due to its color images (32x32 pixels RGB) and diverse set of object categories (including animals, vehicles, and everyday objects).
We’ll skip the CIFAR10 notebook, but if you are interested in the result, you can visit this notebook: NN vs CNN CIFAR10