Fine-Tune BERT for Classification
Goal: Fine-tune a pretrained BERT model for text classification using HuggingFace Transformers. From tokenization to evaluation.
Prerequisites: BERT and Masked Language Models, Text Classification, Transfer Learning, Text Preprocessing
Why Fine-Tune BERT?
BERT was pretrained on massive text corpora to understand language. Fine-tuning adapts it to your specific task with a small labeled dataset — same idea as fine-tuning ResNet for images.
[CLS] This movie was terrible [SEP]
↓
BERT (pretrained, 110M params)
↓
[CLS] hidden state → Linear → 2 classes (pos/neg)
Setup
import torch
import numpy as np
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from datasets import load_dataset
import matplotlib.pyplot as plt
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using: {device}")Load Data: IMDB Reviews
dataset = load_dataset("imdb")
# Use a subset for speed
train_data = dataset["train"].shuffle(seed=42).select(range(2000))
test_data = dataset["test"].shuffle(seed=42).select(range(500))
print(f"Train: {len(train_data)}, Test: {len(test_data)}")
print(f"Labels: {set(train_data['label'])}")
print(f"\nExample:\n{train_data[0]['text'][:200]}...")
print(f"Label: {train_data[0]['label']}")Tokenization
BERT needs specific tokenization — WordPiece subwords, special tokens [CLS] and [SEP]:
model_name = "distilbert-base-uncased" # smaller, 66M params, 97% of BERT performance
tokenizer = AutoTokenizer.from_pretrained(model_name)
# See what tokenization looks like
example = "The movie was absolutely fantastic!"
tokens = tokenizer(example, return_tensors="pt")
print(f"Input: '{example}'")
print(f"Token IDs: {tokens['input_ids'][0].tolist()}")
print(f"Decoded: {tokenizer.convert_ids_to_tokens(tokens['input_ids'][0])}")
print(f"Attention: {tokens['attention_mask'][0].tolist()}")Tokenize the dataset
def tokenize_batch(batch):
return tokenizer(batch["text"], padding="max_length", truncation=True, max_length=256)
train_tokenized = train_data.map(tokenize_batch, batched=True, batch_size=64)
test_tokenized = test_data.map(tokenize_batch, batched=True, batch_size=64)
# Convert to PyTorch format
train_tokenized.set_format("torch", columns=["input_ids", "attention_mask", "label"])
test_tokenized.set_format("torch", columns=["input_ids", "attention_mask", "label"])
train_loader = DataLoader(train_tokenized, batch_size=16, shuffle=True)
test_loader = DataLoader(test_tokenized, batch_size=32)Load Pretrained Model
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
model = model.to(device)
# Count parameters
total = sum(p.numel() for p in model.parameters())
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total params: {total:,}")
print(f"Trainable params: {trainable:,}")Training Loop
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5, weight_decay=0.01)
def train_epoch(model, loader, optimizer):
model.train()
total_loss, correct, total = 0, 0, 0
for batch in loader:
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
labels = batch["label"].to(device)
optimizer.zero_grad()
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
loss = outputs.loss
loss.backward()
# Gradient clipping — BERT gradients can be large
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
total_loss += loss.item() * len(labels)
preds = outputs.logits.argmax(dim=1)
correct += (preds == labels).sum().item()
total += len(labels)
return total_loss / total, correct / total
def evaluate(model, loader):
model.eval()
correct, total = 0, 0
all_preds, all_labels = [], []
with torch.no_grad():
for batch in loader:
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
labels = batch["label"].to(device)
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
preds = outputs.logits.argmax(dim=1)
correct += (preds == labels).sum().item()
total += len(labels)
all_preds.extend(preds.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
return correct / total, np.array(all_preds), np.array(all_labels)Train
history = []
for epoch in range(3):
loss, train_acc = train_epoch(model, train_loader, optimizer)
test_acc, preds, labels = evaluate(model, test_loader)
history.append((loss, train_acc, test_acc))
print(f"Epoch {epoch+1}: loss={loss:.4f}, train_acc={train_acc:.3f}, test_acc={test_acc:.3f}")Evaluate
from sklearn.metrics import classification_report, confusion_matrix
test_acc, preds, labels = evaluate(model, test_loader)
print(f"Test accuracy: {test_acc:.3f}")
print(classification_report(labels, preds, target_names=["negative", "positive"]))
# Confusion matrix
cm = confusion_matrix(labels, preds)
plt.figure(figsize=(6, 5))
plt.imshow(cm, cmap="Blues")
for i in range(2):
for j in range(2):
plt.text(j, i, str(cm[i, j]), ha="center", va="center", fontsize=16)
plt.xticks([0, 1], ["negative", "positive"])
plt.yticks([0, 1], ["negative", "positive"])
plt.xlabel("Predicted"); plt.ylabel("True")
plt.title("Confusion Matrix")
plt.colorbar()
plt.show()Inference on New Text
def predict_sentiment(text):
model.eval()
tokens = tokenizer(text, return_tensors="pt", truncation=True, max_length=256)
tokens = {k: v.to(device) for k, v in tokens.items()}
with torch.no_grad():
logits = model(**tokens).logits
probs = torch.softmax(logits, dim=1)[0]
label = "positive" if probs[1] > probs[0] else "negative"
return label, probs.cpu().numpy()
tests = [
"This was the best movie I've ever seen!",
"Terrible acting, waste of time.",
"It was okay, nothing special.",
"The cinematography was stunning but the plot was weak.",
]
for text in tests:
label, probs = predict_sentiment(text)
print(f"{label:8s} (neg={probs[0]:.3f}, pos={probs[1]:.3f}) | {text}")Freeze Strategies Comparison
Freeze everything except classifier head
def count_trainable(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
# Reset model
model_frozen = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2).to(device)
# Freeze everything
for param in model_frozen.parameters():
param.requires_grad = False
# Unfreeze classifier only
for param in model_frozen.classifier.parameters():
param.requires_grad = True
print(f"Full fine-tune trainable: {count_trainable(model):,}")
print(f"Frozen backbone trainable: {count_trainable(model_frozen):,}")
# Train frozen version
optimizer_frozen = torch.optim.Adam(model_frozen.classifier.parameters(), lr=1e-3)
for epoch in range(3):
loss, train_acc = train_epoch(model_frozen, train_loader, optimizer_frozen)
test_acc, _, _ = evaluate(model_frozen, test_loader)
print(f" Epoch {epoch+1}: loss={loss:.4f}, train_acc={train_acc:.3f}, test_acc={test_acc:.3f}")Learning Rate Schedule
Linear warmup + decay — standard for transformer fine-tuning:
from transformers import get_linear_schedule_with_warmup
# Typical setup
total_steps = len(train_loader) * 3 # 3 epochs
warmup_steps = total_steps // 10
model_lr = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2).to(device)
optimizer_lr = torch.optim.AdamW(model_lr.parameters(), lr=2e-5)
scheduler = get_linear_schedule_with_warmup(optimizer_lr, warmup_steps, total_steps)
# Track LR
lrs = []
for epoch in range(3):
for batch in train_loader:
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
labels = batch["label"].to(device)
optimizer_lr.zero_grad()
outputs = model_lr(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
outputs.loss.backward()
optimizer_lr.step()
scheduler.step()
lrs.append(optimizer_lr.param_groups[0]["lr"])
plt.plot(lrs)
plt.xlabel("Step"); plt.ylabel("Learning rate")
plt.title("Linear warmup + decay schedule")
plt.show()Exercises
-
TF-IDF baseline: Train a simple TF-IDF + LogisticRegression on the same data. What’s the accuracy gap vs BERT? Is BERT worth the compute?
-
Different models: Try
bert-base-uncased(110M) vsdistilbert-base-uncased(66M) vsalbert-base-v2(12M). Plot accuracy vs inference time. -
Multi-class: Fine-tune on AG News (4 classes: World, Sports, Business, Sci/Tech). The code barely changes — just
num_labels=4. -
Error analysis: Find the 20 examples with highest loss (most confident wrong predictions). What patterns do you see? Are they genuinely ambiguous?
-
LoRA: Install
peftand fine-tune with LoRA instead of full fine-tuning. Compare parameter count, training time, and accuracy.
Next: 14 - Build a RAG Pipeline — combine retrieval with generation for question answering.