Repository on GitHub

To cite: Kilpatrick, A. (2024-06-06). Decision Tree-Based Classification Algorithms. [the day you viewed the site]. Retrieved from: this web address

Introduction

In this tutorial, we will learn how to apply decision tree-based algorithms to classification problems. We begin by examining the structure and mechanics of decision trees. We then move into ensemble learning using the random forest algorithm and begin to learn the hyper-parameter tuning process. After that, we will learn how to construct fully-tuned extreme gradient boosted (XGBoost) algorithms. These cutting-edge techniques push the boundaries of decision tree-based models, delivering unparalleled performance and efficiency in classification tasks.

#Modify the base directory to where you have stored the files.
base_dir <- "/Users/nucb/Documents/GitHub/Decision-Tree-Based-Classification-Algorithms"
setwd(base_dir)

Required Packages

Before we begin, let’s ensure we have the necessary R packages installed and updated. The following code will install any missing packages required for this tutorial. At time of writing, mlr has been replaced by mlr3; however mlr is a dependency so you will likely receive a warning.

# Set CRAN mirror
options(repos = c(CRAN = "https://cran.rstudio.com"))

# Define required packages
required_packages <- c(
  "rpart", 
  "rpart.plot", 
  "ranger", 
  "tuneRanger", 
  "xgboost", 
  "caret", 
  "ggplot2",
  "Rcpp"
)

# Install and load the required packages
for (pkg in required_packages) {
  if (!requireNamespace(pkg, quietly = TRUE)) {
    install.packages(pkg, dependencies = TRUE, type = "binary")
  }
  library(pkg, character.only = TRUE)
}
## Loading required package: mlrMBO
## Loading required package: mlr
## Loading required package: ParamHelpers
## Warning message: 'mlr' is in 'maintenance-only' mode since July 2019.
## Future development will only happen in 'mlr3'
## (<https://mlr3.mlr-org.com>). Due to the focus on 'mlr3' there might be
## uncaught bugs meanwhile in {mlr} - please consider switching.
## Loading required package: smoof
## Loading required package: checkmate
## Loading required package: parallel
## Loading required package: lubridate
## 
## Attaching package: 'lubridate'
## The following objects are masked from 'package:base':
## 
##     date, intersect, setdiff, union
## Loading required package: lhs
## Loading required package: ggplot2
## Loading required package: lattice
## 
## Attaching package: 'caret'
## The following object is masked from 'package:mlr':
## 
##     train
# Update Rcpp
if (!requireNamespace("Rcpp", quietly = TRUE)) {
  install.packages("Rcpp", dependencies = TRUE, type = "binary")
} else {
  update.packages("Rcpp", ask = FALSE)
}
library(Rcpp)

Why Use Machine Learning?

In the context of classification tasks such as the one discussed in this tutorial, machine learning techniques offer distinct advantages over traditional statistical hypothesis tests like regression analysis. While regression analysis is well-suited for exploring relationships between variables and making predictions based on continuous outcomes, machine learning excels in handling complex, high-dimensional data and making predictions or classifications based on patterns learned from the data itself. In the case of Pokémon evolution classification, machine learning algorithms such as decision trees, random forests, and XGBoost can effectively handle the intricate relationships between Pokémon features (such as sound symbolism) and their evolutionary stages, without making strict assumptions about the underlying data distribution. Additionally, machine learning techniques have the flexibility to handle non-linear relationships and interactions between features, which may be challenging to capture using traditional regression models. Moreover, machine learning models are often more robust to outliers and noise in the data, making them well-suited for real-world datasets where the assumptions of classical statistical methods may not hold. Therefore, in scenarios where the goal is to accurately classify or predict categorical outcomes based on complex data structures, machine learning approaches provide a powerful alternative to traditional statistical methods.

Processing Time

One of the downsides of machine learning is that it is computationally intensive. That is, constructing models takes a long time and can use a lot of your computer’s resources. We can speed up the construction of the models by allocating more cores from your central processing unit (CPU) to their construction. You may also allocate cores from your graphics processing unit (GPU); however, this process is more complicated and can vary depending on your GPU so I have not included this option in this tutorial. The following code will check how many cores your CPU has and will allocate 80% of them to constructing the models. You can increase and decrease this allocation depending on what other programs you are running and whether you want to continue to use your computer while the models are constructing. Another option is to reduce the size of the algorithms. For example, the algorithms in this tutorial consist of 100 decision trees which is a very small number for these types of algorithms.

# Determine the number of CPU cores
num_cores <- parallel::detectCores()
# Calculate the number of cores to be used (80% of available cores)
num_cores_to_use <- ceiling(0.8 * num_cores)
# Set the number of cores for parallel processing
options(mc.cores = num_cores_to_use)

The Data

The data for this tutorial comes from a sound symbolism study (Kilpatrick, Ćwiek & Kawahara, 2023) where my colleagues and I explored how Pokémon evolution was expressed in Pokémon names. We took the names of Pokémon and converted them into a count of the number of times each sound occurs in each name. We discarded any Pokémon that was not part of an evolutionary tree and any mid-stage Pokémon because we were only interested in classifying pre- and post-evolution Pokémon. We also ran an elicitation experiment where we asked Japanese university students to name previously unseen Pokémon-like images like those below.

Figure 1: Samples of the images used in the elicitation experiment. Note that the middle image was not used because that is an example of a mid-stage variant. Images presented here with permission of the artist Deviant Art user: Involuntary-Twitch.

Jpoke <- read.csv(file.path(base_dir, "Japanese_Pokemon.csv"))
Jpoke$Evolution <- factor(Jpoke$Evolution)
options(width = 150)
head(Jpoke)
##   Designation       Name     Evolution a i ɯ e o ɸ ts b d g gj z dʒ p t C k kj s ʃ h m mj n nj ɾ j w Q N Long
## 1    Official ɸuʃigidane pre-evolution 1 2 1 1 0 1  0 0 1 1  0 0  0 0 0 0 0  0 0 1 0 0  0 1  0 0 0 0 0 0    0
## 2    Official   hitokage pre-evolution 1 1 0 1 1 0  0 0 0 1  0 0  0 0 1 0 1  0 0 0 1 0  0 0  0 0 0 0 0 0    0
## 3    Official   zenigame pre-evolution 1 1 0 2 0 0  0 0 0 1  0 1  0 0 0 0 0  0 0 0 0 1  0 1  0 0 0 0 0 0    0
## 4    Official   kjatapi: pre-evolution 2 1 0 0 0 0  0 0 0 0  0 0  0 1 1 0 0  1 0 0 0 0  0 0  0 0 0 0 0 0    1
## 5    Official    bi:doru pre-evolution 0 1 1 0 1 0  0 1 1 0  0 0  0 0 0 0 0  0 0 0 0 0  0 0  0 1 0 0 0 0    1
## 6    Official      poQpo pre-evolution 0 0 0 0 2 0  0 0 0 0  0 0  0 2 0 0 0  0 0 0 0 0  0 0  0 0 0 0 1 0    0

Decision Trees

In the first part of this tutorial, we will learn how to construct decision trees using the npart and npart.plot packages. While decision trees serve as effective and intuitive tools for data visualization, they are not appropriate for hypothesis testing. Decision trees are included in this tutorial in order to provide the foundation for the random forest and XGBoost models in the following sections.

Decision trees decide on features by evaluating the data at each node and selecting the feature that best splits the data into distinct groups based on certain criteria, such as maximizing homogeneity within groups and maximizing separation between groups. This process continues recursively until a stopping criterion is met, resulting in a tree structure that represents the decision-making process.

library(rpart)
library(rpart.plot)
Jtree <- rpart(Evolution ~ . - Designation - Name, data = Jpoke)
rpart.plot(Jtree)

Each rounded rectangle within the decision tree represents a node, serving as a decision point in the classification process. At the top sits the root node, branching out into internal nodes and eventually leading to the leaf nodes at the bottom. The label on the root node notes “post-evolution” followed by the number 0.49 and then 100%. The 0.49 denotes the proportion of pre-evolution samples in the group, while 100% tell us that all of the samples are passing through this node. If the classification were solely based on this root node, all samples would be deemed post-evolution Pokémon, given the slight distribution skew towards post-evolution in the data set. Hence, the root node is designated post-evolution.

The initial decision, represented by “d >= 1,” tells us that “d” is the most important sound according to the model. If a sample contains one or more “d” sounds, it is automatically classified as a post-evolution Pokémon. The subsequent branch to the far left leads to a terminal node labeled “post-evolution,” with numbers 0.30 and 14%. Here, 14% of the samples contain one or more “d” sounds, of which 30% are pre-evolution Pokémon so 70% of the samples that have at least one “d” sound are post-evolution.

For samples lacking a “d” sound, the journey continues along the branch stemming from the right side of the root node. Here, an internal node evaluates the presence of one or more “g” sounds. Those samples with at least one “g” are also categorized as post-evolution Pokémon, constituting 16% of the total samples and classified with 65% accuracy. Further traversal to the right of the “g” node unveils additional internal nodes, including those scrutinizing the presence of long vowels (Long), each contributing to the classification process.

The color scheme within the decision tree illustrates the dominant classification passing through each node and the final decision in the leaf nodes. Darker blues signify a predominant distribution towards the post-evolution category, while greener shades indicate a tilt towards the pre-evolution category. The darker the color, the more accurate the classification.

We can evaluate the classification accuracy of the decision tree using the predict function, which applies the trained model to the data and generates predicted class labels. By comparing these predicted labels to the actual labels in the data set, we can assess the accuracy of the model’s predictions.

predictions <- predict(Jtree, Jpoke, type = "class")
accuracy <- mean(predictions == Jpoke$Evolution)
cat("Model Accuracy: ", round(accuracy * 100), "%", sep="")
## Model Accuracy: 65%

Prediction accuracy might be thought of as a measure of effect size; however, individual decision trees should not be used for hypothesis testing. Decision trees are prone to overfitting and are susceptible to outliers. Random forests were developed to resolve these issues.

Random Forests

In this section, we construct random forests (Breiman, 2001), powerful ensemble learning techniques that are built upon the foundation of decision trees. Random forests avoid the issue of overfitting in decision trees by constructing many randomized decision trees. Each tree is made of a random selection of samples (typically around 63.2% in Bootstrap Aggregating; Breiman, 1996) and features (the default for classification is the square root of the number of features in the Random Subspace Method: Ho, 1998). By randomizing both dimensions and then combining decision trees, random forests are able to reduce the impact of outlier samples.

Another way that random forests address the issue of overfitting is by withholding a subset of the data from the training process. This data is then used to test the classification accuracy of the model with the same predict function we used for the decision trees. In the following code, we install the required packages, set the randomization seed so that our results are replicable, split the data into training and testing subsets, and construct an untuned random forest.

Take note that we also modify the data because the tuneRanger package has stricter limitations on symbols that can be used for column names.

Jpoke <- Jpoke[, !(names(Jpoke) %in% c("Designation", "Name"))]
colnames(Jpoke)[4] <- "u"
colnames(Jpoke)[7] <- "f"
colnames(Jpoke)[14] <- "G"
colnames(Jpoke)[21] <- "S"
colnames(Jpoke)[27] <- "R"

library(ranger)
library(tuneRanger)

set.seed(1)
indices <- sample(1:nrow(Jpoke), size = 0.7 * nrow(Jpoke))
train_data <- Jpoke[indices, ]
test_data <- Jpoke[-indices, ]

Untuned_Forest <- ranger(Evolution ~ ., 
                         num.trees = 100,
                         data = train_data)
print(Untuned_Forest)
## Ranger result
## 
## Call:
##  ranger(Evolution ~ ., num.trees = 100, data = train_data) 
## 
## Type:                             Classification 
## Number of trees:                  100 
## Sample size:                      1116 
## Number of independent variables:  31 
## Mtry:                             5 
## Target node size:                 1 
## Variable importance mode:         none 
## Splitrule:                        gini 
## OOB prediction error:             36.02 %

This is the result of training a random forest model constructed with the default values for hyperparameters on the training data. The model consists of 100 trees. The “Mtry” parameter, set to 5, controls the number of randomly selected variables considered at each split. The target node size is set to 1, meaning each terminal node in the trees will have at least one observation. The split rule used is Gini impurity, a measure of how often a randomly chosen element would be incorrectly labeled if it were randomly labeled according to the distribution of labels in the node. The out-of-bag (OOB) prediction error provides an estimate of the model’s accuracy without using a separate testing data set. It suggests that the model should achieve ~63% accuracy. To validate this estimation, we apply the predict function to label the holdout data based on the model’s decisions. Subsequently, we analyze the difference between the predicted labels and the actual designations in the test set.

predictions <- predict(Untuned_Forest, data = test_data)$predictions
accuracy <- mean(predictions == test_data$Evolution)
cat("Model Accuracy: ", round(accuracy * 100, 2), "%", sep="")
## Model Accuracy: 58.66%

Tuning Hyperparameters

Hyperparameters play a crucial role in algorithms like random forests, influencing their behavior and performance. In random forests, hyperparameters control aspects such as the number of trees in the forest, the number of features considered at each split, and the criteria for splitting nodes. Tuning these hyperparameters is essential for optimizing the model’s accuracy and robustness. For the Random Forest model, we will use the tuneRanger function to determine the optimal values for each hyperparameter.

Tune_Output<-makeClassifTask(data=train_data, target="Evolution")
Tune_Output<-tuneRanger(Tune_Output, measure=list(multiclass.brier), num.trees=100)
## Computing y column(s) for design. Not provided.
## [mbo] 0: mtry=7; min.node.size=16; sample.fraction=0.769 : y = 0.441 : 0.5 secs : initdesign
## [mbo] 0: mtry=29; min.node.size=84; sample.fraction=0.46 : y = 0.449 : 0.5 secs : initdesign
## [mbo] 0: mtry=25; min.node.size=7; sample.fraction=0.809 : y = 0.473 : 0.6 secs : initdesign
## [mbo] 0: mtry=10; min.node.size=2; sample.fraction=0.393 : y = 0.438 : 0.5 secs : initdesign
## [mbo] 0: mtry=27; min.node.size=3; sample.fraction=0.494 : y = 0.443 : 0.6 secs : initdesign
## [mbo] 0: mtry=5; min.node.size=33; sample.fraction=0.53 : y = 0.441 : 0.4 secs : initdesign
## [mbo] 0: mtry=18; min.node.size=27; sample.fraction=0.678 : y = 0.441 : 0.4 secs : initdesign
## [mbo] 0: mtry=12; min.node.size=10; sample.fraction=0.582 : y = 0.437 : 0.5 secs : initdesign
## [mbo] 0: mtry=15; min.node.size=5; sample.fraction=0.699 : y = 0.452 : 0.5 secs : initdesign
## [mbo] 0: mtry=18; min.node.size=58; sample.fraction=0.37 : y = 0.446 : 0.4 secs : initdesign
## [mbo] 0: mtry=17; min.node.size=11; sample.fraction=0.42 : y = 0.435 : 0.4 secs : initdesign
## [mbo] 0: mtry=4; min.node.size=68; sample.fraction=0.747 : y = 0.448 : 0.4 secs : initdesign
## [mbo] 0: mtry=20; min.node.size=105; sample.fraction=0.616 : y = 0.453 : 0.5 secs : initdesign
## [mbo] 0: mtry=22; min.node.size=3; sample.fraction=0.792 : y = 0.491 : 0.6 secs : initdesign
## [mbo] 0: mtry=26; min.node.size=13; sample.fraction=0.864 : y = 0.475 : 0.5 secs : initdesign
## [mbo] 0: mtry=6; min.node.size=20; sample.fraction=0.317 : y = 0.439 : 0.4 secs : initdesign
## [mbo] 0: mtry=2; min.node.size=7; sample.fraction=0.643 : y = 0.452 : 0.4 secs : initdesign
## [mbo] 0: mtry=30; min.node.size=155; sample.fraction=0.721 : y = 0.458 : 0.4 secs : initdesign
## [mbo] 0: mtry=24; min.node.size=8; sample.fraction=0.341 : y = 0.433 : 0.4 secs : initdesign
## [mbo] 0: mtry=15; min.node.size=160; sample.fraction=0.232 : y = 0.469 : 0.4 secs : initdesign
## [mbo] 0: mtry=13; min.node.size=43; sample.fraction=0.845 : y = 0.442 : 0.5 secs : initdesign
## [mbo] 0: mtry=12; min.node.size=4; sample.fraction=0.44 : y = 0.439 : 0.5 secs : initdesign
## [mbo] 0: mtry=3; min.node.size=220; sample.fraction=0.223 : y = 0.486 : 0.4 secs : initdesign
## [mbo] 0: mtry=22; min.node.size=25; sample.fraction=0.552 : y = 0.435 : 0.5 secs : initdesign
## [mbo] 0: mtry=21; min.node.size=2; sample.fraction=0.296 : y = 0.443 : 0.5 secs : initdesign
## [mbo] 0: mtry=28; min.node.size=2; sample.fraction=0.649 : y = 0.467 : 0.6 secs : initdesign
## [mbo] 0: mtry=11; min.node.size=124; sample.fraction=0.508 : y = 0.454 : 0.4 secs : initdesign
## [mbo] 0: mtry=1; min.node.size=3; sample.fraction=0.292 : y = 0.474 : 0.4 secs : initdesign
## [mbo] 0: mtry=8; min.node.size=5; sample.fraction=0.885 : y = 0.462 : 0.5 secs : initdesign
## [mbo] 0: mtry=31; min.node.size=45; sample.fraction=0.256 : y = 0.442 : 0.4 secs : initdesign
## [mbo] 1: mtry=31; min.node.size=8; sample.fraction=0.352 : y = 0.436 : 0.5 secs : infill_cb
## [mbo] 2: mtry=11; min.node.size=20; sample.fraction=0.485 : y = 0.438 : 0.4 secs : infill_cb
## [mbo] 3: mtry=26; min.node.size=18; sample.fraction=0.354 : y = 0.433 : 0.5 secs : infill_cb
## [mbo] 4: mtry=25; min.node.size=16; sample.fraction=0.358 : y = 0.434 : 0.5 secs : infill_cb
## [mbo] 5: mtry=20; min.node.size=21; sample.fraction=0.294 : y = 0.434 : 0.4 secs : infill_cb
## [mbo] 6: mtry=15; min.node.size=23; sample.fraction=0.331 : y = 0.436 : 0.4 secs : infill_cb
## [mbo] 7: mtry=26; min.node.size=11; sample.fraction=0.315 : y = 0.434 : 0.5 secs : infill_cb
## [mbo] 8: mtry=26; min.node.size=8; sample.fraction=0.468 : y = 0.44 : 0.5 secs : infill_cb
## [mbo] 9: mtry=24; min.node.size=17; sample.fraction=0.341 : y = 0.431 : 0.4 secs : infill_cb
## [mbo] 10: mtry=23; min.node.size=19; sample.fraction=0.351 : y = 0.436 : 0.6 secs : infill_cb
## [mbo] 11: mtry=10; min.node.size=23; sample.fraction=0.585 : y = 0.434 : 0.4 secs : infill_cb
## [mbo] 12: mtry=25; min.node.size=15; sample.fraction=0.319 : y = 0.437 : 0.5 secs : infill_cb
## [mbo] 13: mtry=24; min.node.size=18; sample.fraction=0.362 : y = 0.434 : 0.5 secs : infill_cb
## [mbo] 14: mtry=10; min.node.size=20; sample.fraction=0.538 : y = 0.434 : 0.4 secs : infill_cb
## [mbo] 15: mtry=27; min.node.size=20; sample.fraction=0.359 : y = 0.439 : 0.4 secs : infill_cb
## [mbo] 16: mtry=21; min.node.size=11; sample.fraction=0.344 : y = 0.431 : 0.5 secs : infill_cb
## [mbo] 17: mtry=21; min.node.size=10; sample.fraction=0.348 : y = 0.437 : 0.4 secs : infill_cb
## [mbo] 18: mtry=13; min.node.size=14; sample.fraction=0.385 : y = 0.434 : 0.4 secs : infill_cb
## [mbo] 19: mtry=11; min.node.size=20; sample.fraction=0.563 : y = 0.436 : 0.4 secs : infill_cb
## [mbo] 20: mtry=23; min.node.size=11; sample.fraction=0.34 : y = 0.435 : 0.4 secs : infill_cb
## [mbo] 21: mtry=22; min.node.size=16; sample.fraction=0.344 : y = 0.44 : 0.4 secs : infill_cb
## [mbo] 22: mtry=15; min.node.size=10; sample.fraction=0.396 : y = 0.433 : 0.5 secs : infill_cb
## [mbo] 23: mtry=15; min.node.size=9; sample.fraction=0.361 : y = 0.437 : 0.4 secs : infill_cb
## [mbo] 24: mtry=13; min.node.size=16; sample.fraction=0.414 : y = 0.436 : 0.5 secs : infill_cb
## [mbo] 25: mtry=31; min.node.size=5; sample.fraction=0.205 : y = 0.439 : 0.4 secs : infill_cb
## [mbo] 26: mtry=9; min.node.size=26; sample.fraction=0.612 : y = 0.433 : 0.5 secs : infill_cb
## [mbo] 27: mtry=9; min.node.size=30; sample.fraction=0.644 : y = 0.436 : 0.4 secs : infill_cb
## [mbo] 28: mtry=9; min.node.size=25; sample.fraction=0.565 : y = 0.437 : 0.5 secs : infill_cb
## [mbo] 29: mtry=24; min.node.size=11; sample.fraction=0.335 : y = 0.435 : 0.5 secs : infill_cb
## [mbo] 30: mtry=25; min.node.size=8; sample.fraction=0.333 : y = 0.433 : 0.4 secs : infill_cb
## [mbo] 31: mtry=15; min.node.size=9; sample.fraction=0.386 : y = 0.433 : 0.4 secs : infill_cb
## [mbo] 32: mtry=14; min.node.size=9; sample.fraction=0.391 : y = 0.438 : 0.4 secs : infill_cb
## [mbo] 33: mtry=25; min.node.size=10; sample.fraction=0.347 : y = 0.435 : 0.5 secs : infill_cb
## [mbo] 34: mtry=25; min.node.size=7; sample.fraction=0.33 : y = 0.433 : 0.5 secs : infill_cb
## [mbo] 35: mtry=25; min.node.size=7; sample.fraction=0.346 : y = 0.436 : 0.5 secs : infill_cb
## [mbo] 36: mtry=14; min.node.size=22; sample.fraction=0.54 : y = 0.435 : 0.4 secs : infill_cb
## [mbo] 37: mtry=10; min.node.size=23; sample.fraction=0.593 : y = 0.432 : 0.4 secs : infill_cb
## [mbo] 38: mtry=25; min.node.size=10; sample.fraction=0.311 : y = 0.433 : 0.5 secs : infill_cb
## [mbo] 39: mtry=25; min.node.size=10; sample.fraction=0.316 : y = 0.436 : 0.4 secs : infill_cb
## [mbo] 40: mtry=24; min.node.size=12; sample.fraction=0.328 : y = 0.438 : 0.5 secs : infill_cb
## [mbo] 41: mtry=11; min.node.size=27; sample.fraction=0.592 : y = 0.438 : 0.4 secs : infill_cb
## [mbo] 42: mtry=17; min.node.size=15; sample.fraction=0.367 : y = 0.433 : 0.5 secs : infill_cb
## [mbo] 43: mtry=17; min.node.size=17; sample.fraction=0.38 : y = 0.435 : 0.5 secs : infill_cb
## [mbo] 44: mtry=17; min.node.size=17; sample.fraction=0.375 : y = 0.437 : 0.4 secs : infill_cb
## [mbo] 45: mtry=9; min.node.size=22; sample.fraction=0.577 : y = 0.434 : 0.4 secs : infill_cb
## [mbo] 46: mtry=25; min.node.size=7; sample.fraction=0.344 : y = 0.434 : 0.5 secs : infill_cb
## [mbo] 47: mtry=25; min.node.size=7; sample.fraction=0.337 : y = 0.439 : 0.5 secs : infill_cb
## [mbo] 48: mtry=10; min.node.size=21; sample.fraction=0.543 : y = 0.435 : 0.4 secs : infill_cb
## [mbo] 49: mtry=9; min.node.size=19; sample.fraction=0.585 : y = 0.434 : 0.5 secs : infill_cb
## [mbo] 50: mtry=10; min.node.size=19; sample.fraction=0.567 : y = 0.437 : 0.4 secs : infill_cb
## [mbo] 51: mtry=16; min.node.size=11; sample.fraction=0.396 : y = 0.434 : 0.5 secs : infill_cb
## [mbo] 52: mtry=15; min.node.size=12; sample.fraction=0.394 : y = 0.433 : 0.4 secs : infill_cb
## [mbo] 53: mtry=16; min.node.size=11; sample.fraction=0.381 : y = 0.431 : 0.5 secs : infill_cb
## [mbo] 54: mtry=16; min.node.size=11; sample.fraction=0.372 : y = 0.437 : 0.4 secs : infill_cb
## [mbo] 55: mtry=16; min.node.size=10; sample.fraction=0.39 : y = 0.438 : 0.4 secs : infill_cb
## [mbo] 56: mtry=15; min.node.size=13; sample.fraction=0.397 : y = 0.433 : 0.4 secs : infill_cb
## [mbo] 57: mtry=15; min.node.size=12; sample.fraction=0.401 : y = 0.435 : 0.5 secs : infill_cb
## [mbo] 58: mtry=15; min.node.size=14; sample.fraction=0.388 : y = 0.434 : 0.4 secs : infill_cb
## [mbo] 59: mtry=15; min.node.size=13; sample.fraction=0.403 : y = 0.437 : 0.4 secs : infill_cb
## [mbo] 60: mtry=9; min.node.size=22; sample.fraction=0.622 : y = 0.43 : 0.4 secs : infill_cb
## [mbo] 61: mtry=9; min.node.size=23; sample.fraction=0.628 : y = 0.439 : 0.4 secs : infill_cb
## [mbo] 62: mtry=9; min.node.size=18; sample.fraction=0.596 : y = 0.435 : 0.4 secs : infill_cb
## [mbo] 63: mtry=14; min.node.size=18; sample.fraction=0.392 : y = 0.434 : 0.4 secs : infill_cb
## [mbo] 64: mtry=13; min.node.size=18; sample.fraction=0.392 : y = 0.434 : 0.5 secs : infill_cb
## [mbo] 65: mtry=15; min.node.size=16; sample.fraction=0.392 : y = 0.433 : 0.5 secs : infill_cb
## [mbo] 66: mtry=15; min.node.size=17; sample.fraction=0.394 : y = 0.434 : 0.5 secs : infill_cb
## [mbo] 67: mtry=16; min.node.size=16; sample.fraction=0.382 : y = 0.437 : 0.5 secs : infill_cb
## [mbo] 68: mtry=14; min.node.size=18; sample.fraction=0.401 : y = 0.435 : 0.4 secs : infill_cb
## [mbo] 69: mtry=10; min.node.size=13; sample.fraction=0.39 : y = 0.435 : 0.4 secs : infill_cb
## [mbo] 70: mtry=15; min.node.size=14; sample.fraction=0.389 : y = 0.436 : 0.4 secs : infill_cb
Tune_Output
## Recommended parameter settings: 
##   mtry min.node.size sample.fraction
## 1   16            17       0.4562907
## Results: 
##   multiclass.brier exec.time
## 1         0.431003      0.44

The ranger package tunes hyperparameters using an iterative process, evaluating different combinations of hyperparameters to optimize model performance. mtry refers to the number of features examined by each split (random subspace method), min.node.size is the minimum size of leaf nodes, and sample.fraction is the fraction of samples bootstrapped for each tree (bootstrap aggregating). The output of the tuning process provides us with optimized values for the hyperparameters in the random forest; however, the process is not perfect. Examine the mtry value for example, the default mtry for a classification forest is the square root of the number of features. The tuning process has recommended 16 of the 33 features to be examined at each node because the data primarily consists of null values. This means that the benefits of the randomization process are somewhat limited in this scenario and this increases the risk of overfitting.

Constructing a Tuned Forest

Tuned_Forest <- ranger(Evolution ~ ., 
                       data = train_data, 
                       num.trees = 100, 
                       mtry = 16, 
                       min.node.size = 17, 
                       sample.fraction = 0.4562907)

predictions <- predict(Tuned_Forest, data = test_data)$predictions
accuracy <- mean(predictions == test_data$Evolution)
cat("Model Accuracy: ", round(accuracy * 100, 2), "%", sep="")
## Model Accuracy: 61.17%

The slightly improved accuracy of the tuned forest over the untuned forest is a typical outcome for the random forest algorithm where tuning typically yields an increase of 1-3% depending on the predictive power of the features. Other than increasing the size of the random forests, this is the limit of complexity for the Random Forest algorithm. In the forthcoming section, we will turn our attention to exploring the dataset using the more advanced XGBoost algorithm.

The XGBoost Algorithm

XGBoost and random forest are both powerful machine learning algorithms widely used for classification and regression tasks. However, they employ distinct methodologies, leading to differences in their performance and behavior. XGBoost, short for eXtreme Gradient Boosting, is an ensemble learning technique that sequentially builds a series of weak decision trees with each subsequent tree correcting the errors of its predecessors. In contrast, random forest constructs multiple decision trees independently and combines their predictions through averaging or voting.

The following code installs the xgboost and the caret packages. The caret (Classification And REgression Training) package serves as a versatile toolkit for streamlining machine learning workflows, facilitating cross-validation, hyperparameter tuning, and model performance evaluation.

library(xgboost)
library(caret)

The caret package allows for additional protection against overfitting by having in-built cross validation. The following code tells the algorithm to split the data into 3 k-folded subsets. Somewhat unintuitively, “cv” stands for k-fold. Increasing the number of folds here will reduce the likelihood of overfitting at the cost of increased processing time. Unless your data set is particularly small, there is not much benefit to increasing this value further.

ctrl <- trainControl(method = "cv", number = 3)

Another clever feature of the caret package is that it allows for in-built hyperparameter tuning similar to that provided by the ranger package. Note here that the names of the hyperparameters are different to the random forest algorithm. “nrounds” is the number of trees, “max_depth” is the maximum depth of the trees, “eta” is the learning rate, “gamma” is the minimum loss reduction required to split a node, “colsample_bytree” is the fraction of samples for each tree, and “min_child_weight” is the minimum number of instances needed to be in each node. Any of these hyperparameters can be adjusted and added to.

tuneGrid = expand.grid(nrounds = 100, 
                       max_depth = c(3, 6, 9), 
                       eta = c(0.01, 0.1, 0.3), 
                       gamma = 0, 
                       colsample_bytree = 1, 
                       min_child_weight = 1,
                       subsample = 1)

xgb_model <- train(Evolution ~ ., 
                   data = train_data, 
                   method = "xgbTree", 
                   trControl = ctrl, 
                   tuneGrid = tuneGrid)

print(xgb_model)
## eXtreme Gradient Boosting 
## 
## 1116 samples
##   31 predictor
##    2 classes: 'post-evolution', 'pre-evolution' 
## 
## No pre-processing
## Resampling: Cross-Validated (3 fold) 
## Summary of sample sizes: 744, 744, 744 
## Resampling results across tuning parameters:
## 
##   eta   max_depth  Accuracy   Kappa    
##   0.01  3          0.6424731  0.2804349
##   0.01  6          0.6424731  0.2809073
##   0.01  9          0.6308244  0.2588751
##   0.10  3          0.6559140  0.3096101
##   0.10  6          0.6523297  0.3035915
##   0.10  9          0.6406810  0.2802934
##   0.30  3          0.6621864  0.3232597
##   0.30  6          0.6317204  0.2628943
##   0.30  9          0.6326165  0.2642866
## 
## Tuning parameter 'nrounds' was held constant at a value of 100
## Tuning parameter 'gamma' was held constant at a value of 0
## Tuning
##  parameter 'colsample_bytree' was held constant at a value of 1
## Tuning parameter 'min_child_weight' was held constant at a value of 1
## 
## Tuning parameter 'subsample' was held constant at a value of 1
## Accuracy was used to select the optimal model using the largest value.
## The final values used for the model were nrounds = 100, max_depth = 3, eta = 0.3, gamma = 0, colsample_bytree = 1, min_child_weight = 1 and
##  subsample = 1.
predictions <- predict(xgb_model, newdata = test_data)
accuracy <- mean(predictions == test_data$Evolution)
cat("Model Accuracy: ", round(accuracy * 100, 2), "%", sep="")
## Model Accuracy: 63.47%

The XGBoost algorithm demonstrates an enhanced performance compared to random forests, yielding an accuracy of 63.47% on the test data. In the following section, we will examine which features (sounds) are contributing to the models.

The confusion matrix is a summary of prediction results on a classification problem. The matrix compares the actual target values with those predicted by the machine learning model. The entries in the matrix represent counts of predictions made by the model.

confusion_matrix <- confusionMatrix(predictions, test_data$Evolution)
print(confusion_matrix)
## Confusion Matrix and Statistics
## 
##                 Reference
## Prediction       post-evolution pre-evolution
##   post-evolution            167            75
##   pre-evolution             100           137
##                                           
##                Accuracy : 0.6347          
##                  95% CI : (0.5898, 0.6779)
##     No Information Rate : 0.5574          
##     P-Value [Acc > NIR] : 0.0003594       
##                                           
##                   Kappa : 0.2684          
##                                           
##  Mcnemar's Test P-Value : 0.0696424       
##                                           
##             Sensitivity : 0.6255          
##             Specificity : 0.6462          
##          Pos Pred Value : 0.6901          
##          Neg Pred Value : 0.5781          
##              Prevalence : 0.5574          
##          Detection Rate : 0.3486          
##    Detection Prevalence : 0.5052          
##       Balanced Accuracy : 0.6358          
##                                           
##        'Positive' Class : post-evolution  
## 

Although true/false positive/negative refers to a one-tailed test and ours is a two-tailed test, I am going to use that terminology here because you will likely encounter it if you continue down the machine learning rabbit hole. The top left value (167) represents the true positive results, post-evolution samples correctly classified as post-evolution. The bottom right value (137) represents the true negative results, pre-evolution correctly classified as pre-evolution. The top right (false positive: 75) and bottom left (false negative: 100) values represent those samples that were incorrectly classified where 75 pre-evolution samples were classified as post-evolution and 100 post-evolution samples were classified as pre-evolution.

Accuracy is the summation of true positive and true negative samples divided by all samples. Confidence interval gives the range in which the true accuracy is expected to fall 95% of the time. No information rate (NIR) is the accuracy that could be achieved by always predicting the most frequent class. P-value indicates whether the model’s accuracy is significantly better than the NIR. A p-value < 0.05 suggests that the model performs significantly better than random guessing.

Feature Importance

In this sound symbolism study, feature importance represents the significance of each sound in determining whether a Pokémon is in its pre-evolution or post-evolution stage. The following analysis quantifies how important each sound is in the classification process.

feature_importance <- xgb.importance(model = xgb_model$finalModel)
print(feature_importance)
##     Feature        Gain       Cover   Frequency
##  1:       u 0.098108156 0.110605541 0.089075630
##  2:       a 0.086963561 0.107741262 0.095798319
##  3:    Long 0.080575250 0.078347732 0.057142857
##  4:       o 0.080403076 0.100952369 0.087394958
##  5:       N 0.071480849 0.053379373 0.042016807
##  6:       g 0.067627330 0.031370236 0.040336134
##  7:       d 0.052560269 0.022914027 0.020168067
##  8:       R 0.048769177 0.023236380 0.048739496
##  9:       e 0.046513483 0.046861285 0.055462185
## 10:       i 0.043773532 0.061468901 0.057142857
## 11:       k 0.039906745 0.046457293 0.053781513
## 12:       s 0.031753355 0.025576332 0.031932773
## 13:       h 0.027719244 0.030060784 0.035294118
## 14:       b 0.027409617 0.020892540 0.030252101
## 15:       z 0.025665459 0.013912170 0.020168067
## 16:       G 0.022503442 0.022309474 0.015126050
## 17:       n 0.017456357 0.016761287 0.023529412
## 18:       w 0.016310691 0.015155764 0.018487395
## 19:       Q 0.016063958 0.013800828 0.025210084
## 20:       f 0.015967531 0.012772748 0.013445378
## 21:       m 0.013828631 0.015096669 0.021848739
## 22:       S 0.013653332 0.020957315 0.020168067
## 23:       p 0.009942865 0.014253522 0.015126050
## 24:       j 0.009217188 0.014247532 0.011764706
## 25:       C 0.008141792 0.004061228 0.008403361
## 26:      nj 0.007541208 0.009895482 0.013445378
## 27:      gj 0.007151947 0.032742823 0.016806723
## 28:       t 0.006347482 0.020403034 0.018487395
## 29:      kj 0.005075274 0.008014626 0.008403361
## 30:      ts 0.001569197 0.005751443 0.005042017
##     Feature        Gain       Cover   Frequency
feature_importance <- feature_importance[order(-Gain)]
ggplot(data = feature_importance, aes(x = reorder(Feature, Gain), y = Gain)) +
  geom_bar(stat = "identity", fill = "skyblue") +
  labs(title = "Feature Importance",
       x = "Feature",
       y = "Gain") +
  theme(axis.text.x = element_text(angle = 45, hjust = 1))

Gain refers to the improvement in accuracy brought by including the feature in the model. It is a measure of the relative importance of a feature, calculated as the average gain of each feature when it is used in trees within the model. Essentially, gain quantifies the contribution of each feature to the model’s predictive power.

To elaborate, the gain value for a feature indicates how much including that feature in a split helps in reducing the loss function, such as error or impurity, averaged over all the splits in which the feature is used. Higher gain values suggest that the feature is more influential in making accurate predictions, while lower gain values indicate less influence.

For example, if the gain of a particular sound in Pokémon names is high, it means that this sound plays a significant role in correctly classifying Pokémon into their pre- or post-evolution stages. Conversely, features with low gain have less impact on decisions.

Gain is not a percentage increase in accuracy, but rather an abstract measure of a feature’s contribution to the model’s overall performance. It helps identify which features are most important in the classification process, guiding feature selection and interpretation in machine learning models.

Conclusion

In conclusion, this tutorial explored the application of decision tree-based classification algorithms, starting with decision trees and progressing to more advanced techniques such as random forests and XGBoost. Through practical examples and code demonstrations, we learned how these algorithms can effectively handle complex classification tasks, such as determining Pokémon evolution stages based on the sounds that make up their names. By understanding the underlying principles and mechanics of these algorithms, as well as techniques for hyperparameter tuning and feature importance analysis, we can leverage decision tree-based models to achieve high accuracy and robustness in classification tasks across various domains. With their ability to handle non-linear relationships, high-dimensional data, and interactions between features, decision tree-based algorithms stand as powerful tools in the machine learning toolkit, offering valuable insights and predictive capabilities for diverse real-world applications.

Kilpatrick, A. J., Ćwiek, A., & Kawahara, S. (2023). Random forests, sound symbolism and Pokémon evolution. PloS one, 18(1), e0279350.

Breiman, L. (2001). Random forests. Machine learning, 45, 5-32.

Breiman, L. (1996). Bagging predictors. Machine learning, 24, 123-140.

Ho, T. K. (1998). The random subspace method for constructing decision forests. IEEE transactions on pattern analysis and machine intelligence, 20(8), 832-844.