library(ggplot2)
library(rpart)
library(rpart.plot)
library(tree) 
load("BostonHousing1.Rdata")  # data: Housing1

There are two R packages for tree models, tree and rpart. We will mainly use rpart. The package tree is called for its command partition.tree, which we use to generate the first figure.

Fit a regression tree using just two predictors

trfit= tree(Y ~ lon + lat, data=Housing1)
small.tree = prune.tree(trfit, best=7)
small.tree
## node), split, n, deviance, yval
##       * denotes terminal node
## 
##   1) root 506 84.1800 3.035  
##     2) lon < -71.0667 202 16.8800 3.297 *
##     3) lon > -71.0667 304 44.1400 2.860  
##       6) lon < -71.0155 185 32.8300 2.752  
##        12) lat < 42.241 147 27.4400 2.671  
##          24) lat < 42.1698 18  0.8337 3.104 *
##          25) lat > 42.1698 129 22.7700 2.611  
##            50) lon < -71.0332 102 15.4200 2.707  
##             100) lat < 42.2011 51  3.5830 2.525 *
##             101) lat > 42.2011 51  8.4700 2.888 *
##            51) lon > -71.0332 27  2.8910 2.250 *
##        13) lat > 42.241 38  0.6923 3.066 *
##       7) lon > -71.0155 119  5.8160 3.028 *
par(mfrow=c(1,2))
plot(small.tree)
text(small.tree, cex=.75)

price.quantiles = cut(Housing1$Y, quantile(Housing1$Y, 0:20/20),
                    include.lowest=TRUE)
plot(Housing1$lat, Housing1$lon, col=grey(20:1/21)[price.quantiles],
     pch=20, ylab="Longitude", xlab="Latitude")
partition.tree(small.tree, ordvars=c("lat","lon"), add=TRUE)

detach("package:tree")

Fit a regression tree using all predictors

set.seed(1234)
tr1 = rpart(Y ~ ., data = Housing1)
par(mfrow=c(1,2))
plot(tr1)
rpart.plot(tr1)

Prunning

printcp(tr1)
## 
## Regression tree:
## rpart(formula = Y ~ ., data = Housing1)
## 
## Variables actually used in tree construction:
## [1] age   crim  lstat nox   rm   
## 
## Root node error: 84.178/506 = 0.16636
## 
## n= 506 
## 
##          CP nsplit rel error  xerror     xstd
## 1  0.464749      0   1.00000 1.00287 0.074777
## 2  0.157620      1   0.53525 0.55363 0.037792
## 3  0.077637      2   0.37763 0.39804 0.034795
## 4  0.034542      3   0.29999 0.34091 0.029867
## 5  0.022416      4   0.26545 0.31719 0.030337
## 6  0.020665      5   0.24304 0.31267 0.030365
## 7  0.016709      6   0.22237 0.28880 0.028304
## 8  0.012136      7   0.20566 0.27497 0.027298
## 9  0.011932      8   0.19353 0.27223 0.026260
## 10 0.010000      9   0.18159 0.26576 0.025580
prune(tr1, cp=0.3)
## n= 506 
## 
## node), split, n, deviance, yval
##       * denotes terminal node
## 
## 1) root 506 84.17756 3.034558  
##   2) lstat>=3.794731 178 19.20389 2.657108 *
##   3) lstat< 3.794731 328 25.85221 3.239394 *
prune(tr1, cp=0.2)
## n= 506 
## 
## node), split, n, deviance, yval
##       * denotes terminal node
## 
## 1) root 506 84.17756 3.034558  
##   2) lstat>=3.794731 178 19.20389 2.657108 *
##   3) lstat< 3.794731 328 25.85221 3.239394 *
prune(tr1, cp=0.156)
## n= 506 
## 
## node), split, n, deviance, yval
##       * denotes terminal node
## 
## 1) root 506 84.177560 3.034558  
##   2) lstat>=3.794731 178 19.203890 2.657108 *
##   3) lstat< 3.794731 328 25.852210 3.239394  
##     6) rm< 1.898519 224  6.537080 3.102350 *
##     7) rm>=1.898519 104  6.047061 3.534565 *
# long and detailed output on each node of the tree
# summary(tr1) 

plotcp(tr1)

Not sure whether xerror has reached the bottom. Let’s start with a bigger tree.

tr2 = rpart(Y ~ ., data = Housing1, 
            control = list(cp = 0, xval = 10))
plot(tr2)

printcp(tr2)
## 
## Regression tree:
## rpart(formula = Y ~ ., data = Housing1, control = list(cp = 0, 
##     xval = 10))
## 
## Variables actually used in tree construction:
##  [1] age     b       crim    dis     indus   lat     lon     lstat  
##  [9] nox     ptratio rad     rm      tax    
## 
## Root node error: 84.178/506 = 0.16636
## 
## n= 506 
## 
##            CP nsplit rel error  xerror     xstd
## 1  0.46474932      0   1.00000 1.00249 0.074594
## 2  0.15762004      1   0.53525 0.55364 0.037483
## 3  0.07763659      2   0.37763 0.39954 0.034712
## 4  0.03454223      3   0.29999 0.35290 0.031400
## 5  0.02241607      4   0.26545 0.32013 0.030876
## 6  0.02066531      5   0.24304 0.30977 0.031068
## 7  0.01670855      6   0.22237 0.28816 0.028852
## 8  0.01213622      7   0.20566 0.28067 0.028073
## 9  0.01193153      8   0.19353 0.28210 0.027765
## 10 0.00801572      9   0.18159 0.27776 0.027241
## 11 0.00737715     10   0.17358 0.26378 0.025496
## 12 0.00697612     11   0.16620 0.25816 0.025126
## 13 0.00617548     12   0.15923 0.25663 0.024799
## 14 0.00610398     13   0.15305 0.25481 0.024778
## 15 0.00425990     14   0.14695 0.25119 0.023478
## 16 0.00422742     15   0.14269 0.24758 0.023430
## 17 0.00374584     16   0.13846 0.24197 0.022823
## 18 0.00313710     17   0.13471 0.23871 0.022757
## 19 0.00290104     18   0.13158 0.23620 0.022875
## 20 0.00280686     19   0.12867 0.23291 0.022368
## 21 0.00251807     20   0.12587 0.23174 0.022381
## 22 0.00232167     21   0.12335 0.22808 0.021564
## 23 0.00192404     22   0.12103 0.22601 0.021786
## 24 0.00175484     23   0.11910 0.22467 0.021773
## 25 0.00154239     24   0.11735 0.22423 0.021838
## 26 0.00104456     26   0.11426 0.22102 0.021804
## 27 0.00095568     27   0.11322 0.22022 0.021551
## 28 0.00090998     28   0.11226 0.21921 0.021542
## 29 0.00087953     29   0.11135 0.21927 0.021538
## 30 0.00083581     30   0.11047 0.21927 0.021539
## 31 0.00081484     31   0.10964 0.21964 0.021565
## 32 0.00077591     32   0.10882 0.22016 0.021590
## 33 0.00068275     34   0.10727 0.22064 0.021598
## 34 0.00046201     35   0.10659 0.22141 0.021638
## 35 0.00041203     38   0.10520 0.22129 0.021622
## 36 0.00041002     39   0.10479 0.22129 0.021622
## 37 0.00034920     40   0.10438 0.22137 0.021628
## 38 0.00000000     41   0.10403 0.22129 0.021634
plotcp(tr2)

# get index of CP with lowest xerror
opt = which.min(tr2$cptable[, "xerror"])  # 28
# get the optimal CP value
tr2$cptable[opt, 4]
## [1] 0.2192144
# upper bound for equivalent optimal xerror
tr2$cptable[opt, 4] + tr2$cptable[opt, 5]
## [1] 0.2407566
# row IDs for CPs whose xerror is equivalent to min(xerror)
tmp.id = which(tr2$cptable[, 4] <= tr2$cptable[opt, 4] +
                          tr2$cptable[opt, 5])
# CP.1se = any value between row (tmp.id) and (tmp.id-1)
CP.1se = 0.0032

# Prune tree with CP.1se
tr3 = prune(tr2, cp = CP.1se)

Connection between α and CP

RSS(T)+α|T|,RSS(T)RSS(root)+CP⋅|T|.

Understand the relationship between the 1st column and the 3rd column of the CP table.

cbind(tr2$cptable[, 1], c(-diff(tr2$cptable[, 3]), 0))
##            [,1]         [,2]
## 1  0.4647493178 0.4647493178
## 2  0.1576200424 0.1576200424
## 3  0.0776365910 0.0776365910
## 4  0.0345422266 0.0345422266
## 5  0.0224160718 0.0224160718
## 6  0.0206653062 0.0206653062
## 7  0.0167085516 0.0167085516
## 8  0.0121362195 0.0121362195
## 9  0.0119315308 0.0119315308
## 10 0.0080157229 0.0080157229
## 11 0.0073771456 0.0073771456
## 12 0.0069761226 0.0069761226
## 13 0.0061754834 0.0061754834
## 14 0.0061039757 0.0061039757
## 15 0.0042599035 0.0042599035
## 16 0.0042274176 0.0042274176
## 17 0.0037458397 0.0037458397
## 18 0.0031371039 0.0031371039
## 19 0.0029010376 0.0029010376
## 20 0.0028068555 0.0028068555
## 21 0.0025180707 0.0025180707
## 22 0.0023216719 0.0023216719
## 23 0.0019240413 0.0019240413
## 24 0.0017548430 0.0017548430
## 25 0.0015423890 0.0030847780
## 26 0.0010445649 0.0010445649
## 27 0.0009556794 0.0009556794
## 28 0.0009099819 0.0009099819
## 29 0.0008795299 0.0008795299
## 30 0.0008358143 0.0008358143
## 31 0.0008148426 0.0008148426
## 32 0.0007759120 0.0015518239
## 33 0.0006827532 0.0006827532
## 34 0.0004620058 0.0013860175
## 35 0.0004120334 0.0004120334
## 36 0.0004100166 0.0004100166
## 37 0.0003491957 0.0003491957
## 38 0.0000000000 0.0000000000

Prediction

?predict.rpart  # check use the fitted tree to do prediction

Handle categorical predictors

set.seed(1234)
n = nrow(Housing1); 
m = 30 
X = as.factor(sample(1:m, n, replace = TRUE))
tmp = data.frame(Y = Housing1$Y, X = X)
myfit = rpart(Y ~ X, data = tmp)
myfit
## n= 506 
## 
## node), split, n, deviance, yval
##       * denotes terminal node
## 
## 1) root 506 84.177560 3.034558  
##   2) X=2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,23,24,27,28,29,30 450 71.340350 3.013409  
##     4) X=6,21 31  6.175555 2.831359 *
##     5) X=2,3,4,5,7,8,9,10,11,12,13,14,15,16,17,18,19,20,23,24,27,28,29,30 419 64.061380 3.026878 *
##   3) X=1,22,25,26 56 11.018500 3.204507 *
group.mean = as.vector(tapply(tmp$Y, tmp$X, mean))
order(group.mean)
##  [1]  6 21 28  7  2 29 19  4 23  9  5 17 27 18 12 13 14 30 11 24 16 20 15
## [24] 10  8  3 22 26  1 25
group.mean[order(group.mean)] #same as sort(group.mean)
##  [1] 2.817958 2.847632 2.956281 2.964265 2.976372 2.994647 2.994683
##  [8] 3.002507 3.008517 3.010799 3.011183 3.011745 3.012958 3.019534
## [15] 3.022470 3.030586 3.031969 3.040009 3.044503 3.051756 3.062507
## [22] 3.064296 3.065691 3.077757 3.087916 3.098846 3.142409 3.197372
## [29] 3.198564 3.275671
tmp$Z= rnorm(n)  # add a numerical feature
myfit = rpart(Y ~ X + Z, data = tmp)
rpart.plot(myfit)

myfit
## n= 506 
## 
## node), split, n, deviance, yval
##       * denotes terminal node
## 
##  1) root 506 84.177560 3.034558  
##    2) X=2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,23,24,27,28,29,30 450 71.340350 3.013409  
##      4) X=6,21 31  6.175555 2.831359 *
##      5) X=2,3,4,5,7,8,9,10,11,12,13,14,15,16,17,18,19,20,23,24,27,28,29,30 419 64.061380 3.026878  
##       10) Z< 0.9497702 365 55.316660 3.012491  
##         20) Z>=0.8147289 14  2.357930 2.753210 *
##         21) Z< 0.8147289 351 51.980020 3.022832 *
##       11) Z>=0.9497702 54  8.158479 3.124125  
##         22) X=3,4,7,9,10,11,12,13,14,17,18,23,24,28,30 41  4.088782 2.997439 *
##         23) X=2,5,8,15,16,19,20,27 13  1.336357 3.523675 *
##    3) X=1,22,25,26 56 11.018500 3.204507 *
id1 = which(! (tmp$X %in% c(1, 22, 25, 26, 6, 21)))
length(id1)  # 419
## [1] 419
id2 = id1[which(tmp$Z[id1] > 0.95)]
length(id2)  # 54
## [1] 54
group.mean = as.vector(tapply(tmp$Y[id2], tmp$X[id2], mean))
order(group.mean)
##  [1] 12  4 14 24 28 10  7  9 18 23 30 17  3 11 13  2 15 20 19 16  5  8 27
## [24]  1  6 21 22 25 26 29
group.mean
##  [1]       NA 3.314186 3.122365 2.860671 3.802208       NA 2.991151
##  [8] 3.912023 2.999644 2.960736 3.139833 2.624669 3.173206 2.923436
## [15] 3.339919 3.658274 3.116115 3.000720 3.465736 3.355096       NA
## [22]       NA 3.016733 2.939057       NA       NA 3.912023 2.942704
## [29]       NA 3.020425