Fine-Tune BERT for Classification

Goal: Fine-tune a pretrained BERT model for text classification using HuggingFace Transformers. From tokenization to evaluation.

Prerequisites: BERT and Masked Language Models, Text Classification, Transfer Learning, Text Preprocessing


Why Fine-Tune BERT?

BERT was pretrained on massive text corpora to understand language. Fine-tuning adapts it to your specific task with a small labeled dataset — same idea as fine-tuning ResNet for images.

[CLS] This movie was terrible [SEP]
  ↓
BERT (pretrained, 110M params)
  ↓
[CLS] hidden state → Linear → 2 classes (pos/neg)

Setup

import torch
import numpy as np
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from datasets import load_dataset
import matplotlib.pyplot as plt
 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using: {device}")

Load Data: IMDB Reviews

dataset = load_dataset("imdb")
 
# Use a subset for speed
train_data = dataset["train"].shuffle(seed=42).select(range(2000))
test_data = dataset["test"].shuffle(seed=42).select(range(500))
 
print(f"Train: {len(train_data)}, Test: {len(test_data)}")
print(f"Labels: {set(train_data['label'])}")
print(f"\nExample:\n{train_data[0]['text'][:200]}...")
print(f"Label: {train_data[0]['label']}")

Tokenization

BERT needs specific tokenization — WordPiece subwords, special tokens [CLS] and [SEP]:

model_name = "distilbert-base-uncased"  # smaller, 66M params, 97% of BERT performance
tokenizer = AutoTokenizer.from_pretrained(model_name)
 
# See what tokenization looks like
example = "The movie was absolutely fantastic!"
tokens = tokenizer(example, return_tensors="pt")
print(f"Input:    '{example}'")
print(f"Token IDs: {tokens['input_ids'][0].tolist()}")
print(f"Decoded:   {tokenizer.convert_ids_to_tokens(tokens['input_ids'][0])}")
print(f"Attention: {tokens['attention_mask'][0].tolist()}")

Tokenize the dataset

def tokenize_batch(batch):
    return tokenizer(batch["text"], padding="max_length", truncation=True, max_length=256)
 
train_tokenized = train_data.map(tokenize_batch, batched=True, batch_size=64)
test_tokenized = test_data.map(tokenize_batch, batched=True, batch_size=64)
 
# Convert to PyTorch format
train_tokenized.set_format("torch", columns=["input_ids", "attention_mask", "label"])
test_tokenized.set_format("torch", columns=["input_ids", "attention_mask", "label"])
 
train_loader = DataLoader(train_tokenized, batch_size=16, shuffle=True)
test_loader = DataLoader(test_tokenized, batch_size=32)

Load Pretrained Model

model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
model = model.to(device)
 
# Count parameters
total = sum(p.numel() for p in model.parameters())
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total params:     {total:,}")
print(f"Trainable params: {trainable:,}")

Training Loop

optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5, weight_decay=0.01)
 
def train_epoch(model, loader, optimizer):
    model.train()
    total_loss, correct, total = 0, 0, 0
    for batch in loader:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["label"].to(device)
 
        optimizer.zero_grad()
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        loss.backward()
 
        # Gradient clipping — BERT gradients can be large
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
 
        total_loss += loss.item() * len(labels)
        preds = outputs.logits.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += len(labels)
 
    return total_loss / total, correct / total
 
def evaluate(model, loader):
    model.eval()
    correct, total = 0, 0
    all_preds, all_labels = [], []
    with torch.no_grad():
        for batch in loader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["label"].to(device)
 
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            preds = outputs.logits.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += len(labels)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
 
    return correct / total, np.array(all_preds), np.array(all_labels)

Train

history = []
for epoch in range(3):
    loss, train_acc = train_epoch(model, train_loader, optimizer)
    test_acc, preds, labels = evaluate(model, test_loader)
    history.append((loss, train_acc, test_acc))
    print(f"Epoch {epoch+1}: loss={loss:.4f}, train_acc={train_acc:.3f}, test_acc={test_acc:.3f}")

Evaluate

from sklearn.metrics import classification_report, confusion_matrix
 
test_acc, preds, labels = evaluate(model, test_loader)
print(f"Test accuracy: {test_acc:.3f}")
print(classification_report(labels, preds, target_names=["negative", "positive"]))
 
# Confusion matrix
cm = confusion_matrix(labels, preds)
plt.figure(figsize=(6, 5))
plt.imshow(cm, cmap="Blues")
for i in range(2):
    for j in range(2):
        plt.text(j, i, str(cm[i, j]), ha="center", va="center", fontsize=16)
plt.xticks([0, 1], ["negative", "positive"])
plt.yticks([0, 1], ["negative", "positive"])
plt.xlabel("Predicted"); plt.ylabel("True")
plt.title("Confusion Matrix")
plt.colorbar()
plt.show()

Inference on New Text

def predict_sentiment(text):
    model.eval()
    tokens = tokenizer(text, return_tensors="pt", truncation=True, max_length=256)
    tokens = {k: v.to(device) for k, v in tokens.items()}
    with torch.no_grad():
        logits = model(**tokens).logits
    probs = torch.softmax(logits, dim=1)[0]
    label = "positive" if probs[1] > probs[0] else "negative"
    return label, probs.cpu().numpy()
 
tests = [
    "This was the best movie I've ever seen!",
    "Terrible acting, waste of time.",
    "It was okay, nothing special.",
    "The cinematography was stunning but the plot was weak.",
]
for text in tests:
    label, probs = predict_sentiment(text)
    print(f"{label:8s} (neg={probs[0]:.3f}, pos={probs[1]:.3f}) | {text}")

Freeze Strategies Comparison

Freeze everything except classifier head

def count_trainable(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
 
# Reset model
model_frozen = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2).to(device)
 
# Freeze everything
for param in model_frozen.parameters():
    param.requires_grad = False
 
# Unfreeze classifier only
for param in model_frozen.classifier.parameters():
    param.requires_grad = True
 
print(f"Full fine-tune trainable: {count_trainable(model):,}")
print(f"Frozen backbone trainable: {count_trainable(model_frozen):,}")
 
# Train frozen version
optimizer_frozen = torch.optim.Adam(model_frozen.classifier.parameters(), lr=1e-3)
for epoch in range(3):
    loss, train_acc = train_epoch(model_frozen, train_loader, optimizer_frozen)
    test_acc, _, _ = evaluate(model_frozen, test_loader)
    print(f"  Epoch {epoch+1}: loss={loss:.4f}, train_acc={train_acc:.3f}, test_acc={test_acc:.3f}")

Learning Rate Schedule

Linear warmup + decay — standard for transformer fine-tuning:

from transformers import get_linear_schedule_with_warmup
 
# Typical setup
total_steps = len(train_loader) * 3  # 3 epochs
warmup_steps = total_steps // 10
 
model_lr = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2).to(device)
optimizer_lr = torch.optim.AdamW(model_lr.parameters(), lr=2e-5)
scheduler = get_linear_schedule_with_warmup(optimizer_lr, warmup_steps, total_steps)
 
# Track LR
lrs = []
for epoch in range(3):
    for batch in train_loader:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["label"].to(device)
        optimizer_lr.zero_grad()
        outputs = model_lr(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        outputs.loss.backward()
        optimizer_lr.step()
        scheduler.step()
        lrs.append(optimizer_lr.param_groups[0]["lr"])
 
plt.plot(lrs)
plt.xlabel("Step"); plt.ylabel("Learning rate")
plt.title("Linear warmup + decay schedule")
plt.show()

Exercises

  1. TF-IDF baseline: Train a simple TF-IDF + LogisticRegression on the same data. What’s the accuracy gap vs BERT? Is BERT worth the compute?

  2. Different models: Try bert-base-uncased (110M) vs distilbert-base-uncased (66M) vs albert-base-v2 (12M). Plot accuracy vs inference time.

  3. Multi-class: Fine-tune on AG News (4 classes: World, Sports, Business, Sci/Tech). The code barely changes — just num_labels=4.

  4. Error analysis: Find the 20 examples with highest loss (most confident wrong predictions). What patterns do you see? Are they genuinely ambiguous?

  5. LoRA: Install peft and fine-tune with LoRA instead of full fine-tuning. Compare parameter count, training time, and accuracy.


Next: 14 - Build a RAG Pipeline — combine retrieval with generation for question answering.