import torch
# Load MNIST data
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# Load MNIST data
= datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
mnist_train = datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor(), download=True)
mnist_test
# Inspect data
print(mnist_train)
print(mnist_test)
# Use Data Loader
= DataLoader(mnist_train, batch_size=100, shuffle=True)
train_loader = DataLoader(mnist_test, batch_size=100, shuffle=False)
test_loader
# Iterate through data
for images, labels in train_loader:
print('Image batch dimensions:', images.shape)
print('Image label dimensions:', labels.shape)
break
Transfer Learning
Introduction
Do you need help to train complex models due to limited high-quality data and resources? 🤔 Don’t panic! The answer is Transfer Learning. This technique leverages pre-trained models, like BERT for NLP or ImageNet for image classification, to significantly reduce training time. Think of it as teaching an old dog new tricks: You can easily adapt an ImageNet model to tasks like dog breed classification. And voila! You’ve just made quick progress even with scarce data.
How does Neural Network change in Transfer Learning?
Imagine a chef who is skilled in baking cakes. Suppose this chef needs to cook a new dish, like pasta. Instead of starting from scratch, they leverage their culinary skills, adjusting only where necessary for the pasta-specific nuances.
Similarly, two possible approaches exist in machine learning: “Training from Scratch” and “Transfer Learning”. In the former, a model like CNN is trained on a new dataset, like Vehicles, without prior knowledge. In the latter, the model leverages prior knowledge acquired from a different dataset, like Animals, and adjusts this understanding to the new task.
The image above illustrates this concept. As shown, a model trained from scratch (the top one) is set up to learn directly from the vehicle dataset. It starts with no inherent understanding of images and must learn the features that differentiate one vehicle from another.
In contrast, a transfer learning model (the bottom) begins with a pre-trained network with pre-existing knowledge about different animals. This model is fine-tuned to distinguish different types of vehicles, typically achieving faster and more efficient results than training from scratch.
In essence, while both models aim to classify different types of vehicles, they learn differently: the model trained from scratch learns all features independently, like a chef learning a new dish from scratch, whereas the transfer learning model refines existing knowledge for the new task, similar to a chef adapting their existing skills to a new recipe.
Hands-On with MNIST
Let’s see transfer learning in action on the famous MNIST dataset, a large collection of handwritten numbers. We will use PyTorch, a powerful open source library for machine learning in Python, as we have done in previous PyTorch materials.
Data Loader
First, we have to load the data and check it. But how can you load large data sets in manageable batches? PyTorch DataLoader makes this possible and efficient. This is especially useful when the data is large and cannot be loaded into memory.
Visualizing the Data
Visualizing our data can help us understand it better. But what if you’re not sure how to display images from your dataset? No worries! With matplotlib, a powerful plotting library in Python, we can easily visualize our images.
import matplotlib.pyplot as plt
import numpy as np
# Define label mapping
= [str(i) for i in range(10)] # because MNIST has 10 classes from digit 0 to digit 9
classes # Load a batch of images
= next(iter(train_loader))
images, labels # Convert images to numpy for visualization
= images.numpy()
images # Convert images from 1 channel to 3 channels for visualization
= np.repeat(images, 3, axis=1)
images # Plot the images with their labels
= plt.figure(figsize=(25, 4))
fig
# Display 20 images
for idx in range(20):
= fig.add_subplot(2, 20//2, idx+1, xticks=[], yticks=[])
ax 1, 2, 0)))
plt.imshow(np.transpose(images[idx], ( ax.set_title(classes[labels[idx]])
OK, now let’s start training with the model we created ourselves.
import torch.nn as nn
# Define model
class CreateModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(784, 100)
self.hidden = nn.Linear(100, 10)
def forward(self, xb):
# Flatten the image tensors using reshape
= xb.reshape(-1, 784)
xb = self.linear(xb)
out = self.hidden(out)
out return out
# Instantiate the model
= CreateModel()
model
# Define loss function
= nn.CrossEntropyLoss()
loss_fn
# Define optimizer
= 1e-2
learning_rate = torch.optim.SGD(model.parameters(), lr=learning_rate)
optimizer
# Define accuracy function
def accuracy(outputs, labels):
= torch.max(outputs, dim=1)
_, preds return torch.tensor(torch.sum(preds == labels).item() / len(preds))
# Train
for epoch in range(20):
for images, labels in train_loader:
# Generate predictions
= model(images)
outputs = loss_fn(outputs, labels)
loss # Perform gradient descent
optimizer.zero_grad()
loss.backward()
optimizer.step()print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, 20, loss.item()))
# Evaluate
with torch.no_grad():
= 0
accum_acc for images, labels in test_loader:
= model(images)
outputs = loss_fn(outputs, labels)
loss = accuracy(outputs, labels)
acc += acc
accum_acc
print('Test loss: {:.4f}, Test accuracy: {:.4f}'.format(loss.item(), accum_acc/len(test_loader)))
Making Individual Predictions
Building and training a model is great, but how do we make predictions on individual images? What if you’re not sure how to use your newly trained model? The predict_image
function provides a straightforward way to get your model’s prediction for a single image.
import matplotlib.pyplot as plt
def predict_image(img, model):
= img.unsqueeze(0)
xb = model(xb)
yb = torch.max(yb, dim=1)
_, preds return preds[0].item()
= mnist_test[4] # Explore the data with index
img, label 0], cmap='gray')
plt.imshow(img[print('Label:', label, ', Predicted:', predict_image(img, model))
Fine-tuning with FashionMNIST
Now, what if we want to classify different types of clothing, not just digits? Transfer learning to the rescue! We can fine-tune our pre-trained model on the new data to achieve great results.
That’s where the FashionMNIST dataset comes in. It’s a dataset of Zalando’s article images, consisting of a training set of 60,000 examples and a test set of 10,000 examples. Each example is a 28x28 grayscale image, associated with a label from 10 classes.
Sounds daunting? Don’t worry! With transfer learning, we can leverage the experience our model gained from MNIST to tackle this new task.
Data Loader
The first step, we have to load the FashionMNIST data as we do in MNIST.
# Load FashionMNIST data
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
= datasets.FashionMNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
fashion_train = datasets.FashionMNIST(root='./data', train=False, transform=transforms.ToTensor(), download=True)
fashion_test
# Inspect data
print(fashion_train)
print(fashion_test)
# Use Data Loader
= DataLoader(fashion_train, batch_size=100, shuffle=True)
fashion_train_loader = DataLoader(fashion_test, batch_size=100, shuffle=False)
fashion_test_loader
# Iterate through data
for images, labels in fashion_train_loader:
print('Image batch dimensions:', images.shape)
print('Image label dimensions:', labels.shape)
break
Visualizing the FashionMNIST Data with Class Labels
Visualizing data is helpful, but wouldn’t it be even better if we could visualize it with the corresponding class labels? What if you’re not sure how to map the labels of your data to their actual class names? That’s where the classes
list comes into play.
import matplotlib.pyplot as plt
import numpy as np
# Define label mapping for FashionMNIST
= ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
classes 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
= next(iter(fashion_train_loader))
images, labels = images.numpy()
images = np.repeat(images, 3, axis=1)
images = plt.figure(figsize=(25, 4))
fig
# Display 20 images
for idx in range(20):
= fig.add_subplot(2, 20//2, idx+1, xticks=[], yticks=[])
ax 1, 2, 0)))
plt.imshow(np.transpose(images[idx], (f"{labels[idx]}. {classes[labels[idx]]}") ax.set_title(
Testing the Model’s Predictions on Unseen FashionMNIST Data
After training our model on the MNIST dataset, we might wonder how it would perform on the FashionMNIST dataset without any further training. What if we could use our model to make a prediction on a FashionMNIST image? By calling the predict_image
function, we can do exactly that!
import matplotlib.pyplot as plt
= fashion_test[6] # Explore the data with index
img, label 0], cmap='gray')
plt.imshow(img[print('Label:', label, ', Predicted:', predict_image(img, model))
Unfortunately, without any further training, the model struggles to correctly classify images from the FashionMNIST dataset. This result may seem disappointing, but it is not entirely surprising. The MNIST and FashionMNIST datasets, despite sharing the same structure, represent completely different kinds of images (handwritten digits versus clothing items), and it’s a tough ask for a model trained specifically on digits to accurately classify clothing items.
This is exactly where the power of transfer learning shines! Let’s start training with the model architecture that we used for MNIST but this time, we’ll fine-tune it on the FashionMNIST data. With this approach, our model can quickly learn to generalize from digits to clothing items, showcasing the strength of transfer learning in practice.
# Fine-tuning with FashionMNIST data
for epoch in range(10):
for images, labels in fashion_train_loader:
# Generate predictions
= model(images)
outputs = loss_fn(outputs, labels)
loss
# Perform gradient descent
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, 20, loss.item()))
# Evaluate with FashionMNIST test data
with torch.no_grad():
= 0
accum_acc for images, labels in fashion_test_loader:
= model(images)
outputs = loss_fn(outputs, labels)
loss = accuracy(outputs, labels)
acc += acc
accum_acc
print('Test loss: {:.4f}, Test accuracy: {:.4f}'.format(loss.item(), accum_acc/len(fashion_test_loader)))
Testing the Model’s Predictions After Fine-Tuning
After fine-tuning our model on the FashionMNIST dataset, we need to verify if it improved the model’s performance. What if we could test our model’s prediction on a FashionMNIST image again, but this time after fine-tuning? That’s precisely where the predict_image
function comes in handy.
import matplotlib.pyplot as plt
= {
label_map 0: "T-shirt/top",
1: "Trouser",
2: "Pullover",
3: "Dress",
4: "Coat",
5: "Sandal",
6: "Shirt",
7: "Sneaker",
8: "Bag",
9: "Ankle boot",
}
= fashion_test[1] # Explore the data with index
img, label 0], cmap='gray')
plt.imshow(img[print(f"Label: {label}. is: {label_map[label]}, Predicted: {predict_image(img, model)}. is: {label_map[predict_image(img, model)]}")
Exploring Further Into Fine-Tuning and Use Cases
After testing our model on FashionMNIST data, we should now further deepen our understanding of fine-tuning. This sophisticated technique is not only limited to image classification tasks but can also be applied across various domains, including:
- Natural Language Processing (NLP)
- Text Classification
- Summarization
- And more
- Audio
- Audio classification
- Automatic speech recognition
- And more
- Computer Vision
- Image Classification
- Object detection
- And more
- Multimodal
- Image Captioning
- Document Question Answering
- And more
One such example is audio classification. By using a pre-trained model like Wav2Vec2 and fine-tuning it on a specific dataset, we can create a powerful audio classification model even with less training data. This process includes several steps such as loading the dataset, preprocessing the data, setting up an evaluation metric, and finally training and evaluating the model.
Let’s see how we can apply fine-tuning to develop an audio classification model and test its performance!
Audio Classification
Imagine you are a wildlife conservationist, you have a library of sounds, and you want to classify whether they come from one animal species or another. You struggle because there are so many audio files and analyzing them manually takes so much time and effort!
But don’t worry! With the power of AI, we can make this process significantly easier and faster using a technique called Audio Classification.
Tasks: Audio Classification (source: youtube.com)
Audio Classification can be used for various applications, including detecting speaker’s intent, classifying languages, or in our case, identifying animal species by their sounds.
For this purpose, we are going to utilize the pre-trained Wav2Vec2 model and fine-tune it on the MInDS-14 dataset specifically designed for this task. This approach reduces our training time and improves the model’s performance even with less data!
Before we start, we need to make sure to have all the necessary Python libraries installed. Run this command in your Python environment:
%pip install transformers datasets evaluate
We also encourage you to log in to your Hugging Face account so you can upload and share your model with the community. If you are prompted, enter your token to login:
from huggingface_hub import notebook_login
notebook_login()
Load MInDS-14 dataset
Now we are ready to load our dataset, MInDS-14, from the Datasets library:
from datasets import load_dataset, Audio
= load_dataset("PolyAI/minds14", name="en-US", split="train") minds
We will then split our dataset into a smaller training and testing set. This step allows us to experiment and validate our model before spending more time on the full dataset.
= minds.train_test_split(test_size=0.2) minds
Then take a look at the dataset:
minds
While our dataset contains a lot of useful information, we will focus on the audio and intent_class in this guide. Let’s remove the other columns:
= minds.remove_columns(["path", "transcription", "english_transcription", "lang_id"]) minds
Now let’s take a look at an example in our dataset:
"train"][0] minds[
Great! Now we have our dataset loaded and ready. Its fields are: - audio
: a 1-dimensional array of the speech signal that we will use for training our model. - intent_class
: represents the class id of the speaker’s intent (or in our case, the species of the animal).
To make it easier for the model to get the label name from the label id, we create a dictionary that maps the label name to an integer and vice versa:
= minds["train"].features["intent_class"].names
labels = dict(), dict()
label2id, id2label for i, label in enumerate(labels):
= str(i)
label2id[label] str(i)] = label id2label[
Now we can easily convert a label id to a label name:
str(2)] id2label[
Preprocess
Now comes the interesting part: Preprocessing. In this step, we are going to load a Wav2Vec2 feature extractor to process the audio signal:
from transformers import AutoFeatureExtractor
= AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base") feature_extractor
The MInDS-14 dataset has a sampling rate of 8000kHz, as found in its dataset card. The trained Wav2Vec2 model, however, requires the audio input to have a sampling rate of 16000kHz. Therefore, we must resample our dataset to convert the sampling rate from 8000kHz to 16000kHz to meet the model’s requirements.
= minds.cast_column("audio", Audio(sampling_rate=16_000))
minds "train"][0] minds[
We then create a preprocessing function to load and resample the audio file and ensure it matches the sampling rate of the model’s pre-training data.
def preprocess_function(examples):
= [x["array"] for x in examples["audio"]]
audio_arrays = feature_extractor(
inputs =feature_extractor.sampling_rate, max_length=16000, truncation=True
audio_arrays, sampling_rate
)return inputs
We use the Datasets map function to apply the preprocessing function across the complete dataset. Speed it up by enabling batched=True
to process multiple dataset elements simultaneously. Remove unnecessary columns and rename intent_class
to label
, as the model expects this name.
Let’s apply our preprocessing function across the complete dataset:
= minds.map(preprocess_function, remove_columns="audio", batched=True) encoded_minds
After applying our preprocessing function, we rename our column name from intent_class
to label
because our model expects this name:
= encoded_minds.rename_column("intent_class", "label") encoded_minds
Setting up Evaluation Metric
Now, how can we know if our model is performing well? We need a yardstick to measure it. That’s when Evaluation Metrics come into play. We will use the accuracy
metric from the evaluate
library for this task.
import evaluate
= evaluate.load("accuracy") accuracy
Then, we create a function that will take our model’s predictions and labels to calculate the accuracy:
import numpy as np
def compute_metrics(eval_pred):
= np.argmax(eval_pred.predictions, axis=1)
predictions return accuracy.compute(predictions=predictions, references=eval_pred.label_ids)
Now, our compute_metrics
function is ready to be used for training.
Training and Evaluation
We’ve come a long way! It’s time to Train and Evaluate our model. We’re going to load the Wav2Vec2 model along with the number of expected labels, and the label mappings:
from transformers import AutoModelForAudioClassification
= len(id2label)
num_labels = AutoModelForAudioClassification.from_pretrained(
model "facebook/wav2vec2-base", num_labels=num_labels, label2id=label2id, id2label=id2label
)
During training, we will use an optimization strategy where we set a learning rate and batch size. We also set the load_best_model_at_end=True
option which means our trainer will load the model with the highest accuracy at the end of the training:
from transformers import TrainingArguments, Trainer
= TrainingArguments(
training_args ="model/audio_classification",
output_dir="epoch",
evaluation_strategy="epoch",
save_strategy=3e-5, # learning rate
learning_rate=32, # training batch size
per_device_train_batch_size=4,
gradient_accumulation_steps=32,
per_device_eval_batch_size=10,
num_train_epochs=0.1,
warmup_ratio=10,
logging_steps=True, # model optimization via early stopping
load_best_model_at_end="accuracy",
metric_for_best_model=True, # When set to True, this allows the trained model to be directly uploaded to the Hugging Face Model Hub.
push_to_hub
)
= Trainer(
trainer =model,
model=training_args,
args=encoded_minds["train"],
train_dataset=encoded_minds["test"],
eval_dataset=feature_extractor,
tokenizer=compute_metrics,
compute_metrics
)
trainer.train()
Once training is completed, we can share our model to the Hugging Face Model Hub so that everyone can use our model:
trainer.push_to_hub()
Here is an example of a model that has been trained: https://huggingface.co/aditira/audio_classification
Inference
Finally, let’s test our model with a new audio file. This step is called Inference. We will load an audio file and run our model to classify it. Remember to resample the audio file to match the model’s sampling rate!
from datasets import load_dataset, Audio
from transformers import pipeline
= load_dataset("PolyAI/minds14", name="en-US", split="train")
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
dataset = dataset.features["audio"].sampling_rate
sampling_rate = dataset[0]["audio"]["path"]
audio_file
= pipeline("audio-classification", model="model/audio_classification")
classifier classifier(audio_file)
Congratulations! You just classified an audio file using your fine-tuned model! Now, you can classify your entire library of animal sounds and make your work as a wildlife conservationist a bit easier.
Image classification
Imagine you’ve just been hired as a data scientist at a health-focused tech startup. You’re tasked with developing a system that identifies food items from images uploaded by users. The goal is to enable users to track their nutrition intake simply by taking a picture. But how would you go about this?
Well, the good news is, with the advent of Image Classification in machine learning, this task is no longer a pipe dream, but a very achievable reality.
Tasks: Image Classification (source: youtube.com)
Image classification can be used in countless applications, ranging from detecting objects in satellite images to medical imaging. In our case, we aim to classify food items from given images. We will use a pre-trained model called the Vision Transformer, or ViT for short. ViT is a model that applies the transformer architecture, which was initially built for text data, to image data. We will fine-tune this model on the Food-101 dataset to classify images into 101 food categories.
Load Food-101 dataset
Let’s start with Loading our Food-101 dataset. We will load only a small subset of it to ensure everything works before committing to training on the full dataset:
from datasets import load_dataset
= load_dataset("food101", split="train[:5000]") food
We then split our dataset into a training set and a testing set:
= food.train_test_split(test_size=0.2) food
Let’s take a look at an example from our dataset:
"train"][0] food[
Each example in the dataset has two fields: - image
: a PIL image of the food item - label
: the label class of the food item
To make it easier for the model to get the label name from the label id, create a dictionary that maps the label name to an integer and vice versa:
= food["train"].features["label"].names
labels = dict(), dict()
label2id, id2label for i, label in enumerate(labels):
= str(i)
label2id[label] str(i)] = label id2label[
Now we can easily convert a label id to a label name:
str(79)] id2label[
Preprocess
The next step is Preprocessing. Here we load a ViT image processor to process the image into a tensor:
from transformers import AutoImageProcessor
= "google/vit-base-patch16-224-in21k"
checkpoint = AutoImageProcessor.from_pretrained(checkpoint) image_processor
To make our model robust against overfitting, we apply some image transformations to the images. Here we are using torchvision’s transforms module, but you can also use any image library you like. We crop a random part of the image, resize it, and normalize it with the image mean and standard deviation:
from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor
= Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
normalize = (
size "shortest_edge"]
image_processor.size[if "shortest_edge" in image_processor.size
else (image_processor.size["height"], image_processor.size["width"])
)= Compose([RandomResizedCrop(size), ToTensor(), normalize]) _transforms
Then we create a preprocessing function to apply the transforms and return the pixel_values - the inputs to the model - of the image:
def transforms(examples):
"pixel_values"] = [_transforms(img.convert("RGB")) for img in examples["image"]]
examples[del examples["image"]
return examples
To apply the preprocessing function over the entire dataset, use Datasets with_transform method. The transforms are applied on the fly when you load an element of the dataset:
= food.with_transform(transforms) food
Next, we create a batch of examples using DefaultDataCollator. Unlike other data collators, the DefaultDataCollator does not apply additional preprocessing such as padding.
from transformers import DefaultDataCollator
= DefaultDataCollator() data_collator
Setting up Evaluation Metric
Including a metric during training is beneficial for evaluating your model’s performance. We can load an evaluation method with the Evaluate library:
import evaluate
= evaluate.load("accuracy") accuracy
Now, we’ll create a function that will take our model’s predictions and labels to calculate the accuracy:
import numpy as np
def compute_metrics(eval_pred):
= eval_pred
predictions, labels = np.argmax(predictions, axis=1)
predictions return accuracy.compute(predictions=predictions, references=labels)
Our compute_metrics
function is ready now, and we’ll use it when we set up our training.
Training and Evaluation
It’s time to Train and Evaluate our model. We’re going to load ViT with AutoModelForImageClassification
. We’ll specify the number of labels along with the number of expected labels, and the label mappings:
from transformers import AutoModelForImageClassification
= AutoModelForImageClassification.from_pretrained(
model
checkpoint,=len(labels),
num_labels=id2label,
id2label=label2id,
label2id )
At this point, we’ll set our training hyperparameters in TrainingArguments. We’ll make sure to set remove_unused_columns=False
to keep the image column, which is crucial for creating pixel_values
. We’ll specify output_dir
to save your model and enable push_to_hub
to upload to Hugging Face. The Trainer will evaluate the accuracy and save a checkpoint after each epoch:
from transformers import TrainingArguments, Trainer
= TrainingArguments(
training_args ="model/image_classification",
output_dir=False,
remove_unused_columns="epoch",
evaluation_strategy="epoch",
save_strategy=5e-5, # learning rate
learning_rate=16, # training batch size
per_device_train_batch_size=4,
gradient_accumulation_steps=16,
per_device_eval_batch_size=3,
num_train_epochs=0.1,
warmup_ratio=10,
logging_steps=True, # model optimization via early stopping
load_best_model_at_end="accuracy",
metric_for_best_model=True, # When set to True, this allows the trained model to be directly uploaded to the Hugging Face Model Hub.
push_to_hub
)
= Trainer(
trainer =model,
model=training_args,
args=data_collator,
data_collator=food["train"],
train_dataset=food["test"],
eval_dataset=image_processor,
tokenizer=compute_metrics,
compute_metrics
)
trainer.train()
Evaluating the Trained Model
After the training is completed, we need to evaluate the model. We’ll use the accuracy metric from the evaluate
library to assess the model’s performance.
= trainer.evaluate(eval_dataset=food["test"])
eval_result print(eval_result)
The evaluation results will give us the accuracy of the test set. If this accuracy is satisfactory, we could decide to publish the model. If not, we might need to revisit the preprocessing, model architecture, or training process (e.g., tuning hyperparameters and increasing the number of epochs).
Once the training is completed, we can share our model to Hugging Face Model Hub so the model can be accessed by anyone:
trainer.push_to_hub()
Here is an example of a model that has been trained: https://huggingface.co/aditira/image_classification
Inference
With the model trained, we can now make use of it for Inference. Let’s load an image we’d like to run inference on:
= load_dataset("food101", split="validation[:10]")
ds = ds["image"][0] image
The simplest way to try out our fine-tuned model for inference is to use it in a pipeline(). Instantiate a pipeline for image classification with our model, and pass our image to it:
from transformers import pipeline
= pipeline("image-classification", model="my_awesome_food_model")
classifier classifier(image)
Congratulations! You’ve managed to classify a food item from an image, bringing our startup’s vision one step closer to reality. Now you can continue to refine and apply this model, making nutrition tracking easier and more accessible for users around the world.
Conclusion: The Power of Fine-Tuning and the Right Choice of Pre-trained Models
Think of learning as building a house. Starting from scratch, you’d need to lay down the foundation, put up the walls, install the plumbing and wiring, and then finally add the finishing touches like paint and furniture.
Building a House (source: kutaisitoday.com)
Similarly, training a model from scratch requires learning all the features and architectures, often requiring a large amount of data and computational resources, which is not always feasible.
Now, what if you could start with a house that’s already built and just rearrange the furniture and repaint the walls to suit your taste? This is the idea behind transfer learning or fine-tuning. We start with a model that has already learned useful features from a large-scale dataset (the pre-built house), and then fine-tune it on our specific task (rearranging and repainting).
In our case, the ‘house’ is the Vision Transformer (ViT), pre-trained on the ImageNet dataset. It’s a robust and versatile model, well-suited for various image classification tasks. But it’s not specialized in identifying food items.
So, we give it a makeover. We fine-tune the ViT on our Food-101 dataset. The model keeps its acquired knowledge and tailors it to our specific purpose - classifying food images. It’s like hiring an interior designer (the fine-tuning process) to transform a generic house (the pre-trained ViT) into a gourmet restaurant (the food-classification model).
But why ViT? Why not another model? 🤔
The choice of a pre-trained model is crucial. It should ideally have an architecture that’s well-suited to your task and has been trained on a large, diverse dataset. ViT is a good choice because it brings the power of transformers, which can capture complex dependencies in data, into the realm of vision. It’s been pre-trained on ImageNet, a large-scale, diverse dataset, allowing it to learn a wide variety of features.
Yet it’s important to remember that there is no one-size-fits-all model. The choice would depend on your task, data, and computational resources.