FE581 – 0313 - R scripts WalkthroughClassification Example game.csvrpartBetter fitMake PredictionConfusion MatrixAccuracyPrecisionRecallRegression Example (Boston)Larger treeLargest treeClassification Example (Iris) with tree
game.csvWe have our data game.csv with contents as:
| Outlook | Temperature | Humidity | Windy | Play |
|---|---|---|---|---|
| sunny | hot | high | false | no |
| sunny | hot | high | true | no |
| overcast | hot | high | false | yes |
| rainy | mild | high | false | yes |
| rainy | cool | normal | false | yes |
| rainy | cool | normal | true | no |
| overcast | cool | normal | true | yes |
| sunny | mild | high | false | no |
| sunny | cool | normal | false | yes |
| rainy | mild | normal | false | yes |
| sunny | mild | normal | true | yes |
| overcast | mild | high | true | yes |
| overcast | hot | normal | false | yes |
| rainy | mild | high | true | no |
This data set represents a hypothetical game played outdoors and contains five variables: Outlook, Temperature, Humidity, Windy, and Play. The goal is to predict whether the game will be played or not based on the other variables.
To begin exploring this data set, let's load it into R. We can do this by first saving the contents into a file called "game.csv" and then using the read.csv() function to read the file:
11game <- read.csv("game.csv")Now that we have loaded the data into R, let's take a look at its structure. We can use the str() function to do this:
11str(game)This will output the following:
71> str(game)2'data.frame': 14 obs. of 5 variables:3$ Outlook : chr "sunny" "sunny" "overcast" "rainy" ...4$ Temperature: chr "hot" "hot" "hot" "mild" ...5$ Humidity : chr "high" "high" "high" "high" ...6$ Windy : chr "false" "true" "false" "false" ...7$ Play : chr "no" "no" "yes" "yes" ...
In R, factors are used to represent categorical variables, while characters are used to represent text data. The difference between them is that factors have a predefined set of possible values (i.e., levels), while characters can have any possible value.
Using factors instead of characters for categorical variables can have several advantages in data mining and statistical modeling:
Factors take up less memory than characters, which can be important when working with large data sets.
Factors can be ordered, which can be useful for variables that have a natural order (e.g., temperature categories: low, medium, high).
Factors can be used to specify contrasts and reference levels in statistical models, which can help to avoid issues with multicollinearity.
Functions that operate on factors (e.g., table(), summary()) automatically generate output that is formatted for categorical data, which can make it easier to understand and interpret.
In general, it's a good practice to use factors instead of characters for categorical variables in R.
Since the variables are categorical, they should be read as factors instead of characters. Here's the correct way to load the data into R:
11game <- read.csv("game.csv", stringsAsFactors = TRUE)Now, if we run str(game), we should see that the variables are factors:
71> str(game)2'data.frame': 14 obs. of 5 variables:3 $ Outlook : Factor w/ 3 levels "overcast","rainy",..: 3 3 1 2 2 2 1 3 3 2 ...4 $ Temperature: Factor w/ 3 levels "cool","hot","mild": 2 2 2 3 1 1 1 3 1 3 ...5 $ Humidity : Factor w/ 2 levels "high","normal": 1 1 1 1 2 2 2 1 2 2 ...6 $ Windy : Factor w/ 2 levels "false","true": 1 2 1 1 1 2 2 1 1 1 ...7 $ Play : Factor w/ 2 levels "no","yes": 1 1 2 2 2 1 2 1 2 2 ...We can see that there are 14 observations (rows) and 5 variables (columns), all of which are factors. Factors in R are used to represent categorical variables.
To get a summary of the data set, we can use the summary() function:
51> summary(game)2 Outlook Temperature Humidity Windy Play 3 overcast:4 cool:4 high :7 false:8 no :5 4 rainy :5 hot :4 normal:7 true :6 yes:9 5 sunny :5 mild:6 We can see that there are three levels of the Outlook variable (overcast, rainy, and sunny), three levels of the Temperature variable (cool, hot, and mild), two levels of the Humidity variable (high and normal), two levels of the Windy variable (false and true), and two levels of the Play variable (no and yes). We can also see that 9 out of the 14 games were played.
Now that we have a basic understanding of the data set, let's start exploring it in more detail using various data mining techniques.
rpartrpart() is a function in R that can be used to create decision trees. In this case, gametree1 is a decision tree model that predicts the Play variable (whether the game was played or not) based on the other variables in the game data set.
The ~. syntax in the formula specifies that all other variables in the data set should be used to predict Play. The method="class" argument tells rpart() to treat the Play variable as a categorical variable (i.e., a factor) instead of a continuous variable.
To fit the decision tree model to the data, we use the rpart() function with the formula and data as arguments:
11gametree1 <- rpart(Play ~ ., method = "class", data = game)71> gametree12n= 14 3
4node), split, n, loss, yval, (yprob)5 * denotes terminal node6
71) root 14 5 yes (0.3571429 0.6428571) *Here's a breakdown of what each line means:
n=14: The total number of observations in the data set.
node: The ID number of the current node in the tree.
split: The variable and value that were used to split the data at this node.
n: The number of observations at this node.
loss: The number of misclassified observations at this node.
yval: The predicted class for this node based on the majority class at the node.
(yprob): The probabilities of each class at this node.
*: Indicates that this is a terminal node, meaning that no further splits are made after this point.
In this case, the output shows that the root node (node 1) includes all 14 observations in the data set. The split variable and value are not shown because this is the initial node and no split has been made yet. The node predicts that the majority class is "yes" (i.e., the game will be played) with a probability of 0.64.
This is just the beginning of the decision tree model. As we continue to split the data based on the other variables, the tree will become more complex and accurate. We can use the plot() function to visualize the entire tree.
11rpart.plot(gametree1,type=4,extra=101,roundint = FALSE,nn=TRUE, main = "Classification Tree for Game Prediction (gametree1)")
11printcp(gametree1)The printcp() function provides a table of the complexity parameter (CP) values and corresponding tree size, residual deviance, and cross-validation error for each value of CP.
In this case, the output shows that gametree1 has not yet split on any variables, which is why there are no variables listed in the "Variables actually used in tree construction" line. The root node error (i.e., the misclassification rate at the root node) is 0.35714, which is calculated as the number of misclassified observations divided by the total number of observations in the data set.
The table shows that there is only one CP value (CP = 0.01) listed, which means that there is no complexity reduction available for this model. The tree size is 1, the residual deviance is 0 (since there are no splits yet), and the cross-validation error is 0, which indicates perfect prediction.
In general, we want to choose a CP value that balances model complexity with prediction accuracy. This can be done by selecting a CP value that results in a relatively small tree size while still maintaining a reasonable level of accuracy in predicting the outcome variable.
We can use this tree to make predictions for new data (or old data) by passing it to the predict() function:
21new_data <- data.frame(Outlook = "sunny", Temperature = "hot", Humidity = "high", Windy = "false")2predict(gametree1, new_data, type = "class")31> predict(gametree1, new_data, type = "class")2[1] yes3Levels: no yesThe output of predict() is "yes", which indicates that the decision tree model predicts that the game will be played based on the values of the predictor variables in new_data.
21# Better fit2gametree2=rpart(Play~.,method="class",data=game,minbucket=1,minsplit=2)gametree2 is another decision tree model that predicts the Play variable based on the other variables in the game data set, just like gametree1.
The main difference between gametree2 and gametree1 is that gametree2 includes two additional arguments: minbucket and minsplit.
minbucket specifies the minimum number of observations that must exist in a terminal node in order for it to be created. If the number of observations in a node is less than minbucket, the node will not be split, and it will become a terminal node.
minsplit specifies the minimum number of observations that must exist in a node in order for it to be considered for splitting. If the number of observations in a node is less than minsplit, the node will not be split, and it will become a terminal node.
In this case, minbucket = 1 and minsplit = 2, which means that any node with only one observation will become a terminal node, and any node with less than two observations will not be considered for splitting. These values are often used as starting points for tuning the model, and can be adjusted to find the optimal values for the data set.
191> gametree22n= 14 3
4node), split, n, loss, yval, (yprob)5 * denotes terminal node6
7 1) root 14 5 yes (0.3571429 0.6428571) 8 2) Outlook=rainy,sunny 10 5 no (0.5000000 0.5000000) 9 4) Humidity=high 5 1 no (0.8000000 0.2000000) 10 8) Outlook=sunny 3 0 no (1.0000000 0.0000000) *11 9) Outlook=rainy 2 1 no (0.5000000 0.5000000) 12 18) Windy=true 1 0 no (1.0000000 0.0000000) *13 19) Windy=false 1 0 yes (0.0000000 1.0000000) *14 5) Humidity=normal 5 1 yes (0.2000000 0.8000000) 15 10) Windy=true 2 1 no (0.5000000 0.5000000) 16 20) Outlook=rainy 1 0 no (1.0000000 0.0000000) *17 21) Outlook=sunny 1 0 yes (0.0000000 1.0000000) *18 11) Windy=false 3 0 yes (0.0000000 1.0000000) *19 3) Outlook=overcast 4 0 yes (0.0000000 1.0000000) *The output you provided is the summary of the decision tree model gametree2.
For example, at the root node, the yprob values are 0.36 for "no" and 0.64 for "yes", indicating that the majority class at this node is "yes".
x
1rpart.plot(gametree2,type=4,extra=101,roundint = FALSE,nn=TRUE, main = "Classification Tree for Game Prediction (gametree2)")
11printcp(gametree2)171> printcp(gametree2)2
3Classification tree:4rpart(formula = Play ~ ., data = game, method = "class", minbucket = 1, 5 minsplit = 2)6
7Variables actually used in tree construction:8[1] Humidity Outlook Windy 9
10Root node error: 5/14 = 0.3571411
12n= 14 13
14 CP nsplit rel error xerror xstd151 0.30 0 1.0 1.0 0.35857162 0.10 2 0.4 1.4 0.37417173 0.01 6 0.0 1.4 0.37417In this case, the output shows that gametree2 was fit with minbucket = 1 and minsplit = 2.
The table shows three CP values: 0.30, 0.10, and 0.01. Each row represents a candidate tree, where nsplit is the number of splits in the tree, rel error is the relative error rate for the tree, xerror is the cross-validation error rate for the tree, and xstd is the standard deviation of the cross-validation error.
The table shows three CP values: 0.30, 0.10, and 0.01. Each row represents a candidate tree, where nsplit is the number of splits in the tree, rel error is the relative error rate for the tree, xerror is the cross-validation error rate for the tree, and xstd is the standard deviation of the cross-validation error.
We can use this table to select the optimal tree for our data set. The idea is to choose a CP value that provides the smallest possible tree size while maintaining a reasonable level of accuracy in predicting the outcome variable. In this case, we can see that the tree size is reduced from 8 (the number of terminal nodes in gametree2) to 6 with CP = 0.10, and to 0 with CP = 0.01.
To select the optimal tree, we can use the plotcp() function to visualize the relationship between CP values and tree size, and select the CP value that provides the optimal tradeoff between complexity and accuracy:
11plotcp(gametree2)
This will generate a plot of the CP values and tree size, which can be used to select the optimal CP value.
21bestcp <- gametree2$cptable[which.min(gametree2$cptable[,"xerror"]),"CP"]2bestcp21> bestcp2[1] 0.3you can also use the predict() function to obtain the predicted class labels for the observations in the original data set (game), like this:
11game_pred1 <- predict(gametree2, type = "class")41> game_pred12 1 2 3 4 5 6 7 8 9 10 11 12 13 14 3 no no yes yes yes no yes no yes yes yes yes yes no 4Levels: no yesThe output of predict(gametree2, type = "class") is a vector of predicted class labels for each observation in the original data set game.
The predicted class labels are "no" or "yes", which are the levels of the Play variable. The Levels: no yes part of the output indicates that the predicted class label can take on one of these two levels.
For example, the first observation in game has predictor variables Outlook = "sunny", Temperature = "hot", Humidity = "high", and Windy = "false". According to the decision tree model gametree2, this observation should be classified as "no", which means that the game will not be played.
The output shows that game_pred1 is a vector of predicted class labels for each observation in game. For example, the first predicted label is "no", which corresponds to the first observation in game.
21new_data <- data.frame(Outlook = "sunny", Temperature = "hot", Humidity = "high", Windy = "false")2predict(gametree2, new_data, type = "class")The table() function can be used to create a contingency table that shows the actual and predicted class labels for a given set of data.
To create a contingency table for the original data set game, using the predicted class labels in game_pred2 obtained from the decision tree model gametree2, you can use the following code:
11table(game$Play, game_pred2)This will create a 2x2 contingency table with the rows representing the actual class labels ("no" and "yes") and the columns representing the predicted class labels ("no" and "yes").
The output will look something like this:
41 game_pred22 no yes3 no 5 04 yes 0 9This table shows that the decision tree model gametree2 correctly predicted all of the observations in the no class, and 9 out of 9 observations in the yes class, for a total accuracy rate of 100%.
The accuracy() function can be used to calculate the accuracy of a classification model, by comparing the predicted class labels to the actual class labels.
To calculate the accuracy of the decision tree model gametree2, using the predicted class labels in game_pred2 and the actual class labels in game$Play, you can use the following code:
11accuracy(actual = game$Play, predicted = game_pred2)This will calculate the accuracy of the model as the proportion of correctly classified observations out of the total number of observations.
The output will be a numeric value representing the accuracy rate, expressed as a proportion between 0 and 1. For example, an accuracy rate of 1.0 would indicate that all observations were correctly classified, while an accuracy rate of 0.5 would indicate that only half of the observations were correctly classified.
In this case, the output will be 1, indicating that all of the observations in game were correctly classified by the gametree2 model.
Accuracy = 5 + 9 / 14 = 1
The precision() function can be used to calculate the precision of a classification model, by comparing the predicted positive class labels to the actual positive class labels.
41precision(actual = game$Play,predicted = game_pred2)2Warning message:3In mean.default(actual[predicted == 1]) :4 argument is not numeric or logical: returning NAIn our case, we are getting a warning message because the function is unable to calculate the precision for the predicted class label "no". This is because there are no observations that were predicted to be "no". The mean() function used internally in the precision() function expects a numeric or logical argument, but in this case it is receiving a character vector that cannot be converted to a numeric or logical value.
To avoid this warning message and calculate the precision correctly, we can convert the data to numerical,
Here's how we can do this:
41game_pred2=as.numeric(game_pred2)2game$Play=as.numeric(game$Play)3precision(actual = game$Play,predicted = game_pred2)4# 1Precision = 9 / 9+0 = 1
The recall() function in the caret package can be used to calculate the recall (also known as sensitivity or true positive rate) of a classification model. Recall is defined as the number of true positives divided by the sum of true positives and false negatives.
21recall(actual = game$Play,predicted = game_pred2)2# 1Recall = 9 / 9 + 0 = 1
Boston dataset is a famous dataset in data mining used for regression analysis. It contains information collected by the U.S Census Service concerning housing in the area of Boston, Massachusetts. The dataset has 506 instances and 14 variables, including the median value of owner-occupied homes in $1000's (the target variable) and other variables such as crime rate, average number of rooms per dwelling, and the proportion of non-retail business acres per town.
11str(Boston)Here is a brief description of each variable:
crim: per capita crime rate by town
zn: proportion of residential land zoned for lots over 25,000 sq.ft.
indus: proportion of non-retail business acres per town
chas: Charles River dummy variable (1 if tract bounds river; 0 otherwise)
nox: nitric oxides concentration (parts per 10 million)
rm: average number of rooms per dwelling
age: proportion of owner-occupied units built prior to 1940
dis: weighted distances to five Boston employment centers
rad: index of accessibility to radial highways
tax: full-value property-tax rate per $10,000
ptratio: pupil-teacher ratio by town
black: 1000(Bk - 0.63)^2 where Bk is the proportion of blacks by town
lstat: lower status of the population (percent)
medv: median value of owner-occupied homes in $1000s (this is the target variable)
| FIELD1 | crim | zn | indus | chas | nox | rm | age | dis | rad | tax | ptratio | black | lstat | medv |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 1 | 0.00632 | 18 | 2.31 | 0 | 0.538 | 6.575 | 65.2 | 4.09 | 1 | 296 | 15.3 | 396.9 | 4.98 | 24 |
| 2 | 0.02731 | 0 | 7.07 | 0 | 0.469 | 6.421 | 78.9 | 4.9671 | 2 | 242 | 17.8 | 396.9 | 9.14 | 21.6 |
| 3 | 0.02729 | 0 | 7.07 | 0 | 0.469 | 7.185 | 61.1 | 4.9671 | 2 | 242 | 17.8 | 392.83 | 4.03 | 34.7 |
| 4 | 0.03237 | 0 | 2.18 | 0 | 0.458 | 6.998 | 45.8 | 6.0622 | 3 | 222 | 18.7 | 394.63 | 2.94 | 33.4 |
| 5 | 0.06905 | 0 | 2.18 | 0 | 0.458 | 7.147 | 54.2 | 6.0622 | 3 | 222 | 18.7 | 396.9 | 5.33 | 36.2 |
| 6 | 0.02985 | 0 | 2.18 | 0 | 0.458 | 6.43 | 58.7 | 6.0622 | 3 | 222 | 18.7 | 394.12 | 5.21 | 28.7 |
to be more convenient we rename Boston data to a variable named data
11data <- BostonCheck for missing values in the dataset
x
1# Check for missing values in the dataset2any(is.na(data))21> any(is.na(data))2[1] FALSExxxxxxxxxx11model <- rpart(medv ~ ., data = data)"model" is a user-defined name for the regression tree model object that will be created.
"rpart" is a function in R used to build regression trees.
"medv ~ ." specifies the formula for the model, where "medv" is the target variable and "." indicates that all other variables in the dataset should be used as predictors.
"data = data" specifies the dataset that will be used to build the model.
x1> model2n= 50634node), split, n, deviance, yval5* denotes terminal node671) root 506 42716.3000 22.5328182) rm< 6.941 430 17317.3200 19.9337294) lstat>=14.4 175 3373.2510 14.95600108) crim>=6.99237 74 1085.9050 11.97838 *119) crim< 6.99237 101 1150.5370 17.13762 *125) lstat< 14.4 255 6632.2170 23.349801310) dis>=1.5511 248 3658.3930 22.936291420) rm< 6.543 193 1589.8140 21.65648 *1521) rm>=6.543 55 643.1691 27.42727 *1611) dis< 1.5511 7 1429.0200 38.00000 *173) rm>=6.941 76 6059.4190 37.23816186) rm< 7.437 46 1899.6120 32.113041912) lstat>=9.65 7 432.9971 23.05714 *2013) lstat< 9.65 39 789.5123 33.73846 *217) rm>=7.437 30 1098.8500 45.09667 *
x
1rpart.plot(model,type=4,extra=101,roundint = FALSE,nn=TRUE, main = "Regression Tree for Boston Dataset (model)")
The decision to stop splitting further in a decision tree is based on a stopping criterion. This criterion can be either a pre-defined threshold for a certain parameter, such as the minimum number of samples required in a node to allow a split, or a threshold for the improvement in model performance that is achieved by the split.
In the context of a decision tree, deviance is a measure of the impurity or error in a node that is used to determine the optimal split. Specifically, the deviance in a node is the sum of the squared differences between the observed values and the predicted values in that node. The deviance is calculated differently depending on whether the decision tree is a regression tree or a classification tree.
In regression trees, the deviance in a node is typically measured by the residual sum of squares (RSS) or mean squared error (MSE), which represents the sum of squared differences between the observed values and the mean or median predicted value in the node. The RSS is calculated as follows:
where
In classification trees, the deviance in a node is typically measured by an impurity measure such as the Gini index or cross-entropy, which represents the probability of misclassification of a randomly chosen data point in the node. The Gini index is calculated as follows:
where p_i is the proportion of data points in the node that belong to the i-th class, and the sum is taken over all classes.
In the output of the rpart function in R, the deviance in each node is reported as the "deviance" or "deviance reduction" in that node, which represents the reduction in deviance achieved by splitting that node. The algorithm selects the split that maximizes the reduction in deviance, which leads to the best split in terms of reducing the impurity or error in the resulting nodes.
We can also manually calcuate the deviance for root:
xxxxxxxxxx21sum((Boston$medv-mean(Boston$medv))^2)2# [1] 42716.3The output of the R command sum((Boston$medv-mean(Boston$medv))^2) is the TSS for the "medv" variable in the "Boston" dataset.
The TSS is calculated as the sum of the squared differences between each observation and the mean of the response variable. In this case, the TSS for the "medv" variable is 42716.3, which represents the total variation in the median value of owner-occupied homes in the "Boston" dataset.
xxxxxxxxxx11printcp(model)xxxxxxxxxx1211printcp(model)23Regression tree:4rpart(formula = medv ~ ., data = data)56Variables actually used in tree construction:7[1] crim dis lstat rm89Root node error: 42716/506 = 84.421011n= 5061213CP nsplit rel error xerror xstd141 0.452744 0 1.00000 1.00295 0.083032152 0.171172 1 0.54726 0.64392 0.060341163 0.071658 2 0.37608 0.43453 0.048783174 0.036164 3 0.30443 0.35117 0.043788185 0.033369 4 0.26826 0.32938 0.043434196 0.026613 5 0.23489 0.33470 0.043560207 0.015851 6 0.20828 0.31457 0.044202218 0.010000 7 0.19243 0.28785 0.041753
The complexity parameter table shows the cross-validated error (xerror), the number of splits (nsplit), the relative error reduction (rel error), the complexity parameter (CP), and the standard error of the cross-validation error (xstd) for each candidate tree. The optimal tree size is selected based on the complexity parameter that minimizes the cross-validation error rate.
x
1reg <- model2pre = predict(reg)The code reg <- model assigns the previously fitted regression tree model (model) to a new object called reg. This is a common practice in R to store the results of a previous computation for later use or comparison.
The code pre = predict(reg) creates a new object called pre by predicting the response variable (medv) for each observation in the data dataframe using the previously fitted regression tree model (reg). The predict() function in R takes two arguments: the first argument is the fitted model object, and the second argument is the new data for which we want to make predictions. In this case, the new data is the same as the training data, i.e., the data dataframe.
After executing this code, the object pre will contain a vector of predicted values for the response variable based on the regression tree model.
xxxxxxxxxx11rmse(actual = data$medv,predicted = pre)2mae(actual = data$medv,predicted = pre)3mape(actual = data$medv,predicted = pre)xxxxxxxxxx61> rmse(actual = data$medv, predicted = pre)2[1] 4.0304683> mae(actual = data$medv, predicted = pre)4[1] 2.9097025> mape(actual = data$medv, predicted = pre)6[1] 0.1545698
The code rmse(actual = data$medv,predicted = pre) calculates the root mean squared error (RMSE) between the actual values of the response variable (medv) in the data dataframe and the predicted values from the regression tree model stored in the pre object.
The code mae(actual = data$medv,predicted = pre) calculates the mean absolute error (MAE) between the actual values of the response variable in the data dataframe and the predicted values from the regression tree model stored in the pre object.
The code mape(actual = data$medv,predicted = pre) calculates the mean absolute percentage error (MAPE) between the actual values of the response variable in the data dataframe and the predicted values from the regression tree model stored in the pre object.
We can also manually calculate them:
x
1sqrt(mean((data$medv-pre)^2))The formula can be broken down into the following steps:
Calculate the difference between the predicted and actual values: data$medv - pre
Square the differences: (data$medv - pre)^2
Calculate the mean of the squared differences: mean((data$medv - pre)^2)
Take the square root of the mean squared difference to get the RMSE: sqrt(mean((data$medv - pre)^2))
xxxxxxxxxx11# save original results for later2prd <- prex
1rplot <- function(model, main = "Tree") {2 rpart.plot(3 model,4 type = 4,5 extra = 101,6 roundint = FALSE,7 nn = TRUE,8 main = main9 )10}11
12df <- data13reg2 = rpart(medv ~ .,14 data = df,15 minsplit = 2,16 minbucket = 1)17reg218rplot(reg2,"Larger Tree for Boston")
x1> reg22n= 50634node), split, n, deviance, yval5* denotes terminal node671) root 506 42716.3000 22.5328182) rm< 6.941 430 17317.3200 19.9337294) lstat>=14.4 175 3373.2510 14.95600108) crim>=6.99237 74 1085.9050 11.97838 *119) crim< 6.99237 101 1150.5370 17.13762 *125) lstat< 14.4 255 6632.2170 23.349801310) dis>=1.38485 250 3721.1630 22.905201420) rm< 6.543 195 1636.0670 21.62974 *1521) rm>=6.543 55 643.1691 27.42727 *1611) dis< 1.38485 5 390.7280 45.58000 *173) rm>=6.941 76 6059.4190 37.23816186) rm< 7.437 46 1899.6120 32.113041912) crim>=7.393425 3 27.9200 14.40000 *2013) crim< 7.393425 43 864.7674 33.34884 *217) rm>=7.437 30 1098.8500 45.096672214) nox>=0.6825 1 0.0000 21.90000 *2315) nox< 0.6825 29 542.2097 45.89655 *
x
1# Functions to reuse later in our code.2# azt::function wrap rpart.plot 3rplot <- function(model, main = "Tree") {4 rpart.plot(5 model,6 type = 4,7 extra = 101,8 roundint = FALSE,9 nn = TRUE,10 main = main11 )12}13# azt::function count leaf nodes of a given model14nleaf <- function(model) {15 sum(model$frame$var == "<leaf>")16}xxxxxxxxxx11## largest tree2reg3 = rpart(medv ~ ., df, minsplit=2, minbucket=1, cp=0.001)3reg34rplot(reg3,"largest tree")5nleaf(reg3)
xxxxxxxxxx11891> reg32n= 506 3
4node), split, n, deviance, yval5 * denotes terminal node6
7 1) root 506 42716.30000 22.532810 8 2) rm< 6.941 430 17317.32000 19.933720 9 4) lstat>=14.4 175 3373.25100 14.956000 10 8) crim>=6.99237 74 1085.90500 11.978380 11 16) nox>=0.6055 62 552.28840 11.077420 12 32) lstat>=19.645 44 271.79180 9.913636 13 64) nox>=0.675 34 160.86260 9.114706 14 128) crim>=13.2402 19 52.44737 8.052632 *15 129) crim< 13.2402 15 59.83600 10.460000 *16 65) nox< 0.675 10 15.44100 12.630000 *17 33) lstat< 19.645 18 75.23111 13.922220 *18 17) nox< 0.6055 12 223.26670 16.633330 19 34) rm< 6.8425 11 94.44727 15.645450 20 68) crim< 12.66115 6 29.22833 13.583330 *21 69) crim>=12.66115 5 9.08800 18.120000 *22 35) rm>=6.8425 1 0.00000 27.500000 *23 9) crim< 6.99237 101 1150.53700 17.137620 24 18) nox>=0.531 77 672.46310 16.238960 25 36) lstat>=18.885 24 188.67830 14.041670 26 72) indus>=26.695 2 0.60500 7.550000 *27 73) indus< 26.695 22 96.12773 14.631820 *28 37) lstat< 18.885 53 315.43890 17.233960 29 74) age>=85.2 41 187.72980 16.597560 30 148) tax>=291.5 37 113.79730 16.170270 *31 149) tax< 291.5 4 4.69000 20.550000 *32 75) age< 85.2 12 54.36917 19.408330 *33 19) nox< 0.531 24 216.37960 20.020830 34 38) dis>=5.57015 11 120.92180 18.327270 35 76) crim>=0.157295 8 23.72875 16.912500 *36 77) crim< 0.157295 3 38.48000 22.100000 *37 39) dis< 5.57015 13 37.21231 21.453850 *38 5) lstat< 14.4 255 6632.21700 23.349800 39 10) dis>=1.38485 250 3721.16300 22.905200 40 20) rm< 6.543 195 1636.06700 21.629740 41 40) lstat>=7.57 152 1204.37200 20.967760 42 80) tax>=208 147 860.82990 20.765990 43 160) rm< 6.0775 79 437.39950 19.997470 44 320) age>=69.1 24 187.83330 18.816670 45 640) rm>=6.0285 2 3.38000 13.200000 *46 641) rm< 6.0285 22 115.62360 19.327270 *47 321) age< 69.1 55 201.50110 20.512730 *48 161) rm>=6.0775 68 322.56470 21.658820 49 322) lstat>=9.98 44 131.91160 20.970450 *50 323) lstat< 9.98 24 131.57960 22.920830 51 646) crim< 0.04637 9 29.68889 20.911110 *52 647) crim>=0.04637 15 43.72933 24.126670 *53 81) tax< 208 5 161.60000 26.900000 54 162) crim>=0.068935 3 19.82000 22.900000 *55 163) crim< 0.068935 2 21.78000 32.900000 *56 41) lstat< 7.57 43 129.63070 23.969770 *57 21) rm>=6.543 55 643.16910 27.427270 58 42) tax>=269 38 356.88210 26.168420 59 84) nox>=0.526 9 38.10000 23.466670 *60 85) nox< 0.526 29 232.69860 27.006900 61 170) nox< 0.436 11 34.14545 24.563640 *62 171) nox>=0.436 18 92.76000 28.500000 63 342) lstat>=7.05 9 31.40222 26.855560 *64 343) lstat< 7.05 9 12.68222 30.144440 *65 43) tax< 269 17 91.46118 30.241180 66 86) ptratio>=17.85 7 12.07714 28.142860 *67 87) ptratio< 17.85 10 26.98900 31.710000 *68 11) dis< 1.38485 5 390.72800 45.580000 69 22) crim>=10.5917 1 0.00000 27.900000 *70 23) crim< 10.5917 4 0.00000 50.000000 *71 3) rm>=6.941 76 6059.41900 37.238160 72 6) rm< 7.437 46 1899.61200 32.113040 73 12) crim>=7.393425 3 27.92000 14.400000 *74 13) crim< 7.393425 43 864.76740 33.348840 75 26) dis>=1.88595 41 509.52240 32.748780 76 52) nox>=0.4885 14 217.65210 30.035710 77 104) rm< 7.121 7 49.62857 26.914290 *78 105) rm>=7.121 7 31.61714 33.157140 *79 53) nox< 0.4885 27 135.38670 34.155560 80 106) age< 11.95 2 0.18000 29.300000 *81 107) age>=11.95 25 84.28160 34.544000 *82 27) dis< 1.88595 2 37.84500 45.650000 *83 7) rm>=7.437 30 1098.85000 45.096670 84 14) nox>=0.6825 1 0.00000 21.900000 *85 15) nox< 0.6825 29 542.20970 45.896550 86 30) ptratio>=14.8 15 273.47730 43.653330 87 60) black>=385.48 10 164.76000 41.900000 88 120) crim>=0.06095 7 55.67714 39.942860 *89 121) crim< 0.06095 3 19.70667 46.466670 *90 61) black< 385.48 5 16.49200 47.160000 *91 31) ptratio< 14.8 14 112.38000 48.300000 92 62) rm< 7.706 4 37.84750 44.725000 *93 63) rm>=7.706 10 2.96100 49.730000 *94> rplot(reg3,"largest tree")95> nleaf(reg3)96[1] 4497> reg398n= 506 99
100node), split, n, deviance, yval101 * denotes terminal node102
103 1) root 506 42716.30000 22.532810 104 2) rm< 6.941 430 17317.32000 19.933720 105 4) lstat>=14.4 175 3373.25100 14.956000 106 8) crim>=6.99237 74 1085.90500 11.978380 107 16) nox>=0.6055 62 552.28840 11.077420 108 32) lstat>=19.645 44 271.79180 9.913636 109 64) nox>=0.675 34 160.86260 9.114706 110 128) crim>=13.2402 19 52.44737 8.052632 *111 129) crim< 13.2402 15 59.83600 10.460000 *112 65) nox< 0.675 10 15.44100 12.630000 *113 33) lstat< 19.645 18 75.23111 13.922220 *114 17) nox< 0.6055 12 223.26670 16.633330 115 34) rm< 6.8425 11 94.44727 15.645450 116 68) crim< 12.66115 6 29.22833 13.583330 *117 69) crim>=12.66115 5 9.08800 18.120000 *118 35) rm>=6.8425 1 0.00000 27.500000 *119 9) crim< 6.99237 101 1150.53700 17.137620 120 18) nox>=0.531 77 672.46310 16.238960 121 36) lstat>=18.885 24 188.67830 14.041670 122 72) indus>=26.695 2 0.60500 7.550000 *123 73) indus< 26.695 22 96.12773 14.631820 *124 37) lstat< 18.885 53 315.43890 17.233960 125 74) age>=85.2 41 187.72980 16.597560 126 148) tax>=291.5 37 113.79730 16.170270 *127 149) tax< 291.5 4 4.69000 20.550000 *128 75) age< 85.2 12 54.36917 19.408330 *129 19) nox< 0.531 24 216.37960 20.020830 130 38) dis>=5.57015 11 120.92180 18.327270 131 76) crim>=0.157295 8 23.72875 16.912500 *132 77) crim< 0.157295 3 38.48000 22.100000 *133 39) dis< 5.57015 13 37.21231 21.453850 *134 5) lstat< 14.4 255 6632.21700 23.349800 135 10) dis>=1.38485 250 3721.16300 22.905200 136 20) rm< 6.543 195 1636.06700 21.629740 137 40) lstat>=7.57 152 1204.37200 20.967760 138 80) tax>=208 147 860.82990 20.765990 139 160) rm< 6.0775 79 437.39950 19.997470 140 320) age>=69.1 24 187.83330 18.816670 141 640) rm>=6.0285 2 3.38000 13.200000 *142 641) rm< 6.0285 22 115.62360 19.327270 *143 321) age< 69.1 55 201.50110 20.512730 *144 161) rm>=6.0775 68 322.56470 21.658820 145 322) lstat>=9.98 44 131.91160 20.970450 *146 323) lstat< 9.98 24 131.57960 22.920830 147 646) crim< 0.04637 9 29.68889 20.911110 *148 647) crim>=0.04637 15 43.72933 24.126670 *149 81) tax< 208 5 161.60000 26.900000 150 162) crim>=0.068935 3 19.82000 22.900000 *151 163) crim< 0.068935 2 21.78000 32.900000 *152 41) lstat< 7.57 43 129.63070 23.969770 *153 21) rm>=6.543 55 643.16910 27.427270 154 42) tax>=269 38 356.88210 26.168420 155 84) nox>=0.526 9 38.10000 23.466670 *156 85) nox< 0.526 29 232.69860 27.006900 157 170) nox< 0.436 11 34.14545 24.563640 *158 171) nox>=0.436 18 92.76000 28.500000 159 342) lstat>=7.05 9 31.40222 26.855560 *160 343) lstat< 7.05 9 12.68222 30.144440 *161 43) tax< 269 17 91.46118 30.241180 162 86) ptratio>=17.85 7 12.07714 28.142860 *163 87) ptratio< 17.85 10 26.98900 31.710000 *164 11) dis< 1.38485 5 390.72800 45.580000 165 22) crim>=10.5917 1 0.00000 27.900000 *166 23) crim< 10.5917 4 0.00000 50.000000 *167 3) rm>=6.941 76 6059.41900 37.238160 168 6) rm< 7.437 46 1899.61200 32.113040 169 12) crim>=7.393425 3 27.92000 14.400000 *170 13) crim< 7.393425 43 864.76740 33.348840 171 26) dis>=1.88595 41 509.52240 32.748780 172 52) nox>=0.4885 14 217.65210 30.035710 173 104) rm< 7.121 7 49.62857 26.914290 *174 105) rm>=7.121 7 31.61714 33.157140 *175 53) nox< 0.4885 27 135.38670 34.155560 176 106) age< 11.95 2 0.18000 29.300000 *177 107) age>=11.95 25 84.28160 34.544000 *178 27) dis< 1.88595 2 37.84500 45.650000 *179 7) rm>=7.437 30 1098.85000 45.096670 180 14) nox>=0.6825 1 0.00000 21.900000 *181 15) nox< 0.6825 29 542.20970 45.896550 182 30) ptratio>=14.8 15 273.47730 43.653330 183 60) black>=385.48 10 164.76000 41.900000 184 120) crim>=0.06095 7 55.67714 39.942860 *185 121) crim< 0.06095 3 19.70667 46.466670 *186 61) black< 385.48 5 16.49200 47.160000 *187 31) ptratio< 14.8 14 112.38000 48.300000 188 62) rm< 7.706 4 37.84750 44.725000 *189 63) rm>=7.706 10 2.96100 49.730000 *xxxxxxxxxx11> nleaf(reg3)2[1] 44xxxxxxxxxx11printcp(reg3)2plotcp(reg3)We can use the printcp() function to print the complexity parameter table for the reg3 decision tree model. This table shows the complexity parameter (CP), the number of splits (nsplit), the number of terminal nodes (ncompete), the relative improvement in the model's goodness of fit (rel error reduction), and the estimated error rate of the tree (xerror).

Add another custom function:
xxxxxxxxxx11# azt::function -> Return the best CP value for given model2bestcp <- function(model){3 model$cptable[which.min(model$cptable[,"xerror"]),"CP"]4}xxxxxxxxxx11### get the best cp value that minimizes the xerror2best = bestcp(reg3)3### get the optimal tree with best cp4reg3.pruned = prune.rpart(reg3,cp=best)5rplot(reg3.pruned,"Pruned Tree")6nleaf(reg3.pruned) We use the prune.rpart() function to prune the reg3 decision tree model based on the best complexity parameter value, which you have obtained using the bestcp() function.
The prune.rpart() function takes two arguments: the decision tree model to be pruned (reg3 in this case), and the best complexity parameter value (best in this case). The function returns a new decision tree model that has been pruned based on the specified complexity parameter value.

treex
1# load the libraries2library(rpart)3library(rpart.plot)4library(caTools)5library(MASS)6library(caret)7library(Metrics)8
9# Functions to reuse later in our code.10## azt::function -> Wrap rpart.plot 11rplot <- function(model, main = "Tree") {12 rpart.plot(13 model,14 type = 4,15 extra = 101,16 roundint = FALSE,17 nn = TRUE,18 main = main19 )20}21## azt::function -> Count leaf nodes of a given model22nleaf <- function(model) {23 sum(model$frame$var == "<leaf>")24}25## azt::function -> Return the best CP value for given model26bestcp <- function(model){27 model$cptable[which.min(model$cptable[,"xerror"]),"CP"]28}
xxxxxxxxxx11df <- irisThe "iris" dataset is a built-in dataset in R that contains information on the length and width of sepals and petals for three different species of iris flowers: setosa, versicolor, and virginica.

| Sepal.Length | Sepal.Width | Petal.Length | Petal.Width | Species | |
|---|---|---|---|---|---|
| 1 | 5.1 | 3.5 | 1.4 | 0.2 | setosa |
| 2 | 4.9 | 3 | 1.4 | 0.2 | setosa |
| 3 | 4.7 | 3.2 | 1.3 | 0.2 | setosa |
| 4 | 4.6 | 3.1 | 1.5 | 0.2 | setosa |
| 5 | 5 | 3.6 | 1.4 | 0.2 | setosa |
| 6 | 5.4 | 3.9 | 1.7 | 0.4 | setosa |
xxxxxxxxxx11str(df)2'data.frame': 150 obs. of 5 variables:3 $ Sepal.Length: num 5.1 4.9 4.7 4.6 5 5.4 4.6 5 4.4 4.9 ...4 $ Sepal.Width : num 3.5 3 3.2 3.1 3.6 3.9 3.4 3.4 2.9 3.1 ...5 $ Petal.Length: num 1.4 1.4 1.3 1.5 1.4 1.7 1.4 1.5 1.4 1.5 ...6 $ Petal.Width : num 0.2 0.2 0.2 0.2 0.2 0.4 0.3 0.2 0.2 0.1 ...7 $ Species : Factor w/ 3 levels "setosa","versicolor",..: 1 1 1 1 1 1 1 1 1 1 ...xxxxxxxxxx11any(is.na(df))xxxxxxxxxx11> any(is.na(df))2[1] FALSE
xxxxxxxxxx11# Create the model2cls = tree(Species ~., data = df)3plot(cls)4text(cls,cex=0.8)5summary(cls)The model is being trained to predict the "Species" variable based on the other variables in the "df" data frame.
After creating the model, you are using the plot() function to visualize the resulting decision tree, and the text() function to label each node in the tree with the corresponding variable and split point. The cex=0.8 argument adjusts the size of the text labels to make them easier to read.
Finally, we are using the summary() function to display some basic information about the tree model, including the number of terminal nodes (i.e., the number of leaf nodes), the residual mean deviance, and the percentage of correct classifications for each species.
xxxxxxxxxx11> nleaf(cls)2[1] 6xxxxxxxxxx11> summary(cls)2
3Classification tree:4tree(formula = Species ~ ., data = df)5Variables actually used in tree construction:6[1] "Petal.Length" "Petal.Width" "Sepal.Length"7Number of terminal nodes: 6 8Residual mean deviance: 0.1253 = 18.05 / 144 9Misclassification error rate: 0.02667 = 4 / 150 The output of summary(cls) provides some useful information about the decision tree model that you have created.
The first line shows that the model was created using the tree() function, with the formula Species ~ . indicating that the "Species" variable is being predicted based on all other variables in the "df" data frame.
The second line indicates the variables that were actually used in the construction of the tree. In this case, the model selected "Petal.Length", "Petal.Width", and "Sepal.Length" as the most important variables for predicting the species of iris.
The third line shows the number of terminal nodes in the tree. In this case, the tree has six terminal nodes, which means there are six possible classifications for the iris flowers based on the selected variables.
The fourth line shows the residual mean deviance of the model, which is a measure of how well the model fits the data. The lower the value, the better the fit. In this case, the residual mean deviance is 0.1253, which indicates a relatively good fit.
The fifth line shows the misclassification error rate of the model, which is the proportion of observations that are misclassified. In this case, the misclassification error rate is 0.02667 or 4/150, which means that the model misclassifies about 2.67% of the observations in the dataset.
The decision tree algorithm works by recursively partitioning the data based on the most informative variables. At each step of the algorithm, the variable that provides the most information gain (i.e., the most useful for predicting the target variable) is chosen as the splitting variable.
In the case of the iris dataset, the summary(cls) output indicates that only three variables were used in the tree construction: Petal.Length, Petal.Width, and Sepal.Length. This suggests that these variables are the most informative for predicting the species of iris in this dataset.
It is possible that the other variables (Sepal.Width) did not provide as much information gain as the three selected variables, or they may have been highly correlated with the selected variables. Correlated variables can sometimes lead to overfitting and can make the model less generalizable to new data.
xxxxxxxxxx11prd = predict(cls,newdata = df,type = "class")2prd3table(df$Species,prd)We are using the predict() function to generate predictions for the "df" dataset using the decision tree model "cls". However, we have included the newdata = df argument to specify that we want to generate predictions for the same dataset that was used to train the model.
x
1> table(df$Species,prd)2 prd3 setosa versicolor virginica4 setosa 50 0 05 versicolor 0 47 36 virginica 0 1 49The table() function with the true labels in the first argument and predicted labels in the second argument creates a confusion matrix that shows the number of correct and incorrect predictions for each class in the dataset.
In this output, the rows of the confusion matrix represent the true labels of the iris flowers ("setosa", "versicolor", and "virginica"), and the columns represent the predicted labels based on the decision tree model.
The diagonal elements of the matrix represent the number of correct predictions for each class, while the off-diagonal elements represent the number of incorrect predictions.
For example, in this output, the decision tree model correctly predicted all 50 instances of the "setosa" class, 47 out of 50 instances of the "versicolor" class, and 49 out of 50 instances of the "virginica" class.
Overall, the model achieved high accuracy in predicting the species of iris flowers, with only 4 misclassifications out of 150 total observations (as shown in the output of summary(cls) earlier).
In the context of decision trees, deviance is a measure of the goodness of fit of a model to the data. Deviance is defined as the difference between the observed response and the predicted response, squared and summed over all observations.
In the case of a classification problem with K classes, the deviance is defined as:
where N is the number of observations, K is the number of classes,
In the context of decision trees, the deviance is used to measure how well the tree model fits the training data. The goal of building a decision tree is to minimize the deviance, which means finding the tree that best fits the data.
The residual mean deviance, which is the output of the summary() function for a decision tree model, is the deviance of the model divided by the degrees of freedom (i.e., the number of observations minus the number of terminal nodes in the tree). The residual mean deviance can be used to compare the goodness of fit of different tree models. A lower residual mean deviance indicates a better fit to the data.
xxxxxxxxxx11Dev=-2*(3*50*log(1/3))In the case of the iris dataset, which has three classes ("setosa", "versicolor", and "virginica"), we can calculate the deviance using the formula:
where N is the total number of observations, K is the number of classes (in this case, K = 3),