Random forests are a widely used ensemble learning method for classification or regression tasks. However, they are typically used as a black box prediction method that offers only little insight into their inner workings.
In this vignette, we illustrate how the stablelearner
package can be used to gain insight into this black box by visualizing
and summarizing the variable and cutpoint selection of the trees within
a random forest.
Recall that, in simple terms, a random forest is a tree ensemble, and
the forest is grown by resampling the training data and refitting trees
on the resampled data. Contrary to bagging, random forests have the
restriction that the number of feature variables randomly sampled as
candidates at each node of a tree (in implementations, this is typically
called the mtry argument) is smaller than the total number
of feature variables available. Random forests were introduced by Breiman (2001).
The stablelearner package was originally designed to
provide functionality for assessing the stability of tree learners and
other supervised statistical learners, both visually (Philipp, Zeileis, and Strobl 2016) and by means
of computing similarity measures (Philipp et al.
2018), on the basis of repeatedly resampling the training data
and refitting the learner.
However, in this vignette we are interested in visualizing the
variable and cutpoint selection of the trees within a random
forest. Therefore, contrary to the original design of the
stablelearner package, where the aim was to assess the
stability of a single original tree, we are not interested in
highlighting any single tree, but want all trees to be treated as equal.
As a result, some functions will later require to set the argument
original = FALSE. Moreover, this vignette does not cover
similarity measures for random forests, which are still work in
progress.
In all sections of this vignette, we are going to work with credit
scoring data where applicants are rated as "good" or
"bad", which will be introduced in Section 1.
In Section 2 we will cover the stablelearner package and
how to fit a random forest using the stabletree() function
(Section 2.1). In Section 2.2 we show how to summarize and visualize the
variable and cutpoint selection of the trees in a random forest.
In the final Section 3, we will demonstrate how the same summary and
visualizations can be produced when working with random forests that
were already fitted via the cforest() function of the
partykit package, the cforest() function of
the party package, the randomForest() function
of the randomForest package, or the ranger()
function of the ranger package.
Note that in the following, functions will be specified with the
double colon notation, indicating the package they belong to, e.g.,
partykit::cforest() denoting the cforest()
function of the partykit package, while
party::cforest() denotes the cforest()
function of the party package.
In all sections we are going to work with the german
dataset, which is included in the rchallenge package (note
that this is a transformed version of the German Credit data set with
factors instead of dummy variables, and corrected as proposed by Grömping (2019).
data("german", package = "rchallenge")The dataset consists of 1000 observations on 21 variables. For a full
description of all variables, see ?rchallenge::german. The
random forest we are going to fit in this vignette predicts whether a
person was classified as "good" or "bad" with
respect to the credit_risk variable using all other
available variables as feature variables. To allow for a lower runtime
we only use a subsample of the data (500 persons):
set.seed(2409)
dat <- droplevels(german[sample(seq_len(NROW(german)), size = 500), ])str(dat)## 'data.frame':    500 obs. of  21 variables:
##  $ status                 : Factor w/ 4 levels "no checking account",..: 2 4 4..
##  $ duration               : int  30 21 18 24 12 15 7 9 27 12 ...
##  $ credit_history         : Factor w/ 5 levels "delay in paying off in the pa"..
##  $ purpose                : Factor w/ 10 levels "others","car (new)",..: 1 1 4..
##  $ amount                 : int  4249 5003 1505 5743 3331 3594 846 2753 2520 1..
##  $ savings                : Factor w/ 5 levels "unknown/no savings account",....
##  $ employment_duration    : Factor w/ 5 levels "unemployed","< 1 yr",..: 1 3 3..
##  $ installment_rate       : Ord.factor w/ 4 levels ">= 35"<"25 <= ... < 35"<....
##  $ personal_status_sex    : Factor w/ 4 levels "male : divorced/separated",..:..
##  $ other_debtors          : Factor w/ 3 levels "none","co-applicant",..: 1 1 1..
##  $ present_residence      : Ord.factor w/ 4 levels "< 1 yr"<"1 <= ... < 4 yrs"..
##  $ property               : Factor w/ 4 levels "unknown / no property",..: 3 2..
##  $ age                    : int  28 29 32 24 42 46 36 35 23 22 ...
##  $ other_installment_plans: Factor w/ 3 levels "bank","stores",..: 3 1 3 3 2 3..
##  $ housing                : Factor w/ 3 levels "for free","rent",..: 2 2 3 3 2..
##  $ number_credits         : Ord.factor w/ 4 levels "1"<"2-3"<"4-5"<..: 2 2 1 2..
##  $ job                    : Factor w/ 4 levels "unemployed/unskilled - non-re"..
##  $ people_liable          : Factor w/ 2 levels "3 or more","0 to 2": 2 2 2 2 2..
##  $ telephone              : Factor w/ 2 levels "no","yes (under customer name"..
##  $ foreign_worker         : Factor w/ 2 levels "yes","no": 2 2 2 2 2 2 2 2 2 2..
##  $ credit_risk            : Factor w/ 2 levels "bad","good": 1 1 2 2 2 2 2 2 1..stablelearnerstablelearnerIn our first approach, we want to grow a random forest directly in
stablelearner. This is possible using conditional inference
trees (Hothorn, Hornik, and Zeileis 2006)
as base learners relying on the function ctree() of the
partykit package. This procedure results in a forest equal
to a random forest fitted via partykit::cforest().
To achieve this, we have to make sure that our initial
ctree, that will be repeatedly refitted on the resampled
data, is specified correctly with respect to the resampling method and
the number of feature variables randomly sampled as candidates at each
node of a tree (argument mtry). By default,
partykit::cforest() uses subsampling with a fraction of
0.632 and sets mtry = ceiling(sqrt(nvar)). In
our example, this would be 5, as this dataset includes 20
feature variables. Note that setting mtry equal to the
number of all feature variables available would result in bagging. In a
real analysis mtry should be tuned by means of, e.g.,
cross-validation.
We now fit our initial tree, mimicking the defaults of
partykit::cforest() (see ?partykit::cforest
and ?partykit::ctree_control for a description of the
arguments teststat, testtype,
mincriterion and saveinfo). The formula
credit_risk ~ . simply indicates that we use all remaining
variables of dat as feature variables to predict the
credit_risk of a person.
set.seed(2906)
ct_partykit <- partykit::ctree(credit_risk ~ ., data = dat,
  control = partykit::ctree_control(mtry = 5, teststat = "quadratic",
    testtype = "Univariate", mincriterion = 0, saveinfo = FALSE))We can now proceed to grow our forest based on this initial tree,
using stablelearner::stabletree(). We use subsampling with
a fraction of v = 0.632 and grow B = 100
trees. We set savetrees = TRUE, to be able to extract the
individual trees later:
set.seed(2907)
cf_stablelearner <- stablelearner::stabletree(ct_partykit,
  sampler = stablelearner::subsampling, savetrees = TRUE, B = 100, v = 0.632)Internally, stablelearner::stabletree() does the
following: For each of the 100 trees to be generated, the dataset is
resampled according to the resampling method specified (in our case
subsampling with a fraction of v = 0.632) and the function
call of our initial tree (which we labeled ct_partykit) is
updated with respect to this resampled data and reevaluated, resulting
in a new tree. All the 100 trees together then build the forest.
The following summary prints the variable selection frequency
(freq) as well as the average number of splits in each
variable (mean) over all 100 trees. As we do not
want to focus on our initial tree (remember that we just grew a forest,
where all trees are of equal interest), we set
original = FALSE, as already mentioned in the
introduction:
summary(cf_stablelearner, original = FALSE)## 
## Call:
## partykit::ctree(formula = credit_risk ~ ., data = dat, control = partykit::ctree_control(mtry = 5, 
##     teststat = "quadratic", testtype = "Univariate", mincriterion = 0, 
##     saveinfo = FALSE))
## 
## Sampler:
## B = 100 
## Method = Subsampling with 63.2% data
## 
## Variable selection overview:
## 
##                         freq mean
## status                  0.96 2.38
## duration                0.88 1.72
## credit_history          0.88 1.69
## employment_duration     0.87 1.62
## amount                  0.83 1.41
## age                     0.80 1.34
## property                0.75 1.19
## purpose                 0.72 1.23
## present_residence       0.71 1.11
## savings                 0.69 1.01
## housing                 0.69 1.00
## installment_rate        0.68 1.12
## number_credits          0.64 0.82
## job                     0.61 0.80
## personal_status_sex     0.60 0.87
## telephone               0.55 0.75
## other_installment_plans 0.53 0.65
## other_debtors           0.40 0.41
## people_liable           0.39 0.47
## foreign_worker          0.19 0.19Looking at the status variable (status of the existing
checking account of a person) for example, this variable was selected in
almost all 100 trees (freq = 0.99). Moreover, this variable
was often selected more than once for a split because the average number
of splits is at around 2.40.
Plotting the variable selection frequency is achieved via the
following command (note that cex.names allows us to specify
the relative font size of the x-axis labels):
barplot(cf_stablelearner, original = FALSE, cex.names = 0.6)To get a more detailed view, we can also inspect the variable selection patterns displayed for each tree. The following plot shows us for each variable whether it was selected (colored in darkgrey) in each of the 100 trees within the forest, where the variables are ordered on the x-axis so that top ranking ones come first:
image(cf_stablelearner, original = FALSE, cex.names = 0.6)This may allow for interesting observations, e.g., we observe that in
those trees where duration was not selected, both
credit_history and employment_duration were
almost always selected as splitting variables.
Finally, the plot() function allows us to inspect the
cutpoints and resulting partitions for each variable over all 100 trees.
Here we focus on the variables status,
present_residence, and duration:
plot(cf_stablelearner, original = FALSE,
  select = c("status", "present_residence", "duration"))Looking at the first variable status (an unordered
categorical variable), we are given a so called image plot, visualizing
the partition of this variable for each of the 100 trees.. We observe
that the most frequent partition is to distinguish between persons with
the values no checking account and
... < 0 DM vs. persons with the values
0 <= ... < 200 DM and
... >= 200 DM / salary for at least 1 year, but also
other partitions occur. The light gray color is used when a category was
no more represented by the observations left for partitioning in the
particular node.
For ordered categorical variables such as
present_residence, a barplot is given showing the frequency
of all possible cutpoints sorted on the x-axis in their natural order.
Here, the cutpoint between 1 <= ... < 4 years and
4 <= ... < 7 years is selected most frequently.
Lastly, for numerical variables a histogram is given, showing the
distribution of cutpoints. We observe that most cutpoints for the
variable duration occurred between 0 and 30, but there
appears to be high variance, indicating a smooth effect of this variable
rather than a pattern with a distinct change-point.
For a more detailed explanation of the different kinds of plots, Section 3 of Philipp, Zeileis, and Strobl (2016) is very helpful.
To conclude, the summary table and different plots helped us to gain some insight into the variable and cutpoint selection of the 100 trees within our forest. Finally, in case we want to extract individual trees, e.g., the first tree, we can do this via:
cf_stablelearner$tree[[1]]It should be noted that from a technical and performance-wise
perspective, there is little reason to grow a forest directly in
stablelearner, as the cforest()
implementations in partykit and especially in
party are more efficient. Nevertheless, it should be noted
that the approach of growing a forest directly in
stablelearner allows us to be more flexible with respect
to, e.g., the resampling method, as we could specify any method we want,
e.g., bootstrap, subsampling, samplesplitting, jackknife, splithalf or
even custom samplers. For a discussion why subsampling should be
preferred over bootstrap sampling, see Strobl et
al. (2007).
In this final section we cover how to work with random forests that
have already been fitted via the cforest() function of the
partykit package, the cforest() function of
the party package, the randomForest() function
of the randomForest package, or the ranger()
function of the ranger package.
Essentially, we just fit the random forest and then use
stablelearner::as.stabletree() to coerce the forest to a
stabletree object, which allows us to get the same summary
and plots as presented above.
Fitting a cforest with 100 trees using
partykit is straightforward:
set.seed(2908)
cf_partykit <- partykit::cforest(credit_risk ~ ., data = dat,
  ntree = 100, mtry = 5)stablelearner::as.stabletree() then allows us to coerce
this cforest and we can produce summaries and plots as
above (note that for plotting, we can now omit
original = FALSE, because the coerced forest has no
particular initial tree):
cf_partykit_st <- stablelearner::as.stabletree(cf_partykit)
summary(cf_partykit_st, original = FALSE)
barplot(cf_partykit_st)
image(cf_partykit_st)
plot(cf_partykit_st, select = c("status", "present_residence", "duration"))We do not observe substantial differences compared to growing the
forest directly in stablelearner (of course, this is the
expected behavior, because we tried to mimic the algorithm of
partykit::cforest() in the previous section), therefore we
will not display the results again.
Looking at the variable importance as reported by
partykit::varimp() shows substantial overlap, i.e., the
variable status yields a high importance:
partykit::varimp(cf_partykit)##                  status                duration          credit_history 
##             0.247738374             0.034059673             0.077873676 
##                 purpose                  amount                 savings 
##            -0.004825073             0.031340832             0.040444307 
##     employment_duration        installment_rate     personal_status_sex 
##             0.016477162            -0.022483572            -0.004295483 
##           other_debtors       present_residence                property 
##             0.020315372             0.001464200             0.008682788 
##                     age other_installment_plans                 housing 
##             0.058371493             0.046699228             0.031434510 
##          number_credits                     job           people_liable 
##             0.011540270            -0.022146762             0.014946051 
##               telephone          foreign_worker 
##            -0.011154765             0.024520948The coercing procedure described above is analogous for forests
fitted via party::cforest():
set.seed(2909)
cf_party <- party::cforest(credit_risk ~ ., data = dat,
  control = party::cforest_unbiased(ntree = 100, mtry = 5))cf_party_st <- stablelearner::as.stabletree(cf_party)
summary(cf_party_st, original = FALSE)
barplot(cf_party_st)
image(cf_party_st)
plot(cf_party_st, select = c("status", "present_residence", "duration"))Again, we do not observe substantial differences compared to
partykit::cforest(). This is the expected behavior, as
partykit::cforest() is a (pure R) reimplementation of
party::cforest() (implemented in C).
For forests fitted via randomForest::randomForest, we
can do the same as above. However, as these forests are not using
conditional inference trees as base learners, we can expect some
difference with respect to the results:
set.seed(2910)
rf <- randomForest::randomForest(credit_risk ~ ., data = dat,
  ntree = 100, mtry = 5)rf_st <- stablelearner::as.stabletree(rf)
summary(rf_st, original = FALSE)## 
## Call:
## randomForest(formula = credit_risk ~ ., data = dat, ntree = 100, 
##     mtry = 5)
## 
## Sampler:
## B = 100 
## Method = randomForest::randomForest
## 
## Variable selection overview:
## 
##                         freq  mean
## status                  1.00  4.79
## duration                1.00  7.82
## credit_history          1.00  4.23
## purpose                 1.00  6.27
## amount                  1.00 10.81
## savings                 1.00  4.32
## employment_duration     1.00  5.64
## installment_rate        1.00  4.66
## property                1.00  4.50
## age                     1.00  9.08
## personal_status_sex     0.99  3.47
## present_residence       0.99  4.40
## job                     0.97  3.17
## other_installment_plans 0.94  2.26
## number_credits          0.94  2.26
## housing                 0.93  2.41
## telephone               0.85  1.90
## other_debtors           0.83  1.57
## people_liable           0.73  1.09
## foreign_worker          0.43  0.48barplot(rf_st, cex.names = 0.6)image(rf_st, cex.names = 0.6)plot(rf_st, select = c("status", "present_residence", "duration"))We observe that for numerical variables the average number of splits
is much higher now, i.e., amount is selected at an average
of around 10 times. Their preference for variables offering many
cutpoints is a known drawback of Breiman and Cutler’s original Random
Forest algorithm, which random forests based on conditional inference
trees do not share. Note, however, that Breiman and Cutler did not
intend the variable selection frequencies to be used as a measure of the
relevance of the predictor variables, but have suggested a
permutation-based variable importance measure for this purpose. For more
details, see Hothorn, Hornik, and Zeileis
(2006), Strobl et al. (2007), and
Strobl, Malley, and Tutz (2009).
Finally, for forests fitted via ranger::ranger() (that
also implements Breiman and Cutler’s original algorithm), the coercing
procedure is again the same:
set.seed(2911)
rf_ranger <- ranger::ranger(credit_risk ~ ., data = dat,
  num.trees = 100, mtry = 5)rf_ranger_st <- stablelearner::as.stabletree(rf_ranger)
summary(rf_ranger_st, original = FALSE)
barplot(rf_ranger_st)
image(rf_ranger_st)
plot(rf_ranger_st, select = c("status", "present_residence", "duration"))As a final comment on computational performance, note that just as
stablelearner::stabletree(),
stablelearner::as.stabletree() allows for parallel
computation (see the arguments applyfun and
cores). This may be helpful when dealing with the coercion
of large random forests.