Transfer Learning End-to-End
Goal: Fine-tune a pretrained ResNet on CIFAR-10 with PyTorch. Feature extraction vs full fine-tuning. From loading the model to evaluating results.
Prerequisites: Transfer Learning, Convolutional Neural Networks, Optimizers, Data Augmentation
Why Transfer Learning?
Training a CNN from scratch on a small dataset gives ~60-70% accuracy. Using a model pretrained on ImageNet (1.2M images) as a starting point gives 90%+ with a fraction of the compute.
Setup
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
import matplotlib.pyplot as plt
import numpy as np
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using: {device}")Data: CIFAR-10
10 classes, 32x32 images. We resize to 224x224 (ResNet’s expected input):
# ResNet expects 224x224, normalized with ImageNet stats
imagenet_mean = [0.485, 0.456, 0.406]
imagenet_std = [0.229, 0.224, 0.225]
train_transform = transforms.Compose([
transforms.Resize(224),
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(224, padding=8),
transforms.ToTensor(),
transforms.Normalize(imagenet_mean, imagenet_std),
])
test_transform = transforms.Compose([
transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize(imagenet_mean, imagenet_std),
])
train_data = datasets.CIFAR10(root="./data", train=True, download=True, transform=train_transform)
test_data = datasets.CIFAR10(root="./data", train=False, download=True, transform=test_transform)
# Use a subset for faster experimentation
train_subset = torch.utils.data.Subset(train_data, range(5000))
test_subset = torch.utils.data.Subset(test_data, range(1000))
train_loader = DataLoader(train_subset, batch_size=32, shuffle=True, num_workers=2)
test_loader = DataLoader(test_subset, batch_size=64, num_workers=2)
classes = train_data.classes
print(f"Classes: {classes}")
print(f"Train: {len(train_subset)}, Test: {len(test_subset)}")Method 1: Feature Extraction (Freeze Everything)
Use the pretrained model as a fixed feature extractor. Only train the final classification layer.
def create_model_frozen(n_classes=10):
model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
# Freeze all layers
for param in model.parameters():
param.requires_grad = False
# Replace final layer (was 1000 classes for ImageNet)
model.fc = nn.Linear(model.fc.in_features, n_classes)
# New layer is unfrozen by default
return model.to(device)
model_frozen = create_model_frozen()
# Count trainable parameters
total = sum(p.numel() for p in model_frozen.parameters())
trainable = sum(p.numel() for p in model_frozen.parameters() if p.requires_grad)
print(f"Total params: {total:,}")
print(f"Trainable: {trainable:,} ({100*trainable/total:.1f}%)")Training loop
def train_epoch(model, loader, criterion, optimizer):
model.train()
total_loss, correct, total = 0, 0, 0
for X, y in loader:
X, y = X.to(device), y.to(device)
optimizer.zero_grad()
out = model(X)
loss = criterion(out, y)
loss.backward()
optimizer.step()
total_loss += loss.item() * len(y)
correct += (out.argmax(1) == y).sum().item()
total += len(y)
return total_loss / total, correct / total
def evaluate(model, loader):
model.eval()
correct, total = 0, 0
with torch.no_grad():
for X, y in loader:
X, y = X.to(device), y.to(device)
out = model(X)
correct += (out.argmax(1) == y).sum().item()
total += len(y)
return correct / total
# Train feature extractor
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model_frozen.fc.parameters(), lr=0.001)
print("Feature extraction (frozen backbone):")
frozen_history = []
for epoch in range(5):
loss, train_acc = train_epoch(model_frozen, train_loader, criterion, optimizer)
test_acc = evaluate(model_frozen, test_loader)
frozen_history.append((loss, train_acc, test_acc))
print(f" Epoch {epoch+1}: loss={loss:.4f}, train_acc={train_acc:.3f}, test_acc={test_acc:.3f}")Method 2: Fine-Tuning (Unfreeze and Train)
Start from the feature-extracted model and unfreeze layers gradually:
def create_model_finetune(n_classes=10):
model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
model.fc = nn.Linear(model.fc.in_features, n_classes)
return model.to(device)
model_finetune = create_model_finetune()
# Different learning rates: small for pretrained, larger for new head
optimizer = optim.Adam([
{"params": model_finetune.fc.parameters(), "lr": 1e-3}, # new head: fast
{"params": list(model_finetune.parameters())[:-2], "lr": 1e-5}, # backbone: slow
])
print("\nFull fine-tuning (different LRs):")
finetune_history = []
for epoch in range(5):
loss, train_acc = train_epoch(model_finetune, train_loader, criterion, optimizer)
test_acc = evaluate(model_finetune, test_loader)
finetune_history.append((loss, train_acc, test_acc))
print(f" Epoch {epoch+1}: loss={loss:.4f}, train_acc={train_acc:.3f}, test_acc={test_acc:.3f}")Method 3: Gradual Unfreezing
Unfreeze layers one at a time, from top to bottom:
def create_model_gradual(n_classes=10):
model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
for param in model.parameters():
param.requires_grad = False
model.fc = nn.Linear(model.fc.in_features, n_classes)
return model.to(device)
model_gradual = create_model_gradual()
# Phase 1: Train head only
optimizer = optim.Adam(model_gradual.fc.parameters(), lr=1e-3)
print("\nGradual unfreezing:")
print("Phase 1: head only")
for epoch in range(3):
loss, train_acc = train_epoch(model_gradual, train_loader, criterion, optimizer)
test_acc = evaluate(model_gradual, test_loader)
print(f" Epoch {epoch+1}: loss={loss:.4f}, train_acc={train_acc:.3f}, test_acc={test_acc:.3f}")
# Phase 2: Unfreeze last residual block
for param in model_gradual.layer4.parameters():
param.requires_grad = True
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model_gradual.parameters()), lr=1e-4)
print("Phase 2: + layer4")
for epoch in range(3):
loss, train_acc = train_epoch(model_gradual, train_loader, criterion, optimizer)
test_acc = evaluate(model_gradual, test_loader)
print(f" Epoch {epoch+1}: loss={loss:.4f}, train_acc={train_acc:.3f}, test_acc={test_acc:.3f}")
# Phase 3: Unfreeze everything
for param in model_gradual.parameters():
param.requires_grad = True
optimizer = optim.Adam(model_gradual.parameters(), lr=1e-5)
print("Phase 3: everything")
for epoch in range(3):
loss, train_acc = train_epoch(model_gradual, train_loader, criterion, optimizer)
test_acc = evaluate(model_gradual, test_loader)
print(f" Epoch {epoch+1}: loss={loss:.4f}, train_acc={train_acc:.3f}, test_acc={test_acc:.3f}")Compare Methods
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
for history, label in [(frozen_history, "Frozen"),
(finetune_history, "Fine-tune")]:
losses = [h[0] for h in history]
test_accs = [h[2] for h in history]
axes[0].plot(losses, label=label)
axes[1].plot(test_accs, label=label)
axes[0].set_xlabel("Epoch"); axes[0].set_ylabel("Loss"); axes[0].legend()
axes[1].set_xlabel("Epoch"); axes[1].set_ylabel("Test Accuracy"); axes[1].legend()
axes[0].set_title("Training loss"); axes[1].set_title("Test accuracy")
plt.tight_layout()
plt.show()Visualize Predictions
model_finetune.eval()
images, labels = next(iter(test_loader))
with torch.no_grad():
preds = model_finetune(images.to(device)).argmax(1).cpu()
fig, axes = plt.subplots(2, 5, figsize=(14, 6))
for i, ax in enumerate(axes.ravel()):
img = images[i].permute(1, 2, 0).numpy()
img = img * np.array(imagenet_std) + np.array(imagenet_mean) # denormalize
img = np.clip(img, 0, 1)
ax.imshow(img)
color = "green" if preds[i] == labels[i] else "red"
ax.set_title(f"pred={classes[preds[i]]}\ntrue={classes[labels[i]]}", color=color, fontsize=9)
ax.axis("off")
plt.tight_layout()
plt.show()When to Use Which Strategy
| Data size | Similarity to ImageNet | Strategy |
|---|---|---|
| Small (<1k) | High | Feature extraction (freeze all) |
| Small (<1k) | Low | Feature extraction, maybe unfreeze last block |
| Medium (1k-10k) | High | Fine-tune with small LR |
| Medium (1k-10k) | Low | Gradual unfreezing |
| Large (>10k) | Any | Fine-tune everything |
Exercises
-
Without pretrained weights: Train the same ResNet18 from scratch (random init). Compare accuracy after 10 epochs. This shows exactly what pretraining gives you.
-
Different backbone: Replace ResNet18 with EfficientNet-B0 (
models.efficientnet_b0). Is it better? Smaller? -
Learning rate finder: For each unfreezing phase, sweep LR from 1e-6 to 1e-1 and plot loss vs LR to find the optimal rate.
-
Feature visualization: Extract features from the penultimate layer (before fc) and apply t-SNE. Do the 10 classes form clusters? Compare frozen vs fine-tuned features.
Next: 13 - Fine-Tune BERT for Classification — transfer learning for text.