Supervised learning: decision trees


SURE 2024

Department of Statistics & Data Science
Carnegie Mellon University

Background

Tree-based methods

  • Can be applied to both regression and classification problems
  • A single decision tree is simple and useful for interpretation

    • Main idea: stratify/segment the predictor space into a number of simple regions

    • The set of splitting rules used to segment the predictor space can be summarized in a tree

  • Bagging, random forests, and boosting: grow multiple trees which are then combined to yield a single consensus prediction

  • Combining a large number of trees can result in great improvements in prediction accuracy, at the expense of some loss interpretation

Predictor space

  • Two predictor variables with binary response variable

  • Left: Linear boundaries that form rectangles will perform well in predicting response

  • Right: Circular boundaries will perform better

Regression trees

Example: baseball salary

Regression trees

Fit a regression tree to predict the salary of a baseball player using

  • Number of years they played in the major leagues

  • Number of hits they made in the previous year

Regression trees

  • At each node the label (e.g., \(X_j < t_k\) ) indicates that the left branch that comes from that split. The right branch is the opposite, e.g. \(X_j \geq t_k\).

  • The first internal node indicates that those to the left have less than 4.5 years in the major league, on the right have \(\geq\) 4.5 years.

  • The number on the top of the nodes indicates the predicted Salary, for example before doing any splitting, the average Salary for the whole dataset is 536 thousand dollars.

  • This tree has two internal nodes and three termninal nodes

Plotting a regression trees

Regression trees: partitioning the feature space

Regression trees: partitioning the feature space

Regression trees: partitioning the feature space

Terminology

  • The final regions (1), (2) and (3) are called terminal nodes

  • View the trees from upside down, the leaves are at the bottom

  • The splits are called internal nodes

Interpretation of results

  • Years is the most important factor in determining Salary

    • Players with less experience earn lower salaries
  • Given that a player is less experienced, the number of Hits seems to play little role in the Salary
  • Among players who have been in the major leagues for 4.5 years or more, the number of Hits made in the previous year does affect Salary

    • Players with more Hits tend to have higher salaries
  • This is probably an oversimplification, but it is very easy to interpret

Decision tree: a more complex example

Decision tree structure

Decision tree structure

Predict the response value for an observation by following its path along the tree

(See previous baseball salary example)

  • Decision trees are very easy to explain to non-statisticians.

  • Easy to visualize and thus easy to interpret without assuming a parametric form

Tree-building process: the big picture

  • Divide the training data into distinct and non-overlapping regions

    • The regions are found recursively using recursive binary splitting (i.e. asking a series of yes/no questions about the predictors)

    • Stop splitting the tree once a stopping criteria has been reached (e.g. maximum depth allowed)

  • For a given region, make the same prediction for all observations in that region

    • Regression tree: the average of the response values in the node

    • Classification tree: the most popular class in the node

  • Most popular algorithm: Classification and Regression Tree (CART)

Tree-building process: more details

  • Divide the predictor space into high-dimensional rectangles (or boxes) - for simplicity and ease of interpretation
  • Goal: find regions \(R_1, \dots, R_J\) that minimize the RSS

\[ \sum_{j=1}^J \sum_{i \in R_j} (y_i - \hat y_{R_j})^2 \]

where \(\hat y_{R_j}\) is the mean response for the training observations within the \(j\)th region

  • Challenge: it is computationally infeasible to consider every possible partition of the feature space into \(J\) regions

Tree-building process: more details

Solution: recursive binary splitting (a top-down, greedy approach)

  • top-down

    • begin at the top of the tree

    • then successively split the predictor space

    • each split is indicated via two new branches further down on the tree

  • greedy

    • at each step of the tree-building process, the best split is made at that particular step

    • rather than looking ahead and picking a split that will lead to a better tree in some future step

Tree-building process: more details

  • First, select predictor \(X_j\) and cutpoint \(s\) such that

    • splitting the predictor space into the regions \(\{X \mid X_j < s\}\) and \(\{X \mid X_j \ge s\}\) will yield the greatest possible reduction in RSS
  • Next, repeat the process

    • looking for the best predictor and best cutpoint in order to split the data further

    • so as to minimize the RSS within each of the resulting regions

    • BUT…split one of the two previously identified region (instead of the entire predictor space)

  • Continue the process until a stopping criterion is reached (i.e. how complex should the tree be?)

    • maximum tree depth

    • minimum node size

Pruning a tree

  • The process described above may produce good training performance but poor test set performance (i.e. overfitting)
  • Solution: tree pruning

    • Grow a very large complicated tree

    • Then prune back to an optimal subtree

  • Tuning parameter: cost complexity parameter \(\alpha\)

    • Minimize \[\text{RSS} + \alpha | T|\] where \(| T|\) is the number of terminal nodes of the tree \(T\)

    • Controls a trade-off between the subtree’s complexity and its fit to the training data

    • How do we select the optimal value?

Pruning a tree

Pruning a tree

Classification trees

  • Predict that each observation belongs to the most commonly occurring class in the region to which it belongs
  • Just like regression trees, use recursive binary splitting to grow a classification tree
  • Instead of RSS, use the Gini index \[G = \sum_{k=1}^K \hat p_{jk} (1 - \hat p_{jk})\] where \(\hat p_{jk}\) is proportion of observations in the \(j\)th region that are from the \(k\)th class

    • A measure of total variance across the \(K\) classes

    • A measure of node purity (or node impurity)

      • small value: a node contains mostly observations from a single class

Examples

Predicting MLB home run probability

library(tidyverse)
theme_set(theme_light())
batted_balls <- read_csv("https://raw.githubusercontent.com/36-SURE/36-SURE.github.io/main/data/batted_balls.csv")
glimpse(batted_balls)
Rows: 19,852
Columns: 12
$ is_hr         <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
$ launch_angle  <dbl> 14, 6, -31, 25, 5, 33, 42, 4, -10, 49, -22, -69, 45, 12,…
$ launch_speed  <dbl> 73.9, 104.6, 87.7, 92.9, 90.7, 102.0, 85.7, 111.9, 102.1…
$ bat_speed     <dbl> 69.25944, 69.95896, 75.73321, 73.28100, 67.64741, 73.959…
$ swing_length  <dbl> 7.66206, 7.63497, 6.75095, 7.58858, 6.61407, 7.15180, 7.…
$ plate_x       <dbl> -0.81, -0.11, -0.22, -0.30, -0.30, -0.28, 0.82, -0.37, 0…
$ plate_z       <dbl> 1.96, 3.03, 2.23, 2.92, 2.49, 2.23, 2.44, 2.39, 1.51, 2.…
$ inning        <dbl> 9, 9, 9, 9, 8, 8, 8, 8, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 5,…
$ balls         <dbl> 0, 0, 0, 2, 2, 0, 2, 2, 2, 2, 0, 1, 0, 1, 1, 0, 1, 0, 2,…
$ strikes       <dbl> 0, 0, 0, 0, 0, 1, 1, 0, 2, 1, 1, 1, 1, 1, 0, 0, 2, 0, 2,…
$ is_stand_left <dbl> 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0,…
$ is_throw_left <dbl> 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1,…
set.seed(123)
train <- batted_balls |> 
  slice_sample(prop = 0.5)
test <- batted_balls |> 
  anti_join(train)

Model training with caret

library(caret)
hr_tree <- train(as.factor(is_hr) ~ ., method = "rpart", tuneLength = 20,
                 trControl = trainControl(method = "cv", number = 10),
                 data = train)
# str(hr_tree)
ggplot(hr_tree)

Display the final tree model

library(rpart.plot)
hr_tree |> 
  pluck("finalModel") |> 
  rpart.plot()

Evaluate predictions

  • In-sample evaluation
train |> 
  mutate(pred = predict(hr_tree, newdata = train)) |> 
  summarize(correct = mean(is_hr == pred))
# A tibble: 1 × 1
  correct
    <dbl>
1   0.977
  • Out-of-sample evaluation
test |> 
  mutate(pred = predict(hr_tree, newdata = test)) |> 
  summarize(correct = mean(is_hr == pred))
# A tibble: 1 × 1
  correct
    <dbl>
1   0.972

Variable importance

library(vip)
hr_tree |> 
  vip()

Partial dependence plot

Partial dependence of home run outcome on launch speed and launch angle (individually)

library(pdp)
hr_tree |> 
  partial(pred.var = "launch_speed", 
          which.class = 2, 
          prob = TRUE) |> 
  autoplot()

hr_tree |> 
  partial(pred.var = "launch_angle", 
          which.class = 2, 
          prob = TRUE) |> 
  autoplot()

Partial dependence plot

Partial dependence of home run outcome on launch speed and launch angle (jointly)

hr_tree |>
  partial(pred.var = c("launch_speed", "launch_angle"), which.class = 2, prob = TRUE) |>
  autoplot(contour = TRUE)

Appendix: code to build dataset

savant <- read_csv("https://raw.githubusercontent.com/36-SURE/36-SURE.github.io/main/data/savant.csv")
batted_balls <- savant |> 
  filter(type == "X") |> 
  mutate(is_hr = as.numeric(events == "home_run"),
         is_stand_left = as.numeric(stand == "L"),
         is_throw_left = as.numeric(p_throws == "L")) |> 
  filter(!is.na(launch_angle), !is.na(launch_speed), !is.na(is_hr)) |> 
  select(is_hr, launch_angle, launch_speed, bat_speed, swing_length, 
          plate_x, plate_z, inning, balls, strikes, is_stand_left, is_throw_left) |> 
  drop_na()