Decision Trees: Splitting Criteria, Pruning, and Interpretability

Beginner Ml
~6 min read Ml

Definition

Decision Trees are non-parametric supervised learning algorithms that model decisions through a hierarchical structure of internal nodes (feature tests), branches (outcomes), and leaf nodes (predictions). The algorithm recursively partitions the feature space into rectangular regions, assigning predictions based on the majority class (classification) or average value (regression) of samples in each region. Trees select splits by maximizing information gain (reduction in impurity), using criteria like Gini impurity or entropy for classification, and variance reduction or MSE for regression. The greedy top-down approach makes locally optimal splits without backtracking, which can lead to overfitting on complex datasets. Tree-based models are highly interpretable - you can literally trace the decision path - but this simplicity comes at the cost of instability (small data changes cause different trees) and bias toward features with many levels.

Intuition

💡

Think of decision trees like a game of 20 Questions. You ask yes/no questions about features to narrow down the answer. 'Is it bigger than a breadbox?' splits the possibilities. 'Is it alive?' further refines them. Each question should ideally eliminate as much uncertainty as possible. The tree learns the most informative questions to ask first, creating a flowchart that anyone can follow.

Mathematical Formula

Gini Impurity (Classification):
\[ Gini(t) = 1 - \sum_{i=1}^{C} (p_i)^2 \]
Entropy (Classification):
\[ Entropy(t) = -\sum_{i=1}^{C} p_i \log_2(p_i) \]
Information Gain:
\[ IG(T, a) = H(T) - \sum_{v \in values(a)} \frac{|T_v|}{|T|} H(T_v) \]
Mean Squared Error (Regression):
\[ MSE(node) = \frac{1}{n}\sum_{i=1}^{n}(y_i - \bar{y})^2 \]
Variance Reduction:
\[ \Delta Var = Var(parent) - \sum_{j} \frac{n_j}{n} Var(child_j) \]

Step-by-Step Explanation:

  1. Gini impurity measures probability of misclassifying a randomly chosen element
  2. Entropy measures average information content - uncertainty in the node
  3. Information gain is the reduction in impurity after a split on attribute a
  4. H(T) is the entropy of the parent node, H(Tᵥ) is entropy of child v
  5. MSE for regression measures spread around the mean prediction in a node
  6. Variance reduction quantifies how much a split decreases target variance

Real-World Use Cases

Healthcare

Medical diagnosis decision support: 'Is patient over 50? Yes → Check cholesterol. No → Check blood pressure.' Interpretable for doctors.

Finance

Loan approval rules engine: explicit if-then rules satisfying regulatory requirements for explainability.

Retail

Customer segmentation for marketing: 'High income AND frequent purchaser → Premium tier'.

Manufacturing

Quality control: decision rules identifying defective products based on sensor thresholds.

Implementation

Manual Implementation (No Libraries)

This implementation demonstrates the core recursive tree-building algorithm. It searches all features and thresholds to maximize information gain, then splits and repeats until stopping criteria are met. The tree structure is stored as nested dictionaries for interpretability.
import numpy as np
from collections import Counter

class DecisionTree:
    """
    Manual implementation of a Decision Tree classifier.
    Supports Gini impurity and Information Gain (Entropy).
    """
    
    def __init__(self, max_depth=5, min_samples_split=2, criterion='gini'):
        self.max_depth = max_depth
        self.min_samples_split = min_samples_split
        self.criterion = criterion
        self.tree = None
    
    def _gini(self, y):
        """Calculate Gini impurity."""
        if len(y) == 0:
            return 0
        proportions = np.bincount(y) / len(y)
        return 1 - np.sum(proportions ** 2)
    
    def _entropy(self, y):
        """Calculate entropy."""
        if len(y) == 0:
            return 0
        proportions = np.bincount(y) / len(y)
        # Add epsilon to avoid log(0)
        return -np.sum(proportions * np.log2(proportions + 1e-10))
    
    def _impurity(self, y):
        """Calculate impurity based on chosen criterion."""
        if self.criterion == 'gini':
            return self._gini(y)
        return self._entropy(y)
    
    def _information_gain(self, y, left_idx, right_idx):
        """Calculate information gain from a split."""
        parent_impurity = self._impurity(y)
        n = len(y)
        n_left = len(left_idx)
        n_right = len(right_idx)
        
        if n_left == 0 or n_right == 0:
            return 0
        
        child_impurity = (n_left / n) * self._impurity(y[left_idx]) + \
                         (n_right / n) * self._impurity(y[right_idx])
        
        return parent_impurity - child_impurity
    
    def _best_split(self, X, y):
        """Find the best feature and threshold to split on."""
        best_gain = -1
        best_feature = None
        best_threshold = None
        
        n_features = X.shape[1]
        
        for feature in range(n_features):
            thresholds = np.unique(X[:, feature])
            
            for threshold in thresholds:
                left_idx = np.where(X[:, feature] <= threshold)[0]
                right_idx = np.where(X[:, feature] > threshold)[0]
                
                gain = self._information_gain(y, left_idx, right_idx)
                
                if gain > best_gain:
                    best_gain = gain
                    best_feature = feature
                    best_threshold = threshold
        
        return best_feature, best_threshold, best_gain
    
    def _build_tree(self, X, y, depth=0):
        """Recursively build the decision tree."""
        n_samples = len(y)
        n_classes = len(np.unique(y))
        
        # Stopping conditions
        if (depth >= self.max_depth or 
            n_samples < self.min_samples_split or 
            n_classes == 1):
            return {'prediction': Counter(y).most_common(1)[0][0]}
        
        # Find best split
        feature, threshold, gain = self._best_split(X, y)
        
        if gain <= 0:
            return {'prediction': Counter(y).most_common(1)[0][0]}
        
        # Split data
        left_idx = X[:, feature] <= threshold
        right_idx = X[:, feature] > threshold
        
        # Recursively build left and right subtrees
        left_subtree = self._build_tree(X[left_idx], y[left_idx], depth + 1)
        right_subtree = self._build_tree(X[right_idx], y[right_idx], depth + 1)
        
        return {
            'feature': feature,
            'threshold': threshold,
            'left': left_subtree,
            'right': right_subtree
        }
    
    def fit(self, X, y):
        """Build the decision tree."""
        self.tree = self._build_tree(X, y)
        return self
    
    def _predict_single(self, x, node):
        """Predict for a single sample."""
        if 'prediction' in node:
            return node['prediction']
        
        if x[node['feature']] <= node['threshold']:
            return self._predict_single(x, node['left'])
        else:
            return self._predict_single(x, node['right'])
    
    def predict(self, X):
        """Predict for all samples."""
        return np.array([self._predict_single(x, self.tree) for x in X])

# Demonstration
if __name__ == '__main__':
    from sklearn.datasets import make_classification
    from sklearn.model_selection import train_test_split
    
    # Generate sample data
    X, y = make_classification(
        n_samples=200, n_features=4, n_redundant=0, 
        n_informative=4, n_clusters_per_class=1, random_state=42
    )
    
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=42
    )
    
    # Train tree
    tree = DecisionTree(max_depth=3, criterion='gini')
    tree.fit(X_train, y_train)
    
    # Predict
    y_pred = tree.predict(X_test)
    accuracy = np.mean(y_pred == y_test)
    
    print(f'Accuracy: {accuracy:.3f}')

Using Libraries (scikit-learn, numpy, matplotlib)

from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor, export_text, plot_tree
from sklearn.model_selection import train_test_split, GridSearchCV, cross_val_score
from sklearn.datasets import load_iris, fetch_california_housing
from sklearn.metrics import accuracy_score, mean_squared_error, classification_report
import numpy as np
import matplotlib.pyplot as plt

# CLASSIFICATION EXAMPLE with Iris dataset
print('=== DECISION TREE CLASSIFICATION ===')
iris = load_iris()
X_cls, y_cls = iris.data, iris.target

X_train_c, X_test_c, y_train_c, y_test_c = train_test_split(
    X_cls, y_cls, test_size=0.2, random_state=42, stratify=y_cls
)

# Train with pruning parameters
clf = DecisionTreeClassifier(
    criterion='gini',          # 'gini' or 'entropy'
    max_depth=4,               # Limit tree depth
    min_samples_split=5,       # Min samples to split
    min_samples_leaf=2,        # Min samples in leaf
    random_state=42
)
clf.fit(X_train_c, y_train_c)

# Evaluate
y_pred_c = clf.predict(X_test_c)
print(f'Accuracy: {accuracy_score(y_test_c, y_pred_c):.3f}')
print(f'
Feature Importances:')
for name, importance in zip(iris.feature_names, clf.feature_importances_):
    print(f'  {name}: {importance:.3f}')

# Visualize tree rules
print(f'
Tree Rules:
{export_text(clf, feature_names=list(iris.feature_names))}')

# REGRESSION EXAMPLE with California Housing
print('
=== DECISION TREE REGRESSION ===')
housing = fetch_california_housing()
X_reg, y_reg = housing.data[:1000], housing.target[:1000]  # Sample for speed

X_train_r, X_test_r, y_train_r, y_test_r = train_test_split(
    X_reg, y_reg, test_size=0.2, random_state=42
)

reg = DecisionTreeRegressor(
    max_depth=6,
    min_samples_split=10,
    min_samples_leaf=5,
    random_state=42
)
reg.fit(X_train_r, y_train_r)

y_pred_r = reg.predict(X_test_r)
print(f'R² Score: {reg.score(X_test_r, y_test_r):.3f}')
print(f'RMSE: {np.sqrt(mean_squared_error(y_test_r, y_pred_r)):.3f}')

# HYPERPARAMETER TUNING
print('
=== HYPERPARAMETER TUNING ===')
param_grid = {
    'max_depth': [3, 5, 7, 10, None],
    'min_samples_split': [2, 5, 10],
    'min_samples_leaf': [1, 2, 4],
    'criterion': ['gini', 'entropy']
}

grid_search = GridSearchCV(
    DecisionTreeClassifier(random_state=42),
    param_grid,
    cv=5,
    scoring='accuracy',
    n_jobs=-1
)
grid_search.fit(X_train_c, y_train_c)

print(f'Best parameters: {grid_search.best_params_}')
print(f'Best CV accuracy: {grid_search.best_score_:.3f}')
print(f'Test accuracy: {accuracy_score(y_test_c, grid_search.predict(X_test_c)):.3f}')

When to Use

✅ Appropriate Use Cases:

  • When interpretability is critical: you can explain exactly why a prediction was made
  • Small to medium datasets where model simplicity is valued
  • Mixed data types (numeric and categorical) without extensive preprocessing
  • Feature selection: tree-based importance helps identify predictive features
  • Baseline model before trying ensemble methods
  • When decision rules need to be explicit (regulatory compliance)

❌ Avoid When:

  • Very large datasets: computationally expensive and prone to overfitting
  • When prediction accuracy is paramount (use Random Forest or XGBoost instead)
  • High-dimensional sparse data (text, genomics): trees struggle with many features
  • When you need probability estimates: tree probabilities are poorly calibrated
  • Extrapolation tasks: trees cannot predict outside training data range
  • Unbalanced datasets without careful tuning

Common Pitfalls

  • Overfitting: Unrestricted trees memorize training data. Always set max_depth
  • Instability: Small data changes cause completely different trees
  • Bias toward high-cardinality features: features with many unique values
  • Greedy splits: Locally optimal choices may not be globally optimal
  • Axis-aligned splits: Diagonal decision boundaries require many splits
  • Ignoring class weights: Skewed classes need balanced class_weight parameter
  • Not pruning: Post-pruning (ccp_alpha) or pre-pruning (max_depth) is essential