import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
import wandb
import random

device = "cuda" if torch.cuda.is_available() else "cpu"

# Define CNN architecture (example - customize as needed)
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)  # Input channels, output channels, kernel size
        self.pool = nn.MaxPool2d(2, 2)  # Kernel size, stride (optional)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # Input features, output features
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10) # Output for 10 CIFAR-10 classes

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)  # Flatten for fully-connected layers
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

def validate_model(model, valid_dl, loss_func, log_images=False, batch_idx=0):
    "Compute performance of the model on the validation dataset and log a wandb.Table"
    val_loss = 0.
    with torch.inference_mode():
        correct = 0
        for i, (images, labels) in enumerate(valid_dl):
            images, labels =,

            # Forward pass ➡
            outputs = model(images)
            val_loss += loss_func(outputs, labels)*labels.size(0)

            # Compute accuracy and accumulate
            _, predicted = torch.max(, 1)
            correct += (predicted == labels).sum().item()

            # Log one batch of images to the dashboard, always same batch_idx.
            if i==batch_idx and log_images:
                log_image_table(images, predicted, labels, outputs.softmax(dim=1))
    return val_loss / len(valid_dl.dataset), correct / len(valid_dl.dataset)

def log_image_table(images, predicted, labels, probs):
    "Log a wandb.Table with (img, pred, target, scores)"
    # 🐝 Create a wandb Table to log images, labels and predictions to
    table = wandb.Table(columns=["image", "pred", "target"]+[f"score_{i}" for i in range(10)])
    for img, pred, targ, prob in zip("cpu"),"cpu"),"cpu"),"cpu")):
        table.add_data(wandb.Image(img[0].numpy()*255), pred, targ, *prob.numpy())
    wandb.log({"predictions_table":table}, commit=False)
import math
# Initialize Wandb project
            "epochs": 5,
            "batch_size": 128,
            "lr": 2e-3,
            "dropout": random.uniform(0.01, 0.80),
            "threshold_accuracy": 0.8

config = wandb.config

# Load and transform CIFAR-10 data
transform = transforms.Compose([
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

trainloader =, batch_size=config.batch_size, shuffle=True)
testloader =, batch_size=config.batch_size, shuffle=False)

n_steps_per_epoch = math.ceil(len(trainloader.dataset) / config.batch_size)
# Define model, optimizer, and loss function
model = CNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),

# Training loop
for epoch in range(config.epochs):
    running_loss = 0.0
    example_ct = 0
    step_ct = 0
    for step, (images, labels) in enumerate(trainloader, 0):
        inputs, labels =,

        outputs = model(inputs)
        loss = criterion(outputs, labels)

        running_loss += loss.item()

        example_ct += len(inputs)
        metrics = {"train/train_loss": loss,
                    "train/epoch": (step + 1 + (n_steps_per_epoch * epoch)) / n_steps_per_epoch,
                    "train/example_ct": example_ct}
        if step + 1 < n_steps_per_epoch:
            # 🐝 Log train metrics to wandb
        step_ct += 1

    val_loss, accuracy = validate_model(model, testloader, criterion, log_images=(epoch==(config.epochs-1)))
    val_metrics = {
        "val/val_loss": val_loss,
        "val/val_accuracy": accuracy
    wandb.log({**metrics, **val_metrics})
    print(f"Train Loss: {loss:.3f}, Valid Loss: {val_loss:3f}, Accuracy: {accuracy:.2f}")

    if accuracy > config.threshold_accuracy:
            title='Low Accuracy',
            text=f'Accuracy {accuracy} at step {step_ct} is below the acceptable theshold, {config.threshold_accuracy}',
Train Loss: 1.429, Valid Loss: 1.424600, Accuracy: 0.48
Train Loss: 1.476, Valid Loss: 1.267869, Accuracy: 0.54
Train Loss: 1.086, Valid Loss: 1.173016, Accuracy: 0.59
Train Loss: 1.096, Valid Loss: 1.134961, Accuracy: 0.60
Train Loss: 1.137, Valid Loss: 1.126435, Accuracy: 0.60

