library(ggplot2)
library(rpart)
library(rpart.plot)
library(tree)
load("BostonHousing1.Rdata") # data: Housing1
There are two R packages for tree models, tree
and rpart
. We will mainly use rpart
. The package tree
is called for its command partition.tree
, which we use to generate the first figure.
trfit= tree(Y ~ lon + lat, data=Housing1)
small.tree = prune.tree(trfit, best=7)
small.tree
## node), split, n, deviance, yval
## * denotes terminal node
##
## 1) root 506 84.1800 3.035
## 2) lon < -71.0667 202 16.8800 3.297 *
## 3) lon > -71.0667 304 44.1400 2.860
## 6) lon < -71.0155 185 32.8300 2.752
## 12) lat < 42.241 147 27.4400 2.671
## 24) lat < 42.1698 18 0.8337 3.104 *
## 25) lat > 42.1698 129 22.7700 2.611
## 50) lon < -71.0332 102 15.4200 2.707
## 100) lat < 42.2011 51 3.5830 2.525 *
## 101) lat > 42.2011 51 8.4700 2.888 *
## 51) lon > -71.0332 27 2.8910 2.250 *
## 13) lat > 42.241 38 0.6923 3.066 *
## 7) lon > -71.0155 119 5.8160 3.028 *
par(mfrow=c(1,2))
plot(small.tree)
text(small.tree, cex=.75)
price.quantiles = cut(Housing1$Y, quantile(Housing1$Y, 0:20/20),
include.lowest=TRUE)
plot(Housing1$lat, Housing1$lon, col=grey(20:1/21)[price.quantiles],
pch=20, ylab="Longitude", xlab="Latitude")
partition.tree(small.tree, ordvars=c("lat","lon"), add=TRUE)
detach("package:tree")
set.seed(1234)
tr1 = rpart(Y ~ ., data = Housing1)
par(mfrow=c(1,2))
plot(tr1)
rpart.plot(tr1)
printcp(tr1)
##
## Regression tree:
## rpart(formula = Y ~ ., data = Housing1)
##
## Variables actually used in tree construction:
## [1] age crim lstat nox rm
##
## Root node error: 84.178/506 = 0.16636
##
## n= 506
##
## CP nsplit rel error xerror xstd
## 1 0.464749 0 1.00000 1.00287 0.074777
## 2 0.157620 1 0.53525 0.55363 0.037792
## 3 0.077637 2 0.37763 0.39804 0.034795
## 4 0.034542 3 0.29999 0.34091 0.029867
## 5 0.022416 4 0.26545 0.31719 0.030337
## 6 0.020665 5 0.24304 0.31267 0.030365
## 7 0.016709 6 0.22237 0.28880 0.028304
## 8 0.012136 7 0.20566 0.27497 0.027298
## 9 0.011932 8 0.19353 0.27223 0.026260
## 10 0.010000 9 0.18159 0.26576 0.025580
prune(tr1, cp=0.3)
## n= 506
##
## node), split, n, deviance, yval
## * denotes terminal node
##
## 1) root 506 84.17756 3.034558
## 2) lstat>=3.794731 178 19.20389 2.657108 *
## 3) lstat< 3.794731 328 25.85221 3.239394 *
prune(tr1, cp=0.2)
## n= 506
##
## node), split, n, deviance, yval
## * denotes terminal node
##
## 1) root 506 84.17756 3.034558
## 2) lstat>=3.794731 178 19.20389 2.657108 *
## 3) lstat< 3.794731 328 25.85221 3.239394 *
prune(tr1, cp=0.156)
## n= 506
##
## node), split, n, deviance, yval
## * denotes terminal node
##
## 1) root 506 84.177560 3.034558
## 2) lstat>=3.794731 178 19.203890 2.657108 *
## 3) lstat< 3.794731 328 25.852210 3.239394
## 6) rm< 1.898519 224 6.537080 3.102350 *
## 7) rm>=1.898519 104 6.047061 3.534565 *
# long and detailed output on each node of the tree
# summary(tr1)
plotcp(tr1)
tr2 = rpart(Y ~ ., data = Housing1,
control = list(cp = 0, xval = 10))
plot(tr2)
printcp(tr2)
##
## Regression tree:
## rpart(formula = Y ~ ., data = Housing1, control = list(cp = 0,
## xval = 10))
##
## Variables actually used in tree construction:
## [1] age b crim dis indus lat lon lstat
## [9] nox ptratio rad rm tax
##
## Root node error: 84.178/506 = 0.16636
##
## n= 506
##
## CP nsplit rel error xerror xstd
## 1 0.46474932 0 1.00000 1.00249 0.074594
## 2 0.15762004 1 0.53525 0.55364 0.037483
## 3 0.07763659 2 0.37763 0.39954 0.034712
## 4 0.03454223 3 0.29999 0.35290 0.031400
## 5 0.02241607 4 0.26545 0.32013 0.030876
## 6 0.02066531 5 0.24304 0.30977 0.031068
## 7 0.01670855 6 0.22237 0.28816 0.028852
## 8 0.01213622 7 0.20566 0.28067 0.028073
## 9 0.01193153 8 0.19353 0.28210 0.027765
## 10 0.00801572 9 0.18159 0.27776 0.027241
## 11 0.00737715 10 0.17358 0.26378 0.025496
## 12 0.00697612 11 0.16620 0.25816 0.025126
## 13 0.00617548 12 0.15923 0.25663 0.024799
## 14 0.00610398 13 0.15305 0.25481 0.024778
## 15 0.00425990 14 0.14695 0.25119 0.023478
## 16 0.00422742 15 0.14269 0.24758 0.023430
## 17 0.00374584 16 0.13846 0.24197 0.022823
## 18 0.00313710 17 0.13471 0.23871 0.022757
## 19 0.00290104 18 0.13158 0.23620 0.022875
## 20 0.00280686 19 0.12867 0.23291 0.022368
## 21 0.00251807 20 0.12587 0.23174 0.022381
## 22 0.00232167 21 0.12335 0.22808 0.021564
## 23 0.00192404 22 0.12103 0.22601 0.021786
## 24 0.00175484 23 0.11910 0.22467 0.021773
## 25 0.00154239 24 0.11735 0.22423 0.021838
## 26 0.00104456 26 0.11426 0.22102 0.021804
## 27 0.00095568 27 0.11322 0.22022 0.021551
## 28 0.00090998 28 0.11226 0.21921 0.021542
## 29 0.00087953 29 0.11135 0.21927 0.021538
## 30 0.00083581 30 0.11047 0.21927 0.021539
## 31 0.00081484 31 0.10964 0.21964 0.021565
## 32 0.00077591 32 0.10882 0.22016 0.021590
## 33 0.00068275 34 0.10727 0.22064 0.021598
## 34 0.00046201 35 0.10659 0.22141 0.021638
## 35 0.00041203 38 0.10520 0.22129 0.021622
## 36 0.00041002 39 0.10479 0.22129 0.021622
## 37 0.00034920 40 0.10438 0.22137 0.021628
## 38 0.00000000 41 0.10403 0.22129 0.021634
plotcp(tr2)
# get index of CP with lowest xerror
opt = which.min(tr2$cptable[, "xerror"]) # 28
# get the optimal CP value
tr2$cptable[opt, 4]
## [1] 0.2192144
# upper bound for equivalent optimal xerror
tr2$cptable[opt, 4] + tr2$cptable[opt, 5]
## [1] 0.2407566
# row IDs for CPs whose xerror is equivalent to min(xerror)
tmp.id = which(tr2$cptable[, 4] <= tr2$cptable[opt, 4] +
tr2$cptable[opt, 5])
# CP.1se = any value between row (tmp.id) and (tmp.id-1)
CP.1se = 0.0032
# Prune tree with CP.1se
tr3 = prune(tr2, cp = CP.1se)
Understand the relationship between the 1st column and the 3rd column of the CP table.
cbind(tr2$cptable[, 1], c(-diff(tr2$cptable[, 3]), 0))
## [,1] [,2]
## 1 0.4647493178 0.4647493178
## 2 0.1576200424 0.1576200424
## 3 0.0776365910 0.0776365910
## 4 0.0345422266 0.0345422266
## 5 0.0224160718 0.0224160718
## 6 0.0206653062 0.0206653062
## 7 0.0167085516 0.0167085516
## 8 0.0121362195 0.0121362195
## 9 0.0119315308 0.0119315308
## 10 0.0080157229 0.0080157229
## 11 0.0073771456 0.0073771456
## 12 0.0069761226 0.0069761226
## 13 0.0061754834 0.0061754834
## 14 0.0061039757 0.0061039757
## 15 0.0042599035 0.0042599035
## 16 0.0042274176 0.0042274176
## 17 0.0037458397 0.0037458397
## 18 0.0031371039 0.0031371039
## 19 0.0029010376 0.0029010376
## 20 0.0028068555 0.0028068555
## 21 0.0025180707 0.0025180707
## 22 0.0023216719 0.0023216719
## 23 0.0019240413 0.0019240413
## 24 0.0017548430 0.0017548430
## 25 0.0015423890 0.0030847780
## 26 0.0010445649 0.0010445649
## 27 0.0009556794 0.0009556794
## 28 0.0009099819 0.0009099819
## 29 0.0008795299 0.0008795299
## 30 0.0008358143 0.0008358143
## 31 0.0008148426 0.0008148426
## 32 0.0007759120 0.0015518239
## 33 0.0006827532 0.0006827532
## 34 0.0004620058 0.0013860175
## 35 0.0004120334 0.0004120334
## 36 0.0004100166 0.0004100166
## 37 0.0003491957 0.0003491957
## 38 0.0000000000 0.0000000000
?predict.rpart # check use the fitted tree to do prediction
set.seed(1234)
n = nrow(Housing1);
m = 30
X = as.factor(sample(1:m, n, replace = TRUE))
tmp = data.frame(Y = Housing1$Y, X = X)
myfit = rpart(Y ~ X, data = tmp)
myfit
## n= 506
##
## node), split, n, deviance, yval
## * denotes terminal node
##
## 1) root 506 84.177560 3.034558
## 2) X=2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,23,24,27,28,29,30 450 71.340350 3.013409
## 4) X=6,21 31 6.175555 2.831359 *
## 5) X=2,3,4,5,7,8,9,10,11,12,13,14,15,16,17,18,19,20,23,24,27,28,29,30 419 64.061380 3.026878 *
## 3) X=1,22,25,26 56 11.018500 3.204507 *
group.mean = as.vector(tapply(tmp$Y, tmp$X, mean))
order(group.mean)
## [1] 6 21 28 7 2 29 19 4 23 9 5 17 27 18 12 13 14 30 11 24 16 20 15
## [24] 10 8 3 22 26 1 25
group.mean[order(group.mean)] #same as sort(group.mean)
## [1] 2.817958 2.847632 2.956281 2.964265 2.976372 2.994647 2.994683
## [8] 3.002507 3.008517 3.010799 3.011183 3.011745 3.012958 3.019534
## [15] 3.022470 3.030586 3.031969 3.040009 3.044503 3.051756 3.062507
## [22] 3.064296 3.065691 3.077757 3.087916 3.098846 3.142409 3.197372
## [29] 3.198564 3.275671
tmp$Z= rnorm(n) # add a numerical feature
myfit = rpart(Y ~ X + Z, data = tmp)
rpart.plot(myfit)
myfit
## n= 506
##
## node), split, n, deviance, yval
## * denotes terminal node
##
## 1) root 506 84.177560 3.034558
## 2) X=2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,23,24,27,28,29,30 450 71.340350 3.013409
## 4) X=6,21 31 6.175555 2.831359 *
## 5) X=2,3,4,5,7,8,9,10,11,12,13,14,15,16,17,18,19,20,23,24,27,28,29,30 419 64.061380 3.026878
## 10) Z< 0.9497702 365 55.316660 3.012491
## 20) Z>=0.8147289 14 2.357930 2.753210 *
## 21) Z< 0.8147289 351 51.980020 3.022832 *
## 11) Z>=0.9497702 54 8.158479 3.124125
## 22) X=3,4,7,9,10,11,12,13,14,17,18,23,24,28,30 41 4.088782 2.997439 *
## 23) X=2,5,8,15,16,19,20,27 13 1.336357 3.523675 *
## 3) X=1,22,25,26 56 11.018500 3.204507 *
id1 = which(! (tmp$X %in% c(1, 22, 25, 26, 6, 21)))
length(id1) # 419
## [1] 419
id2 = id1[which(tmp$Z[id1] > 0.95)]
length(id2) # 54
## [1] 54
group.mean = as.vector(tapply(tmp$Y[id2], tmp$X[id2], mean))
order(group.mean)
## [1] 12 4 14 24 28 10 7 9 18 23 30 17 3 11 13 2 15 20 19 16 5 8 27
## [24] 1 6 21 22 25 26 29
group.mean
## [1] NA 3.314186 3.122365 2.860671 3.802208 NA 2.991151
## [8] 3.912023 2.999644 2.960736 3.139833 2.624669 3.173206 2.923436
## [15] 3.339919 3.658274 3.116115 3.000720 3.465736 3.355096 NA
## [22] NA 3.016733 2.939057 NA NA 3.912023 2.942704
## [29] NA 3.020425