Image Segmentation

What

Classify every pixel in an image. More precise than object detection (bounding boxes). The output is a mask the same size as the input, where each pixel gets a class label.

Types

TypeWhatExample
SemanticLabel every pixel by classRoad vs sidewalk vs building
InstanceSeparate different objects of same classThis car vs that car
PanopticSemantic + instance combinedFull scene understanding

Key architectures

  • U-Net: encoder-decoder with skip connections. Standard for medical imaging
  • DeepLab (v3+): uses atrous (dilated) convolutions to capture multi-scale context without losing resolution. ASPP module applies parallel dilated convolutions at different rates
  • Mask R-CNN: extends Faster R-CNN with pixel masks. Instance segmentation
  • SAM (Segment Anything): foundation model for segmentation. Zero-shot, prompt-based

Loss functions

LossWhat it doesWhen to use
Cross-entropyStandard per-pixel classification lossBalanced classes
Dice loss1 - 2*intersection/union. Directly optimizes overlapImbalanced classes (medical)
IoU (Jaccard) loss1 - intersection/unionSimilar to Dice, slightly harsher
Focal lossDown-weights easy pixels, focuses on hard onesExtreme class imbalance

In practice, people often combine them: loss = CE + Dice works well.

Evaluation metrics

  • mIoU (mean Intersection over Union): the standard metric. Average IoU across all classes
  • Pixel accuracy: fraction of correctly classified pixels (misleading with imbalanced classes)
  • Dice coefficient: 2 * overlap / total. Same as F1 score for binary masks

Practical example

from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor
import torch
 
processor = SegformerImageProcessor.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
model = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
 
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
    outputs = model(**inputs)
# outputs.logits shape: (batch, num_classes, H/4, W/4)
# upsample to original size for final mask

Training tips

  • Class imbalance: background dominates. Use Dice/Focal loss, or oversample rare classes
  • Data augmentation: random crops, flips, and color jitter work. Be careful with geometric transforms — apply the same transform to image AND mask
  • Resolution matters: higher input resolution = better boundaries but more VRAM. Use mixed precision
  • Pre-trained encoders: always start from ImageNet weights. Fine-tune the decoder first