Matplotlib Essentials

What

The foundational plotting library. Everything else (seaborn, pandas plots) is built on top of it.

Key patterns

import matplotlib.pyplot as plt
import numpy as np
 
# Basic line plot
x = np.linspace(0, 10, 100)
plt.plot(x, np.sin(x))
plt.title("Sine wave")
plt.xlabel("x")
plt.ylabel("sin(x)")
plt.show()
 
# Scatter plot
plt.scatter(x_data, y_data, c=labels, cmap="viridis", alpha=0.6)
plt.colorbar()
 
# Histogram
plt.hist(data, bins=30, edgecolor="black")
 
# Subplots
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
axes[0].plot(x, y1)
axes[1].scatter(x, y2)
axes[2].hist(y3)
plt.tight_layout()
 
# Save
fig.savefig("plot.png", dpi=150, bbox_inches="tight")

Common patterns in ML

# Loss curve
plt.plot(train_losses, label="train")
plt.plot(val_losses, label="val")
plt.legend()
plt.xlabel("Epoch")
plt.ylabel("Loss")
 
# Confusion matrix
from sklearn.metrics import ConfusionMatrixDisplay
ConfusionMatrixDisplay.from_predictions(y_true, y_pred)