Interpretability for Tree Ensembles

Biostat 212A

Author

Dr. Jin Zhou @ UCLA

Published

March 7, 2026

Credit: This note builds on tree/ensemble material from An Introduction to Statistical Learning: with Applications in R (ISL2) and Elements of Statistical Learning (ESL2), and introduces modern interpretability tools including permutation importance and SHAP.

Display system information for reproducibility.

sessionInfo()
R version 4.5.2 (2025-10-31)
Platform: aarch64-apple-darwin20
Running under: macOS Sequoia 15.7.4

Matrix products: default
BLAS:   /System/Library/Frameworks/Accelerate.framework/Versions/A/Frameworks/vecLib.framework/Versions/A/libBLAS.dylib 
LAPACK: /Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRlapack.dylib;  LAPACK version 3.12.1

locale:
[1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8

time zone: America/Los_Angeles
tzcode source: internal

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

loaded via a namespace (and not attached):
 [1] htmlwidgets_1.6.4 compiler_4.5.2    fastmap_1.2.0     cli_3.6.5        
 [5] tools_4.5.2       htmltools_0.5.8.1 rstudioapi_0.17.1 yaml_2.3.10      
 [9] rmarkdown_2.29    knitr_1.50        jsonlite_2.0.0    xfun_0.56        
[13] digest_0.6.37     rlang_1.1.6       evaluate_1.0.5   
import IPython
print(IPython.sys_info())
{'commit_hash': '5ed988a91',
 'commit_source': 'installation',
 'default_encoding': 'utf-8',
 'ipython_path': '/Users/jinjinzhou/.virtualenvs/r-tensorflow/lib/python3.10/site-packages/IPython',
 'ipython_version': '8.33.0',
 'os_name': 'posix',
 'platform': 'macOS-15.7.4-arm64-arm-64bit',
 'sys_executable': '/Users/jinjinzhou/.virtualenvs/r-tensorflow/bin/python',
 'sys_platform': 'darwin',
 'sys_version': '3.10.16 (main, Mar  3 2025, 20:01:33) [Clang 16.0.0 '
                '(clang-1600.0.26.6)]'}

1 The problem: ensembles are black boxes

A single decision tree is readable — you can trace the splitting rules and explain any prediction in a few steps.

But the models that win in practice are ensembles (random forests, boosting), which combine hundreds of trees. A single tree says: “If CRuns > 300 and CRBI > 200, predict high salary.” A 500-tree random forest just says: “The predicted salary is $842K” — but why?

We need tools to look inside. Two questions drive everything that follows:

  1. Which variables matter? (variable importance)
  2. How does each variable affect a specific prediction? (SHAP)
NoteImportance \(\neq\) causality

A variable can be “important” because it correlates with the outcome, correlates with a true cause, or even due to data leakage. Importance answers: “Does the model rely on this variable?” — not “Does this variable cause the outcome?”

2 Variable importance: two flavors

2.1 Split / impurity importance (built into trees)

Every time a tree splits on feature \(j\), the impurity (RSS, Gini, entropy) decreases by some amount \(\Delta I(t)\). Sum those decreases over all splits on feature \(j\), across all trees:

\[ VI_j = \sum_{t:\, \text{split variable}=j} \Delta I(t). \]

This is computed for free during training and available in every tree library. But it can be biased toward high-cardinality features and tends to split credit unpredictably among correlated predictors.

2.2 Permutation importance (model-agnostic)

A cleaner idea: measure how much the model’s performance drops when you destroy a feature’s information.

  1. Score the model on held-out data → baseline metric \(m_0\).
  2. Randomly shuffle feature \(j\) (break its link with \(y\)).
  3. Re-score → \(m_j\).
  4. Importance \(= m_0 - m_j\).

This works for any model and any metric. The main caveat: with correlated features, shuffling \(X_j\) may not hurt much because \(X_k\) substitutes for it.

WarningThe correlated-features trap

If \(X_1\) and \(X_2\) carry nearly the same information, split importance shares credit roughly 50/50 (neither looks dominant), while permutation importance underestimates both (each substitutes for the other). No importance method fully resolves this ambiguity — keep a correlation matrix handy.

2.3 Worked example: Hitters salary data

library(tidymodels)
library(vip)
library(ISLR2)

Hitters2 <- Hitters %>% drop_na(Salary)

set.seed(1)
hitters_split <- initial_split(Hitters2, prop = 0.7)
hitters_train <- training(hitters_split)
hitters_test  <- testing(hitters_split)

# Two RF fits: one tracks impurity importance, one tracks permutation importance
rf_impurity_spec <-
  rand_forest(trees = 1000) %>%
  set_engine("ranger", importance = "impurity") %>%
  set_mode("regression")

rf_perm_spec <-
  rand_forest(trees = 1000) %>%
  set_engine("ranger", importance = "permutation") %>%
  set_mode("regression")

rf_impurity_fit <- workflow() %>% add_model(rf_impurity_spec) %>%
  add_formula(Salary ~ .) %>% fit(data = hitters_train)
rf_perm_fit <- workflow() %>% add_model(rf_perm_spec) %>%
  add_formula(Salary ~ .) %>% fit(data = hitters_train)

vip(extract_fit_engine(rf_impurity_fit), num_features = 10) +
  ggtitle("Split/impurity importance")

vip(extract_fit_engine(rf_perm_fit), num_features = 10) +
  ggtitle("Permutation importance (OOB)")

import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder
from sklearn.impute import SimpleImputer
from sklearn.ensemble import RandomForestRegressor
from sklearn.inspection import permutation_importance

Hitters = pd.read_csv("../data/Hitters.csv")
Hitters2 = Hitters.dropna(subset=["Salary"]).copy()

y = Hitters2["Salary"].values
X = Hitters2.drop(columns=["Salary"])

cat_cols = X.select_dtypes(include=["object", "category"]).columns.tolist()
num_cols = [c for c in X.columns if c not in cat_cols]

try:
    ohe = OneHotEncoder(handle_unknown="ignore", sparse_output=False)
except TypeError:
    ohe = OneHotEncoder(handle_unknown="ignore", sparse=False)

preprocess = ColumnTransformer(
    transformers=[
        ("num", Pipeline([("impute", SimpleImputer(strategy="median"))]), num_cols),
        ("cat", Pipeline([
            ("impute", SimpleImputer(strategy="most_frequent")),
            ("onehot", ohe)
        ]), cat_cols),
    ],
    remainder="drop",
    verbose_feature_names_out=False
)

rf = RandomForestRegressor(n_estimators=1000, random_state=1, n_jobs=-1)
pipe = Pipeline(steps=[("prep", preprocess), ("model", rf)])

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.3, random_state=1
)
pipe.fit(X_train, y_train)
Pipeline(steps=[('prep',
                 ColumnTransformer(transformers=[('num',
                                                  Pipeline(steps=[('impute',
                                                                   SimpleImputer(strategy='median'))]),
                                                  ['AtBat', 'Hits', 'HmRun',
                                                   'Runs', 'RBI', 'Walks',
                                                   'Years', 'CAtBat', 'CHits',
                                                   'CHmRun', 'CRuns', 'CRBI',
                                                   'CWalks', 'PutOuts',
                                                   'Assists', 'Errors']),
                                                 ('cat',
                                                  Pipeline(steps=[('impute',
                                                                   SimpleImputer(strategy='most_frequent')),
                                                                  ('onehot',
                                                                   OneHotEncoder(handle_unknown='ignore',
                                                                                 sparse_output=False))]),
                                                  ['League', 'Division',
                                                   'NewLeague'])],
                                   verbose_feature_names_out=False)),
                ('model',
                 RandomForestRegressor(n_estimators=1000, n_jobs=-1,
                                       random_state=1))])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
# Permutation importance on held-out data
perm = permutation_importance(pipe, X_test, y_test, n_repeats=10,
                               random_state=1, n_jobs=-1)

feat_names = pipe.named_steps["prep"].get_feature_names_out()
top_idx = np.argsort(perm.importances_mean)[::-1][:10]
list(zip(feat_names[top_idx], np.round(perm.importances_mean[top_idx], 2)))
[('CRBI', np.float64(0.11)), ('CHits', np.float64(0.1)), ('RBI', np.float64(0.04)), ('CAtBat', np.float64(0.03)), ('CHmRun', np.float64(0.03)), ('Walks', np.float64(0.03)), ('CRuns', np.float64(0.02)), ('AtBat', np.float64(0.01)), ('Assists', np.float64(0.0)), ('CWalks', np.float64(0.0))]

3 SHAP: local explanations that add up

3.1 The idea

Permutation importance gives one number per feature (global). SHAP gives one number per feature per observation (local). For a single prediction:

\[ \hat f(x) = \underbrace{\phi_0}_{\text{baseline}} + \phi_1 + \phi_2 + \cdots + \phi_p. \]

Each \(\phi_j\) says: “Feature \(j\) pushed this prediction up (or down) by this much.”

SHAP values come from Shapley values in game theory — a principled way to fairly divide credit among players in a cooperative game.

Group project analogy. Three students (Alice, Bob, Carla) earn a group grade. Different subsets would have earned different grades:

Coalition Grade
nobody 50
{Alice} 70
{Bob} 60
{Alice, Bob} 85
{Alice, Bob, Carla} 92

To fairly credit Alice: look at every possible ordering the students could have joined; record Alice’s marginal contribution (grade jump when she joins); average across all orderings. That average is her Shapley value.

In ML: players = features, grade = model prediction. For tree ensembles, an algorithm called TreeSHAP computes this efficiently without enumerating all \(2^p\) subsets.

3.2 Key SHAP plots

  • Beeswarm: each dot = one observation; \(x\)-axis = SHAP value; color = feature value. Shows which features matter and the direction of their effect.
  • Waterfall: one observation explained step-by-step from baseline to final prediction.

3.3 Worked example (boosted trees → SHAP)

library(tidymodels)
library(shapviz)

set.seed(1)

hitters_xgb_rec <-
  recipe(Salary ~ ., data = hitters_train) %>%
  step_dummy(all_nominal_predictors()) %>%
  step_zv(all_predictors())

hitters_xgb_prep <- prep(hitters_xgb_rec)
p <- juice(hitters_xgb_prep) %>% select(-Salary) %>% ncol()

xgb_spec <-
  boost_tree(trees = 500, tree_depth = 3, learn_rate = 0.05,
             sample_size = 0.8, mtry = max(1, floor(0.8 * p))) %>%
  set_engine("xgboost") %>%
  set_mode("regression")

xgb_fit <- workflow() %>% add_recipe(hitters_xgb_rec) %>%
  add_model(xgb_spec) %>% fit(data = hitters_train)

xgb_engine <- extract_fit_engine(xgb_fit)
X_train_mat <- bake(hitters_xgb_prep, new_data = hitters_train) %>%
  select(-Salary) %>% as.matrix()

sv <- shapviz(xgb_engine, X_pred = X_train_mat)

# Global: which features matter most?
sv_importance(sv, kind = "beeswarm")

# Local: explain one player's predicted salary
sv_waterfall(sv, row_id = 1)

try:
    import shap
except ImportError:
    import sys, subprocess
    subprocess.check_call([sys.executable, "-m", "pip", "install", "shap"])
    import shap

from sklearn.ensemble import GradientBoostingRegressor

gbr = GradientBoostingRegressor(random_state=1)
pipe_gbr = Pipeline(steps=[("preprocess", preprocess), ("model", gbr)])
pipe_gbr.fit(X_train, y_train)
Pipeline(steps=[('preprocess',
                 ColumnTransformer(transformers=[('num',
                                                  Pipeline(steps=[('impute',
                                                                   SimpleImputer(strategy='median'))]),
                                                  ['AtBat', 'Hits', 'HmRun',
                                                   'Runs', 'RBI', 'Walks',
                                                   'Years', 'CAtBat', 'CHits',
                                                   'CHmRun', 'CRuns', 'CRBI',
                                                   'CWalks', 'PutOuts',
                                                   'Assists', 'Errors']),
                                                 ('cat',
                                                  Pipeline(steps=[('impute',
                                                                   SimpleImputer(strategy='most_frequent')),
                                                                  ('onehot',
                                                                   OneHotEncoder(handle_unknown='ignore',
                                                                                 sparse_output=False))]),
                                                  ['League', 'Division',
                                                   'NewLeague'])],
                                   verbose_feature_names_out=False)),
                ('model', GradientBoostingRegressor(random_state=1))])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
# Extract transformed features for SHAP
X_train_trans = pipe_gbr.named_steps["preprocess"].transform(X_train)
X_test_trans  = pipe_gbr.named_steps["preprocess"].transform(X_test)
feat_names = pipe_gbr.named_steps["preprocess"].get_feature_names_out()

X_train_shap = pd.DataFrame(X_train_trans, columns=feat_names)
X_test_shap  = pd.DataFrame(X_test_trans,  columns=feat_names)

model = pipe_gbr.named_steps["model"]

# TreeSHAP: background = training data, explain = test observations
explainer = shap.Explainer(model, X_train_shap)
shap_values = explainer(X_test_shap)

# Global: beeswarm
shap.plots.beeswarm(shap_values, max_display=10)

# Local: waterfall for one test observation
shap.plots.waterfall(shap_values[0])

4 Quick reference

Split / Gain Permutation SHAP
What it gives Importance ranking Importance ranking Local + global explanations
Global or local? Global Global Both
Model-agnostic? Trees only Any model Practical for trees; possible for others
Correlation caveat Shares credit Underestimates (substitution) Shares credit
Speed Free Moderate Fast for trees (TreeSHAP)