Decision Trees & Random Forests

Decision trees are a classic “nonparametric” technique for making decisions. Let’s build one quickly for the relationship between pubs and waiting times.

library(tidyverse)
library(knitr)
library(caret)
times <- read_csv('../data/waiting_times.csv')
times %>% head(5) %>% kable()
waiting pub pub_number long_wait drinktype
5.896570 The Berkeley 1 TRUE Beer
8.822136 The Berkeley 1 TRUE Beer
4.521891 The Berkeley 1 TRUE Beer
6.656845 The Berkeley 1 TRUE Cocktail
1.005883 The Berkeley 1 TRUE Cocktail

Before, we fit the very simple regression model:

wait_model <- lm(waiting ~ 0 + pub, data=times)

Which I wrote to you as: \[y_i = \begin{cases} \text{if at berkeley:} & \alpha_b \\ \text{if at gallimaufry:} & \alpha_g \\ \text{if at milk thistle:} & \alpha_{m} \end{cases} + e_i\] And which you saw gives you the same predictions as:

times %>% group_by(pub) %>% summarize(mean(waiting))
## # A tibble: 3 x 2
##   pub              `mean(waiting)`
##   <chr>                      <dbl>
## 1 The Berkeley                4.70
## 2 The Gallimaufry             1.40
## 3 The Milk Thistle            6.27

And corresponds to predicting the means of the following sets of observations:

ggplot(times, aes(x=pub, y=waiting, color=pub)) + 
  geom_point()

Another way we can express this is as a tree:

library(rpart)
library(rpart.plot)
dtree <- rpart(waiting ~ pub, data=times, 
               control=list(minsplit=1, # override the default minimum split size
                            minbucket=1 # override the default minimum leaf size
                            )
               )
prp(dtree)

You read decision trees top to bottom. The top part, called the root, indicates the “first split” that is most useful in describing the data. Then, reading downwards, each “split” represents one point at which the data is divided in order to make a prediction. You can see that if the Pub is the Gallimaufrey, we immediately split, and predict a value of 1.4 for the waiting time. This is because the Galli is quite a different distribution than the other two pubs; it’s variance is way lower, and its waiting time is also quite low. Then, the tree splits on whether the pub is the Berkeley, predicting 4.7 if so and 6.3 if not. These, again, are the means of those groups of observations.

So, decision trees are a way to “slice and dice” the data in order to make predictions about progressively smaller subsets of the data. They are intrinsically nonlinear, in that they generally don’t know anything about a functional form for the outcome variable, and they are usually very nonlinear in their predictions, since they can change their predictions rapidly based on very simple rules.

Another more complicated example: the house prices over time:

library(sf)
## Linking to GEOS 3.9.1, GDAL 3.2.1, PROJ 7.2.1
weca = read_sf("../data/weca.gpkg") %>% 
  st_drop_geometry() %>% 
  pivot_longer(price_dec_1995:price_dec_2018, 
               names_to=c(NA, 'quarter', 'year'), 
               values_to='price', names_sep='_')
dtree_weca <- rpart(price ~ year + quarter + la_name, 
                    data=weca %>% mutate(price = price/1000) # make the plot easier to read
                    )
prp(dtree_weca, faclen=20)

In this tree, the model branches immediately on whether or not we’re predicting after 2003. If we’re predicting before 2003, then we then branch on whether we’re before 2001. If so, we predict a median house price of £71,000. If not, then the median house is £126,000. Then, we’ll split on whether the house was sold between 2004 and 2014. If yes, then we’ll enter the middle branch, and if not (it’s sold after 2014), we’ll enter the right branch. Both of those then split on whether or not the house is in Bath or in either of Bristol or South Glos. If it’s in Bristol or South Glos and sold from 2004 to 2014, the prediction is £185k. But, if it’s in Bath during that time, the tree predicts 241. Alternatively, if it’s after 2014 in Bath, we predict £329,000. This strategy works for both regression problems (where we predict the mean of a given group within the split) or classification problems, where we predict a single class for the whole split.

The actual algorithsm to design decision trees vary from implementation to implementation. Often, they’re looking for “binary recursive splits”. That is, they’re looking for a single sp[lit for a feature upon which to divide the dataset in two.1 As an aside, the information about the sales quarter, here, does not significantly help the model because quarter is a “seasonal” variate. It helps you make predictions around the year mean, but it’s not gonna be useful by itself to predict houses; all “March” houses aren’t going to have similar prices absolutely, they’ll only have similar prices relative to their year mean. Decision trees generally don’t like features like this. A split is “good” if the variance of the two halves is substantially smaller than the variance of the data overall. Once a split is made, then another split is identified for each of the two halves of the data, ad nauseum.

Overfitting and random forests

As you can tell, this procedure may make this model a bit too sensitive to the input data. The precision with which data can be sliced and diced may make for very predictions with very low prediction bias, sure.2 Remember: we mean bias in the sense that the predictions are fairly close to reality, not that the predictions are systematically over (or under) the correct value. But, the prediction variance can be immense. For example, in the weca tree above, just changing from 2000 to 2001 will nearly double the predicted house price!

So, random forests provide a way to collect together decision trees to reduce this prediction variance, while keeping the bias acceptably low. Basically, a random forest works by fitting many trees3 Hence “forest,” a good name for a collection of trees 😏 , and giving each tree a slightly different view of the data. They do this by randomizing the features that the decision trees “see.” For example, let’s think about a random forest trained to predict song genre using danceability, tempo, acousticness, instrumentalness, energy, and valence.

songs = read_csv("../data/midterm-songs.csv")
## Rows: 32833 Columns: 23
## -- Column specification --------------------------------------------------------
## Delimiter: ","
## chr (10): track_id, track_name, track_artist, track_album_id, track_album_na...
## dbl (13): track_popularity, danceability, energy, key, loudness, mode, speec...
## 
## i Use `spec()` to retrieve the full column specification for this data.
## i Specify the column types or set `show_col_types = FALSE` to quiet this message.
dance_rf = train(playlist_genre ~ danceability + tempo + acousticness + instrumentalness + energy + valence, 
              data=songs, 
              method='ranger', # a faster random forest than method='rf'
              importance='impurity', # save measures of how important each feature is
              trControl= trainControl(method='repeatedcv', repeats=1, number=5))
## Growing trees.. Progress: 87%. Estimated remaining time: 4 seconds.
## Growing trees.. Progress: 60%. Estimated remaining time: 20 seconds.
## Growing trees.. Progress: 97%. Estimated remaining time: 1 seconds.
## Growing trees.. Progress: 66%. Estimated remaining time: 15 seconds.
## Growing trees.. Progress: 89%. Estimated remaining time: 3 seconds.
## Growing trees.. Progress: 64%. Estimated remaining time: 17 seconds.
## Growing trees.. Progress: 91%. Estimated remaining time: 3 seconds.
## Growing trees.. Progress: 53%. Estimated remaining time: 27 seconds.
## Growing trees.. Progress: 93%. Estimated remaining time: 4 seconds.

In this forest, each decision tree “sees” a different set of variables. Some trees may look at danceability, tempo, and energy (like we’ve seen before), but others may look at valence, instrumentalness, and acousticness. By ensuring that the trees see many different perspectives of the data, we can ensure that they are relatively independent. That is, they will not all use the same splits, nor even see the same basic picture of the dataset.4 li id="fn4">

This idea, that a collection of independent but “weak” prediction algorithms (like Decision Trees) can perform learning tasks well, is a central idea of ensemble learning, of which Random Forests are a central component.↩︎

This ensures that the trees won’t all use the same splits and (hopefully) will make it so that the decision trees don’t all make the same mistakes.5 li id="fn5">

In addition, another strategy that is used by default in Random Forest-trained models is called “Bootstrap Aggregation,” or “bagging,” for short. This is discussed in ISL 8.2.1, and (at its core) does a similar thing to the rows of the training data that the random forest does to the columns of a dataframe.↩︎

Importance

One very nice trait of random forests is that they allow you to get a measurement of the importance of a variable for the prediction. This is kind of like an effect size (from a regression framework). Usually, the “best” feature is rated 100. For any model fit in the caret framework, you can get feature importance using the varImp() function (short for variable Importance):

varImp(dance_rf)
## ranger variable importance
## 
##                  Overall
## tempo             100.00
## danceability       75.82
## energy             51.12
## valence            47.44
## acousticness       43.33
## instrumentalness    0.00

thus, you can see that the tempo variable is the most useful in predicting genre. The next most important feature is danceability. Energy, valence, and acousticness all hang around the same value. And, finally, we see that instrumentalness is not very useful at all in predicting the genre of a song’s playlist. A classic visualization of this data uses a lollipop plot:

varImp(dance_rf) %>% plot()

Scale-free inference

It’s also useful to note that, unlike KNN methods, decision trees are generally not sensitive to re-scaling of the data. This is because the trees find cut-points that reduce the variability of the data. Whether this cut point is shifted left or right (as is done in centering) or stretched relative to zero (as is done in re-scaling) will not matter for finding the cut point.6 li id="fn6">

show this for yourself using the preprocess= argument for train.↩︎

Your Turn

  1. Using the earnings.csv data (which describes yearly earnings in thousands of dollars as a function of demographic variables), fit a single decision tree (using the rpart) to predict earnings using all of the variables in the data except age_band.7 li id="fn7">

    Note, you may need to drop NA values…↩︎

  1. Train a random forest with the same model specification as in Part 1 using whatever crossvalidation procedure you like.