Image Segmentation
What
Classify every pixel in an image. More precise than object detection (bounding boxes). The output is a mask the same size as the input, where each pixel gets a class label.
Types
| Type | What | Example |
|---|---|---|
| Semantic | Label every pixel by class | Road vs sidewalk vs building |
| Instance | Separate different objects of same class | This car vs that car |
| Panoptic | Semantic + instance combined | Full scene understanding |
Key architectures
- U-Net: encoder-decoder with skip connections. Standard for medical imaging
- DeepLab (v3+): uses atrous (dilated) convolutions to capture multi-scale context without losing resolution. ASPP module applies parallel dilated convolutions at different rates
- Mask R-CNN: extends Faster R-CNN with pixel masks. Instance segmentation
- SAM (Segment Anything): foundation model for segmentation. Zero-shot, prompt-based
Loss functions
| Loss | What it does | When to use |
|---|---|---|
| Cross-entropy | Standard per-pixel classification loss | Balanced classes |
| Dice loss | 1 - 2*intersection/union. Directly optimizes overlap | Imbalanced classes (medical) |
| IoU (Jaccard) loss | 1 - intersection/union | Similar to Dice, slightly harsher |
| Focal loss | Down-weights easy pixels, focuses on hard ones | Extreme class imbalance |
In practice, people often combine them: loss = CE + Dice works well.
Evaluation metrics
- mIoU (mean Intersection over Union): the standard metric. Average IoU across all classes
- Pixel accuracy: fraction of correctly classified pixels (misleading with imbalanced classes)
- Dice coefficient: 2 * overlap / total. Same as F1 score for binary masks
Practical example
from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor
import torch
processor = SegformerImageProcessor.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
model = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
# outputs.logits shape: (batch, num_classes, H/4, W/4)
# upsample to original size for final maskTraining tips
- Class imbalance: background dominates. Use Dice/Focal loss, or oversample rare classes
- Data augmentation: random crops, flips, and color jitter work. Be careful with geometric transforms — apply the same transform to image AND mask
- Resolution matters: higher input resolution = better boundaries but more VRAM. Use mixed precision
- Pre-trained encoders: always start from ImageNet weights. Fine-tune the decoder first
Links
- Object Detection — bounding boxes instead of pixel masks
- Convolutional Neural Networks — the backbone architectures
- Computer Vision Roadmap — where segmentation fits in CV