import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
import wandb
import random
= "cuda" if torch.cuda.is_available() else "cpu"
device
# Define CNN architecture (example - customize as needed)
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5) # Input channels, output channels, kernel size
self.pool = nn.MaxPool2d(2, 2) # Kernel size, stride (optional)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120) # Input features, output features
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10) # Output for 10 CIFAR-10 classes
def forward(self, x):
= self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5) # Flatten for fully-connected layers
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
x return x
def validate_model(model, valid_dl, loss_func, log_images=False, batch_idx=0):
"Compute performance of the model on the validation dataset and log a wandb.Table"
eval()
model.= 0.
val_loss with torch.inference_mode():
= 0
correct for i, (images, labels) in enumerate(valid_dl):
= images.to(device), labels.to(device)
images, labels
# Forward pass ➡
= model(images)
outputs += loss_func(outputs, labels)*labels.size(0)
val_loss
# Compute accuracy and accumulate
= torch.max(outputs.data, 1)
_, predicted += (predicted == labels).sum().item()
correct
# Log one batch of images to the dashboard, always same batch_idx.
if i==batch_idx and log_images:
=1))
log_image_table(images, predicted, labels, outputs.softmax(dimreturn val_loss / len(valid_dl.dataset), correct / len(valid_dl.dataset)
def log_image_table(images, predicted, labels, probs):
"Log a wandb.Table with (img, pred, target, scores)"
# 🐝 Create a wandb Table to log images, labels and predictions to
= wandb.Table(columns=["image", "pred", "target"]+[f"score_{i}" for i in range(10)])
table for img, pred, targ, prob in zip(images.to("cpu"), predicted.to("cpu"), labels.to("cpu"), probs.to("cpu")):
0].numpy()*255), pred, targ, *prob.numpy())
table.add_data(wandb.Image(img["predictions_table":table}, commit=False) wandb.log({
import math
# Initialize Wandb project
wandb.init(="cifar10-cnn-wandb",
project={
config"epochs": 5,
"batch_size": 128,
"lr": 2e-3,
"dropout": random.uniform(0.01, 0.80),
"threshold_accuracy": 0.8
},=True)
save_code
= wandb.config
config
# Load and transform CIFAR-10 data
= transforms.Compose([
transform
transforms.ToTensor(),0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
transforms.Normalize((
])
= torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testset
= torch.utils.data.DataLoader(trainset, batch_size=config.batch_size, shuffle=True)
trainloader = torch.utils.data.DataLoader(testset, batch_size=config.batch_size, shuffle=False)
testloader
= math.ceil(len(trainloader.dataset) / config.batch_size) n_steps_per_epoch
Finishing last run (ID:xivbf7yh) before initializing another...
View run eternal-pine-3 at: https://wandb.ai/ayamerushia/cifar10-cnn-wandb/runs/xivbf7yh
Synced 6 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)
Synced 6 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)
Find logs at:
./wandb/run-20240307_093720-xivbf7yh/logs
Successfully finished last run (ID:xivbf7yh). Initializing new run:
Tracking run with wandb version 0.16.4
Run data is saved locally in
/Users/fa-15566/ruangguru/github_rg/research-pribadi/bootcamp-ai/wandb/wandb/run-20240307_093756-f1m7acxw
View project at https://wandb.ai/ayamerushia/cifar10-cnn-wandb
Files already downloaded and verified
Files already downloaded and verified
# Define model, optimizer, and loss function
= CNN().to(device)
model = nn.CrossEntropyLoss()
criterion = torch.optim.Adam(model.parameters(), lr=config.lr) optimizer
# Training loop
for epoch in range(config.epochs):
= 0.0
running_loss = 0
example_ct = 0
step_ct for step, (images, labels) in enumerate(trainloader, 0):
= images.to(device), labels.to(device)
inputs, labels
optimizer.zero_grad()= model(inputs)
outputs = criterion(outputs, labels)
loss
loss.backward()
optimizer.step()
+= loss.item()
running_loss
+= len(inputs)
example_ct = {"train/train_loss": loss,
metrics "train/epoch": (step + 1 + (n_steps_per_epoch * epoch)) / n_steps_per_epoch,
"train/example_ct": example_ct}
if step + 1 < n_steps_per_epoch:
# 🐝 Log train metrics to wandb
wandb.log(metrics)+= 1
step_ct
= validate_model(model, testloader, criterion, log_images=(epoch==(config.epochs-1)))
val_loss, accuracy = {
val_metrics "val/val_loss": val_loss,
"val/val_accuracy": accuracy
}**metrics, **val_metrics})
wandb.log({print(f"Train Loss: {loss:.3f}, Valid Loss: {val_loss:3f}, Accuracy: {accuracy:.2f}")
if accuracy > config.threshold_accuracy:
wandb.alert(='Low Accuracy',
title=f'Accuracy {accuracy} at step {step_ct} is below the acceptable theshold, {config.threshold_accuracy}',
text )
Train Loss: 1.429, Valid Loss: 1.424600, Accuracy: 0.48
Train Loss: 1.476, Valid Loss: 1.267869, Accuracy: 0.54
Train Loss: 1.086, Valid Loss: 1.173016, Accuracy: 0.59
Train Loss: 1.096, Valid Loss: 1.134961, Accuracy: 0.60
Train Loss: 1.137, Valid Loss: 1.126435, Accuracy: 0.60
wandb.finish()
Run history:
train/epoch | ▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███ |
train/example_ct | ▁▂▃▄▅▆▆▇▁▂▃▄▅▆▇█▁▂▃▄▅▆▇█▁▂▃▄▅▆▇█▁▂▃▄▅▆▇█ |
train/train_loss | █▆▅▅▅▄▄▄▃▄▄▄▄▄▄▂▃▃▃▃▂▃▃▄▂▁▃▂▂▃▂▂▂▂▁▃▃▂▂▂ |
val/val_accuracy | ▁▅▇██ |
val/val_loss | █▄▂▁▁ |
Run summary:
train/epoch | 5.0 |
train/example_ct | 50000 |
train/train_loss | 1.13669 |
val/val_accuracy | 0.6026 |
val/val_loss | 1.12644 |
View run polished-armadillo-4 at: https://wandb.ai/ayamerushia/cifar10-cnn-wandb/runs/f1m7acxw
Synced 6 W&B file(s), 1 media file(s), 129 artifact file(s) and 0 other file(s)
Synced 6 W&B file(s), 1 media file(s), 129 artifact file(s) and 0 other file(s)
Find logs at:
./wandb/run-20240307_093756-f1m7acxw/logs