Handling Class Imbalance
Goal: Understand why accuracy lies when classes are imbalanced. Learn resampling, class weights, and proper evaluation.
Prerequisites: Evaluation Metrics, Loss Functions, Cross-Validation
The Problem
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import (classification_report, confusion_matrix,
roc_auc_score, precision_recall_curve, average_precision_score)
# 95% negative, 5% positive (e.g., fraud detection)
X, y = make_classification(n_samples=2000, n_features=20, weights=[0.95, 0.05],
n_informative=5, random_state=42, flip_y=0.05)
print(f"Class distribution: {np.bincount(y)}")
print(f"Positive rate: {y.mean():.3f}")
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3,
stratify=y, random_state=42)The accuracy trap
# A model that always predicts "negative"
print(f"'Always negative' accuracy: {1 - y_test.mean():.3f}")
# Logistic regression — looks great by accuracy
model = LogisticRegression(max_iter=200)
model.fit(X_train, y_train)
preds = model.predict(X_test)
print(f"Model accuracy: {(preds == y_test).mean():.3f}")
print(f"But recall on positive class: {(preds[y_test == 1] == 1).mean():.3f}")
print("\nConfusion matrix:")
print(confusion_matrix(y_test, preds))
print("\n" + classification_report(y_test, preds))High accuracy, but the model misses most positives. Accuracy is useless here.
The Right Metrics
Precision-Recall Curve
probs = model.predict_proba(X_test)[:, 1]
precision, recall, thresholds = precision_recall_curve(y_test, probs)
ap = average_precision_score(y_test, probs)
plt.figure(figsize=(8, 5))
plt.plot(recall, precision, linewidth=2)
plt.xlabel("Recall"); plt.ylabel("Precision")
plt.title(f"Precision-Recall curve (AP={ap:.3f})")
plt.axhline(y_test.mean(), color="gray", linestyle="--", label="Random baseline")
plt.legend()
plt.show()Pick a threshold
Default threshold of 0.5 is arbitrary. Choose based on business needs:
# F1 at different thresholds
f1s = 2 * precision[:-1] * recall[:-1] / (precision[:-1] + recall[:-1] + 1e-10)
best_idx = np.argmax(f1s)
best_threshold = thresholds[best_idx]
print(f"Best threshold for F1: {best_threshold:.3f}")
print(f"At this threshold: precision={precision[best_idx]:.3f}, recall={recall[best_idx]:.3f}")
preds_tuned = (probs >= best_threshold).astype(int)
print("\nWith tuned threshold:")
print(classification_report(y_test, preds_tuned))Method 1: Class Weights
Tell the model to penalize misclassifying the minority class more:
# Most sklearn models support class_weight
model_weighted = LogisticRegression(class_weight="balanced", max_iter=200)
model_weighted.fit(X_train, y_train)
preds_w = model_weighted.predict(X_test)
print("With class_weight='balanced':")
print(classification_report(y_test, preds_w))
print(f"AUC-ROC: {roc_auc_score(y_test, model_weighted.predict_proba(X_test)[:, 1]):.3f}")"balanced" sets weight = for each class . Minority class gets higher weight.
Method 2: Random Oversampling
Duplicate minority samples:
def random_oversample(X, y, seed=42):
"""Duplicate minority class samples to match majority."""
rng = np.random.RandomState(seed)
classes, counts = np.unique(y, return_counts=True)
max_count = counts.max()
X_res, y_res = [X.copy()], [y.copy()]
for cls, count in zip(classes, counts):
if count < max_count:
idx = np.where(y == cls)[0]
extra = rng.choice(idx, max_count - count, replace=True)
X_res.append(X[extra])
y_res.append(y[extra])
return np.vstack(X_res), np.concatenate(y_res)
X_over, y_over = random_oversample(X_train, y_train)
print(f"Before: {np.bincount(y_train)}")
print(f"After: {np.bincount(y_over)}")
model_over = LogisticRegression(max_iter=200)
model_over.fit(X_over, y_over)
print("\nWith oversampling:")
print(classification_report(y_test, model_over.predict(X_test)))Method 3: Random Undersampling
Remove majority samples:
def random_undersample(X, y, seed=42):
"""Remove majority class samples to match minority."""
rng = np.random.RandomState(seed)
classes, counts = np.unique(y, return_counts=True)
min_count = counts.min()
indices = []
for cls in classes:
idx = np.where(y == cls)[0]
indices.append(rng.choice(idx, min_count, replace=False))
idx_all = np.concatenate(indices)
return X[idx_all], y[idx_all]
X_under, y_under = random_undersample(X_train, y_train)
print(f"Before: {np.bincount(y_train)}")
print(f"After: {np.bincount(y_under)}")
model_under = LogisticRegression(max_iter=200)
model_under.fit(X_under, y_under)
print("\nWith undersampling:")
print(classification_report(y_test, model_under.predict(X_test)))Downside: throws away data. Only use when you have plenty.
Method 4: SMOTE (Synthetic Oversampling)
Create new minority samples by interpolating between existing ones:
def smote(X, y, k=5, seed=42):
"""Simplified SMOTE — generate synthetic minority samples."""
rng = np.random.RandomState(seed)
classes, counts = np.unique(y, return_counts=True)
minority_cls = classes[counts.argmin()]
majority_count = counts.max()
X_min = X[y == minority_cls]
n_to_generate = majority_count - len(X_min)
# Find k nearest neighbors within minority class
from sklearn.neighbors import NearestNeighbors
nn = NearestNeighbors(n_neighbors=k + 1).fit(X_min)
_, neighbors = nn.kneighbors(X_min)
neighbors = neighbors[:, 1:] # exclude self
synthetic = []
for _ in range(n_to_generate):
idx = rng.randint(len(X_min))
neighbor_idx = rng.choice(neighbors[idx])
lam = rng.random()
new_sample = X_min[idx] + lam * (X_min[neighbor_idx] - X_min[idx])
synthetic.append(new_sample)
X_new = np.vstack([X, np.array(synthetic)])
y_new = np.concatenate([y, np.full(n_to_generate, minority_cls)])
return X_new, y_new
X_smote, y_smote = smote(X_train, y_train)
print(f"After SMOTE: {np.bincount(y_smote)}")
model_smote = LogisticRegression(max_iter=200)
model_smote.fit(X_smote, y_smote)
print("\nWith SMOTE:")
print(classification_report(y_test, model_smote.predict(X_test)))Compare All Methods
methods = {
"Baseline": model,
"Class weights": model_weighted,
"Oversampling": model_over,
"Undersampling": model_under,
"SMOTE": model_smote,
}
fig, ax = plt.subplots(figsize=(10, 6))
for name, m in methods.items():
probs = m.predict_proba(X_test)[:, 1]
precision, recall, _ = precision_recall_curve(y_test, probs)
ap = average_precision_score(y_test, probs)
ax.plot(recall, precision, label=f"{name} (AP={ap:.3f})")
ax.set_xlabel("Recall"); ax.set_ylabel("Precision")
ax.set_title("Precision-Recall comparison")
ax.legend()
plt.show()Method 5: Focal Loss (for Neural Networks)
Down-weight easy examples, focus on hard ones:
def focal_loss(y_true, y_pred, gamma=2.0, alpha=0.25):
"""Focal loss — reduces loss for well-classified examples."""
eps = 1e-15
y_pred = np.clip(y_pred, eps, 1 - eps)
pt = np.where(y_true == 1, y_pred, 1 - y_pred)
alpha_t = np.where(y_true == 1, alpha, 1 - alpha)
return -np.mean(alpha_t * (1 - pt) ** gamma * np.log(pt))
# Compare with standard BCE
from sklearn.metrics import log_loss
probs = model.predict_proba(X_test)[:, 1]
print(f"BCE loss: {log_loss(y_test, probs):.4f}")
print(f"Focal loss: {focal_loss(y_test, probs):.4f}")Decision Guide
| Scenario | Best approach |
|---|---|
| Moderate imbalance (10-30% minority) | Class weights |
| Severe imbalance (<5% minority) | SMOTE + class weights |
| Large dataset | Undersampling or class weights |
| Small dataset | SMOTE (creates data) |
| Neural networks | Focal loss + class weights |
| Tree-based models | Often handle imbalance well natively |
Always: Use stratified CV and evaluate with precision-recall, not accuracy.
Exercises
-
Imbalance ratio sweep: Generate datasets with 1%, 5%, 10%, 25%, 50% positive rate. For each, compare class weights vs SMOTE vs baseline. At what ratio does imbalance handling stop mattering?
-
Cost-sensitive threshold: If false negatives cost 10x more than false positives (e.g., missing fraud), what threshold minimizes total cost?
-
Ensemble with undersampling: Train 10 models, each on a different random undersample. Predict by averaging probabilities. Compare with single model on full data.
-
Per-class F1: Compute macro vs weighted vs micro F1. When do they differ most? Which should you report?
Next: 12 - Transfer Learning End-to-End — use pretrained models instead of training from scratch.