FE581 – 0313 - R scripts WalkthroughClassification Example game.csv
rpart
Better fitMake PredictionConfusion MatrixAccuracyPrecisionRecallRegression Example (Boston)Larger treeLargest treeClassification Example (Iris) with tree
game.csv
We have our data game.csv
with contents as:
Outlook | Temperature | Humidity | Windy | Play |
---|---|---|---|---|
sunny | hot | high | false | no |
sunny | hot | high | true | no |
overcast | hot | high | false | yes |
rainy | mild | high | false | yes |
rainy | cool | normal | false | yes |
rainy | cool | normal | true | no |
overcast | cool | normal | true | yes |
sunny | mild | high | false | no |
sunny | cool | normal | false | yes |
rainy | mild | normal | false | yes |
sunny | mild | normal | true | yes |
overcast | mild | high | true | yes |
overcast | hot | normal | false | yes |
rainy | mild | high | true | no |
This data set represents a hypothetical game played outdoors and contains five variables: Outlook, Temperature, Humidity, Windy, and Play. The goal is to predict whether the game will be played or not based on the other variables.
To begin exploring this data set, let's load it into R. We can do this by first saving the contents into a file called "game.csv" and then using the read.csv()
function to read the file:
11game <- read.csv("game.csv")
Now that we have loaded the data into R, let's take a look at its structure. We can use the str() function to do this:
11str(game)
This will output the following:
71> str(game)
2'data.frame': 14 obs. of 5 variables:
3$ Outlook : chr "sunny" "sunny" "overcast" "rainy" ...
4$ Temperature: chr "hot" "hot" "hot" "mild" ...
5$ Humidity : chr "high" "high" "high" "high" ...
6$ Windy : chr "false" "true" "false" "false" ...
7$ Play : chr "no" "no" "yes" "yes" ...
In R, factors are used to represent categorical variables, while characters are used to represent text data. The difference between them is that factors have a predefined set of possible values (i.e., levels), while characters can have any possible value.
Using factors instead of characters for categorical variables can have several advantages in data mining and statistical modeling:
Factors take up less memory than characters, which can be important when working with large data sets.
Factors can be ordered, which can be useful for variables that have a natural order (e.g., temperature categories: low, medium, high).
Factors can be used to specify contrasts and reference levels in statistical models, which can help to avoid issues with multicollinearity.
Functions that operate on factors (e.g., table()
, summary()
) automatically generate output that is formatted for categorical data, which can make it easier to understand and interpret.
In general, it's a good practice to use factors instead of characters for categorical variables in R.
Since the variables are categorical, they should be read as factors instead of characters. Here's the correct way to load the data into R:
11game <- read.csv("game.csv", stringsAsFactors = TRUE)
Now, if we run str(game)
, we should see that the variables are factors:
71> str(game)
2'data.frame': 14 obs. of 5 variables:
3 $ Outlook : Factor w/ 3 levels "overcast","rainy",..: 3 3 1 2 2 2 1 3 3 2 ...
4 $ Temperature: Factor w/ 3 levels "cool","hot","mild": 2 2 2 3 1 1 1 3 1 3 ...
5 $ Humidity : Factor w/ 2 levels "high","normal": 1 1 1 1 2 2 2 1 2 2 ...
6 $ Windy : Factor w/ 2 levels "false","true": 1 2 1 1 1 2 2 1 1 1 ...
7 $ Play : Factor w/ 2 levels "no","yes": 1 1 2 2 2 1 2 1 2 2 ...
We can see that there are 14 observations (rows) and 5 variables (columns), all of which are factors. Factors in R are used to represent categorical variables.
To get a summary of the data set, we can use the summary()
function:
51> summary(game)
2 Outlook Temperature Humidity Windy Play
3 overcast:4 cool:4 high :7 false:8 no :5
4 rainy :5 hot :4 normal:7 true :6 yes:9
5 sunny :5 mild:6
We can see that there are three levels of the Outlook variable (overcast, rainy, and sunny), three levels of the Temperature variable (cool, hot, and mild), two levels of the Humidity variable (high and normal), two levels of the Windy variable (false and true), and two levels of the Play variable (no and yes). We can also see that 9 out of the 14 games were played.
Now that we have a basic understanding of the data set, let's start exploring it in more detail using various data mining techniques.
rpart
rpart()
is a function in R that can be used to create decision trees. In this case, gametree1
is a decision tree model that predicts the Play
variable (whether the game was played or not) based on the other variables in the game
data set.
The ~.
syntax in the formula specifies that all other variables in the data set should be used to predict Play
. The method="class"
argument tells rpart()
to treat the Play
variable as a categorical variable (i.e., a factor) instead of a continuous variable.
To fit the decision tree model to the data, we use the rpart()
function with the formula and data as arguments:
11gametree1 <- rpart(Play ~ ., method = "class", data = game)
71> gametree1
2n= 14
3
4node), split, n, loss, yval, (yprob)
5 * denotes terminal node
6
71) root 14 5 yes (0.3571429 0.6428571) *
Here's a breakdown of what each line means:
n=14
: The total number of observations in the data set.
node
: The ID number of the current node in the tree.
split
: The variable and value that were used to split the data at this node.
n
: The number of observations at this node.
loss
: The number of misclassified observations at this node.
yval
: The predicted class for this node based on the majority class at the node.
(yprob)
: The probabilities of each class at this node.
*
: Indicates that this is a terminal node, meaning that no further splits are made after this point.
In this case, the output shows that the root node (node 1) includes all 14 observations in the data set. The split variable and value are not shown because this is the initial node and no split has been made yet. The node predicts that the majority class is "yes" (i.e., the game will be played) with a probability of 0.64.
This is just the beginning of the decision tree model. As we continue to split the data based on the other variables, the tree will become more complex and accurate. We can use the plot()
function to visualize the entire tree.
11rpart.plot(gametree1,type=4,extra=101,roundint = FALSE,nn=TRUE, main = "Classification Tree for Game Prediction (gametree1)")
11printcp(gametree1)
The printcp()
function provides a table of the complexity parameter (CP) values and corresponding tree size, residual deviance, and cross-validation error for each value of CP.
In this case, the output shows that gametree1
has not yet split on any variables, which is why there are no variables listed in the "Variables actually used in tree construction" line. The root node error (i.e., the misclassification rate at the root node) is 0.35714, which is calculated as the number of misclassified observations divided by the total number of observations in the data set.
The table shows that there is only one CP value (CP = 0.01
) listed, which means that there is no complexity reduction available for this model. The tree size is 1, the residual deviance is 0 (since there are no splits yet), and the cross-validation error is 0, which indicates perfect prediction.
In general, we want to choose a CP value that balances model complexity with prediction accuracy. This can be done by selecting a CP value that results in a relatively small tree size while still maintaining a reasonable level of accuracy in predicting the outcome variable.
We can use this tree to make predictions for new data (or old data) by passing it to the predict()
function:
21new_data <- data.frame(Outlook = "sunny", Temperature = "hot", Humidity = "high", Windy = "false")
2predict(gametree1, new_data, type = "class")
31> predict(gametree1, new_data, type = "class")
2[1] yes
3Levels: no yes
The output of predict()
is "yes"
, which indicates that the decision tree model predicts that the game will be played based on the values of the predictor variables in new_data
.
21# Better fit
2gametree2=rpart(Play~.,method="class",data=game,minbucket=1,minsplit=2)
gametree2
is another decision tree model that predicts the Play
variable based on the other variables in the game
data set, just like gametree1
.
The main difference between gametree2
and gametree1
is that gametree2
includes two additional arguments: minbucket
and minsplit
.
minbucket
specifies the minimum number of observations that must exist in a terminal node in order for it to be created. If the number of observations in a node is less than minbucket
, the node will not be split, and it will become a terminal node.
minsplit
specifies the minimum number of observations that must exist in a node in order for it to be considered for splitting. If the number of observations in a node is less than minsplit
, the node will not be split, and it will become a terminal node.
In this case, minbucket = 1
and minsplit = 2
, which means that any node with only one observation will become a terminal node, and any node with less than two observations will not be considered for splitting. These values are often used as starting points for tuning the model, and can be adjusted to find the optimal values for the data set.
191> gametree2
2n= 14
3
4node), split, n, loss, yval, (yprob)
5 * denotes terminal node
6
7 1) root 14 5 yes (0.3571429 0.6428571)
8 2) Outlook=rainy,sunny 10 5 no (0.5000000 0.5000000)
9 4) Humidity=high 5 1 no (0.8000000 0.2000000)
10 8) Outlook=sunny 3 0 no (1.0000000 0.0000000) *
11 9) Outlook=rainy 2 1 no (0.5000000 0.5000000)
12 18) Windy=true 1 0 no (1.0000000 0.0000000) *
13 19) Windy=false 1 0 yes (0.0000000 1.0000000) *
14 5) Humidity=normal 5 1 yes (0.2000000 0.8000000)
15 10) Windy=true 2 1 no (0.5000000 0.5000000)
16 20) Outlook=rainy 1 0 no (1.0000000 0.0000000) *
17 21) Outlook=sunny 1 0 yes (0.0000000 1.0000000) *
18 11) Windy=false 3 0 yes (0.0000000 1.0000000) *
19 3) Outlook=overcast 4 0 yes (0.0000000 1.0000000) *
The output you provided is the summary of the decision tree model gametree2
.
For example, at the root node, the yprob
values are 0.36 for "no" and 0.64 for "yes", indicating that the majority class at this node is "yes".
x
1rpart.plot(gametree2,type=4,extra=101,roundint = FALSE,nn=TRUE, main = "Classification Tree for Game Prediction (gametree2)")
11printcp(gametree2)
171> printcp(gametree2)
2
3Classification tree:
4rpart(formula = Play ~ ., data = game, method = "class", minbucket = 1,
5 minsplit = 2)
6
7Variables actually used in tree construction:
8[1] Humidity Outlook Windy
9
10Root node error: 5/14 = 0.35714
11
12n= 14
13
14 CP nsplit rel error xerror xstd
151 0.30 0 1.0 1.0 0.35857
162 0.10 2 0.4 1.4 0.37417
173 0.01 6 0.0 1.4 0.37417
In this case, the output shows that gametree2
was fit with minbucket = 1
and minsplit = 2
.
The table shows three CP values: 0.30, 0.10, and 0.01. Each row represents a candidate tree, where nsplit
is the number of splits in the tree, rel error
is the relative error rate for the tree, xerror
is the cross-validation error rate for the tree, and xstd
is the standard deviation of the cross-validation error.
The table shows three CP values: 0.30, 0.10, and 0.01. Each row represents a candidate tree, where nsplit
is the number of splits in the tree, rel error
is the relative error rate for the tree, xerror
is the cross-validation error rate for the tree, and xstd
is the standard deviation of the cross-validation error.
We can use this table to select the optimal tree for our data set. The idea is to choose a CP value that provides the smallest possible tree size while maintaining a reasonable level of accuracy in predicting the outcome variable. In this case, we can see that the tree size is reduced from 8 (the number of terminal nodes in gametree2
) to 6 with CP = 0.10
, and to 0 with CP = 0.01
.
To select the optimal tree, we can use the plotcp()
function to visualize the relationship between CP values and tree size, and select the CP value that provides the optimal tradeoff between complexity and accuracy:
11plotcp(gametree2)
This will generate a plot of the CP values and tree size, which can be used to select the optimal CP value.
21bestcp <- gametree2$cptable[which.min(gametree2$cptable[,"xerror"]),"CP"]
2bestcp
21> bestcp
2[1] 0.3
you can also use the predict()
function to obtain the predicted class labels for the observations in the original data set (game
), like this:
11game_pred1 <- predict(gametree2, type = "class")
41> game_pred1
2 1 2 3 4 5 6 7 8 9 10 11 12 13 14
3 no no yes yes yes no yes no yes yes yes yes yes no
4Levels: no yes
The output of predict(gametree2, type = "class")
is a vector of predicted class labels for each observation in the original data set game
.
The predicted class labels are "no"
or "yes"
, which are the levels of the Play
variable. The Levels: no yes
part of the output indicates that the predicted class label can take on one of these two levels.
For example, the first observation in game
has predictor variables Outlook = "sunny"
, Temperature = "hot"
, Humidity = "high"
, and Windy = "false"
. According to the decision tree model gametree2
, this observation should be classified as "no"
, which means that the game will not be played.
The output shows that game_pred1
is a vector of predicted class labels for each observation in game
. For example, the first predicted label is "no"
, which corresponds to the first observation in game
.
21new_data <- data.frame(Outlook = "sunny", Temperature = "hot", Humidity = "high", Windy = "false")
2predict(gametree2, new_data, type = "class")
The table()
function can be used to create a contingency table that shows the actual and predicted class labels for a given set of data.
To create a contingency table for the original data set game
, using the predicted class labels in game_pred2
obtained from the decision tree model gametree2
, you can use the following code:
11table(game$Play, game_pred2)
This will create a 2x2 contingency table with the rows representing the actual class labels ("no"
and "yes"
) and the columns representing the predicted class labels ("no"
and "yes"
).
The output will look something like this:
41 game_pred2
2 no yes
3 no 5 0
4 yes 0 9
This table shows that the decision tree model gametree2
correctly predicted all of the observations in the no
class, and 9 out of 9 observations in the yes
class, for a total accuracy rate of 100%.
The accuracy()
function can be used to calculate the accuracy of a classification model, by comparing the predicted class labels to the actual class labels.
To calculate the accuracy of the decision tree model gametree2
, using the predicted class labels in game_pred2
and the actual class labels in game$Play
, you can use the following code:
11accuracy(actual = game$Play, predicted = game_pred2)
This will calculate the accuracy of the model as the proportion of correctly classified observations out of the total number of observations.
The output will be a numeric value representing the accuracy rate, expressed as a proportion between 0 and 1. For example, an accuracy rate of 1.0 would indicate that all observations were correctly classified, while an accuracy rate of 0.5 would indicate that only half of the observations were correctly classified.
In this case, the output will be 1, indicating that all of the observations in game
were correctly classified by the gametree2
model.
Accuracy = 5 + 9 / 14 = 1
The precision()
function can be used to calculate the precision of a classification model, by comparing the predicted positive class labels to the actual positive class labels.
41precision(actual = game$Play,predicted = game_pred2)
2Warning message:
3In mean.default(actual[predicted == 1]) :
4 argument is not numeric or logical: returning NA
In our case, we are getting a warning message because the function is unable to calculate the precision for the predicted class label "no"
. This is because there are no observations that were predicted to be "no"
. The mean()
function used internally in the precision()
function expects a numeric or logical argument, but in this case it is receiving a character vector that cannot be converted to a numeric or logical value.
To avoid this warning message and calculate the precision correctly, we can convert the data to numerical,
Here's how we can do this:
41game_pred2=as.numeric(game_pred2)
2game$Play=as.numeric(game$Play)
3precision(actual = game$Play,predicted = game_pred2)
4# 1
Precision = 9 / 9+0 = 1
The recall()
function in the caret
package can be used to calculate the recall (also known as sensitivity or true positive rate) of a classification model. Recall is defined as the number of true positives divided by the sum of true positives and false negatives.
21recall(actual = game$Play,predicted = game_pred2)
2# 1
Recall = 9 / 9 + 0 = 1
Boston dataset is a famous dataset in data mining used for regression analysis. It contains information collected by the U.S Census Service concerning housing in the area of Boston, Massachusetts. The dataset has 506 instances and 14 variables, including the median value of owner-occupied homes in $1000's (the target variable) and other variables such as crime rate, average number of rooms per dwelling, and the proportion of non-retail business acres per town.
11str(Boston)
Here is a brief description of each variable:
crim: per capita crime rate by town
zn: proportion of residential land zoned for lots over 25,000 sq.ft.
indus: proportion of non-retail business acres per town
chas: Charles River dummy variable (1 if tract bounds river; 0 otherwise)
nox: nitric oxides concentration (parts per 10 million)
rm: average number of rooms per dwelling
age: proportion of owner-occupied units built prior to 1940
dis: weighted distances to five Boston employment centers
rad: index of accessibility to radial highways
tax: full-value property-tax rate per $10,000
ptratio: pupil-teacher ratio by town
black: 1000(Bk - 0.63)^2 where Bk is the proportion of blacks by town
lstat: lower status of the population (percent)
medv: median value of owner-occupied homes in $1000s (this is the target variable)
FIELD1 | crim | zn | indus | chas | nox | rm | age | dis | rad | tax | ptratio | black | lstat | medv |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
1 | 0.00632 | 18 | 2.31 | 0 | 0.538 | 6.575 | 65.2 | 4.09 | 1 | 296 | 15.3 | 396.9 | 4.98 | 24 |
2 | 0.02731 | 0 | 7.07 | 0 | 0.469 | 6.421 | 78.9 | 4.9671 | 2 | 242 | 17.8 | 396.9 | 9.14 | 21.6 |
3 | 0.02729 | 0 | 7.07 | 0 | 0.469 | 7.185 | 61.1 | 4.9671 | 2 | 242 | 17.8 | 392.83 | 4.03 | 34.7 |
4 | 0.03237 | 0 | 2.18 | 0 | 0.458 | 6.998 | 45.8 | 6.0622 | 3 | 222 | 18.7 | 394.63 | 2.94 | 33.4 |
5 | 0.06905 | 0 | 2.18 | 0 | 0.458 | 7.147 | 54.2 | 6.0622 | 3 | 222 | 18.7 | 396.9 | 5.33 | 36.2 |
6 | 0.02985 | 0 | 2.18 | 0 | 0.458 | 6.43 | 58.7 | 6.0622 | 3 | 222 | 18.7 | 394.12 | 5.21 | 28.7 |
to be more convenient we rename Boston data to a variable named data
11data <- Boston
Check for missing values in the dataset
x
1# Check for missing values in the dataset
2any(is.na(data))
21> any(is.na(data))
2[1] FALSE
xxxxxxxxxx
11model <- rpart(medv ~ ., data = data)
"model" is a user-defined name for the regression tree model object that will be created.
"rpart" is a function in R used to build regression trees.
"medv ~ ." specifies the formula for the model, where "medv" is the target variable and "." indicates that all other variables in the dataset should be used as predictors.
"data = data" specifies the dataset that will be used to build the model.
x1> model
2n= 506
3
4node), split, n, deviance, yval
5* denotes terminal node
6
71) root 506 42716.3000 22.53281
82) rm< 6.941 430 17317.3200 19.93372
94) lstat>=14.4 175 3373.2510 14.95600
108) crim>=6.99237 74 1085.9050 11.97838 *
119) crim< 6.99237 101 1150.5370 17.13762 *
125) lstat< 14.4 255 6632.2170 23.34980
1310) dis>=1.5511 248 3658.3930 22.93629
1420) rm< 6.543 193 1589.8140 21.65648 *
1521) rm>=6.543 55 643.1691 27.42727 *
1611) dis< 1.5511 7 1429.0200 38.00000 *
173) rm>=6.941 76 6059.4190 37.23816
186) rm< 7.437 46 1899.6120 32.11304
1912) lstat>=9.65 7 432.9971 23.05714 *
2013) lstat< 9.65 39 789.5123 33.73846 *
217) rm>=7.437 30 1098.8500 45.09667 *
x
1rpart.plot(model,type=4,extra=101,roundint = FALSE,nn=TRUE, main = "Regression Tree for Boston Dataset (model)")
The decision to stop splitting further in a decision tree is based on a stopping criterion. This criterion can be either a pre-defined threshold for a certain parameter, such as the minimum number of samples required in a node to allow a split, or a threshold for the improvement in model performance that is achieved by the split.
In the context of a decision tree, deviance is a measure of the impurity or error in a node that is used to determine the optimal split. Specifically, the deviance in a node is the sum of the squared differences between the observed values and the predicted values in that node. The deviance is calculated differently depending on whether the decision tree is a regression tree or a classification tree.
In regression trees, the deviance in a node is typically measured by the residual sum of squares (RSS) or mean squared error (MSE), which represents the sum of squared differences between the observed values and the mean or median predicted value in the node. The RSS is calculated as follows:
where
In classification trees, the deviance in a node is typically measured by an impurity measure such as the Gini index or cross-entropy, which represents the probability of misclassification of a randomly chosen data point in the node. The Gini index is calculated as follows:
where p_i is the proportion of data points in the node that belong to the i-th class, and the sum is taken over all classes.
In the output of the rpart
function in R, the deviance in each node is reported as the "deviance" or "deviance reduction" in that node, which represents the reduction in deviance achieved by splitting that node. The algorithm selects the split that maximizes the reduction in deviance, which leads to the best split in terms of reducing the impurity or error in the resulting nodes.
We can also manually calcuate the deviance for root
:
xxxxxxxxxx
21sum((Boston$medv-mean(Boston$medv))^2)
2# [1] 42716.3
The output of the R command sum((Boston$medv-mean(Boston$medv))^2)
is the TSS for the "medv" variable in the "Boston" dataset.
The TSS is calculated as the sum of the squared differences between each observation and the mean of the response variable. In this case, the TSS for the "medv" variable is 42716.3, which represents the total variation in the median value of owner-occupied homes in the "Boston" dataset.
xxxxxxxxxx
11printcp(model)
xxxxxxxxxx
1211printcp(model)
2
3Regression tree:
4rpart(formula = medv ~ ., data = data)
5
6Variables actually used in tree construction:
7[1] crim dis lstat rm
8
9Root node error: 42716/506 = 84.42
10
11n= 506
12
13CP nsplit rel error xerror xstd
141 0.452744 0 1.00000 1.00295 0.083032
152 0.171172 1 0.54726 0.64392 0.060341
163 0.071658 2 0.37608 0.43453 0.048783
174 0.036164 3 0.30443 0.35117 0.043788
185 0.033369 4 0.26826 0.32938 0.043434
196 0.026613 5 0.23489 0.33470 0.043560
207 0.015851 6 0.20828 0.31457 0.044202
218 0.010000 7 0.19243 0.28785 0.041753
The complexity parameter table shows the cross-validated error (xerror
), the number of splits (nsplit
), the relative error reduction (rel error
), the complexity parameter (CP
), and the standard error of the cross-validation error (xstd
) for each candidate tree. The optimal tree size is selected based on the complexity parameter that minimizes the cross-validation error rate.
x
1reg <- model
2pre = predict(reg)
The code reg <- model
assigns the previously fitted regression tree model (model
) to a new object called reg
. This is a common practice in R to store the results of a previous computation for later use or comparison.
The code pre = predict(reg)
creates a new object called pre
by predicting the response variable (medv
) for each observation in the data
dataframe using the previously fitted regression tree model (reg
). The predict()
function in R takes two arguments: the first argument is the fitted model object, and the second argument is the new data for which we want to make predictions. In this case, the new data is the same as the training data, i.e., the data
dataframe.
After executing this code, the object pre
will contain a vector of predicted values for the response variable based on the regression tree model.
xxxxxxxxxx
11rmse(actual = data$medv,predicted = pre)
2mae(actual = data$medv,predicted = pre)
3mape(actual = data$medv,predicted = pre)
xxxxxxxxxx
61> rmse(actual = data$medv, predicted = pre)
2[1] 4.030468
3> mae(actual = data$medv, predicted = pre)
4[1] 2.909702
5> mape(actual = data$medv, predicted = pre)
6[1] 0.1545698
The code rmse(actual = data$medv,predicted = pre)
calculates the root mean squared error (RMSE) between the actual values of the response variable (medv
) in the data
dataframe and the predicted values from the regression tree model stored in the pre
object.
The code mae(actual = data$medv,predicted = pre)
calculates the mean absolute error (MAE) between the actual values of the response variable in the data
dataframe and the predicted values from the regression tree model stored in the pre
object.
The code mape(actual = data$medv,predicted = pre)
calculates the mean absolute percentage error (MAPE) between the actual values of the response variable in the data
dataframe and the predicted values from the regression tree model stored in the pre
object.
We can also manually calculate them:
x
1sqrt(mean((data$medv-pre)^2))
The formula can be broken down into the following steps:
Calculate the difference between the predicted and actual values: data$medv - pre
Square the differences: (data$medv - pre)^2
Calculate the mean of the squared differences: mean((data$medv - pre)^2)
Take the square root of the mean squared difference to get the RMSE: sqrt(mean((data$medv - pre)^2))
xxxxxxxxxx
11# save original results for later
2prd <- pre
x
1rplot <- function(model, main = "Tree") {
2 rpart.plot(
3 model,
4 type = 4,
5 extra = 101,
6 roundint = FALSE,
7 nn = TRUE,
8 main = main
9 )
10}
11
12df <- data
13reg2 = rpart(medv ~ .,
14 data = df,
15 minsplit = 2,
16 minbucket = 1)
17reg2
18rplot(reg2,"Larger Tree for Boston")
x1> reg2
2n= 506
3
4node), split, n, deviance, yval
5* denotes terminal node
6
71) root 506 42716.3000 22.53281
82) rm< 6.941 430 17317.3200 19.93372
94) lstat>=14.4 175 3373.2510 14.95600
108) crim>=6.99237 74 1085.9050 11.97838 *
119) crim< 6.99237 101 1150.5370 17.13762 *
125) lstat< 14.4 255 6632.2170 23.34980
1310) dis>=1.38485 250 3721.1630 22.90520
1420) rm< 6.543 195 1636.0670 21.62974 *
1521) rm>=6.543 55 643.1691 27.42727 *
1611) dis< 1.38485 5 390.7280 45.58000 *
173) rm>=6.941 76 6059.4190 37.23816
186) rm< 7.437 46 1899.6120 32.11304
1912) crim>=7.393425 3 27.9200 14.40000 *
2013) crim< 7.393425 43 864.7674 33.34884 *
217) rm>=7.437 30 1098.8500 45.09667
2214) nox>=0.6825 1 0.0000 21.90000 *
2315) nox< 0.6825 29 542.2097 45.89655 *
x
1# Functions to reuse later in our code.
2# azt::function wrap rpart.plot
3rplot <- function(model, main = "Tree") {
4 rpart.plot(
5 model,
6 type = 4,
7 extra = 101,
8 roundint = FALSE,
9 nn = TRUE,
10 main = main
11 )
12}
13# azt::function count leaf nodes of a given model
14nleaf <- function(model) {
15 sum(model$frame$var == "<leaf>")
16}
xxxxxxxxxx
11## largest tree
2reg3 = rpart(medv ~ ., df, minsplit=2, minbucket=1, cp=0.001)
3reg3
4rplot(reg3,"largest tree")
5nleaf(reg3)
xxxxxxxxxx
11891> reg3
2n= 506
3
4node), split, n, deviance, yval
5 * denotes terminal node
6
7 1) root 506 42716.30000 22.532810
8 2) rm< 6.941 430 17317.32000 19.933720
9 4) lstat>=14.4 175 3373.25100 14.956000
10 8) crim>=6.99237 74 1085.90500 11.978380
11 16) nox>=0.6055 62 552.28840 11.077420
12 32) lstat>=19.645 44 271.79180 9.913636
13 64) nox>=0.675 34 160.86260 9.114706
14 128) crim>=13.2402 19 52.44737 8.052632 *
15 129) crim< 13.2402 15 59.83600 10.460000 *
16 65) nox< 0.675 10 15.44100 12.630000 *
17 33) lstat< 19.645 18 75.23111 13.922220 *
18 17) nox< 0.6055 12 223.26670 16.633330
19 34) rm< 6.8425 11 94.44727 15.645450
20 68) crim< 12.66115 6 29.22833 13.583330 *
21 69) crim>=12.66115 5 9.08800 18.120000 *
22 35) rm>=6.8425 1 0.00000 27.500000 *
23 9) crim< 6.99237 101 1150.53700 17.137620
24 18) nox>=0.531 77 672.46310 16.238960
25 36) lstat>=18.885 24 188.67830 14.041670
26 72) indus>=26.695 2 0.60500 7.550000 *
27 73) indus< 26.695 22 96.12773 14.631820 *
28 37) lstat< 18.885 53 315.43890 17.233960
29 74) age>=85.2 41 187.72980 16.597560
30 148) tax>=291.5 37 113.79730 16.170270 *
31 149) tax< 291.5 4 4.69000 20.550000 *
32 75) age< 85.2 12 54.36917 19.408330 *
33 19) nox< 0.531 24 216.37960 20.020830
34 38) dis>=5.57015 11 120.92180 18.327270
35 76) crim>=0.157295 8 23.72875 16.912500 *
36 77) crim< 0.157295 3 38.48000 22.100000 *
37 39) dis< 5.57015 13 37.21231 21.453850 *
38 5) lstat< 14.4 255 6632.21700 23.349800
39 10) dis>=1.38485 250 3721.16300 22.905200
40 20) rm< 6.543 195 1636.06700 21.629740
41 40) lstat>=7.57 152 1204.37200 20.967760
42 80) tax>=208 147 860.82990 20.765990
43 160) rm< 6.0775 79 437.39950 19.997470
44 320) age>=69.1 24 187.83330 18.816670
45 640) rm>=6.0285 2 3.38000 13.200000 *
46 641) rm< 6.0285 22 115.62360 19.327270 *
47 321) age< 69.1 55 201.50110 20.512730 *
48 161) rm>=6.0775 68 322.56470 21.658820
49 322) lstat>=9.98 44 131.91160 20.970450 *
50 323) lstat< 9.98 24 131.57960 22.920830
51 646) crim< 0.04637 9 29.68889 20.911110 *
52 647) crim>=0.04637 15 43.72933 24.126670 *
53 81) tax< 208 5 161.60000 26.900000
54 162) crim>=0.068935 3 19.82000 22.900000 *
55 163) crim< 0.068935 2 21.78000 32.900000 *
56 41) lstat< 7.57 43 129.63070 23.969770 *
57 21) rm>=6.543 55 643.16910 27.427270
58 42) tax>=269 38 356.88210 26.168420
59 84) nox>=0.526 9 38.10000 23.466670 *
60 85) nox< 0.526 29 232.69860 27.006900
61 170) nox< 0.436 11 34.14545 24.563640 *
62 171) nox>=0.436 18 92.76000 28.500000
63 342) lstat>=7.05 9 31.40222 26.855560 *
64 343) lstat< 7.05 9 12.68222 30.144440 *
65 43) tax< 269 17 91.46118 30.241180
66 86) ptratio>=17.85 7 12.07714 28.142860 *
67 87) ptratio< 17.85 10 26.98900 31.710000 *
68 11) dis< 1.38485 5 390.72800 45.580000
69 22) crim>=10.5917 1 0.00000 27.900000 *
70 23) crim< 10.5917 4 0.00000 50.000000 *
71 3) rm>=6.941 76 6059.41900 37.238160
72 6) rm< 7.437 46 1899.61200 32.113040
73 12) crim>=7.393425 3 27.92000 14.400000 *
74 13) crim< 7.393425 43 864.76740 33.348840
75 26) dis>=1.88595 41 509.52240 32.748780
76 52) nox>=0.4885 14 217.65210 30.035710
77 104) rm< 7.121 7 49.62857 26.914290 *
78 105) rm>=7.121 7 31.61714 33.157140 *
79 53) nox< 0.4885 27 135.38670 34.155560
80 106) age< 11.95 2 0.18000 29.300000 *
81 107) age>=11.95 25 84.28160 34.544000 *
82 27) dis< 1.88595 2 37.84500 45.650000 *
83 7) rm>=7.437 30 1098.85000 45.096670
84 14) nox>=0.6825 1 0.00000 21.900000 *
85 15) nox< 0.6825 29 542.20970 45.896550
86 30) ptratio>=14.8 15 273.47730 43.653330
87 60) black>=385.48 10 164.76000 41.900000
88 120) crim>=0.06095 7 55.67714 39.942860 *
89 121) crim< 0.06095 3 19.70667 46.466670 *
90 61) black< 385.48 5 16.49200 47.160000 *
91 31) ptratio< 14.8 14 112.38000 48.300000
92 62) rm< 7.706 4 37.84750 44.725000 *
93 63) rm>=7.706 10 2.96100 49.730000 *
94> rplot(reg3,"largest tree")
95> nleaf(reg3)
96[1] 44
97> reg3
98n= 506
99
100node), split, n, deviance, yval
101 * denotes terminal node
102
103 1) root 506 42716.30000 22.532810
104 2) rm< 6.941 430 17317.32000 19.933720
105 4) lstat>=14.4 175 3373.25100 14.956000
106 8) crim>=6.99237 74 1085.90500 11.978380
107 16) nox>=0.6055 62 552.28840 11.077420
108 32) lstat>=19.645 44 271.79180 9.913636
109 64) nox>=0.675 34 160.86260 9.114706
110 128) crim>=13.2402 19 52.44737 8.052632 *
111 129) crim< 13.2402 15 59.83600 10.460000 *
112 65) nox< 0.675 10 15.44100 12.630000 *
113 33) lstat< 19.645 18 75.23111 13.922220 *
114 17) nox< 0.6055 12 223.26670 16.633330
115 34) rm< 6.8425 11 94.44727 15.645450
116 68) crim< 12.66115 6 29.22833 13.583330 *
117 69) crim>=12.66115 5 9.08800 18.120000 *
118 35) rm>=6.8425 1 0.00000 27.500000 *
119 9) crim< 6.99237 101 1150.53700 17.137620
120 18) nox>=0.531 77 672.46310 16.238960
121 36) lstat>=18.885 24 188.67830 14.041670
122 72) indus>=26.695 2 0.60500 7.550000 *
123 73) indus< 26.695 22 96.12773 14.631820 *
124 37) lstat< 18.885 53 315.43890 17.233960
125 74) age>=85.2 41 187.72980 16.597560
126 148) tax>=291.5 37 113.79730 16.170270 *
127 149) tax< 291.5 4 4.69000 20.550000 *
128 75) age< 85.2 12 54.36917 19.408330 *
129 19) nox< 0.531 24 216.37960 20.020830
130 38) dis>=5.57015 11 120.92180 18.327270
131 76) crim>=0.157295 8 23.72875 16.912500 *
132 77) crim< 0.157295 3 38.48000 22.100000 *
133 39) dis< 5.57015 13 37.21231 21.453850 *
134 5) lstat< 14.4 255 6632.21700 23.349800
135 10) dis>=1.38485 250 3721.16300 22.905200
136 20) rm< 6.543 195 1636.06700 21.629740
137 40) lstat>=7.57 152 1204.37200 20.967760
138 80) tax>=208 147 860.82990 20.765990
139 160) rm< 6.0775 79 437.39950 19.997470
140 320) age>=69.1 24 187.83330 18.816670
141 640) rm>=6.0285 2 3.38000 13.200000 *
142 641) rm< 6.0285 22 115.62360 19.327270 *
143 321) age< 69.1 55 201.50110 20.512730 *
144 161) rm>=6.0775 68 322.56470 21.658820
145 322) lstat>=9.98 44 131.91160 20.970450 *
146 323) lstat< 9.98 24 131.57960 22.920830
147 646) crim< 0.04637 9 29.68889 20.911110 *
148 647) crim>=0.04637 15 43.72933 24.126670 *
149 81) tax< 208 5 161.60000 26.900000
150 162) crim>=0.068935 3 19.82000 22.900000 *
151 163) crim< 0.068935 2 21.78000 32.900000 *
152 41) lstat< 7.57 43 129.63070 23.969770 *
153 21) rm>=6.543 55 643.16910 27.427270
154 42) tax>=269 38 356.88210 26.168420
155 84) nox>=0.526 9 38.10000 23.466670 *
156 85) nox< 0.526 29 232.69860 27.006900
157 170) nox< 0.436 11 34.14545 24.563640 *
158 171) nox>=0.436 18 92.76000 28.500000
159 342) lstat>=7.05 9 31.40222 26.855560 *
160 343) lstat< 7.05 9 12.68222 30.144440 *
161 43) tax< 269 17 91.46118 30.241180
162 86) ptratio>=17.85 7 12.07714 28.142860 *
163 87) ptratio< 17.85 10 26.98900 31.710000 *
164 11) dis< 1.38485 5 390.72800 45.580000
165 22) crim>=10.5917 1 0.00000 27.900000 *
166 23) crim< 10.5917 4 0.00000 50.000000 *
167 3) rm>=6.941 76 6059.41900 37.238160
168 6) rm< 7.437 46 1899.61200 32.113040
169 12) crim>=7.393425 3 27.92000 14.400000 *
170 13) crim< 7.393425 43 864.76740 33.348840
171 26) dis>=1.88595 41 509.52240 32.748780
172 52) nox>=0.4885 14 217.65210 30.035710
173 104) rm< 7.121 7 49.62857 26.914290 *
174 105) rm>=7.121 7 31.61714 33.157140 *
175 53) nox< 0.4885 27 135.38670 34.155560
176 106) age< 11.95 2 0.18000 29.300000 *
177 107) age>=11.95 25 84.28160 34.544000 *
178 27) dis< 1.88595 2 37.84500 45.650000 *
179 7) rm>=7.437 30 1098.85000 45.096670
180 14) nox>=0.6825 1 0.00000 21.900000 *
181 15) nox< 0.6825 29 542.20970 45.896550
182 30) ptratio>=14.8 15 273.47730 43.653330
183 60) black>=385.48 10 164.76000 41.900000
184 120) crim>=0.06095 7 55.67714 39.942860 *
185 121) crim< 0.06095 3 19.70667 46.466670 *
186 61) black< 385.48 5 16.49200 47.160000 *
187 31) ptratio< 14.8 14 112.38000 48.300000
188 62) rm< 7.706 4 37.84750 44.725000 *
189 63) rm>=7.706 10 2.96100 49.730000 *
xxxxxxxxxx
11> nleaf(reg3)
2[1] 44
xxxxxxxxxx
11printcp(reg3)
2plotcp(reg3)
We can use the printcp()
function to print the complexity parameter table for the reg3
decision tree model. This table shows the complexity parameter (CP
), the number of splits (nsplit
), the number of terminal nodes (ncompete
), the relative improvement in the model's goodness of fit (rel error reduction
), and the estimated error rate of the tree (xerror
).
Add another custom function:
xxxxxxxxxx
11# azt::function -> Return the best CP value for given model
2bestcp <- function(model){
3 model$cptable[which.min(model$cptable[,"xerror"]),"CP"]
4}
xxxxxxxxxx
11### get the best cp value that minimizes the xerror
2best = bestcp(reg3)
3### get the optimal tree with best cp
4reg3.pruned = prune.rpart(reg3,cp=best)
5rplot(reg3.pruned,"Pruned Tree")
6nleaf(reg3.pruned)
We use the prune.rpart()
function to prune the reg3
decision tree model based on the best complexity parameter value, which you have obtained using the bestcp()
function.
The prune.rpart()
function takes two arguments: the decision tree model to be pruned (reg3
in this case), and the best complexity parameter value (best
in this case). The function returns a new decision tree model that has been pruned based on the specified complexity parameter value.
tree
x
1# load the libraries
2library(rpart)
3library(rpart.plot)
4library(caTools)
5library(MASS)
6library(caret)
7library(Metrics)
8
9# Functions to reuse later in our code.
10## azt::function -> Wrap rpart.plot
11rplot <- function(model, main = "Tree") {
12 rpart.plot(
13 model,
14 type = 4,
15 extra = 101,
16 roundint = FALSE,
17 nn = TRUE,
18 main = main
19 )
20}
21## azt::function -> Count leaf nodes of a given model
22nleaf <- function(model) {
23 sum(model$frame$var == "<leaf>")
24}
25## azt::function -> Return the best CP value for given model
26bestcp <- function(model){
27 model$cptable[which.min(model$cptable[,"xerror"]),"CP"]
28}
xxxxxxxxxx
11df <- iris
The "iris" dataset is a built-in dataset in R that contains information on the length and width of sepals and petals for three different species of iris flowers: setosa, versicolor, and virginica.
Sepal.Length | Sepal.Width | Petal.Length | Petal.Width | Species | |
---|---|---|---|---|---|
1 | 5.1 | 3.5 | 1.4 | 0.2 | setosa |
2 | 4.9 | 3 | 1.4 | 0.2 | setosa |
3 | 4.7 | 3.2 | 1.3 | 0.2 | setosa |
4 | 4.6 | 3.1 | 1.5 | 0.2 | setosa |
5 | 5 | 3.6 | 1.4 | 0.2 | setosa |
6 | 5.4 | 3.9 | 1.7 | 0.4 | setosa |
xxxxxxxxxx
11str(df)
2'data.frame': 150 obs. of 5 variables:
3 $ Sepal.Length: num 5.1 4.9 4.7 4.6 5 5.4 4.6 5 4.4 4.9 ...
4 $ Sepal.Width : num 3.5 3 3.2 3.1 3.6 3.9 3.4 3.4 2.9 3.1 ...
5 $ Petal.Length: num 1.4 1.4 1.3 1.5 1.4 1.7 1.4 1.5 1.4 1.5 ...
6 $ Petal.Width : num 0.2 0.2 0.2 0.2 0.2 0.4 0.3 0.2 0.2 0.1 ...
7 $ Species : Factor w/ 3 levels "setosa","versicolor",..: 1 1 1 1 1 1 1 1 1 1 ...
xxxxxxxxxx
11any(is.na(df))
xxxxxxxxxx
11> any(is.na(df))
2[1] FALSE
xxxxxxxxxx
11# Create the model
2cls = tree(Species ~., data = df)
3plot(cls)
4text(cls,cex=0.8)
5summary(cls)
The model is being trained to predict the "Species" variable based on the other variables in the "df" data frame.
After creating the model, you are using the plot()
function to visualize the resulting decision tree, and the text()
function to label each node in the tree with the corresponding variable and split point. The cex=0.8
argument adjusts the size of the text labels to make them easier to read.
Finally, we are using the summary()
function to display some basic information about the tree model, including the number of terminal nodes (i.e., the number of leaf nodes), the residual mean deviance, and the percentage of correct classifications for each species.
xxxxxxxxxx
11> nleaf(cls)
2[1] 6
xxxxxxxxxx
11> summary(cls)
2
3Classification tree:
4tree(formula = Species ~ ., data = df)
5Variables actually used in tree construction:
6[1] "Petal.Length" "Petal.Width" "Sepal.Length"
7Number of terminal nodes: 6
8Residual mean deviance: 0.1253 = 18.05 / 144
9Misclassification error rate: 0.02667 = 4 / 150
The output of summary(cls)
provides some useful information about the decision tree model that you have created.
The first line shows that the model was created using the tree()
function, with the formula Species ~ .
indicating that the "Species" variable is being predicted based on all other variables in the "df" data frame.
The second line indicates the variables that were actually used in the construction of the tree. In this case, the model selected "Petal.Length", "Petal.Width", and "Sepal.Length" as the most important variables for predicting the species of iris.
The third line shows the number of terminal nodes in the tree. In this case, the tree has six terminal nodes, which means there are six possible classifications for the iris flowers based on the selected variables.
The fourth line shows the residual mean deviance of the model, which is a measure of how well the model fits the data. The lower the value, the better the fit. In this case, the residual mean deviance is 0.1253, which indicates a relatively good fit.
The fifth line shows the misclassification error rate of the model, which is the proportion of observations that are misclassified. In this case, the misclassification error rate is 0.02667 or 4/150, which means that the model misclassifies about 2.67% of the observations in the dataset.
The decision tree algorithm works by recursively partitioning the data based on the most informative variables. At each step of the algorithm, the variable that provides the most information gain (i.e., the most useful for predicting the target variable) is chosen as the splitting variable.
In the case of the iris dataset, the summary(cls)
output indicates that only three variables were used in the tree construction: Petal.Length, Petal.Width, and Sepal.Length. This suggests that these variables are the most informative for predicting the species of iris in this dataset.
It is possible that the other variables (Sepal.Width) did not provide as much information gain as the three selected variables, or they may have been highly correlated with the selected variables. Correlated variables can sometimes lead to overfitting and can make the model less generalizable to new data.
xxxxxxxxxx
11prd = predict(cls,newdata = df,type = "class")
2prd
3table(df$Species,prd)
We are using the predict()
function to generate predictions for the "df" dataset using the decision tree model "cls". However, we have included the newdata = df
argument to specify that we want to generate predictions for the same dataset that was used to train the model.
x
1> table(df$Species,prd)
2 prd
3 setosa versicolor virginica
4 setosa 50 0 0
5 versicolor 0 47 3
6 virginica 0 1 49
The table()
function with the true labels in the first argument and predicted labels in the second argument creates a confusion matrix that shows the number of correct and incorrect predictions for each class in the dataset.
In this output, the rows of the confusion matrix represent the true labels of the iris flowers ("setosa", "versicolor", and "virginica"), and the columns represent the predicted labels based on the decision tree model.
The diagonal elements of the matrix represent the number of correct predictions for each class, while the off-diagonal elements represent the number of incorrect predictions.
For example, in this output, the decision tree model correctly predicted all 50 instances of the "setosa" class, 47 out of 50 instances of the "versicolor" class, and 49 out of 50 instances of the "virginica" class.
Overall, the model achieved high accuracy in predicting the species of iris flowers, with only 4 misclassifications out of 150 total observations (as shown in the output of summary(cls)
earlier).
In the context of decision trees, deviance is a measure of the goodness of fit of a model to the data. Deviance is defined as the difference between the observed response and the predicted response, squared and summed over all observations.
In the case of a classification problem with K classes, the deviance is defined as:
where N is the number of observations, K is the number of classes,
In the context of decision trees, the deviance is used to measure how well the tree model fits the training data. The goal of building a decision tree is to minimize the deviance, which means finding the tree that best fits the data.
The residual mean deviance, which is the output of the summary()
function for a decision tree model, is the deviance of the model divided by the degrees of freedom (i.e., the number of observations minus the number of terminal nodes in the tree). The residual mean deviance can be used to compare the goodness of fit of different tree models. A lower residual mean deviance indicates a better fit to the data.
xxxxxxxxxx
11Dev=-2*(3*50*log(1/3))
In the case of the iris dataset, which has three classes ("setosa", "versicolor", and "virginica"), we can calculate the deviance using the formula:
where N is the total number of observations, K is the number of classes (in this case, K = 3),