import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
import wandb
import random
device = "cuda" if torch.cuda.is_available() else "cpu"
# Define CNN architecture (example - customize as needed)
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5) # Input channels, output channels, kernel size
self.pool = nn.MaxPool2d(2, 2) # Kernel size, stride (optional)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120) # Input features, output features
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10) # Output for 10 CIFAR-10 classes
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5) # Flatten for fully-connected layers
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
def validate_model(model, valid_dl, loss_func, log_images=False, batch_idx=0):
"Compute performance of the model on the validation dataset and log a wandb.Table"
model.eval()
val_loss = 0.
with torch.inference_mode():
correct = 0
for i, (images, labels) in enumerate(valid_dl):
images, labels = images.to(device), labels.to(device)
# Forward pass ➡
outputs = model(images)
val_loss += loss_func(outputs, labels)*labels.size(0)
# Compute accuracy and accumulate
_, predicted = torch.max(outputs.data, 1)
correct += (predicted == labels).sum().item()
# Log one batch of images to the dashboard, always same batch_idx.
if i==batch_idx and log_images:
log_image_table(images, predicted, labels, outputs.softmax(dim=1))
return val_loss / len(valid_dl.dataset), correct / len(valid_dl.dataset)
def log_image_table(images, predicted, labels, probs):
"Log a wandb.Table with (img, pred, target, scores)"
# 🐝 Create a wandb Table to log images, labels and predictions to
table = wandb.Table(columns=["image", "pred", "target"]+[f"score_{i}" for i in range(10)])
for img, pred, targ, prob in zip(images.to("cpu"), predicted.to("cpu"), labels.to("cpu"), probs.to("cpu")):
table.add_data(wandb.Image(img[0].numpy()*255), pred, targ, *prob.numpy())
wandb.log({"predictions_table":table}, commit=False)import math
# Initialize Wandb project
wandb.init(
project="cifar10-cnn-wandb",
config={
"epochs": 5,
"batch_size": 128,
"lr": 2e-3,
"dropout": random.uniform(0.01, 0.80),
"threshold_accuracy": 0.8
},
save_code=True)
config = wandb.config
# Load and transform CIFAR-10 data
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=config.batch_size, shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=config.batch_size, shuffle=False)
n_steps_per_epoch = math.ceil(len(trainloader.dataset) / config.batch_size)
Finishing last run (ID:xivbf7yh) before initializing another...
View run eternal-pine-3 at: https://wandb.ai/ayamerushia/cifar10-cnn-wandb/runs/xivbf7yh
Synced 6 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)
Synced 6 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)
Find logs at:
./wandb/run-20240307_093720-xivbf7yh/logs
Successfully finished last run (ID:xivbf7yh). Initializing new run:
Tracking run with wandb version 0.16.4
Run data is saved locally in
/Users/fa-15566/ruangguru/github_rg/research-pribadi/bootcamp-ai/wandb/wandb/run-20240307_093756-f1m7acxw
View project at https://wandb.ai/ayamerushia/cifar10-cnn-wandb
Files already downloaded and verified
Files already downloaded and verified
# Define model, optimizer, and loss function
model = CNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)
# Training loop
for epoch in range(config.epochs):
running_loss = 0.0
example_ct = 0
step_ct = 0
for step, (images, labels) in enumerate(trainloader, 0):
inputs, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
example_ct += len(inputs)
metrics = {"train/train_loss": loss,
"train/epoch": (step + 1 + (n_steps_per_epoch * epoch)) / n_steps_per_epoch,
"train/example_ct": example_ct}
if step + 1 < n_steps_per_epoch:
# 🐝 Log train metrics to wandb
wandb.log(metrics)
step_ct += 1
val_loss, accuracy = validate_model(model, testloader, criterion, log_images=(epoch==(config.epochs-1)))
val_metrics = {
"val/val_loss": val_loss,
"val/val_accuracy": accuracy
}
wandb.log({**metrics, **val_metrics})
print(f"Train Loss: {loss:.3f}, Valid Loss: {val_loss:3f}, Accuracy: {accuracy:.2f}")
if accuracy > config.threshold_accuracy:
wandb.alert(
title='Low Accuracy',
text=f'Accuracy {accuracy} at step {step_ct} is below the acceptable theshold, {config.threshold_accuracy}',
)Train Loss: 1.429, Valid Loss: 1.424600, Accuracy: 0.48
Train Loss: 1.476, Valid Loss: 1.267869, Accuracy: 0.54
Train Loss: 1.086, Valid Loss: 1.173016, Accuracy: 0.59
Train Loss: 1.096, Valid Loss: 1.134961, Accuracy: 0.60
Train Loss: 1.137, Valid Loss: 1.126435, Accuracy: 0.60
wandb.finish()Run history:
| train/epoch | ▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███ |
| train/example_ct | ▁▂▃▄▅▆▆▇▁▂▃▄▅▆▇█▁▂▃▄▅▆▇█▁▂▃▄▅▆▇█▁▂▃▄▅▆▇█ |
| train/train_loss | █▆▅▅▅▄▄▄▃▄▄▄▄▄▄▂▃▃▃▃▂▃▃▄▂▁▃▂▂▃▂▂▂▂▁▃▃▂▂▂ |
| val/val_accuracy | ▁▅▇██ |
| val/val_loss | █▄▂▁▁ |
Run summary:
| train/epoch | 5.0 |
| train/example_ct | 50000 |
| train/train_loss | 1.13669 |
| val/val_accuracy | 0.6026 |
| val/val_loss | 1.12644 |
View run polished-armadillo-4 at: https://wandb.ai/ayamerushia/cifar10-cnn-wandb/runs/f1m7acxw
Synced 6 W&B file(s), 1 media file(s), 129 artifact file(s) and 0 other file(s)
Synced 6 W&B file(s), 1 media file(s), 129 artifact file(s) and 0 other file(s)
Find logs at:
./wandb/run-20240307_093756-f1m7acxw/logs