Image Classification

What

Given an image, predict its class label. The “hello world” of deep learning.

Approach

import torch
import torchvision.models as models
import torchvision.transforms as T
 
# Load pretrained model
model = models.resnet50(weights="IMAGENET1K_V2")
model.eval()
 
# Preprocess image
transform = T.Compose([
    T.Resize(256),
    T.CenterCrop(224),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
 
# Predict
with torch.no_grad():
    output = model(transform(image).unsqueeze(0))
    predicted_class = output.argmax(dim=1)

Classic datasets for learning

DatasetWhatSize
MNISTHandwritten digits (0-9)60k train, 10k test
CIFAR-1010 object classes, 32×3250k train, 10k test
ImageNet1000 classes, high-res1.2M train