Decision Tree from Scratch

Goal: Build a decision tree classifier from scratch — entropy, information gain, recursive splitting, and prediction. No sklearn.

Prerequisites: Decision Trees, Entropy, Probability Basics


Core Idea

A decision tree asks a series of yes/no questions about features, splitting data at each node to maximize information gain — the reduction in uncertainty.


Entropy

Measures uncertainty in a set of labels. For binary labels with proportion of class 1:

import numpy as np
import matplotlib.pyplot as plt
 
def entropy(y):
    """Entropy of a label array."""
    if len(y) == 0:
        return 0
    counts = np.bincount(y)
    probs = counts[counts > 0] / len(y)
    return -np.sum(probs * np.log2(probs))
 
# Visualize entropy vs proportion
p = np.linspace(0.01, 0.99, 100)
h = -p * np.log2(p) - (1 - p) * np.log2(1 - p)
plt.plot(p, h)
plt.xlabel("P(class=1)"); plt.ylabel("Entropy (bits)")
plt.title("Entropy is maximized at 50/50 split")
plt.show()

Maximum entropy = maximum uncertainty = even split. Zero entropy = pure node = all same class.


Information Gain

How much entropy a split removes:

def information_gain(y, left_idx, right_idx):
    """Information gain from splitting y into left and right."""
    n = len(y)
    if len(left_idx) == 0 or len(right_idx) == 0:
        return 0
    parent_entropy = entropy(y)
    left_entropy = entropy(y[left_idx])
    right_entropy = entropy(y[right_idx])
    child_entropy = (len(left_idx) / n) * left_entropy + (len(right_idx) / n) * right_entropy
    return parent_entropy - child_entropy

Finding the Best Split

For each feature, try every unique value as a threshold. Pick the split with highest information gain.

def best_split(X, y):
    """Find the feature and threshold that maximize information gain."""
    n, d = X.shape
    best_ig = -1
    best_feature = None
    best_threshold = None
 
    for feature in range(d):
        thresholds = np.unique(X[:, feature])
        for threshold in thresholds:
            left_idx = np.where(X[:, feature] <= threshold)[0]
            right_idx = np.where(X[:, feature] > threshold)[0]
            ig = information_gain(y, left_idx, right_idx)
 
            if ig > best_ig:
                best_ig = ig
                best_feature = feature
                best_threshold = threshold
 
    return best_feature, best_threshold, best_ig

The Tree

A tree is just a nested dict. Each node is either a split (internal) or a prediction (leaf).

def build_tree(X, y, depth=0, max_depth=5, min_samples=2):
    """Recursively build a decision tree."""
    # Stopping conditions
    if len(np.unique(y)) == 1:          # pure node
        return {"leaf": True, "class": y[0]}
    if depth >= max_depth:               # max depth reached
        return {"leaf": True, "class": np.bincount(y).argmax()}
    if len(y) < min_samples:             # too few samples
        return {"leaf": True, "class": np.bincount(y).argmax()}
 
    feature, threshold, ig = best_split(X, y)
    if ig == 0:                          # no useful split
        return {"leaf": True, "class": np.bincount(y).argmax()}
 
    left_idx = np.where(X[:, feature] <= threshold)[0]
    right_idx = np.where(X[:, feature] > threshold)[0]
 
    return {
        "leaf": False,
        "feature": feature,
        "threshold": threshold,
        "left": build_tree(X[left_idx], y[left_idx], depth + 1, max_depth, min_samples),
        "right": build_tree(X[right_idx], y[right_idx], depth + 1, max_depth, min_samples),
    }

Prediction

Walk the tree from root to leaf:

def predict_one(tree, x):
    if tree["leaf"]:
        return tree["class"]
    if x[tree["feature"]] <= tree["threshold"]:
        return predict_one(tree["left"], x)
    else:
        return predict_one(tree["right"], x)
 
def predict(tree, X):
    return np.array([predict_one(tree, x) for x in X])

Full Example

from sklearn.datasets import make_classification
 
X, y = make_classification(n_samples=300, n_features=2, n_redundant=0,
                           n_informative=2, random_state=42, n_clusters_per_class=1)
 
tree = build_tree(X, y, max_depth=4)
 
# Training accuracy
preds = predict(tree, X)
acc = np.mean(preds == y)
print(f"Training accuracy: {acc:.4f}")

Visualize the Decision Boundary

def plot_tree_boundary(tree, X, y):
    x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
    y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
    xx, yy = np.meshgrid(np.linspace(x_min, x_max, 300),
                         np.linspace(y_min, y_max, 300))
    grid = np.c_[xx.ravel(), yy.ravel()]
    preds = predict(tree, grid).reshape(xx.shape)
 
    plt.contourf(xx, yy, preds, alpha=0.4, cmap="bwr")
    plt.scatter(X[:, 0], X[:, 1], c=y, cmap="bwr", s=15, edgecolors="k", linewidth=0.3)
    plt.title("Decision tree boundary")
    plt.show()
 
plot_tree_boundary(tree, X, y)

Notice the axis-aligned rectangles — decision trees can only split parallel to feature axes.


def print_tree(tree, indent=""):
    if tree["leaf"]:
        print(f"{indent}→ class {tree['class']}")
        return
    print(f"{indent}feature[{tree['feature']}] <= {tree['threshold']:.3f}?")
    print(f"{indent}  Yes:")
    print_tree(tree["left"], indent + "    ")
    print(f"{indent}  No:")
    print_tree(tree["right"], indent + "    ")
 
print_tree(tree)

Overfitting Demo

# No depth limit → overfits
tree_deep = build_tree(X, y, max_depth=20)
tree_shallow = build_tree(X, y, max_depth=2)
 
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
for ax, t, title in [(axes[0], tree_deep, "max_depth=20 (overfit)"),
                       (axes[1], tree_shallow, "max_depth=2 (underfit)")]:
    xx, yy = np.meshgrid(np.linspace(X[:, 0].min()-1, X[:, 0].max()+1, 300),
                         np.linspace(X[:, 1].min()-1, X[:, 1].max()+1, 300))
    grid = np.c_[xx.ravel(), yy.ravel()]
    preds = predict(t, grid).reshape(xx.shape)
    ax.contourf(xx, yy, preds, alpha=0.4, cmap="bwr")
    ax.scatter(X[:, 0], X[:, 1], c=y, cmap="bwr", s=15, edgecolors="k", linewidth=0.3)
    ax.set_title(title)
plt.tight_layout()
plt.show()

Exercises

  1. Gini impurity: Replace entropy with Gini: . Rebuild the tree and compare the splits. Which criterion gives deeper trees?

  2. Regression tree: Modify the code to predict a continuous value. Use variance reduction instead of information gain. The leaf prediction is the mean of its samples.

  3. Random forest (simple): Build 10 trees, each on a random 70% subset of the data. Predict by majority vote. Compare accuracy to a single tree.

  4. Pruning: After building a full tree, implement post-pruning: for each internal node, check if replacing it with a leaf reduces validation error.


Next: 04 - K-Means from Scratch — unsupervised learning, no labels needed.