import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
import wandb
import random

device = "cuda" if torch.cuda.is_available() else "cpu"


# Define CNN architecture (example - customize as needed)
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)  # Input channels, output channels, kernel size
        self.pool = nn.MaxPool2d(2, 2)  # Kernel size, stride (optional)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # Input features, output features
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10) # Output for 10 CIFAR-10 classes

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)  # Flatten for fully-connected layers
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

def validate_model(model, valid_dl, loss_func, log_images=False, batch_idx=0):
    "Compute performance of the model on the validation dataset and log a wandb.Table"
    model.eval()
    val_loss = 0.
    with torch.inference_mode():
        correct = 0
        for i, (images, labels) in enumerate(valid_dl):
            images, labels = images.to(device), labels.to(device)

            # Forward pass ➡
            outputs = model(images)
            val_loss += loss_func(outputs, labels)*labels.size(0)

            # Compute accuracy and accumulate
            _, predicted = torch.max(outputs.data, 1)
            correct += (predicted == labels).sum().item()

            # Log one batch of images to the dashboard, always same batch_idx.
            if i==batch_idx and log_images:
                log_image_table(images, predicted, labels, outputs.softmax(dim=1))
    return val_loss / len(valid_dl.dataset), correct / len(valid_dl.dataset)

def log_image_table(images, predicted, labels, probs):
    "Log a wandb.Table with (img, pred, target, scores)"
    # 🐝 Create a wandb Table to log images, labels and predictions to
    table = wandb.Table(columns=["image", "pred", "target"]+[f"score_{i}" for i in range(10)])
    for img, pred, targ, prob in zip(images.to("cpu"), predicted.to("cpu"), labels.to("cpu"), probs.to("cpu")):
        table.add_data(wandb.Image(img[0].numpy()*255), pred, targ, *prob.numpy())
    wandb.log({"predictions_table":table}, commit=False)
import math
# Initialize Wandb project
wandb.init(
        project="cifar10-cnn-wandb",
        config={
            "epochs": 5,
            "batch_size": 128,
            "lr": 2e-3,
            "dropout": random.uniform(0.01, 0.80),
            "threshold_accuracy": 0.8
            },
        save_code=True)

config = wandb.config

# Load and transform CIFAR-10 data
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=config.batch_size, shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=config.batch_size, shuffle=False)

n_steps_per_epoch = math.ceil(len(trainloader.dataset) / config.batch_size)
Finishing last run (ID:xivbf7yh) before initializing another...
View run eternal-pine-3 at: https://wandb.ai/ayamerushia/cifar10-cnn-wandb/runs/xivbf7yh
Synced 6 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)
Find logs at: ./wandb/run-20240307_093720-xivbf7yh/logs
Successfully finished last run (ID:xivbf7yh). Initializing new run:
Tracking run with wandb version 0.16.4
Run data is saved locally in /Users/fa-15566/ruangguru/github_rg/research-pribadi/bootcamp-ai/wandb/wandb/run-20240307_093756-f1m7acxw
Files already downloaded and verified
Files already downloaded and verified
# Define model, optimizer, and loss function
model = CNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)

# Training loop
for epoch in range(config.epochs):
    running_loss = 0.0
    example_ct = 0
    step_ct = 0
    for step, (images, labels) in enumerate(trainloader, 0):
        inputs, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        example_ct += len(inputs)
        metrics = {"train/train_loss": loss,
                    "train/epoch": (step + 1 + (n_steps_per_epoch * epoch)) / n_steps_per_epoch,
                    "train/example_ct": example_ct}
        if step + 1 < n_steps_per_epoch:
            # 🐝 Log train metrics to wandb
            wandb.log(metrics)
        step_ct += 1

    val_loss, accuracy = validate_model(model, testloader, criterion, log_images=(epoch==(config.epochs-1)))
    val_metrics = {
        "val/val_loss": val_loss,
        "val/val_accuracy": accuracy
        }
    wandb.log({**metrics, **val_metrics})
    print(f"Train Loss: {loss:.3f}, Valid Loss: {val_loss:3f}, Accuracy: {accuracy:.2f}")

    if accuracy > config.threshold_accuracy:
        wandb.alert(
            title='Low Accuracy',
            text=f'Accuracy {accuracy} at step {step_ct} is below the acceptable theshold, {config.threshold_accuracy}',
        )
Train Loss: 1.429, Valid Loss: 1.424600, Accuracy: 0.48
Train Loss: 1.476, Valid Loss: 1.267869, Accuracy: 0.54
Train Loss: 1.086, Valid Loss: 1.173016, Accuracy: 0.59
Train Loss: 1.096, Valid Loss: 1.134961, Accuracy: 0.60
Train Loss: 1.137, Valid Loss: 1.126435, Accuracy: 0.60
wandb.finish()

Run history:


train/epoch ▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/example_ct ▁▂▃▄▅▆▆▇▁▂▃▄▅▆▇█▁▂▃▄▅▆▇█▁▂▃▄▅▆▇█▁▂▃▄▅▆▇█
train/train_loss █▆▅▅▅▄▄▄▃▄▄▄▄▄▄▂▃▃▃▃▂▃▃▄▂▁▃▂▂▃▂▂▂▂▁▃▃▂▂▂
val/val_accuracy ▁▅▇██
val/val_loss █▄▂▁▁

Run summary:


train/epoch 5.0
train/example_ct 50000
train/train_loss 1.13669
val/val_accuracy 0.6026
val/val_loss 1.12644

View run polished-armadillo-4 at: https://wandb.ai/ayamerushia/cifar10-cnn-wandb/runs/f1m7acxw
Synced 6 W&B file(s), 1 media file(s), 129 artifact file(s) and 0 other file(s)
Find logs at: ./wandb/run-20240307_093756-f1m7acxw/logs
Back to top