FE581 – 0306Tree Based MethodsThe basics of Decision TreeRegression Tree (Predicting Baseball Players’ Salaries)Exploring the dataA larger TreeMore complex modelPruning the treeClassification Tree (With Heart Data)Intro[Impurity] Evaluating the quality of each potential splitClassification of Heart DiseaseA larger classification tree - heartFinding the optimal CP
A decision tree is a type of machine learning model that is used for both classification and regression tasks. The basic idea behind a decision tree is to create a tree-like model of decisions and their possible consequences.
Here's a diagram that shows an example of a decision tree for predicting whether or not a person will buy a new product:
In this example, the first decision node is based on the person's gender, which can either be male or female. If the person is male, we then look at their age. If they are under 25, we predict that they will buy the product, but if they are 25 or older, we predict that they will not buy it.
If the person is female, we look at their income level instead. If their income is low, we predict that they will buy the product, but if their income is high, we predict that they will not buy it.
Decision trees can be a very powerful tool for making predictions based on a set of features or inputs. However, they can also be prone to overfitting and can be difficult to interpret in more complex scenarios.
Let's start by discussing decision trees for regression problems. In a regression problem, the target variable is continuous, so we want to predict a numerical value for each data point. The basic idea behind decision trees for regression is the same as for classification: we partition the predictor space into simple regions based on the features, and we use these regions to make predictions for new data points.
The difference is that in regression trees, we use the average of the response values in each region as the predicted value for new data points. More specifically, each leaf node in the tree contains a predicted value that is the average of the response values of the training data points that fall into that region. To make a prediction for a new data point, we simply traverse the tree from the root node to a leaf node based on the values of the features, and return the predicted value associated with that leaf node.
Here's an example of a decision tree for a regression problem:
To build a regression tree, we use the same basic algorithm as for classification trees, but we choose the splitting feature and the splitting value based on the reduction in the mean squared error (MSE) of the response values in the two subsets created by the split. The splitting feature and value that produce the largest reduction in MSE are chosen as the splitting criteria.
In order to motivate regression trees, we begin with a simple example. Predicting Baseball Players’ Salaries Using Regression Trees. We use the Hitters data set to predict a baseball player’s Salary based on Years (the number of years that he has played in the major leagues) and Hits (the number of hits that he made in the previous year).
The Hitters dataset is a classic dataset in statistics and is often used to illustrate the concepts of regression trees. Let's dive into this example.
First, let's load the dataset and take a look at its structure:
31library(ISLR)
2data(Hitters)
3head(Hitters)
AtBat | Hits | HmRun | Runs | RBI | Walks | Years | CAtBat | CHits | CHmRun | CRuns | CRBI | CWalks | League | Division | PutOuts | Assists | Errors | Salary | NewLeague | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
-Andy Allanson | 293 | 66 | 1 | 30 | 29 | 14 | 1 | 293 | 66 | 1 | 30 | 29 | 14 | A | E | 446 | 33 | 20 | NA | A |
-Alan Ashby | 315 | 81 | 7 | 24 | 38 | 39 | 14 | 3449 | 835 | 69 | 321 | 414 | 375 | N | W | 632 | 43 | 10 | 475 | N |
-Alvin Davis | 479 | 130 | 18 | 66 | 72 | 76 | 3 | 1624 | 457 | 63 | 224 | 266 | 263 | A | W | 880 | 82 | 14 | 480 | A |
-Andre Dawson | 496 | 141 | 20 | 65 | 78 | 37 | 11 | 5628 | 1575 | 225 | 828 | 838 | 354 | N | E | 200 | 11 | 3 | 500 | N |
-Andres Galarraga | 321 | 87 | 10 | 39 | 42 | 30 | 2 | 396 | 101 | 12 | 48 | 46 | 33 | N | E | 805 | 40 | 4 | 91.5 | N |
-Alfredo Griffin | 594 | 169 | 4 | 74 | 51 | 35 | 11 | 4408 | 1133 | 19 | 501 | 336 | 194 | A | W | 282 | 421 | 25 | 750 | A |
As we can see, the dataset contains information about several baseball players, including their Years of experience in the major leagues, the number of Hits they made in the previous year, and their Salary.
We can start by visualizing the relationship between Salary and Years and Salary and Hits:
These scatterplots show that there is some relationship between Salary and both Years and Hits, although the relationship is not linear.
To build a regression tree to predict Salary based on Years and Hits, we can use the rpart
package in R:
31# Remove missing values and log-transform Salary
2Hitters <- na.omit(Hitters)
3Hitters$Salary <- log(Hitters$Salary)
21library(rpart)
2tree <- rpart(Salary ~ Years + Hits, data = Hitters)
191> tree
2n= 263
3
4node), split, n, deviance, yval
5* denotes terminal node
6
71) root 263 207.153700 5.927222
82) Years< 4.5 90 42.353170 5.106790
94) Years< 3.5 62 23.008670 4.891812
108) Hits< 114 43 17.145680 4.727386 *
119) Hits>=114 19 2.069451 5.263932 *
125) Years>=3.5 28 10.134390 5.582812 *
133) Years>=4.5 173 72.705310 6.354036
146) Hits< 117.5 90 28.093710 5.998380
1512) Years< 6.5 26 7.237690 5.688925 *
1613) Years>=6.5 64 17.354710 6.124096
1726) Hits< 50.5 12 2.689439 5.730017 *
1827) Hits>=50.5 52 12.371640 6.215037 *
197) Hits>=117.5 83 20.883070 6.739687 *
This will create a regression tree using Years and Hits as predictors and Salary as the response variable.
We can visualize the resulting tree using the rpart.plot
package:
21library(rpart.plot)
2rpart.plot(tree,type=4,extra=101,roundint = FALSE,nn=TRUE, main = "Regression Tree for Baseball Salaries")
This will give us a visualization of the regression tree:
n
is the number of training data points that fall into the node.
deviance
is a measure of the impurity or heterogeneity of the response values in the node. In a regression tree, deviance is typically measured as the sum of squared errors (SSE) between the predicted values and the actual values in the node.
yval
is the predicted value for the response variable in the node. In a regression tree, this is typically the mean of the actual response values in the node.
The output also shows the structure of the tree itself. Each node is labeled with a number (in parentheses), which corresponds to its index in the tree. The root node is labeled as (1), and subsequent nodes are labeled in the order in which they were created.
Each node is split into two child nodes, based on a splitting criterion that is chosen to minimize the deviance of the response variable in the resulting subsets. The splitting criterion is typically based on the values of one or more predictor variables. In this example, the splitting criterion is based on the values of the Years
and Hits
variables.
For example, the root node (1) has a deviance of 207.1537 and a predicted value of 5.927222. It is split into two child nodes, based on whether Years
is less than 4.5 or not. If Years
is less than 4.5, we move down to node (2), which has a deviance of 42.35317 and a predicted value of 5.10679. If Years
is greater than or equal to 4.5, we move down to node (3), which has a deviance of 72.70531 and a predicted value of 6.354036.
Node (2) is further split into two child nodes, based on whether Years
is less than 3.5 or not. If Years
is less than 3.5, we move down to node (4), which has a deviance of 23.00867 and a predicted value of 4.891812. If Years
is greater than or equal to 3.5, we move down to node (5), which has a deviance of 10.13439 and a predicted value of 5.582812.
And so on. The tree continues to split until the deviance of each node falls below a certain threshold or until some other stopping criterion is met.
To make predictions for new data points using the tree, we simply start at the root node and follow the splits based on the values of Years and Hits for the new data point, until we reach a leaf node. The predicted Salary is then the mean Salary of the training data points that fall into the leaf node.
In the above code, we used default parameters to stop the splitting based on the default rpart parameters. we can manually change those parameters and get more different trees:
21tree2 <- rpart(Salary ~ Years + Hits, data = Hitters, minsplit=2, minbucket=1)
2rpart.plot(tree2,type=4,extra=101,roundint = FALSE,nn=TRUE, main = "Regression Tree for Baseball Salaries (minsplit=2)")
The rpart()
function in R allows you to specify two important parameters when building a regression tree: minsplit
and minbucket
.
minsplit
controls the minimum number of observations that must exist in a node in order for a split to be attempted. A split will not be attempted if the number of observations in the node is less than minsplit
.
minbucket
controls the minimum number of observations that must exist in a terminal node. A terminal node is a node in the tree that cannot be further split, and it represents a prediction for the response variable based on the values of the predictor variables. If a split would result in a terminal node with fewer than minbucket
observations, then the split is not performed and the node becomes a terminal node.
In this code, we have set minsplit=2
and minbucket=1
. This means that the tree will not attempt to split a node if it contains fewer than 2 observations, and it will create terminal nodes with only one observation. This can result in a very complex and overfitted model that may not generalize well to new data.
xxxxxxxxxx
251> tree2
2n= 263
3
4node), split, n, deviance, yval
5* denotes terminal node
6
71) root 263 207.1537000 5.927222
82) Years< 4.5 90 42.3531700 5.106790
94) Hits>=15.5 88 32.6632500 5.058228
108) Years< 3.5 60 11.2277800 4.813422
1116) Hits< 114 41 3.5150580 4.604649 *
1217) Hits>=114 19 2.0694510 5.263932 *
139) Years>=3.5 28 10.1343900 5.582812 *
145) Hits< 15.5 2 0.3513321 7.243499 *
153) Years>=4.5 173 72.7053100 6.354036
166) Hits< 117.5 90 28.0937100 5.998380
1712) Years< 6.5 26 7.2376900 5.688925 *
1813) Years>=6.5 64 17.3547100 6.124096
1926) Hits< 50.5 12 2.6894390 5.730017 *
2027) Hits>=50.5 52 12.3716400 6.215037 *
217) Hits>=117.5 83 20.8830700 6.739687
2214) Hits>=208.5 3 4.9804600 6.231595
2328) Hits< 210.5 1 0.0000000 4.499810 *
2429) Hits>=210.5 2 0.4818414 7.097487 *
2515) Hits< 208.5 80 15.0991000 6.758740 *
node
indicates the node number.
split
indicates the variable and value used to split the node. For example, Years< 4.5
indicates that the node was split based on the Years
variable being less than 4.5.
n
indicates the number of observations in the node.
deviance
indicates the residual sum of squares (RSS) for the node.
yval
indicates the predicted value of the response variable for the node.
*
indicates a terminal node, which cannot be further split.
Each split in the tree partitions the data into two or more groups based on the values of the predictor variables. At each split, the model chooses the variable and value that results in the greatest reduction in the RSS.
In this specific tree, the first split is based on the Years
variable being less than 4.5 or greater than or equal to 4.5. If a player has fewer than 4.5 years of experience, the model then splits based on the Hits
variable being less than 15.5 or greater than or equal to 15.5. If a player has 4.5 or more years of experience, the model then splits based on the Hits
variable being less than 117.5 or greater than or equal to 117.5.
The tree continues to split until the minsplit
or minbucket
criteria are met. The minsplit
argument specifies the minimum number of observations required to attempt a split, so if a node has fewer than 2 observations, it cannot be split any further. The minbucket
argument specifies the minimum number of observations required in a terminal node, so if a split would result in a terminal node with fewer than 1 observation, the split is not performed and the node becomes a terminal node.
In this specific tree, several nodes terminate due to the minbucket
argument. For example, node 28 terminates because it has only one observation and the minbucket
argument is set to 1. In other cases, nodes terminate because further splits do not result in a significant reduction in the RSS. For example, node 12 terminates because splitting on Years
or Hits
does not result in a significant reduction (CP) in the RSS.
xxxxxxxxxx
21tree3 <- rpart(Salary ~ Years + Hits, data = Hitters, minsplit=2, minbucket=1, cp=0.005)
2rpart.plot(tree3,type=4,extra=101,roundint = FALSE,nn=TRUE, main = "Regression Tree for Baseball Salaries (minsplit=2, minbucket=1, cp=0.005)")
xxxxxxxxxx
431> tree3
2n= 263
3
4node), split, n, deviance, yval
5* denotes terminal node
6
71) root 263 207.1537000 5.927222
82) Years< 4.5 90 42.3531700 5.106790
94) Hits>=15.5 88 32.6632500 5.058228
108) Years< 3.5 60 11.2277800 4.813422
1116) Hits< 114 41 3.5150580 4.604649 *
1217) Hits>=114 19 2.0694510 5.263932 *
139) Years>=3.5 28 10.1343900 5.582812
1418) Hits< 106 14 1.7919440 5.315652 *
1519) Hits>=106 14 6.3439520 5.849973
1638) Hits< 149.5 8 4.5064790 5.706729
1776) Hits>=131.5 2 0.2075980 4.976139 *
1877) Hits< 131.5 6 2.8755170 5.950259
19154) Hits< 122.5 5 1.2709500 5.718989 *
20155) Hits>=122.5 1 0.0000000 7.106606 *
2139) Hits>=149.5 6 1.4544520 6.040966 *
225) Hits< 15.5 2 0.3513321 7.243499 *
233) Years>=4.5 173 72.7053100 6.354036
246) Hits< 117.5 90 28.0937100 5.998380
2512) Years< 6.5 26 7.2376900 5.688925
2624) Hits>=116.5 1 0.0000000 4.605170 *
2725) Hits< 116.5 25 6.0161850 5.732275 *
2813) Years>=6.5 64 17.3547100 6.124096
2926) Hits< 50.5 12 2.6894390 5.730017 *
3027) Hits>=50.5 52 12.3716400 6.215037
3154) Hits< 90.5 33 6.4539330 6.111435 *
3255) Hits>=90.5 19 4.9483060 6.394978
33110) Years< 11 9 3.6520620 6.210856
34220) Years>=9.5 1 0.0000000 4.605170 *
35221) Years< 9.5 8 0.7515576 6.411567 *
36111) Years>=11 10 0.7165383 6.560688 *
377) Hits>=117.5 83 20.8830700 6.739687
3814) Hits>=208.5 3 4.9804600 6.231595
3928) Hits< 210.5 1 0.0000000 4.499810 *
4029) Hits>=210.5 2 0.4818414 7.097487 *
4115) Hits< 208.5 80 15.0991000 6.758740
4230) Hits< 185 76 13.3162400 6.727551 *
4331) Hits>=185 4 0.3042794 7.351330 *
One important issue in building regression trees is determining when to stop splitting the data. If we keep splitting the data until each leaf node contains only one training data point, we will likely overfit the model to the training data, which can lead to poor performance on new data. On the other hand, if we stop splitting too early, we may not capture all of the important patterns in the data.
One approach to dealing with this issue is to use a pruning algorithm to trim the tree after it has been built. This involves removing nodes from the tree that do not improve the overall predictive accuracy of the tree on a validation dataset.
The rpart
package in R includes a built-in pruning algorithm that uses cross-validation to determine the optimal level of pruning.
To apply pruning to the tree we built earlier, we can use the following code:
x
1rtree <- rpart(Salary ~ Years + Hits, data = Hitters, minsplit=2, minbucket=1, cp=0.005)
xxxxxxxxxx
11rtree$cptable
xxxxxxxxxx
211> rtree$cptable
2CP nsplit rel error xerror xstd
31 0.246749965 0 1.0000000 1.0079460 0.1390139
42 0.189905771 1 0.7532500 0.8721500 0.1355688
53 0.068574546 2 0.5633443 0.6109872 0.1049813
64 0.023140327 3 0.4947697 0.6495047 0.1165691
75 0.017734885 5 0.4484891 0.7364008 0.1232328
86 0.014280903 7 0.4130193 0.7892051 0.1288521
97 0.013271405 8 0.3987384 0.7749255 0.1211946
108 0.011625435 14 0.3170994 0.8250334 0.1289568
119 0.011160373 15 0.3054739 0.8075609 0.1201331
1210 0.009963678 17 0.2831532 0.8402426 0.1239376
1311 0.009817172 18 0.2731895 0.8636705 0.1285493
1412 0.008871296 19 0.2633723 0.8748087 0.1288389
1513 0.007348896 21 0.2456297 0.9385045 0.1388533
1614 0.007064722 25 0.2162341 0.9748912 0.1425843
1715 0.006959168 26 0.2091694 0.9748912 0.1425843
1816 0.006941783 27 0.2022103 0.9773184 0.1426132
1917 0.006754614 28 0.1952685 0.9684111 0.1425099
2018 0.006412035 29 0.1885139 0.9684111 0.1425099
2119 0.005000000 30 0.1821018 0.9616884 0.1424528
FIELD1 | CP | nsplit | rel error | xerror | xstd |
---|---|---|---|---|---|
1 | 0.246749964834682 | 0 | 1 | 1.00794601007258 | 0.139013921546954 |
2 | 0.189905771439237 | 1 | 0.753250035165318 | 0.872149976002464 | 0.135568787856714 |
3 | 0.0685745463644453 | 2 | 0.563344263726081 | 0.610987157350381 | 0.104981342598344 |
4 | 0.0231403268677788 | 3 | 0.494769717361636 | 0.649504706761546 | 0.116569074229425 |
5 | 0.0177348849948363 | 5 | 0.448489063626078 | 0.73640079945181 | 0.123232838794308 |
6 | 0.0142809034585567 | 7 | 0.413019293636405 | 0.789205054211005 | 0.128852122709316 |
7 | 0.0132714052487796 | 8 | 0.398738390177849 | 0.774925527841304 | 0.121194566352022 |
8 | 0.0116254352676553 | 14 | 0.317099353485133 | 0.825033409022432 | 0.128956811587054 |
9 | 0.0111603729867776 | 15 | 0.305473918217478 | 0.807560947169134 | 0.120133060593691 |
10 | 0.0099636776106262 | 17 | 0.283153172243922 | 0.840242622789819 | 0.123937627375919 |
11 | 0.00981717196654987 | 18 | 0.273189494633296 | 0.863670491630661 | 0.128549290536426 |
12 | 0.00887129602725455 | 19 | 0.263372322666746 | 0.87480867075757 | 0.128838854429451 |
13 | 0.00734889647526427 | 21 | 0.245629730612237 | 0.938504530794627 | 0.138853341013866 |
14 | 0.00706472225139319 | 25 | 0.21623414471118 | 0.974891244717307 | 0.142584260606387 |
15 | 0.00695916804300275 | 26 | 0.209169422459787 | 0.974891244717307 | 0.142584260606387 |
16 | 0.00694178256429693 | 27 | 0.202210254416784 | 0.977318396355282 | 0.142613240190292 |
17 | 0.00675461351781338 | 28 | 0.195268471852487 | 0.968411129141948 | 0.142509912292911 |
18 | 0.00641203523636004 | 29 | 0.188513858334674 | 0.968411129141948 | 0.142509912292911 |
19 | 0.005 | 30 | 0.182101823098314 | 0.961688405119617 | 0.142452846332138 |
x
1bestcp <- rtree$cptable[which.min(rtree$cptable[,"xerror"]),"CP"]
This will plot the complexity parameter table, which shows the cross-validation error as a function of the complexity parameter (cp) of the tree,
xxxxxxxxxx
21> bestcp
2[1] 0.06857455
x
1rtree.pruned <- prune(rtree, cp = bestcp)
2rpart.plot(rtree.pruned,type=4,extra=101,roundint = FALSE,nn=TRUE, main = "Regression Tree for Baseball Salaries (Pruned")
and identify the value of cp that gives the minimum cross-validation error. It will then prune the tree based on this value of cp and plot the pruned tree:
xxxxxxxxxx
1> rtree.pruned
2n=263 (59 observations deleted due to missingness)
3
4node), split, n, deviance, yval
5 * denotes terminal node
6
71) root 263 53319110 535.9259
8 2) Years< 4.5 90 6769171 225.8315 *
9 3) Years>=4.5 173 33393450 697.2467
10 6) Hits< 117.5 90 5312120 464.9167 *
11 7) Hits>=117.5 83 17955720 949.1708 *
We might interpret the regression tree :Years is the most important factor in determining Salary, and players withless experience earn lower salaries than more experienced players. Giventhat a player is less experienced, the number of hits that he made in theprevious year seems to play little role in his salary. But among players whohave been in the major leagues for five or more years, the number of hitsmade in the previous year does affect salary, and players who made morehits last year tend to have higher salaries.
A classification tree is used for predicting qualitative or categorical responses, while a regression tree is used for predicting quantitative responses.
In a classification tree, the predicted class for a new observation is determined by the most commonly occurring class of the training observations that fall into the same terminal node as the new observation. For example, let's say we are trying to predict whether a customer will make a purchase (yes or no) based on their age and income. A classification tree might split the data based on age, and then further split each age group based on income. The final terminal nodes might have different proportions of customers who made a purchase, and we would use the most common class in each terminal node to predict whether a new customer with those characteristics would make a purchase.
In interpreting the results of a classification tree, we are interested not only in the predicted class for a particular observation, but also in the class proportions among the training observations that fall into the same terminal node. This can give us a sense of how well the tree is able to distinguish between different classes, and can help us identify areas where the tree may be overfitting or underfitting the data.
To illustrate this, let's say we have a dataset with two classes (Class A and Class B), and we build a classification tree using some predictor variables. The tree might have several terminal nodes, each corresponding to a different region of the predictor space. Within each terminal node, we can calculate the proportion of training observations that belong to Class A and Class B. This information can help us understand the relative importance of different predictors in distinguishing between the two classes, and can also help us identify any regions where the tree may be making errors or misclassifying observations.
Here's an example of a simple classification tree with two terminal nodes, each corresponding to a different region of the predictor space:
In this example, the tree splits the data based on age, and then further splits each age group based on income. The terminal node B corresponds to customers who are younger than 30 and have an income of 50K or less. Among the training observations in this node, 70% belong to Class A and 30% belong to Class B. The terminal node C corresponds to customers who are younger than 30 and have an income greater than 50K. Among the training observations in this node, 40% belong to Class A and 60% belong to Class B.
When building a classification tree, we need to decide which predictor variable to split on at each node, and we do this by evaluating the quality of each potential split. There are different measures of quality that we can use, but two common ones are the Gini index and the entropy.
The Gini index measures the impurity of a node, where a node is considered pure if all of the training observations in the node belong to the same class. The Gini index takes values between 0 and 1, with 0 indicating a pure node (all observations belong to the same class) and 1 indicating an impure node (observations are evenly split between different classes). When evaluating a potential split, we calculate the weighted sum of the Gini indices for the child nodes, where the weights correspond to the proportion of observations in each child node.
The entropy is another measure of impurity that is often used in decision trees. It is based on information theory and measures the amount of uncertainty or randomness in a node. Again, a node is considered pure if all of the training observations in the node belong to the same class, and the entropy takes values between 0 and 1, with 0 indicating a pure node and 1 indicating an impure node. When evaluating a potential split, we calculate the weighted sum of the entropies for the child nodes.
Both the Gini index and entropy are more sensitive to node purity than the classification error rate, which simply measures the proportion of misclassified observations in a node. This is because the Gini index and entropy take into account the distribution of observations among the different classes, and can help us identify splits that lead to more homogeneous child nodes in terms of class membership.
The Heart dataset is a binary classification problem where the goal is to predict the presence or absence of heart disease based on various patient characteristics.
To get started, we first need to load the dataset into R. The Heart dataset is included in the ISLR
package, so we can load it using the following code:
xxxxxxxxxx
11library(ISLR)
2Heart <- read.csv("0306/Heart.csv")
This will load the dataset into the variable Heart
. We can take a quick look at the dataset using the head()
function:
FIELD1 | X | Age | Sex | ChestPain | RestBP | Chol | Fbs | RestECG | MaxHR | ExAng | Oldpeak | Slope | Ca | Thal | AHD |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
1 | 1 | 63 | 1 | typical | 145 | 233 | 1 | 2 | 150 | 0 | 2.3 | 3 | 0 | fixed | No |
2 | 2 | 67 | 1 | asymptomatic | 160 | 286 | 0 | 2 | 108 | 1 | 1.5 | 2 | 3 | normal | Yes |
3 | 3 | 67 | 1 | asymptomatic | 120 | 229 | 0 | 2 | 129 | 1 | 2.6 | 2 | 2 | reversable | Yes |
4 | 4 | 37 | 1 | nonanginal | 130 | 250 | 0 | 0 | 187 | 0 | 3.5 | 3 | 0 | normal | No |
5 | 5 | 41 | 0 | nontypical | 130 | 204 | 0 | 2 | 172 | 0 | 1.4 | 1 | 0 | normal | No |
6 | 6 | 56 | 1 | nontypical | 120 | 236 | 0 | 0 | 178 | 0 | 0.8 | 1 | 0 | normal | No |
This will show the first few rows of the dataset, which should give us an idea of what variables are included and what the data looks like.
xxxxxxxxxx
1101head(Heart)
2# A tibble: 6 × 15
3 ...1 Age Sex ChestPain RestBP Chol Fbs RestECG MaxHR ExAng Oldpeak Slope Ca Thal AHD
4 <dbl> <dbl> <dbl> <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <chr> <chr>
51 1 63 1 typical 145 233 1 2 150 0 2.3 3 0 fixed No
62 2 67 1 asymptomatic 160 286 0 2 108 1 1.5 2 3 normal Yes
73 3 67 1 asymptomatic 120 229 0 2 129 1 2.6 2 2 reversable Yes
84 4 37 1 nonanginal 130 250 0 0 187 0 3.5 3 0 normal No
95 5 41 0 nontypical 130 204 0 2 172 0 1.4 1 0 normal No
106 6 56 1 nontypical 120 236 0 0 178 0 0.8 1 0 normal No
When building a classification tree, it is important to ensure that any categorical variables in the dataset are treated as factors, so that the tree knows to treat them as discrete categories rather than continuous variables.
In the case of the Heart dataset, there are several variables that are categorical, such as ChestPain
, Thal
, Exang
, and Sex
. These variables should be converted to factors before building the tree.
Here's an example code that converts the categorical variables to factors:
xxxxxxxxxx
11Heart$ChestPain <- as.factor(Heart$ChestPain)
2Heart$Thal <- as.factor(Heart$Thal)
3Heart$Exang <- as.factor(Heart$Exang)
4Heart$Sex <- as.factor(Heart$Sex)
In this code, we use the as.factor()
function to convert each categorical variable to a factor.
The AHD
variable in the Heart dataset is a binary outcome variable indicating the presence or absence of heart disease, and it is coded as a character variable with the values "Yes" and "No". In order to build a classification tree, we need to convert this variable to a binary factor variable with the values 0 and 1 (or equivalently, "No" and "Yes").
Here's an example code that converts the AHD
variable to a binary factor variable:
xxxxxxxxxx
11Heart$AHD <- ifelse(Heart$AHD == "Yes", 1, 0)
2Heart$AHD <- as.factor(Heart$AHD)
In this code, we use the ifelse()
function to replace the "Yes" values in the AHD
variable with 1, and the "No" values with 0. We then use the as.factor()
function to convert the resulting variable to a binary factor variable.
Now the new data becomes:
...1 | Age | Sex | ChestPain | RestBP | Chol | Fbs | RestECG | MaxHR | ExAng | Oldpeak | Slope | Ca | Thal | AHD |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
1 | 63 | 1 | typical | 145 | 233 | 1 | 2 | 150 | 0 | 2.3 | 3 | 0 | fixed | 0 |
2 | 67 | 1 | asymptomatic | 160 | 286 | 0 | 2 | 108 | 1 | 1.5 | 2 | 3 | normal | 1 |
3 | 67 | 1 | asymptomatic | 120 | 229 | 0 | 2 | 129 | 1 | 2.6 | 2 | 2 | reversable | 1 |
4 | 37 | 1 | nonanginal | 130 | 250 | 0 | 0 | 187 | 0 | 3.5 | 3 | 0 | normal | 0 |
5 | 41 | 0 | nontypical | 130 | 204 | 0 | 2 | 172 | 0 | 1.4 | 1 | 0 | normal | 0 |
6 | 56 | 1 | nontypical | 120 | 236 | 0 | 0 | 178 | 0 | 0.8 | 1 | 0 | normal | 0 |
xxxxxxxxxx
1101> head(Heart)
2# A tibble: 6 × 15
3 ...1 Age Sex ChestPain RestBP Chol Fbs RestECG MaxHR ExAng Oldpeak Slope Ca Thal AHD
4 <dbl> <dbl> <fct> <fct> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <fct> <fct>
51 1 63 1 typical 145 233 1 2 150 0 2.3 3 0 fixed 0
62 2 67 1 asymptomatic 160 286 0 2 108 1 1.5 2 3 normal 1
73 3 67 1 asymptomatic 120 229 0 2 129 1 2.6 2 2 reversable 1
84 4 37 1 nonanginal 130 250 0 0 187 0 3.5 3 0 normal 0
95 5 41 0 nontypical 130 204 0 2 172 0 1.4 1 0 normal 0
106 6 56 1 nontypical 120 236 0 0 178 0 0.8 1 0 normal 0
Once we have converted the AHD
variable to a binary factor variable, we can build the classification tree using the code:
x
1library(rpart)
2library(rpart.plot)
3# Fit classification tree
4fit.default <- rpart(AHD ~ ., data = Heart, method = "class")
5fit.default
6rpart.plot(fit.default,type=4,extra=101,roundint = FALSE,nn=TRUE, main = "Classification Tree for Heart Disease (default)")
We fit the classification tree using the rpart()
function, specifying method = "class"
to indicate that we are building a classification tree. Since we do not specify any other arguments, the function will use the default values for all other parameters.
Once the tree has been built, we can visualize it using the rpart.plot()
function. The type = 4
argument specifies that we want to use a compact rectangular layout for the tree, and the extra = 101
argument specifies that we want to include the percentage of observations in each terminal node. The roundint = FALSE
argument specifies that we do not want to round the node percentages to the nearest integer, and the nn = TRUE
argument specifies that we want to include the number of observations in each node. Finally, the main
argument specifies the title of the plot.
x
1> fit.default
2n= 303
3
4node), split, n, loss, yval, (yprob)
5 * denotes terminal node
6
7 1) root 303 139 0 (0.5412541 0.4587459)
8 2) Thal=normal 167 38 0 (0.7724551 0.2275449)
9 4) ChestPain=nonanginal,nontypical 100 9 0 (0.9100000 0.0900000) *
10 5) ChestPain=asymptomatic,typical 67 29 0 (0.5671642 0.4328358)
11 10) Ca< 0.5 40 9 0 (0.7750000 0.2250000)
12 20) X< 168.5 20 1 0 (0.9500000 0.0500000) *
13 21) X>=168.5 20 8 0 (0.6000000 0.4000000)
14 42) RestBP< 145 13 3 0 (0.7692308 0.2307692) *
15 43) RestBP>=145 7 2 1 (0.2857143 0.7142857) *
16 11) Ca>=0.5 27 7 1 (0.2592593 0.7407407) *
17 3) Thal=fixed,reversable 136 35 1 (0.2573529 0.7426471)
18 6) ChestPain=nonanginal,nontypical,typical 46 21 0 (0.5434783 0.4565217)
19 12) Ca< 0.5 29 8 0 (0.7241379 0.2758621) *
20 13) Ca>=0.5 17 4 1 (0.2352941 0.7647059) *
21 7) ChestPain=asymptomatic 90 10 1 (0.1111111 0.8888889) *
This is the output of fit.default
, which is an object of class rpart
. The output shows the structure of the classification tree, including the splits and the terminal nodes.
The first line shows the root node of the tree, which contains all 303 observations. The second line shows the split criterion for the root node, which is the variable Thal
. The third line shows the number of observations that belong to the root node, the number of misclassifications (loss), and the proportion of observations in each class (yval and yprob).
The fourth line shows the first child node of the root, which corresponds to the normal
value of the Thal
variable. This node contains 167 observations. The fifth line shows the split criterion for this node, which is the ChestPain
variable. The sixth line shows the two child nodes of this node, which correspond to the two values of the ChestPain
variable (nonanginal,nontypical
and asymptomatic,typical
).
The seventh line shows the first terminal node of the tree, which corresponds to the nonanginal,nontypical
value of the ChestPain
variable for the normal
value of the Thal
variable. This node contains 100 observations, all of which belong to the 0
class (no heart disease).
The eighth line shows the second child node of the normal
value of the Thal
variable, which corresponds to the asymptomatic,typical
value of the ChestPain
variable. This node contains 67 observations. The ninth line shows the split criterion for this node, which is the Ca
variable. The tenth and eleventh lines show the two child nodes of this node, which correspond to the two possible values of the Ca
variable (< 0.5
and >=0.5
).
…
xxxxxxxxxx
11printcp(fit.default)
2plotcp(fit.default)
x
1> printcp(fit.default)
2
3Classification tree:
4rpart(formula = AHD ~ ., data = Heart, method = "class")
5
6Variables actually used in tree construction:
7[1] Ca ChestPain RestBP Thal X
8
9Root node error: 139/303 = 0.45875
10
11n= 303
12
13 CP nsplit rel error xerror xstd
141 0.474820 0 1.00000 1.00000 0.062401
152 0.046763 1 0.52518 0.61871 0.056460
163 0.010791 5 0.33813 0.43165 0.049905
174 0.010000 7 0.31655 0.41727 0.049268
printcp()
is a function in the rpart
package that prints the cross-validation results for a fitted rpart
object. The output shows the complexity parameter (CP
) used for pruning the tree, as well as the number of splits (nsplit
), the relative error (rel error
), the cross-validation error (xerror
), and the standard deviation of the cross-validation error (xstd
) for each value of the complexity parameter.
In this case, we applied printcp()
to the fit.default
object, which is a fitted rpart
object obtained using the default parameters for the rpart()
function.
The first row of the output shows the results for the root node, which has a relative error of 0.45875, meaning that 45.875% of the observations in the dataset are misclassified by the root node.
The second row shows the results for the first split, which uses the Thal
variable. The optimal value for the complexity parameter is 0.474820, which corresponds to the root node. The relative error decreases to 0.52518 after the first split, meaning that the split improves the classification accuracy. The cross-validation error (xerror
) is 0.61871, which is the average classification error rate across all cross-validation folds. The standard deviation of the cross-validation error (xstd
) is 0.062401, which gives an indication of the stability of the cross-validation estimate.
The third row shows the results for the second split, which uses the ChestPain
variable. The optimal value for the complexity parameter is 0.046763, which corresponds to a tree with one split. The relative error decreases to 0.33813 after the second split, meaning that the second split further improves the classification accuracy. The cross-validation error (xerror
) is 0.43165, which is lower than the error rate for the root node. The standard deviation of the cross-validation error (xstd
) is 0.049905.
The fourth row shows the results for the third split, which uses the Ca
variable. The optimal value for the complexity parameter is 0.010791, which corresponds to a tree with four splits. The relative error decreases to 0.31655 after the third split. The cross-validation error (xerror
) is 0.41727, which is lower than the error rate for the second split. The standard deviation of the cross-validation error (xstd
) is 0.049268.
Based on these results, we can see that adding more splits to the tree decreases the relative error, but may increase the cross-validation error due to overfitting. Therefore, the optimal tree size may depend on the trade-off between classification accuracy and model complexity.
CP | nsplit | rel | error | xerror | xstd | |
---|---|---|---|---|---|---|
1 | 0.474820 | 0 | 1.00000 | 1.00000 | 0.062401 | |
2 | 0.046763 | 1 | 0.52518 | 0.61871 | 0.056460 | |
3 | 0.010791 | 5 | 0.33813 | 0.43165 | 0.049905 | |
4 | 0.010000 | 7 | 0.31655 | 0.41727 | 0.049268 |
xxxxxxxxxx
11bestcp <- fit.default$cptable[which.min(fit.default$cptable[,"xerror"]),"CP"]
2bestcp
3# [1] 0.01
The code fit.default$cptable
returns a table with information about the cross-validation errors for each value of the complexity parameter. The which.min()
function is used to identify the row in the table with the minimum cross-validation error. The fit.default$cptable[,"xerror"]
argument extracts the column containing the cross-validation errors.
The bestcp
variable stores the value of the complexity parameter corresponding to the minimum cross-validation error. In this case, the value is 0.01, which corresponds to a tree with four splits. This value is obtained from the CP
column of the row with the minimum cross-validation error in the fit.default$cptable
table.
x
1fit.larger <- rpart(AHD ~ ., data = Heart, method = "class", minsplit=2, minbucket=1, cp=0.001)
2rpart.plot(fit.larger,type=4,extra=101,roundint = FALSE,nn=TRUE, main = "Classification Tree for Heart Disease (fit.larger)")
3fit.larger
4printcp(fit.larger)
5plotcp(fit.larger)
This code fits a larger classification tree to the Heart
dataset using the rpart()
function with the following arguments:
AHD ~ .
: This formula specifies that AHD
is the response variable and all other variables in the dataset are the predictors.
data = Heart
: This argument specifies the dataset to be used.
method = "class"
: This argument specifies that a classification tree should be built.
minsplit = 2
: This argument specifies the minimum number of observations required to perform a split.
minbucket = 1
: This argument specifies the minimum number of observations allowed in a terminal node.
cp = 0.001
: This argument specifies the complexity parameter value to be used for pruning the tree.
The rpart.plot()
function is used to plot the resulting tree. The type=4
argument specifies that the tree should be plotted with the split criterion shown on each split. The extra=101
argument specifies that the node labels should be printed with 101% of their original size to make them easier to read. The roundint = FALSE
argument specifies that the split criterion should not be rounded to integers. The nn=TRUE
argument specifies that the node labels should include the number of observations in each node. The main
argument specifies the main title for the plot.
By specifying minsplit=2
and minbucket=1
, we are allowing the tree to be more complex than the default tree, which may lead to overfitting.
xxxxxxxxxx
11141n= 303
2
3node), split, n, loss, yval, (yprob)
4 * denotes terminal node
5
6 1) root 303 139 0 (0.54125413 0.45874587)
7 2) Thal=normal 167 38 0 (0.77245509 0.22754491)
8 4) ChestPain=nonanginal,nontypical 100 9 0 (0.91000000 0.09000000)
9 8) Oldpeak< 2.7 97 7 0 (0.92783505 0.07216495)
10 16) Age< 56.5 70 2 0 (0.97142857 0.02857143)
11 32) Chol>=153 68 1 0 (0.98529412 0.01470588)
12 64) RestBP>=109 62 0 0 (1.00000000 0.00000000) *
13 65) RestBP< 109 6 1 0 (0.83333333 0.16666667)
14 130) Sex=0 4 0 0 (1.00000000 0.00000000) *
15 131) Sex=1 2 1 0 (0.50000000 0.50000000)
16 262) X< 134.5 1 0 0 (1.00000000 0.00000000) *
17 263) X>=134.5 1 0 1 (0.00000000 1.00000000) *
18 33) Chol< 153 2 1 0 (0.50000000 0.50000000)
19 66) X< 163.5 1 0 0 (1.00000000 0.00000000) *
20 67) X>=163.5 1 0 1 (0.00000000 1.00000000) *
21 17) Age>=56.5 27 5 0 (0.81481481 0.18518519)
22 34) X< 260.5 23 2 0 (0.91304348 0.08695652)
23 68) Oldpeak< 1.7 22 1 0 (0.95454545 0.04545455)
24 136) X>=36.5 19 0 0 (1.00000000 0.00000000) *
25 137) X< 36.5 3 1 0 (0.66666667 0.33333333)
26 274) X< 30 2 0 0 (1.00000000 0.00000000) *
27 275) X>=30 1 0 1 (0.00000000 1.00000000) *
28 69) Oldpeak>=1.7 1 0 1 (0.00000000 1.00000000) *
29 35) X>=260.5 4 1 1 (0.25000000 0.75000000)
30 70) Age>=62 1 0 0 (1.00000000 0.00000000) *
31 71) Age< 62 3 0 1 (0.00000000 1.00000000) *
32 9) Oldpeak>=2.7 3 1 1 (0.33333333 0.66666667)
33 18) X< 35.5 1 0 0 (1.00000000 0.00000000) *
34 19) X>=35.5 2 0 1 (0.00000000 1.00000000) *
35 5) ChestPain=asymptomatic,typical 67 29 0 (0.56716418 0.43283582)
36 10) Ca< 0.5 40 9 0 (0.77500000 0.22500000)
37 20) RestBP< 154 35 5 0 (0.85714286 0.14285714)
38 40) Age< 60.5 25 1 0 (0.96000000 0.04000000)
39 80) X< 265 23 0 0 (1.00000000 0.00000000) *
40 81) X>=265 2 1 0 (0.50000000 0.50000000)
41 162) X>=273.5 1 0 0 (1.00000000 0.00000000) *
42 163) X< 273.5 1 0 1 (0.00000000 1.00000000) *
43 41) Age>=60.5 10 4 0 (0.60000000 0.40000000)
44 82) X< 192 6 1 0 (0.83333333 0.16666667)
45 164) Age>=61.5 5 0 0 (1.00000000 0.00000000) *
46 165) Age< 61.5 1 0 1 (0.00000000 1.00000000) *
47 83) X>=192 4 1 1 (0.25000000 0.75000000)
48 166) Age>=69 1 0 0 (1.00000000 0.00000000) *
49 167) Age< 69 3 0 1 (0.00000000 1.00000000) *
50 21) RestBP>=154 5 1 1 (0.20000000 0.80000000)
51 42) Age>=62 1 0 0 (1.00000000 0.00000000) *
52 43) Age< 62 4 0 1 (0.00000000 1.00000000) *
53 11) Ca>=0.5 27 7 1 (0.25925926 0.74074074)
54 22) Sex=0 7 3 0 (0.57142857 0.42857143)
55 44) Age>=63.5 3 0 0 (1.00000000 0.00000000) *
56 45) Age< 63.5 4 1 1 (0.25000000 0.75000000)
57 90) Age< 59.5 1 0 0 (1.00000000 0.00000000) *
58 91) Age>=59.5 3 0 1 (0.00000000 1.00000000) *
59 23) Sex=1 20 3 1 (0.15000000 0.85000000)
60 46) ChestPain=typical 6 3 0 (0.50000000 0.50000000)
61 92) X< 220.5 4 1 0 (0.75000000 0.25000000)
62 184) Chol< 263 3 0 0 (1.00000000 0.00000000) *
63 185) Chol>=263 1 0 1 (0.00000000 1.00000000) *
64 93) X>=220.5 2 0 1 (0.00000000 1.00000000) *
65 47) ChestPain=asymptomatic 14 0 1 (0.00000000 1.00000000) *
66 3) Thal=fixed,reversable 136 35 1 (0.25735294 0.74264706)
67 6) ChestPain=nonanginal,nontypical,typical 46 21 0 (0.54347826 0.45652174)
68 12) Ca< 0.5 29 8 0 (0.72413793 0.27586207)
69 24) X< 290 27 6 0 (0.77777778 0.22222222)
70 48) ExAng< 0.5 22 3 0 (0.86363636 0.13636364)
71 96) Chol< 228 9 0 0 (1.00000000 0.00000000) *
72 97) Chol>=228 13 3 0 (0.76923077 0.23076923)
73 194) X< 232 12 2 0 (0.83333333 0.16666667)
74 388) Chol>=230.5 11 1 0 (0.90909091 0.09090909)
75 776) RestBP< 161 9 0 0 (1.00000000 0.00000000) *
76 777) RestBP>=161 2 1 0 (0.50000000 0.50000000)
77 1554) X>=163 1 0 0 (1.00000000 0.00000000) *
78 1555) X< 163 1 0 1 (0.00000000 1.00000000) *
79 389) Chol< 230.5 1 0 1 (0.00000000 1.00000000) *
80 195) X>=232 1 0 1 (0.00000000 1.00000000) *
81 49) ExAng>=0.5 5 2 1 (0.40000000 0.60000000)
82 98) Oldpeak< 1.5 2 0 0 (1.00000000 0.00000000) *
83 99) Oldpeak>=1.5 3 0 1 (0.00000000 1.00000000) *
84 25) X>=290 2 0 1 (0.00000000 1.00000000) *
85 13) Ca>=0.5 17 4 1 (0.23529412 0.76470588)
86 26) Slope< 1.5 5 2 0 (0.60000000 0.40000000)
87 52) RestECG< 1 3 0 0 (1.00000000 0.00000000) *
88 53) RestECG>=1 2 0 1 (0.00000000 1.00000000) *
89 27) Slope>=1.5 12 1 1 (0.08333333 0.91666667)
90 54) X< 99 4 1 1 (0.25000000 0.75000000)
91 108) X>=75 1 0 0 (1.00000000 0.00000000) *
92 109) X< 75 3 0 1 (0.00000000 1.00000000) *
93 55) X>=99 8 0 1 (0.00000000 1.00000000) *
94 7) ChestPain=asymptomatic 90 10 1 (0.11111111 0.88888889)
95 14) Oldpeak< 0.55 22 8 1 (0.36363636 0.63636364)
96 28) X< 55 3 0 0 (1.00000000 0.00000000) *
97 29) X>=55 19 5 1 (0.26315789 0.73684211)
98 58) MaxHR< 117 2 0 0 (1.00000000 0.00000000) *
99 59) MaxHR>=117 17 3 1 (0.17647059 0.82352941)
100 118) Chol< 233.5 7 3 1 (0.42857143 0.57142857)
101 236) RestBP< 136 4 1 0 (0.75000000 0.25000000)
102 472) X>=112 3 0 0 (1.00000000 0.00000000) *
103 473) X< 112 1 0 1 (0.00000000 1.00000000) *
104 237) RestBP>=136 3 0 1 (0.00000000 1.00000000) *
105 119) Chol>=233.5 10 0 1 (0.00000000 1.00000000) *
106 15) Oldpeak>=0.55 68 2 1 (0.02941176 0.97058824)
107 30) Thal=fixed 11 2 1 (0.18181818 0.81818182)
108 60) Age>=65.5 1 0 0 (1.00000000 0.00000000) *
109 61) Age< 65.5 10 1 1 (0.10000000 0.90000000)
110 122) RestBP< 112 2 1 0 (0.50000000 0.50000000)
111 244) X>=162.5 1 0 0 (1.00000000 0.00000000) *
112 245) X< 162.5 1 0 1 (0.00000000 1.00000000) *
113 123) RestBP>=112 8 0 1 (0.00000000 1.00000000) *
114 31) Thal=reversable 57 0 1 (0.00000000 1.00000000) *
x
1> printcp(fit.larger)
2
3Classification tree:
4rpart(formula = AHD ~ ., data = Heart, method = "class", minsplit = 2,
5 minbucket = 1, cp = 0.001)
6
7Variables actually used in tree construction:
8 [1] Age Ca ChestPain Chol ExAng MaxHR Oldpeak RestBP RestECG Sex Slope Thal
9[13] X
10
11Root node error: 139/303 = 0.45875
12
13n= 303
14
15 CP nsplit rel error xerror xstd
161 0.4748201 0 1.0000000 1.00000 0.062401
172 0.0467626 1 0.5251799 0.60432 0.056056
183 0.0215827 5 0.3381295 0.41007 0.048941
194 0.0143885 6 0.3165468 0.44604 0.050521
205 0.0107914 7 0.3021583 0.48921 0.052246
216 0.0071942 16 0.2014388 0.53957 0.054046
227 0.0047962 34 0.0719424 0.58273 0.055422
238 0.0035971 37 0.0575540 0.59712 0.055849
249 0.0023981 51 0.0071942 0.61871 0.056460
2510 0.0010000 54 0.0000000 0.62590 0.056657
The printcp()
function prints the cross-validation results for the fit.larger
object, which is a fitted rpart
object obtained using the larger tree with minsplit=2
, minbucket=1
, and cp=0.001
.
The first row of the output shows the results for the root node, which has a relative error of 0.45875, meaning that 45.875% of the observations in the dataset are misclassified by the root node.
The second row shows the results for the first split, which uses the Thal
variable. The optimal value for the complexity parameter is 0.4748201, which corresponds to the root node. The relative error decreases to 0.5251799 after the first split, meaning that the split improves the classification accuracy. The cross-validation error (xerror
) is 0.60432, which is higher than the error rate for the default tree.
The third to tenth rows show the results for the subsequent splits. As the tree becomes more complex, the relative error decreases, but the cross-validation error increases, indicating overfitting.
The plotcp()
function can be used to visualize the cross-validation error as a function of the complexity parameter. This can help identify the optimal value for the complexity parameter and the corresponding size of the tree.
xxxxxxxxxx
1111 CP nsplit rel error xerror xstd
21 0.4748201 0 1.0000000 1.00000 0.062401
32 0.0467626 1 0.5251799 0.60432 0.056056
43 0.0215827 5 0.3381295 0.41007 0.048941
54 0.0143885 6 0.3165468 0.44604 0.050521
65 0.0107914 7 0.3021583 0.48921 0.052246
76 0.0071942 16 0.2014388 0.53957 0.054046
87 0.0047962 34 0.0719424 0.58273 0.055422
98 0.0035971 37 0.0575540 0.59712 0.055849
109 0.0023981 51 0.0071942 0.61871 0.056460
1110 0.0010000 54 0.0000000 0.62590 0.056657
xxxxxxxxxx
11bestcp <- fit.larger$cptable[which.min(fit.larger$cptable[,"xerror"]),"CP"]
2bestcp
3# [1] 0.02158273
This code extracts the complexity parameter value that corresponds to the minimum cross-validation error from the cptable
attribute of a fitted rpart
object.
The which.min(fit.larger$cptable[,"xerror"])
code identifies the row in the cptable
table with the smallest cross-validation error. The fit.larger$cptable[,"xerror"]
expression returns a vector of cross-validation errors for each value of the complexity parameter, and which.min()
finds the index of the minimum value.
The fit.larger$cptable[which.min(fit.larger$cptable[,"xerror"]),"CP"]
code extracts the complexity parameter value from the CP
column of the row with the minimum cross-validation error.
To prune the tree using the optimal value of the complexity parameter, we can use the prune()
function in R.
Here's the code to prune the fit.larger
tree using the optimal value of bestcp
:
xxxxxxxxxx
11pruned.tree <- prune(fit.larger, cp = bestcp)
The prune()
function takes two arguments: the fitted rpart
object, and the value of the complexity parameter that should be used for pruning the tree. In this case, bestcp
is the value of the complexity parameter that minimizes the cross-validation error, and fit.larger
is the larger classification tree we previously built.
The resulting pruned.tree
object is a smaller version of the original tree, with some of the branches pruned based on the specified value of the complexity parameter. The pruned.tree
object can be visualized using the rpart.plot()
function, just like the previous trees we built.
xxxxxxxxxx
171> pruned.tree
2n= 303
3
4node), split, n, loss, yval, (yprob)
5 * denotes terminal node
6
7 1) root 303 139 0 (0.5412541 0.4587459)
8 2) Thal=normal 167 38 0 (0.7724551 0.2275449)
9 4) ChestPain=nonanginal,nontypical 100 9 0 (0.9100000 0.0900000) *
10 5) ChestPain=asymptomatic,typical 67 29 0 (0.5671642 0.4328358)
11 10) Ca< 0.5 40 9 0 (0.7750000 0.2250000) *
12 11) Ca>=0.5 27 7 1 (0.2592593 0.7407407) *
13 3) Thal=fixed,reversable 136 35 1 (0.2573529 0.7426471)
14 6) ChestPain=nonanginal,nontypical,typical 46 21 0 (0.5434783 0.4565217)
15 12) Ca< 0.5 29 8 0 (0.7241379 0.2758621) *
16 13) Ca>=0.5 17 4 1 (0.2352941 0.7647059) *
17 7) ChestPain=asymptomatic 90 10 1 (0.1111111 0.8888889) *
xxxxxxxxxx
11rpart.plot(pruned.tree,type=4,extra=101,roundint = FALSE,nn=TRUE, main = "Classification Tree for Heart Disease (pruned.tree)")
To be continued ? ~ No. Go to 0313 where we implement all those on the datasets provided by Hocam.