Derives prediction rule ensembles (PREs). Largely follows the
procedure for deriving PREs as described in Friedman & Popescu (2008;
pre is an R package for deriving prediction rule ensembles for binary, multinomial, (multivariate) continuous, count and survival outcome variables. Input variables may be numeric, ordinal and categorical. An extensive description of the implementation and functionality is provided in Fokkema (2017). The package largely implements the algorithm for deriving prediction rule ensembles as described in Friedman & Popescu (2008), with several adjustments:
gpe()
.Note that pre is under development, and much work still needs to be done. Below, a short introductory example is provided. Fokkema (2017) provides an extensive description of the fitting procedures implemented in function pre()
and example analyses with more extensive explanations.
To get a first impression of how function pre()
works, we will fit a prediction rule ensemble to predict Ozone levels using the airquality
dataset. We fit a prediction rule ensemble using function pre()
:
library("pre")airq <- airquality[complete.cases(airquality), ]set.seed(42)airq.ens <- pre(Ozone ~ ., data = airq)
Note that the random seed was set first, to allow for later replication of the results, as the fitting procedure depends on random sampling of training observations.
We can print the resulting ensemble (alternatively, we could use the print
method):
airq.ens#> Final ensemble with cv error within 1se of minimum:#> lambda = 2.331694#> number of terms = 13#> mean cv error (se) = 302.4644 (79.28454)#>#> cv error type : Mean-Squared Error#>#> rule coefficient description#> (Intercept) 72.9680699 1#> rule191 -15.6401487 Wind > 5.7 & Temp <= 87#> rule173 -8.6645924 Wind > 5.7 & Temp <= 82#> rule204 8.1715564 Wind <= 10.3 & Solar.R > 148#> rule42 -7.6928586 Wind > 6.3 & Temp <= 84#> rule10 -6.8032890 Temp <= 84 & Temp <= 77#> rule192 -4.6926624 Wind > 5.7 & Temp <= 87 & Day <= 23#> rule93 3.1468762 Temp > 77 & Wind <= 8.6#> rule51 -2.6981570 Wind > 5.7 & Temp <= 84#> rule25 -2.4481192 Wind > 6.3 & Temp <= 82#> rule28 -2.1119330 Temp <= 84 & Wind > 7.4#> rule74 -0.8276940 Wind > 6.9 & Temp <= 84#> rule200 -0.4479854 Solar.R <= 201#> rule166 -0.1202175 Wind > 6.9 & Temp <= 82
The cross-validated error printed here is calculated using the same data as was used for generating the rules and therefore may provide an overly optimistic estimate of future prediction error. To obtain a more realistic prediction error estimate, we will use function cvpre()
later on.
The table represents the rules and linear terms selected for the final ensemble, with the estimated coefficients. For rules, the description
column provides the conditions. If all conditions of a rule apply to an observation, the predicted value of the response increases by the estimated coefficient, which is printed in the coefficient
column. If linear terms were selected for the final ensemble (which is not the case here), the winsorizing points used to reduce the influence of outliers on the estimated coefficient would be printed in the description
column. For linear terms, the estimated coefficient in coefficient
reflects the increase in the predicted value of the response, for a unit increase in the predictor variable.
If we want to plot the rules in the ensemble as simple decision trees, we can use the plot
method. Here, we request the nine most important baselearners are requested here through specification of the nterms
argument. Through the cex
argument, we specify the size of the node and path labels:
plot(airq.ens, nterms = 9, cex = .5)
We can obtain the estimated coefficients for each of the baselearners using the coef
method (only the first ten are printed here):
coefs <- coef(airq.ens)coefs[1:10,]#> rule coefficient description#> 201 (Intercept) 72.968070 1#> 167 rule191 -15.640149 Wind > 5.7 & Temp <= 87#> 150 rule173 -8.664592 Wind > 5.7 & Temp <= 82#> 179 rule204 8.171556 Wind <= 10.3 & Solar.R > 148#> 39 rule42 -7.692859 Wind > 6.3 & Temp <= 84#> 10 rule10 -6.803289 Temp <= 84 & Temp <= 77#> 168 rule192 -4.692662 Wind > 5.7 & Temp <= 87 & Day <= 23#> 84 rule93 3.146876 Temp > 77 & Wind <= 8.6#> 48 rule51 -2.698157 Wind > 5.7 & Temp <= 84#> 23 rule25 -2.448119 Wind > 6.3 & Temp <= 82
We can generate predictions for new observations using the predict
method:
predict(airq.ens, newdata = airq[1:4, ])#> 1 2 3 4#> 31.10390 20.82041 20.82041 21.26840
We can assess the expected prediction error of the prediction rule ensemble through cross validation (10-fold, by default) using the cvpre()
function:
set.seed(43)airq.cv <- cvpre(airq.ens)#> $MSE#> MSE se#> 332.48191 72.23573#>#> $MAE#> MAE se#> 13.186533 1.200747
The results provide the mean squared error (MSE) and mean absolute error (MAE) with their respective standard errors. The cross-validated predictions, which can be used to compute alternative estimates of predictive accuracy, are saved in airq.cv$cvpreds
. The folds to which observations were assigned are saved in airq.cv$fold_indicators
.
Package pre provides several additional tools for interpretation of the final ensemble. These may be especially helpful for complex ensembles containing many rules and linear terms.
We can assess the relative importance of input variables as well as baselearners using the importance()
function:
imps <- importance(airq.ens, round = 4)
As we already observed in the printed ensemble, the plotted variable importances indicate that Temperature and Wind are most strongly associated with Ozone levels. Solar.R and Day are also associated with Ozone levels, but much less strongly. Variable Month is not plotted, which means it obtained an importance of zero, indicating that it is not associated with Ozone levels. We already observed this in the printed ensemble: Month was not selected as a linear term and did not appear in any of the selected rules. The variable and baselearner importances are saved in imps$varimps
and imps$baseimps
, respectively.
We can obtain partial dependence plots to assess the effect of single predictor variables on the outcome using the singleplot()
function:
singleplot(airq.ens, varname = "Temp")
We can obtain partial dependence plots to assess the effects of pairs of predictor variables on the outcome using the pairplot()
function:
pairplot(airq.ens, varnames = c("Temp", "Wind"))
Note that creating partial dependence plots is computationally intensive and computation time will increase fast with increasing numbers of observations and numbers of variables. R
package plotmo
created by Stephen Milborrow (2018) provides more efficient functions for plotting partial dependence, which also support pre
models.
If the final ensemble does not contain a lot of terms, inspecting individual rules and linear terms through the print
method may be (much) more informative than partial dependence plots. One of the main advantages of prediction rule ensembles is their interpretability: the predictive model contains only simple functions of the predictor variables (rules and linear terms), which are easy to grasp. Partial dependence plots are often much more useful for interpretation of complex models, like random forests for example.
We can obtain explanations of the predictions for individual observations using function explain()
:
expl <- explain(airq.ens, newdata = airq[1:4, ], cex = .6)
The values of the rules and linear terms for each observation are saved in expl$predictors
and the contributions in expl$contribution
.
We can assess correlations between the baselearners appearing in the ensemble using the corplot()
function:
corplot(airq.ens)
We can assess the presence of interactions between the input variables using the interact()
and bsnullinteract()
funtions. Function bsnullinteract()
computes null-interaction models (10, by default) based on bootstrap-sampled and permuted datasets. Function interact()
computes interaction test statistics for each predictor variables appearing in the specified ensemble. If null-interaction models are provided through the nullmods
argument, interaction test statistics will also be computed for the null-interaction model, providing a reference null distribution.
Note that computing null interaction models and interaction test statistics is computationally very intensive.
set.seed(44)nullmods <- bsnullinteract(airq.ens)int <- interact(airq.ens, nullmods = nullmods)
The plotted variable interaction strengths indicate that Temperature and Wind may be involved in interactions, as their observed interaction strengths (darker grey) exceed the upper limit of the 90% confidence interval (CI) of interaction stengths in the null interaction models (lighter grey bar represents the median, error bars represent the 90% CIs). The plot indicates that Solar.R and Day are not involved in any interactions. Note that computation of null interaction models is computationally intensive. A more reliable result can be obtained by computing a larger number of boostrapped null interaction datasets, by setting the nsamp
argument of function bsnullinteract()
to a larger value (e.g., 100).
More complex prediction ensembles can be obtained using the gpe()
function. Abbreviation gpe stands for generalized prediction ensembles, which can also include hinge functions of the predictor variables as described in Friedman (1991), in addition to rules and/or linear terms. Addition of hinge functions may further improve predictive accuracy. See the following example:
set.seed(42)airq.gpe <- gpe(Ozone ~ ., data = airquality[complete.cases(airquality),],base_learners = list(gpe_trees(), gpe_linear(), gpe_earth()))airq.gpe#>#> Final ensemble with cv error within 1se of minimum:#> lambda = 2.44272#> number of terms = 14#> mean cv error (se) = 296.5474 (74.18922)#>#> cv error type : Mean-squared Error#>#> description coefficient#> (Intercept) 67.22404190#> Temp <= 77 -7.72729559#> Temp <= 84 & Wind > 7.4 -0.03574864#> Wind <= 10.3 & Solar.R > 148 6.29678603#> Wind > 5.7 & Temp <= 82 -6.51245200#> Wind > 5.7 & Temp <= 84 -7.58076900#> Wind > 5.7 & Temp <= 87 -9.64346611#> Wind > 5.7 & Temp <= 87 & Day <= 23 -4.38371322#> Wind > 6.3 & Temp <= 82 -4.18790433#> Wind > 6.3 & Temp <= 84 -4.88269073#> Wind > 6.9 & Temp <= 82 -0.12188611#> Wind > 6.9 & Temp <= 84 -0.93529314#> eTerm(Solar.R * h(6.3 - Wind), scale = 150) 1.42794086#> eTerm(h(6.9 - Wind) * Day, scale = 16) 1.35764132#> eTerm(Solar.R * h(9.7 - Wind), scale = 410) 9.84395780#>#> 'h' in the 'eTerm' indicates the hinge function
Breiman, L., Friedman, J., Olshen, R., & Stone, C. (1984). Classification and regression trees. Chapman&Hall/CRC.
Fokkema, M. (2017). Pre: An R package for fitting prediction rule ensembles. arXiv:1707.07149. Retrieved from https://arxiv.org/abs/1707.07149
Friedman, J. (1991). Multivariate adaptive regression splines. The Annals of Statistics, 19(1), 1–67.
Friedman, J., & Popescu, B. (2008). Predictive learning via rule ensembles. The Annals of Applied Statistics, 2(3), 916–954. Retrieved from http://www.jstor.org/stable/30245114
Hothorn, T., Hornik, K., & Zeileis, A. (2006). Unbiased recursive partitioning: A conditional inference framework. Journal of Computational and Graphical Statistics, 15(3), 651–674.
Milborrow, S. (2018). Plotmo: Plot a model’s residuals, response, and partial dependence plots. Retrieved from https://CRAN.R-project.org/package=plotmo
Zeileis, A., Hothorn, T., & Hornik, K. (2008). Model-based recursive partitioning. Journal of Computational and Graphical Statistics, 17(2), 492–514.
Changes in Version 0.7.1 (2019-04-24)
Bugs fixed in caret_pre_model: tuning of penalty.par.val argument always yielded results for "lambda.1se" only. Results are now correctly returned for "lambda.min" and "lambda.1se". caret's varImp() and predictors() not supported (perhaps temporarily), as these would always employ default penalty.par.val of "lambda.1se".
Bugs fixed in explain().
Changes in Version 0.7 (2019-03-30)
Added support for sparse rule matrix, which can be invoked through sparse argument in pre(). If sparse = TRUE, memory usage will be reduced and computation speed may be improved for large datasets.
Added function explain(), which provides (graphical) explanations of the ensemble's predictions at the individual observation level.
Changes in Version 0.6 (2018-08-03)
Added support for survival responses (i.e., family = "cox") in pre()
Added summary methods for pre and gpe.
Extended support to all response variable types available in pre() for functions plot(), importance() and cvpre().
plot.pre now allows for specifying separate plotting colors for rules with positive and negative coefficients.
coef and print methods for pre now return descriptions for the intercept (and factor variables), thanks to suggestion by Stephen Milborrow.
Bug fix in pre(): ordered factors no longer yield error. Implemented new argument 'ordinal' in pre(), which specifies how ordered factors should be processed.
Bug fix in cvpre(): pclass argument now processed correctly.
Bug fix in cvpre(): previously, SDs insteas of SEs were returned for binary classification. Accurate standard errors are returned now.
Bugs fixed in coef.pre(), print.pre(), plot.pre() and importance() when tree.unbiased = FALSE, thanks to a bug report by Stephen Milborrow.
Changes in Version 0.5 (2018-05-07)
Function pre() now also supports multinomial and multivariate gaussian response variables.
Function pre() now has argument 'tree.unbiased'; if set to FALSE, the CART algorithm (as implemented in package 'rpart') is employed for rule induction.
Argument 'maxdepth' of function pre() allows for specifying varying maximum depth across trees, through specifying a vector of length ntrees, or a random number generating function. See ?maxdepth.sampler for details.
Changes in Version 0.4 (2017-08-31)
Added dataset 'carrillo'
By default, a gradient boosting approach is now taken for all response types. That is, partykit::ctree() and a learning rate of .01 is employed by default. Alternatively, glmtree() can be employed for tree induction by sprecifying use.grad = FALSE.
The 'family' argument in pre() now takes character strings as well as glm family objects.
Functions pairplot() and interact() now use HCL instead of highly saturated HSV colors as default plotting colors.
Bug fixed in plot.pre: Node directions are now in accordance with rule definition.
Bug fixed in predict.pre: No error printed when response variable is not supplied.
Changes in Version 0.3 (2017-08-03):
Function gpe() added, which fits general prediction ensembles. By default, it fits an ensemble of rules, linear and hinge functions. Function gpe() allows for specifying custom baselearner generating functions and a custom fitting function for the final model.
Numerous bugs fixed, yielding faster computation times and clearer plots with more customization options.
Added support for count responses. Function pre() now has a 'family' argument, which should be set to 'poisson' for count outcomes (the 'family' argument is set automatically to 'gaussian' for numeric response variables and to 'binomial' for binary response variables (factors)).
A gradient boosting approach for binary outcomes is applied, by default, substantially reducing computation times. This can be turned off through the 'use.grad' argument in function pre().
The default of the 'learnrate' argument of function pre() has been changed to .01, by default. Before, it was .01 for continuous outcomes, but 0 for binary outcomes, to reduce computation time. With gradient boosting implemented, computation time is much reduced.
Argument 'tree.control' in function pre() allows for passing arguments to partykit tree fitting functions.
Arguments for the cv.glmnet() function are directly passed through better use of ... . Most importantly, this means that argument 'mod.sel.crit' cannot be used anymore and should be referred to as 'type.measure' (which will be directly passed to cv.glmnet). Similarly, 'thres' and 'standardize' are not explicit arguments of function pre() anymore and can now be directly passed to cv.glmnet() using ... .
Better use of sample weights: weights specified with the 'weights' argument in pre() are now used as weights in the subsampling procedure, instead of as observation weights in the tree-fitting procedure.
Added corplot() function, which shows the correlation between the baselearners in the ensemble.
Function pairplot() returns a heatmap by default, a 3D or contour plot can also be requested.
Appearance of plot resulting from interaction() improved.
Changes in Version 0.2 (2017-04-25):
Added print() and plot() method for objects of class pre.
Added support for using functions like factor() and log() in formula statement of function pre(). (thanks to Bill Venables for suggesting this)
Added support for parallel computating in functions pre(), cvpre(), bsnullinteract() and interact().
Winsorizing points used for the linear terms are reported in the description of the base learners returned by coef() and importance(). (Thanks to Rishi Sadhir for suggesting this)
Added README file.
Legend included in plot for interaction test statistics.
Fixed importance() function to allow for selecting final ensemble with different value than 'lambda.1se'.
Cleaned up all occurrences of set.seed()
Fixed cvpre() function: penalty.par.val argument now included
Many minor bug fixes.
Changes in Version 0.1 (2016-12-23):