FE581 – 0306

Tree Based Methods

The basics of Decision Tree

A decision tree is a type of machine learning model that is used for both classification and regression tasks. The basic idea behind a decision tree is to create a tree-like model of decisions and their possible consequences.

Here's a diagram that shows an example of a decision tree for predicting whether or not a person will buy a new product:

Male
Female
<25
>=25
Low
High
Gender
Age
Income
Buy
Not buy

In this example, the first decision node is based on the person's gender, which can either be male or female. If the person is male, we then look at their age. If they are under 25, we predict that they will buy the product, but if they are 25 or older, we predict that they will not buy it.

If the person is female, we look at their income level instead. If their income is low, we predict that they will buy the product, but if their income is high, we predict that they will not buy it.

Decision trees can be a very powerful tool for making predictions based on a set of features or inputs. However, they can also be prone to overfitting and can be difficult to interpret in more complex scenarios.

Let's start by discussing decision trees for regression problems. In a regression problem, the target variable is continuous, so we want to predict a numerical value for each data point. The basic idea behind decision trees for regression is the same as for classification: we partition the predictor space into simple regions based on the features, and we use these regions to make predictions for new data points.

The difference is that in regression trees, we use the average of the response values in each region as the predicted value for new data points. More specifically, each leaf node in the tree contains a predicted value that is the average of the response values of the training data points that fall into that region. To make a prediction for a new data point, we simply traverse the tree from the root node to a leaf node based on the values of the features, and return the predicted value associated with that leaf node.

Here's an example of a decision tree for a regression problem:

Yes
No
Yes
No
Yes
No
Is X < 0.5?
Is Y < 1.0?
Is Y < 0.5?
0.9
1.1
0.7
1.3

To build a regression tree, we use the same basic algorithm as for classification trees, but we choose the splitting feature and the splitting value based on the reduction in the mean squared error (MSE) of the response values in the two subsets created by the split. The splitting feature and value that produce the largest reduction in MSE are chosen as the splitting criteria.

Regression Tree (Predicting Baseball Players’ Salaries)

In order to motivate regression trees, we begin with a simple example. Predicting Baseball Players’ Salaries Using Regression Trees. We use the Hitters data set to predict a baseball player’s Salary based on Years (the number of years that he has played in the major leagues) and Hits (the number of hits that he made in the previous year).

The Hitters dataset is a classic dataset in statistics and is often used to illustrate the concepts of regression trees. Let's dive into this example.

Exploring the data

First, let's load the dataset and take a look at its structure:

 AtBatHitsHmRunRunsRBIWalksYearsCAtBatCHitsCHmRunCRunsCRBICWalksLeagueDivisionPutOutsAssistsErrorsSalaryNewLeague
-Andy Allanson2936613029141293661302914AE4463320NAA
-Alan Ashby31581724383914344983569321414375NW6324310475N
-Alvin Davis479130186672763162445763224266263AW8808214480A
-Andre Dawson496141206578371156281575225828838354NE200113500N
-Andres Galarraga3218710394230239610112484633NE80540491.5N
-Alfredo Griffin5941694745135114408113319501336194AW28242125750A

As we can see, the dataset contains information about several baseball players, including their Years of experience in the major leagues, the number of Hits they made in the previous year, and their Salary.

We can start by visualizing the relationship between Salary and Years and Salary and Hits:

image-20230318214518006

image-20230318214531433

These scatterplots show that there is some relationship between Salary and both Years and Hits, although the relationship is not linear.

To build a regression tree to predict Salary based on Years and Hits, we can use the rpart package in R:

image-20230318214837756

image-20230318214850106

This will create a regression tree using Years and Hits as predictors and Salary as the response variable.

We can visualize the resulting tree using the rpart.plot package:

 

image-20230318220218689

This will give us a visualization of the regression tree:

The output also shows the structure of the tree itself. Each node is labeled with a number (in parentheses), which corresponds to its index in the tree. The root node is labeled as (1), and subsequent nodes are labeled in the order in which they were created.

Each node is split into two child nodes, based on a splitting criterion that is chosen to minimize the deviance of the response variable in the resulting subsets. The splitting criterion is typically based on the values of one or more predictor variables. In this example, the splitting criterion is based on the values of the Years and Hits variables.

For example, the root node (1) has a deviance of 207.1537 and a predicted value of 5.927222. It is split into two child nodes, based on whether Years is less than 4.5 or not. If Years is less than 4.5, we move down to node (2), which has a deviance of 42.35317 and a predicted value of 5.10679. If Years is greater than or equal to 4.5, we move down to node (3), which has a deviance of 72.70531 and a predicted value of 6.354036.

Node (2) is further split into two child nodes, based on whether Years is less than 3.5 or not. If Years is less than 3.5, we move down to node (4), which has a deviance of 23.00867 and a predicted value of 4.891812. If Years is greater than or equal to 3.5, we move down to node (5), which has a deviance of 10.13439 and a predicted value of 5.582812.

And so on. The tree continues to split until the deviance of each node falls below a certain threshold or until some other stopping criterion is met.

To make predictions for new data points using the tree, we simply start at the root node and follow the splits based on the values of Years and Hits for the new data point, until we reach a leaf node. The predicted Salary is then the mean Salary of the training data points that fall into the leaf node.

A larger Tree

In the above code, we used default parameters to stop the splitting based on the default rpart parameters. we can manually change those parameters and get more different trees:

The rpart() function in R allows you to specify two important parameters when building a regression tree: minsplit and minbucket.

In this code, we have set minsplit=2 and minbucket=1. This means that the tree will not attempt to split a node if it contains fewer than 2 observations, and it will create terminal nodes with only one observation. This can result in a very complex and overfitted model that may not generalize well to new data.

Each split in the tree partitions the data into two or more groups based on the values of the predictor variables. At each split, the model chooses the variable and value that results in the greatest reduction in the RSS.

In this specific tree, the first split is based on the Years variable being less than 4.5 or greater than or equal to 4.5. If a player has fewer than 4.5 years of experience, the model then splits based on the Hits variable being less than 15.5 or greater than or equal to 15.5. If a player has 4.5 or more years of experience, the model then splits based on the Hits variable being less than 117.5 or greater than or equal to 117.5.

The tree continues to split until the minsplit or minbucket criteria are met. The minsplit argument specifies the minimum number of observations required to attempt a split, so if a node has fewer than 2 observations, it cannot be split any further. The minbucket argument specifies the minimum number of observations required in a terminal node, so if a split would result in a terminal node with fewer than 1 observation, the split is not performed and the node becomes a terminal node.

In this specific tree, several nodes terminate due to the minbucket argument. For example, node 28 terminates because it has only one observation and the minbucket argument is set to 1. In other cases, nodes terminate because further splits do not result in a significant reduction in the RSS. For example, node 12 terminates because splitting on Years or Hits does not result in a significant reduction (CP) in the RSS.

image-20230318222245974

More complex model

 

image-20230318223255380

Pruning the tree

One important issue in building regression trees is determining when to stop splitting the data. If we keep splitting the data until each leaf node contains only one training data point, we will likely overfit the model to the training data, which can lead to poor performance on new data. On the other hand, if we stop splitting too early, we may not capture all of the important patterns in the data.

One approach to dealing with this issue is to use a pruning algorithm to trim the tree after it has been built. This involves removing nodes from the tree that do not improve the overall predictive accuracy of the tree on a validation dataset.

The rpart package in R includes a built-in pruning algorithm that uses cross-validation to determine the optimal level of pruning.

To apply pruning to the tree we built earlier, we can use the following code:

image-20230318224647347

FIELD1CPnsplitrel errorxerrorxstd
10.246749964834682011.007946010072580.139013921546954
20.18990577143923710.7532500351653180.8721499760024640.135568787856714
30.068574546364445320.5633442637260810.6109871573503810.104981342598344
40.023140326867778830.4947697173616360.6495047067615460.116569074229425
50.017734884994836350.4484890636260780.736400799451810.123232838794308
60.014280903458556770.4130192936364050.7892050542110050.128852122709316
70.013271405248779680.3987383901778490.7749255278413040.121194566352022
80.0116254352676553140.3170993534851330.8250334090224320.128956811587054
90.0111603729867776150.3054739182174780.8075609471691340.120133060593691
100.0099636776106262170.2831531722439220.8402426227898190.123937627375919
110.00981717196654987180.2731894946332960.8636704916306610.128549290536426
120.00887129602725455190.2633723226667460.874808670757570.128838854429451
130.00734889647526427210.2456297306122370.9385045307946270.138853341013866
140.00706472225139319250.216234144711180.9748912447173070.142584260606387
150.00695916804300275260.2091694224597870.9748912447173070.142584260606387
160.00694178256429693270.2022102544167840.9773183963552820.142613240190292
170.00675461351781338280.1952684718524870.9684111291419480.142509912292911
180.00641203523636004290.1885138583346740.9684111291419480.142509912292911
190.005300.1821018230983140.9616884051196170.142452846332138

This will plot the complexity parameter table, which shows the cross-validation error as a function of the complexity parameter (cp) of the tree,

and identify the value of cp that gives the minimum cross-validation error. It will then prune the tree based on this value of cp and plot the pruned tree:

 

image-20230318225151242

We might interpret the regression tree :Years is the most important factor in determining Salary, and players withless experience earn lower salaries than more experienced players. Giventhat a player is less experienced, the number of hits that he made in theprevious year seems to play little role in his salary. But among players whohave been in the major leagues for five or more years, the number of hitsmade in the previous year does affect salary, and players who made morehits last year tend to have higher salaries.

Classification Tree (With Heart Data)

Intro

A classification tree is used for predicting qualitative or categorical responses, while a regression tree is used for predicting quantitative responses.

In a classification tree, the predicted class for a new observation is determined by the most commonly occurring class of the training observations that fall into the same terminal node as the new observation. For example, let's say we are trying to predict whether a customer will make a purchase (yes or no) based on their age and income. A classification tree might split the data based on age, and then further split each age group based on income. The final terminal nodes might have different proportions of customers who made a purchase, and we would use the most common class in each terminal node to predict whether a new customer with those characteristics would make a purchase.

In interpreting the results of a classification tree, we are interested not only in the predicted class for a particular observation, but also in the class proportions among the training observations that fall into the same terminal node. This can give us a sense of how well the tree is able to distinguish between different classes, and can help us identify areas where the tree may be overfitting or underfitting the data.

To illustrate this, let's say we have a dataset with two classes (Class A and Class B), and we build a classification tree using some predictor variables. The tree might have several terminal nodes, each corresponding to a different region of the predictor space. Within each terminal node, we can calculate the proportion of training observations that belong to Class A and Class B. This information can help us understand the relative importance of different predictors in distinguishing between the two classes, and can also help us identify any regions where the tree may be making errors or misclassifying observations.

Here's an example of a simple classification tree with two terminal nodes, each corresponding to a different region of the predictor space:

Age <= 30
Income <= 50K
Income > 50K
Class A: 70%
Class B: 30%
Class A: 40%
Class B: 60%

In this example, the tree splits the data based on age, and then further splits each age group based on income. The terminal node B corresponds to customers who are younger than 30 and have an income of 50K or less. Among the training observations in this node, 70% belong to Class A and 30% belong to Class B. The terminal node C corresponds to customers who are younger than 30 and have an income greater than 50K. Among the training observations in this node, 40% belong to Class A and 60% belong to Class B.

[Impurity] Evaluating the quality of each potential split

When building a classification tree, we need to decide which predictor variable to split on at each node, and we do this by evaluating the quality of each potential split. There are different measures of quality that we can use, but two common ones are the Gini index and the entropy.

The Gini index measures the impurity of a node, where a node is considered pure if all of the training observations in the node belong to the same class. The Gini index takes values between 0 and 1, with 0 indicating a pure node (all observations belong to the same class) and 1 indicating an impure node (observations are evenly split between different classes). When evaluating a potential split, we calculate the weighted sum of the Gini indices for the child nodes, where the weights correspond to the proportion of observations in each child node.

The entropy is another measure of impurity that is often used in decision trees. It is based on information theory and measures the amount of uncertainty or randomness in a node. Again, a node is considered pure if all of the training observations in the node belong to the same class, and the entropy takes values between 0 and 1, with 0 indicating a pure node and 1 indicating an impure node. When evaluating a potential split, we calculate the weighted sum of the entropies for the child nodes.

Both the Gini index and entropy are more sensitive to node purity than the classification error rate, which simply measures the proportion of misclassified observations in a node. This is because the Gini index and entropy take into account the distribution of observations among the different classes, and can help us identify splits that lead to more homogeneous child nodes in terms of class membership.

Classification of Heart Disease

The Heart dataset is a binary classification problem where the goal is to predict the presence or absence of heart disease based on various patient characteristics.

To get started, we first need to load the dataset into R. The Heart dataset is included in the ISLR package, so we can load it using the following code:

This will load the dataset into the variable Heart. We can take a quick look at the dataset using the head() function:

FIELD1XAgeSexChestPainRestBPCholFbsRestECGMaxHRExAngOldpeakSlopeCaThalAHD
11631typical1452331215002.330fixedNo
22671asymptomatic1602860210811.523normalYes
33671asymptomatic1202290212912.622reversableYes
44371nonanginal1302500018703.530normalNo
55410nontypical1302040217201.410normalNo
66561nontypical1202360017800.810normalNo

This will show the first few rows of the dataset, which should give us an idea of what variables are included and what the data looks like.

When building a classification tree, it is important to ensure that any categorical variables in the dataset are treated as factors, so that the tree knows to treat them as discrete categories rather than continuous variables.

In the case of the Heart dataset, there are several variables that are categorical, such as ChestPain, Thal, Exang, and Sex. These variables should be converted to factors before building the tree.

Here's an example code that converts the categorical variables to factors:

In this code, we use the as.factor() function to convert each categorical variable to a factor.

The AHD variable in the Heart dataset is a binary outcome variable indicating the presence or absence of heart disease, and it is coded as a character variable with the values "Yes" and "No". In order to build a classification tree, we need to convert this variable to a binary factor variable with the values 0 and 1 (or equivalently, "No" and "Yes").

Here's an example code that converts the AHD variable to a binary factor variable:

In this code, we use the ifelse() function to replace the "Yes" values in the AHD variable with 1, and the "No" values with 0. We then use the as.factor() function to convert the resulting variable to a binary factor variable.

Now the new data becomes:

...1AgeSexChestPainRestBPCholFbsRestECGMaxHRExAngOldpeakSlopeCaThalAHD
1631typical1452331215002.330fixed0
2671asymptomatic1602860210811.523normal1
3671asymptomatic1202290212912.622reversable1
4371nonanginal1302500018703.530normal0
5410nontypical1302040217201.410normal0
6561nontypical1202360017800.810normal0

Once we have converted the AHD variable to a binary factor variable, we can build the classification tree using the code:

We fit the classification tree using the rpart() function, specifying method = "class" to indicate that we are building a classification tree. Since we do not specify any other arguments, the function will use the default values for all other parameters.

Once the tree has been built, we can visualize it using the rpart.plot() function. The type = 4 argument specifies that we want to use a compact rectangular layout for the tree, and the extra = 101 argument specifies that we want to include the percentage of observations in each terminal node. The roundint = FALSE argument specifies that we do not want to round the node percentages to the nearest integer, and the nn = TRUE argument specifies that we want to include the number of observations in each node. Finally, the main argument specifies the title of the plot.

image-20230319133924958

 

This is the output of fit.default, which is an object of class rpart. The output shows the structure of the classification tree, including the splits and the terminal nodes.

The first line shows the root node of the tree, which contains all 303 observations. The second line shows the split criterion for the root node, which is the variable Thal. The third line shows the number of observations that belong to the root node, the number of misclassifications (loss), and the proportion of observations in each class (yval and yprob).

The fourth line shows the first child node of the root, which corresponds to the normal value of the Thal variable. This node contains 167 observations. The fifth line shows the split criterion for this node, which is the ChestPain variable. The sixth line shows the two child nodes of this node, which correspond to the two values of the ChestPain variable (nonanginal,nontypical and asymptomatic,typical).

The seventh line shows the first terminal node of the tree, which corresponds to the nonanginal,nontypical value of the ChestPain variable for the normal value of the Thal variable. This node contains 100 observations, all of which belong to the 0 class (no heart disease).

The eighth line shows the second child node of the normal value of the Thal variable, which corresponds to the asymptomatic,typical value of the ChestPain variable. This node contains 67 observations. The ninth line shows the split criterion for this node, which is the Ca variable. The tenth and eleventh lines show the two child nodes of this node, which correspond to the two possible values of the Ca variable (< 0.5 and >=0.5).

 

image-20230319134530236

printcp() is a function in the rpart package that prints the cross-validation results for a fitted rpart object. The output shows the complexity parameter (CP) used for pruning the tree, as well as the number of splits (nsplit), the relative error (rel error), the cross-validation error (xerror), and the standard deviation of the cross-validation error (xstd) for each value of the complexity parameter.

In this case, we applied printcp() to the fit.default object, which is a fitted rpart object obtained using the default parameters for the rpart() function.

The first row of the output shows the results for the root node, which has a relative error of 0.45875, meaning that 45.875% of the observations in the dataset are misclassified by the root node.

The second row shows the results for the first split, which uses the Thal variable. The optimal value for the complexity parameter is 0.474820, which corresponds to the root node. The relative error decreases to 0.52518 after the first split, meaning that the split improves the classification accuracy. The cross-validation error (xerror) is 0.61871, which is the average classification error rate across all cross-validation folds. The standard deviation of the cross-validation error (xstd) is 0.062401, which gives an indication of the stability of the cross-validation estimate.

The third row shows the results for the second split, which uses the ChestPain variable. The optimal value for the complexity parameter is 0.046763, which corresponds to a tree with one split. The relative error decreases to 0.33813 after the second split, meaning that the second split further improves the classification accuracy. The cross-validation error (xerror) is 0.43165, which is lower than the error rate for the root node. The standard deviation of the cross-validation error (xstd) is 0.049905.

The fourth row shows the results for the third split, which uses the Ca variable. The optimal value for the complexity parameter is 0.010791, which corresponds to a tree with four splits. The relative error decreases to 0.31655 after the third split. The cross-validation error (xerror) is 0.41727, which is lower than the error rate for the second split. The standard deviation of the cross-validation error (xstd) is 0.049268.

Based on these results, we can see that adding more splits to the tree decreases the relative error, but may increase the cross-validation error due to overfitting. Therefore, the optimal tree size may depend on the trade-off between classification accuracy and model complexity.

 CPnsplitrelerrorxerrorxstd
10.47482001.000001.000000.062401 
20.04676310.525180.618710.056460 
30.01079150.338130.431650.049905 
40.01000070.316550.417270.049268 

The code fit.default$cptable returns a table with information about the cross-validation errors for each value of the complexity parameter. The which.min() function is used to identify the row in the table with the minimum cross-validation error. The fit.default$cptable[,"xerror"] argument extracts the column containing the cross-validation errors.

The bestcp variable stores the value of the complexity parameter corresponding to the minimum cross-validation error. In this case, the value is 0.01, which corresponds to a tree with four splits. This value is obtained from the CP column of the row with the minimum cross-validation error in the fit.default$cptable table.

A larger classification tree - heart

This code fits a larger classification tree to the Heart dataset using the rpart() function with the following arguments:

The rpart.plot() function is used to plot the resulting tree. The type=4 argument specifies that the tree should be plotted with the split criterion shown on each split. The extra=101 argument specifies that the node labels should be printed with 101% of their original size to make them easier to read. The roundint = FALSE argument specifies that the split criterion should not be rounded to integers. The nn=TRUE argument specifies that the node labels should include the number of observations in each node. The main argument specifies the main title for the plot.

By specifying minsplit=2 and minbucket=1, we are allowing the tree to be more complex than the default tree, which may lead to overfitting.

image-20230319135949057

image-20230319140206899

The printcp() function prints the cross-validation results for the fit.larger object, which is a fitted rpart object obtained using the larger tree with minsplit=2, minbucket=1, and cp=0.001.

The first row of the output shows the results for the root node, which has a relative error of 0.45875, meaning that 45.875% of the observations in the dataset are misclassified by the root node.

The second row shows the results for the first split, which uses the Thal variable. The optimal value for the complexity parameter is 0.4748201, which corresponds to the root node. The relative error decreases to 0.5251799 after the first split, meaning that the split improves the classification accuracy. The cross-validation error (xerror) is 0.60432, which is higher than the error rate for the default tree.

The third to tenth rows show the results for the subsequent splits. As the tree becomes more complex, the relative error decreases, but the cross-validation error increases, indicating overfitting.

The plotcp() function can be used to visualize the cross-validation error as a function of the complexity parameter. This can help identify the optimal value for the complexity parameter and the corresponding size of the tree.

Finding the optimal CP

This code extracts the complexity parameter value that corresponds to the minimum cross-validation error from the cptable attribute of a fitted rpart object.

The which.min(fit.larger$cptable[,"xerror"]) code identifies the row in the cptable table with the smallest cross-validation error. The fit.larger$cptable[,"xerror"] expression returns a vector of cross-validation errors for each value of the complexity parameter, and which.min() finds the index of the minimum value.

The fit.larger$cptable[which.min(fit.larger$cptable[,"xerror"]),"CP"] code extracts the complexity parameter value from the CP column of the row with the minimum cross-validation error.

Fit classification tree
Calculate cross-validation errors for each value of cp
Identify the row with the minimum cross-validation error
Extract the optimal value of cp
Prune the tree using the optimal cp

To prune the tree using the optimal value of the complexity parameter, we can use the prune() function in R.

Here's the code to prune the fit.larger tree using the optimal value of bestcp:

The prune() function takes two arguments: the fitted rpart object, and the value of the complexity parameter that should be used for pruning the tree. In this case, bestcp is the value of the complexity parameter that minimizes the cross-validation error, and fit.larger is the larger classification tree we previously built.

The resulting pruned.tree object is a smaller version of the original tree, with some of the branches pruned based on the specified value of the complexity parameter. The pruned.tree object can be visualized using the rpart.plot() function, just like the previous trees we built.

image-20230319141206201

 

 

To be continued ? ~ No. Go to 0313 where we implement all those on the datasets provided by Hocam.