Supervised learning: variable selection


SURE 2024

Department of Statistics & Data Science
Carnegie Mellon University

Background

Setting

Suppose we wish to learn a linear model. Our estimate (denoted by hats) is \[ \hat{Y} = \hat{\beta}_0 + \hat{\beta}_1 X_1 + \cdots + \hat{\beta}_p X_p \]

Why would we attempt to select a subset of the \(p\) variables?

  • To improve prediction accuracy

    • Eliminating uninformative predictors can lead to lower variance in the test set MSE, at the expense of a slight increase in bias
  • To improve model interpretability

    • Eliminating uninformative predictors is obviously a good thing when your goal is to tell the story of how your predictors are associated with your response.

Best subset selection

  • Start with the null model \(\mathcal{M}_0\) (intercept-only) that has no predictors
    • Predict the sample mean for each observation
  • For \(k = 1, 2, \dots, p\) (each possible number of predictors)
    • Fit all \(\displaystyle \binom{p}{k} = \frac{p!}{k!(p-k)!}\) with exactly \(k\) predictors
    • Pick the best among these \(\displaystyle \binom{p}{k}\) models, call it \(\mathcal{M}_k\)
      • “Best” can be based on cross-validation error, highest adjusted \(R^2\), etc.
      • “It depends on your loss function”
  • Select a single best model from among \(\mathcal{M}_0, \dots, \mathcal{M}_p\)

Best subset selection

This is not typically used in research!

  • only practical for a smaller number of variables

  • computationally infeasible for a large number of predictors

  • arbitrary way of defining best and ignores prior knowledge about potential predictors

Use the shoe leather approach

Do not turn off your brain!

  • algorithms can be tempting but they are NOT substitutes!
  • you should NOT avoid the hard work of EDA in your modeling efforts

Variable selection is a difficult problem!

  • Like much of a statistics research, there is not one unique, correct answer

Justify which predictors used in modeling based on:

  • domain knowledge
  • context
  • extensive EDA
  • model assessment based on holdout predictions

Covariance and correlation

  • Covariance is a measure of the linear dependence between two variables

    • To be “uncorrelated” is not the same as to be “independent”

    • Independence means there is no dependence, linear or otherwise

    • If two variables are independent, then they are also uncorrelated. However, if two variables are uncorrelated, then they can still be dependent.

    • Recommended reading

  • Correlation is a normalized form of covariance, ranges from -1 to 1

    • -1 means one variable linearly decreases absolutely in value while the other increases

    • 0 means no linear dependence

    • 1 means one variable linear increases absolutely while the other increases

Case study

Data: Hollywood Movies (2012-2018)

library(tidyverse)
theme_set(theme_light())
movies <- read_csv("https://raw.githubusercontent.com/36-SURE/36-SURE.github.io/main/data/movies.csv")
glimpse(movies)
Rows: 1,295
Columns: 15
$ Movie            <chr> "2016: Obama's America", "21 Jump Street", "A Late Qu…
$ LeadStudio       <chr> "Rocky Mountain Pictures", "Sony Pictures Releasing",…
$ RottenTomatoes   <dbl> 26, 85, 76, 90, 35, 27, 91, 56, 11, 44, 93, 63, 87, 9…
$ AudienceScore    <dbl> 73, 82, 71, 82, 51, 72, 62, 47, 47, 63, 82, 51, 63, 9…
$ Genre            <chr> "Documentary", "Comedy", "Drama", "Drama", "Horror", …
$ TheatersOpenWeek <dbl> 1, 3121, 9, 7, 3108, 3039, 132, 245, 2539, 3192, 3, 1…
$ OpeningWeekend   <dbl> 0.03, 36.30, 0.08, 0.04, 16.31, 24.48, 1.14, 0.70, 11…
$ BOAvgOpenWeekend <dbl> 30000, 11631, 8889, 5714, 5248, 8055, 8636, 2857, 449…
$ Budget           <dbl> 3.0, 42.0, NA, NA, 68.0, 12.0, NA, 7.5, 35.0, 50.0, 1…
$ DomesticGross    <dbl> 33.35, 138.45, 1.56, 1.55, 37.52, 70.01, 1.99, 3.01, …
$ WorldGross       <dbl> 33.35, 202.81, 6.30, 7.60, 137.49, 82.50, 3.59, 8.54,…
$ ForeignGross     <dbl> 0.00, 64.36, 4.74, 6.05, 99.97, 12.49, 1.60, 5.53, 9.…
$ Profitability    <dbl> 1334.00, 482.88, NA, NA, 202.19, 687.50, NA, 113.87, …
$ OpenProfit       <dbl> 1.20, 86.43, NA, NA, 23.99, 204.00, NA, 9.33, 32.57, …
$ Year             <dbl> 2012, 2012, 2012, 2012, 2012, 2012, 2012, 2012, 2012,…
movies <- movies |> 
  janitor::clean_names() |> 
  select(audience_score, rotten_tomatoes, theaters_open_week, opening_weekend, budget, domestic_gross, foreign_gross) |> 
  drop_na()

Modeling audience rating

Interested in modeling the audience rating of a movie

movies |> 
  ggplot(aes(x = audience_score)) +
  geom_histogram(fill = "gray", color = "white")

Correlation matrix of audience score and candidate predictors

  • Interested in audience_score relationships with critics rating, opening weekend statistics, budget, gross income of viewers

  • Plot correlation matrix with ggcorrplot

# can also use corrr package
# library(corrr)
# movies |> 
#   correlate(diagonal = 1) |> # get correlation matrix
#   stretch() |>  # similar to pivot_longer
#   ggplot(aes(x, y, fill = r)) +
#   geom_tile()
library(ggcorrplot)
movies_cor <- cor(movies)
ggcorrplot(movies_cor)

Customize the appearance of the correlation matrix

  • Avoid redundancy by only using one half of matrix with type

  • Add correlation value labels using lab (but round first!)

  • Can arrange variables based on clustering…

movies_cor |> 
  round(2) |> 
  ggcorrplot(hc.order = TRUE, type = "lower", lab = TRUE)

Clustering variables using the correlation matrix

Apply hierarchical clustering to variables instead of observations

  • Select the explanatory variables of interest from our data
movies_feat <- movies |> 
  select(-audience_score)
  • Compute correlation matrix of these variables
feat_cor <- cor(movies_feat)
  • Correlations measure similarity and can be negative BUT distances measure dissimilarity and CANNOT

  • Convert your correlations to all be \(\geq 0\): e.g., \(1 - |\rho|\) (which drops the sign) or \(1 - \rho\)

cor_dist_matrix <- 1 - abs(feat_cor)
  • Convert to distance matrix before using hclust
cor_dist_matrix <- as.dist(cor_dist_matrix)

Clustering variables using the correlation matrix

  • Cluster variables using hclust() as before

  • Use ggdendro to quickly visualize dendrogram

library(ggdendro)
movies_feat_hc <- hclust(cor_dist_matrix, "complete")
ggdendrogram(movies_feat_hc,
             rotate = TRUE,
             size = 2)

Clustering variables using the correlation matrix

library(dendextend)
cor_dist_matrix |>
  hclust() |>
  as.dendrogram() |>
  set("branches_k_col", k = 2) |>
  set("labels_cex", 1) |>
  ggplot(horiz = TRUE)

Back to the response variable…

Pairs plot using GGally

  • always look at your data

  • correlation values alone are not enough!

  • what if a variable displayed a nonlinear (e.g. quadratic) relationship?

library(GGally)
ggpairs(movies)

Back to the response variable…

Which variables matter for modeling audience rating?

Use 10-fold cross-validation to assess how well different sets of variables perform in predicting audience_score?

Create a column of test fold assignments to our dataset with the sample() function:

set.seed(100)
k <- 10
movies <- movies |>
  mutate(test_fold = sample(rep(1:k, length.out = n())))

# table(movies$test_fold)  

Always remember to set.seed() prior to performing \(k\)-fold cross-validation!

Writing a function for \(k\)-fold cross-validation

get_cv_pred <- function(model_formula, data = movies) {
  # generate holdout predictions for every row
  get_test_pred <- function(fold) {
  
    # separate test and training data
  
    test_data <- data |> filter(test_fold == fold)
    train_data <- data |> filter(test_fold != fold)
    train_fit <- lm(as.formula(model_formula), data = train_data)
  
    # return test results
    test_res <- tibble(test_pred = predict(train_fit, newdata = test_data),
                       test_actual = test_data$audience_score,
                       test_fold = fold)
    return(test_res)
  }
  
  test_pred <- map(1:k, get_test_pred) |> 
    bind_rows()
  
  return(test_pred)
}

Function enables easy generation of holdout analysis

all_pred <- get_cv_pred(
  "audience_score ~ rotten_tomatoes + theaters_open_week + opening_weekend + budget + domestic_gross + foreign_gross"
)
all_no_critics_pred <- get_cv_pred(
  "audience_score ~ theaters_open_week + opening_weekend + budget + domestic_gross + foreign_gross"
)
critics_only_pred <- get_cv_pred("audience_score ~ rotten_tomatoes")
opening_only_pred <- get_cv_pred("audience_score ~ budget + theaters_open_week + opening_weekend + rotten_tomatoes")
gross_only_pred <- get_cv_pred("audience_score ~ domestic_gross + foreign_gross + rotten_tomatoes")
int_only_pred <- get_cv_pred("audience_score ~ 1")

Can then summarize together for a single plot:

bind_rows(
  mutate(all_pred, mod = "All"),
  mutate(all_no_critics_pred, mod = "All but critics"),
  mutate(critics_only_pred, mod = "Critics only"),
  mutate(opening_only_pred, mod = "Opening only"),
  mutate(gross_only_pred, mod = "Gross income only"),
  mutate(int_only_pred, mod = "Intercept only")
) |>
  group_by(mod) |>
  summarize(
    rmse = sqrt(mean((test_actual - test_pred)^2))
  ) |>
  mutate(mod = fct_reorder(mod, rmse)) |>
  ggplot(aes(x = rmse, y = mod)) +
  geom_point()

Fit selected model on all data and view summary

all_fit <- lm(
  audience_score ~ rotten_tomatoes + theaters_open_week + opening_weekend + budget + domestic_gross + foreign_gross, 
  data = movies
)
# summary(all_fit)
library(broom)
all_fit |> tidy()
# A tibble: 7 × 5
  term               estimate std.error statistic   p.value
  <chr>                 <dbl>     <dbl>     <dbl>     <dbl>
1 (Intercept)        40.5      1.35        29.9   1.66e-142
2 rotten_tomatoes     0.395    0.0160      24.7   1.74e-106
3 theaters_open_week -0.00214  0.000393    -5.44  6.55e-  8
4 opening_weekend    -0.114    0.0362      -3.16  1.62e-  3
5 budget              0.0168   0.0116       1.45  1.46e-  1
6 domestic_gross      0.0690   0.0125       5.52  4.16e-  8
7 foreign_gross       0.00415  0.00497      0.835 4.04e-  1
  • But… do NOT show a coefficients table in a presentation (well… it depends)

  • A nicely formatted table of the summary output is more appropriate in a written report

  • Packages that can take a model object and produce a neat table summary: kableExtra, texreg, modelsummary, gtsummary, huxtable, sjPlot

Coefficient plot (with uncertainty quantification)

all_fit |> 
  tidy(conf.int = TRUE) |> 
  filter(term != "(Intercept)") |> 
  ggplot(aes(x = estimate, y = term))  +
  geom_point(size = 4) +
  geom_errorbar(aes(xmin = conf.low, xmax = conf.high, width = 0.2)) +
  geom_vline(xintercept = 0, linetype = "dashed", color = "red")

Interpretation

all_fit |> tidy(conf.int = TRUE)
# A tibble: 7 × 7
  term               estimate std.error statistic   p.value conf.low conf.high
  <chr>                 <dbl>     <dbl>     <dbl>     <dbl>    <dbl>     <dbl>
1 (Intercept)        40.5      1.35        29.9   1.66e-142 37.8      43.1    
2 rotten_tomatoes     0.395    0.0160      24.7   1.74e-106  0.363     0.426  
3 theaters_open_week -0.00214  0.000393    -5.44  6.55e-  8 -0.00291  -0.00137
4 opening_weekend    -0.114    0.0362      -3.16  1.62e-  3 -0.186    -0.0434 
5 budget              0.0168   0.0116       1.45  1.46e-  1 -0.00587   0.0395 
6 domestic_gross      0.0690   0.0125       5.52  4.16e-  8  0.0445    0.0935 
7 foreign_gross       0.00415  0.00497      0.835 4.04e-  1 -0.00560   0.0139 

For rotten_tomatoes:

  • Coefficient interpretation: Among the Hollywood movies, each additional score in Rotten Tomatoes rating is associated with a 0.395 higher score in audience rating, on average ((95% CI [0.363, 0.426]))

  • Test for the coefficient: With \(t=24.7\) and \(p\)-value \(<0.001\), we have sufficient evidence that audience rating and Rotten Tomatoes critic rating are related, after accounting for other variables in the model (i.e. there was a statistically significant association between audience and critic ratings).

Making tables

broom::tidy()and kable()

See also: kableExtra

all_fit |> 
  tidy() |> 
  knitr::kable(digits = 3,
               col.names = c("Term", "Estimate", "SE", "t", "p-value"))
Term Estimate SE t p-value
(Intercept) 40.456 1.353 29.890 0.000
rotten_tomatoes 0.395 0.016 24.691 0.000
theaters_open_week -0.002 0.000 -5.442 0.000
opening_weekend -0.114 0.036 -3.161 0.002
budget 0.017 0.012 1.454 0.146
domestic_gross 0.069 0.012 5.525 0.000
foreign_gross 0.004 0.005 0.835 0.404

broom::tidy()and gt()

library(gt)
all_fit |> 
  tidy() |> 
  gt() |> 
  fmt_number(columns = where(is.numeric), decimals = 2) |> 
  cols_label(term = "Term",
             estimate = "Estimate",
             std.error = "SE",
             statistic = "t",
             p.value = md("*p*-value"))
Term Estimate SE t p-value
(Intercept) 40.46 1.35 29.89 0.00
rotten_tomatoes 0.39 0.02 24.69 0.00
theaters_open_week 0.00 0.00 −5.44 0.00
opening_weekend −0.11 0.04 −3.16 0.00
budget 0.02 0.01 1.45 0.15
domestic_gross 0.07 0.01 5.52 0.00
foreign_gross 0.00 0.00 0.84 0.40

gtsummary

Use tbl_regression() function

library(gtsummary)
all_fit |> 
  tbl_regression() |> 
  bold_p() |> 
  bold_labels()
Characteristic Beta 95% CI1 p-value
rotten_tomatoes 0.39 0.36, 0.43 <0.001
theaters_open_week 0.00 0.00, 0.00 <0.001
opening_weekend -0.11 -0.19, -0.04 0.002
budget 0.02 -0.01, 0.04 0.15
domestic_gross 0.07 0.04, 0.09 <0.001
foreign_gross 0.00 -0.01, 0.01 0.4
1 CI = Confidence Interval

gt (and gtExtras)

A gt example

For more in-depth tutorials, see here and here

# https://bradcongelio.com/nfl-analytics-with-r-book/04-nfl-analytics-visualization.html
cpoe <- read_csv("http://nfl-book.bradcongelio.com/pure-cpoe")
cpoe_gt <- cpoe |> 
  select(passer, season, total_attempts, mean_cpoe) |> 
  gt(rowname_col = "passer") |> 
  fmt_number(columns = c(mean_cpoe), decimals = 2) |>
  data_color(columns = c(mean_cpoe),
             fn = scales::col_numeric(palette = c("#FEE0D2", "#67000D"), domain = NULL)) |> 
  cols_align(align = "center", columns = c("season", "total_attempts")) |> 
  tab_stubhead(label = "Quarterback") |> 
  cols_label(season = "Season",
             total_attempts = "Attempts",
             mean_cpoe = "Mean CPOE") |> 
  tab_header(title = md("**Average CPOE in Pure Passing Situations**"),
             subtitle = md("*For seasons between 2010 and 2022*")) |> 
  tab_source_note(source_note = md("Example adapted from the book<br>*An Introduction to NFL Analytics with R*")) |> 
  gtExtras::gt_theme_espn()

# gtsave(cpoe_gt, "cpoe_gt.png")

A gt example

Average CPOE in Pure Passing Situations
For seasons between 2010 and 2022
Quarterback Season Attempts Mean CPOE
P.Rivers 2013 291 10.89
P.Mahomes 2018 261 9.29
A.Rodgers 2020 242 8.57
R.Wilson 2013 231 8.47
R.Wilson 2018 277 8.45
D.Brees 2011 311 8.29
M.Ryan 2018 315 8.23
M.Ryan 2012 303 8.11
R.Wilson 2015 279 7.52
J.Burrow 2021 343 7.12
Example adapted from the book
An Introduction to NFL Analytics with R