Transfer Learning End-to-End

Goal: Fine-tune a pretrained ResNet on CIFAR-10 with PyTorch. Feature extraction vs full fine-tuning. From loading the model to evaluating results.

Prerequisites: Transfer Learning, Convolutional Neural Networks, Optimizers, Data Augmentation


Why Transfer Learning?

Training a CNN from scratch on a small dataset gives ~60-70% accuracy. Using a model pretrained on ImageNet (1.2M images) as a starting point gives 90%+ with a fraction of the compute.


Setup

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
import matplotlib.pyplot as plt
import numpy as np
 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using: {device}")

Data: CIFAR-10

10 classes, 32x32 images. We resize to 224x224 (ResNet’s expected input):

# ResNet expects 224x224, normalized with ImageNet stats
imagenet_mean = [0.485, 0.456, 0.406]
imagenet_std = [0.229, 0.224, 0.225]
 
train_transform = transforms.Compose([
    transforms.Resize(224),
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(224, padding=8),
    transforms.ToTensor(),
    transforms.Normalize(imagenet_mean, imagenet_std),
])
 
test_transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize(imagenet_mean, imagenet_std),
])
 
train_data = datasets.CIFAR10(root="./data", train=True, download=True, transform=train_transform)
test_data = datasets.CIFAR10(root="./data", train=False, download=True, transform=test_transform)
 
# Use a subset for faster experimentation
train_subset = torch.utils.data.Subset(train_data, range(5000))
test_subset = torch.utils.data.Subset(test_data, range(1000))
 
train_loader = DataLoader(train_subset, batch_size=32, shuffle=True, num_workers=2)
test_loader = DataLoader(test_subset, batch_size=64, num_workers=2)
 
classes = train_data.classes
print(f"Classes: {classes}")
print(f"Train: {len(train_subset)}, Test: {len(test_subset)}")

Method 1: Feature Extraction (Freeze Everything)

Use the pretrained model as a fixed feature extractor. Only train the final classification layer.

def create_model_frozen(n_classes=10):
    model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
 
    # Freeze all layers
    for param in model.parameters():
        param.requires_grad = False
 
    # Replace final layer (was 1000 classes for ImageNet)
    model.fc = nn.Linear(model.fc.in_features, n_classes)
    # New layer is unfrozen by default
 
    return model.to(device)
 
model_frozen = create_model_frozen()
 
# Count trainable parameters
total = sum(p.numel() for p in model_frozen.parameters())
trainable = sum(p.numel() for p in model_frozen.parameters() if p.requires_grad)
print(f"Total params: {total:,}")
print(f"Trainable:    {trainable:,} ({100*trainable/total:.1f}%)")

Training loop

def train_epoch(model, loader, criterion, optimizer):
    model.train()
    total_loss, correct, total = 0, 0, 0
    for X, y in loader:
        X, y = X.to(device), y.to(device)
        optimizer.zero_grad()
        out = model(X)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * len(y)
        correct += (out.argmax(1) == y).sum().item()
        total += len(y)
    return total_loss / total, correct / total
 
def evaluate(model, loader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for X, y in loader:
            X, y = X.to(device), y.to(device)
            out = model(X)
            correct += (out.argmax(1) == y).sum().item()
            total += len(y)
    return correct / total
 
# Train feature extractor
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model_frozen.fc.parameters(), lr=0.001)
 
print("Feature extraction (frozen backbone):")
frozen_history = []
for epoch in range(5):
    loss, train_acc = train_epoch(model_frozen, train_loader, criterion, optimizer)
    test_acc = evaluate(model_frozen, test_loader)
    frozen_history.append((loss, train_acc, test_acc))
    print(f"  Epoch {epoch+1}: loss={loss:.4f}, train_acc={train_acc:.3f}, test_acc={test_acc:.3f}")

Method 2: Fine-Tuning (Unfreeze and Train)

Start from the feature-extracted model and unfreeze layers gradually:

def create_model_finetune(n_classes=10):
    model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
    model.fc = nn.Linear(model.fc.in_features, n_classes)
    return model.to(device)
 
model_finetune = create_model_finetune()
 
# Different learning rates: small for pretrained, larger for new head
optimizer = optim.Adam([
    {"params": model_finetune.fc.parameters(), "lr": 1e-3},         # new head: fast
    {"params": list(model_finetune.parameters())[:-2], "lr": 1e-5}, # backbone: slow
])
 
print("\nFull fine-tuning (different LRs):")
finetune_history = []
for epoch in range(5):
    loss, train_acc = train_epoch(model_finetune, train_loader, criterion, optimizer)
    test_acc = evaluate(model_finetune, test_loader)
    finetune_history.append((loss, train_acc, test_acc))
    print(f"  Epoch {epoch+1}: loss={loss:.4f}, train_acc={train_acc:.3f}, test_acc={test_acc:.3f}")

Method 3: Gradual Unfreezing

Unfreeze layers one at a time, from top to bottom:

def create_model_gradual(n_classes=10):
    model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
    for param in model.parameters():
        param.requires_grad = False
    model.fc = nn.Linear(model.fc.in_features, n_classes)
    return model.to(device)
 
model_gradual = create_model_gradual()
 
# Phase 1: Train head only
optimizer = optim.Adam(model_gradual.fc.parameters(), lr=1e-3)
print("\nGradual unfreezing:")
print("Phase 1: head only")
for epoch in range(3):
    loss, train_acc = train_epoch(model_gradual, train_loader, criterion, optimizer)
    test_acc = evaluate(model_gradual, test_loader)
    print(f"  Epoch {epoch+1}: loss={loss:.4f}, train_acc={train_acc:.3f}, test_acc={test_acc:.3f}")
 
# Phase 2: Unfreeze last residual block
for param in model_gradual.layer4.parameters():
    param.requires_grad = True
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model_gradual.parameters()), lr=1e-4)
print("Phase 2: + layer4")
for epoch in range(3):
    loss, train_acc = train_epoch(model_gradual, train_loader, criterion, optimizer)
    test_acc = evaluate(model_gradual, test_loader)
    print(f"  Epoch {epoch+1}: loss={loss:.4f}, train_acc={train_acc:.3f}, test_acc={test_acc:.3f}")
 
# Phase 3: Unfreeze everything
for param in model_gradual.parameters():
    param.requires_grad = True
optimizer = optim.Adam(model_gradual.parameters(), lr=1e-5)
print("Phase 3: everything")
for epoch in range(3):
    loss, train_acc = train_epoch(model_gradual, train_loader, criterion, optimizer)
    test_acc = evaluate(model_gradual, test_loader)
    print(f"  Epoch {epoch+1}: loss={loss:.4f}, train_acc={train_acc:.3f}, test_acc={test_acc:.3f}")

Compare Methods

fig, axes = plt.subplots(1, 2, figsize=(12, 5))
 
for history, label in [(frozen_history, "Frozen"),
                        (finetune_history, "Fine-tune")]:
    losses = [h[0] for h in history]
    test_accs = [h[2] for h in history]
    axes[0].plot(losses, label=label)
    axes[1].plot(test_accs, label=label)
 
axes[0].set_xlabel("Epoch"); axes[0].set_ylabel("Loss"); axes[0].legend()
axes[1].set_xlabel("Epoch"); axes[1].set_ylabel("Test Accuracy"); axes[1].legend()
axes[0].set_title("Training loss"); axes[1].set_title("Test accuracy")
plt.tight_layout()
plt.show()

Visualize Predictions

model_finetune.eval()
images, labels = next(iter(test_loader))
with torch.no_grad():
    preds = model_finetune(images.to(device)).argmax(1).cpu()
 
fig, axes = plt.subplots(2, 5, figsize=(14, 6))
for i, ax in enumerate(axes.ravel()):
    img = images[i].permute(1, 2, 0).numpy()
    img = img * np.array(imagenet_std) + np.array(imagenet_mean)  # denormalize
    img = np.clip(img, 0, 1)
    ax.imshow(img)
    color = "green" if preds[i] == labels[i] else "red"
    ax.set_title(f"pred={classes[preds[i]]}\ntrue={classes[labels[i]]}", color=color, fontsize=9)
    ax.axis("off")
plt.tight_layout()
plt.show()

When to Use Which Strategy

Data sizeSimilarity to ImageNetStrategy
Small (<1k)HighFeature extraction (freeze all)
Small (<1k)LowFeature extraction, maybe unfreeze last block
Medium (1k-10k)HighFine-tune with small LR
Medium (1k-10k)LowGradual unfreezing
Large (>10k)AnyFine-tune everything

Exercises

  1. Without pretrained weights: Train the same ResNet18 from scratch (random init). Compare accuracy after 10 epochs. This shows exactly what pretraining gives you.

  2. Different backbone: Replace ResNet18 with EfficientNet-B0 (models.efficientnet_b0). Is it better? Smaller?

  3. Learning rate finder: For each unfreezing phase, sweep LR from 1e-6 to 1e-1 and plot loss vs LR to find the optimal rate.

  4. Feature visualization: Extract features from the penultimate layer (before fc) and apply t-SNE. Do the 10 classes form clusters? Compare frozen vs fine-tuned features.


Next: 13 - Fine-Tune BERT for Classification — transfer learning for text.