Tidy, Type-Safe 'prediction()' Methods

A one-function package containing 'prediction()', a type-safe alternative to 'predict()' that always returns a data frame. The 'summary()' method provides a data frame with average predictions, possibly over counterfactual versions of the data (a la the 'margins' command in 'Stata'). Marginal effect estimation is provided by the related package, 'margins' < https://cran.r-project.org/package=margins>. The package currently supports common model types (e.g., "lm", "glm") from the 'stats' package, as well as numerous other model classes from other add-on packages. See the README or main package documentation page for a complete listing.


The prediction and margins packages are a combined effort to port the functionality of Stata's (closed source) margins command to (open source) R. prediction is focused on one function - prediction() - that provides type-safe methods for generating predictions from fitted regression models. prediction() is an S3 generic, which always return a "data.frame" class object rather than the mix of vectors, lists, etc. that are returned by the predict() methods for various model types. It provides a key piece of underlying infrastructure for the margins package. Users interested in generating marginal (partial) effects, like those generated by Stata's margins, dydx(*) command, should consider using margins() from the sibling project, margins.

In addition to prediction(), this package provides a number of utility functions for generating useful predictions:

  • find_data(), an S3 generic with methods that find the data frame used to estimate a regression model. This is a wrapper around get_all_vars() that attempts to locate data as well as modify it according to subset and na.action arguments used in the original modelling call.
  • mean_or_mode() and median_or_mode(), which provide a convenient way to compute the data needed for predicted values at means (or at medians), respecting the differences between factor and numeric variables.
  • seq_range(), which generates a vector of n values based upon the range of values in a variable
  • build_datalist(), which generates a list of data frames from an input data frame and a specified set of replacement at values (mimicking the atlist option of Stata's margins command)

Simple code examples

A major downside of the predict() methods for common modelling classes is that the result is not type-safe. Consider the following simple example:

library("stats")
library("datasets")
x <- lm(mpg ~ cyl * hp + wt, data = mtcars)
class(predict(x))
## [1] "numeric"
class(predict(x, se.fit = TRUE))
## [1] "list"

prediction solves this issue by providing a wrapper around predict(), called prediction(), that always returns a tidy data frame with a very simple print() method:

library("prediction")
(p <- prediction(x))
## Average prediction for 32 observations: 20.0906
class(p)
## [1] "prediction" "data.frame"
head(p)
##    mpg cyl disp  hp drat    wt  qsec vs am gear carb   fitted se.fitted
## 1 21.0   6  160 110 3.90 2.620 16.46  0  1    4    4 21.90488 0.6927034
## 2 21.0   6  160 110 3.90 2.875 17.02  0  1    4    4 21.10933 0.6266557
## 3 22.8   4  108  93 3.85 2.320 18.61  1  1    4    1 25.64753 0.6652076
## 4 21.4   6  258 110 3.08 3.215 19.44  1  0    3    1 20.04859 0.6041400
## 5 18.7   8  360 175 3.15 3.440 17.02  0  0    3    2 17.25445 0.7436172
## 6 18.1   6  225 105 2.76 3.460 20.22  1  0    3    1 19.53360 0.6436862

The output always contains the original data (i.e., either data found using the find_data() function or passed to the data argument to prediction()). This makes it much simpler to pass predictions to, e.g., further summary or plotting functions.

Additionally the vast majority of methods allow the passing of an at argument, which can be used to obtain predicted values using modified version of data held to specific values:

prediction(x, at = list(hp = seq_range(mtcars$hp, 5)))
## Average predictions for 32 observations:
##  at(hp)  value
##    52.0 22.605
##   122.8 19.328
##   193.5 16.051
##   264.2 12.774
##   335.0  9.497

This more or less serves as a direct R port of (the subset of functionality of) Stata's margins command that calculates predictive marginal means, etc. For calculation of marginal or partial effects, see the margins package.

Supported model classes

The currently supported model classes are:

  • "lm" from stats::lm()
  • "glm" from stats::glm(), MASS::glm.nb(), glmx::glmx(), glmx::hetglm(), brglm::brglm()
  • "ar" from stats::ar()
  • "Arima" from stats::arima()
  • "arima0" from stats::arima0()
  • "biglm" from biglm::biglm() (including "ffdf" backed models)
  • "bigLm" from bigLm::bigLm()
  • "betareg" from betareg::betareg()
  • "bruto" from mda::bruto()
  • "clm" from ordinal::clm()
  • "coxph" from survival::coxph()
  • "crch" from crch::crch()
  • "earth" from earth::earth()
  • "fda" from mda::fda()
  • "Gam" from gam::gam()
  • "gausspr" from kernlab::gausspr()
  • "gee" from gee::gee()
  • "glimML" from aod::betabin(), aod::negbin()
  • "glimQL" from aod::quasibin(), aod::quasipois()
  • "glmnet" from glmnet::glmnet()
  • "gls" from nlme::gls()
  • "hurdle" from pscl::hurdle()
  • "hxlr" from crch::hxlr()
  • "ivreg" from AER::ivreg()
  • "knnreg" from caret::knnreg()
  • "kqr" from kernlab::kqr()
  • "ksvm" from kernlab::ksvm()
  • "lda" from MASS:lda()
  • "lme" from nlme::lme()
  • "loess" from stats::loess()
  • "lqs" from MASS::lqs()
  • "mars" from mda::mars()
  • "mca" from MASS::mca()
  • "mclogit" from mclogit::mclogit()
  • "mda" from mda::mda()
  • "merMod" from lme4::lmer() and lme4::glmer()
  • "mnlogit" from mnlogit::mnlogit()
  • "mnp" from MNP::mnp()
  • "naiveBayes" from e1071::naiveBayes()
  • "nlme" from nlme::nlme()
  • "nls" from stats::nls()
  • "nnet" from nnet::nnet(), nnet::multinom()
  • "plm" from plm::plm()
  • "polr" from MASS::polr()
  • "ppr" from stats::ppr()
  • "princomp" from stats::princomp()
  • "qda" from MASS:qda()
  • "rlm" from MASS::rlm()
  • "rpart" from rpart::rpart()
  • "rq" from quantreg::rq()
  • "selection" from sampleSelection::selection()
  • "speedglm" from speedglm::speedglm()
  • "speedlm" from speedglm::speedlm()
  • "survreg" from survival::survreg()
  • "svm" from e1071::svm()
  • "svyglm" from survey::svyglm()
  • "tobit" from AER::tobit()
  • "train" from caret::train()
  • "truncreg" from truncreg::truncreg()
  • "zeroinfl" from pscl::zeroinfl()

Requirements and Installation

CRAN Downloads Build Status Build status codecov.io Project Status: Active - The project has reached a stable, usable state and is being actively developed.

The development version of this package can be installed directly from GitHub using remotes:

if (!require("remotes")) {
    install.packages("remotes")
}
remotes::install_github("leeper/prediction")

News

prediction 0.3.6

  • Small fixes for failing CRAN checks. (#25)
  • Remove prediction.bigglm() method (from biglm) due to failing tests. (#25)

prediction 0.3.5

  • Fixed a bug that required specifying stats::poly() rather than just poly() in model formulae. (#22)

prediction 0.3.4

  • Added prediction.glmnet() method for "glmnet" objects from glmnet. (#1)

prediction 0.3.3

  • prediction.merMod() gains an re.form argument to pass forward to predict.merMod().

prediction 0.3.2

  • Fix typo in "speedglm" that was overwriting "glm" method.

prediction 0.3.0

  • CRAN release.

prediction 0.2.11

  • Added prediction.glmML() method for "glimML" objects from aod. (#1)
  • Added prediction.glmQL() method for "glimQL" objects from aod. (#1)
  • Added prediction.truncreg() method for "truncreg" objects from truncreg. (#1)
  • Noted implicit support for "tobit" objects from AER. (#1)

prediction 0.2.10

  • Added prediction.bruto() method for "bruto" objects from mda. (#1)
  • Added prediction.fda() method for "fda" objects from mda. (#1)
  • Added prediction.mars() method for "mars" objects from mda. (#1)
  • Added prediction.mda() method for "mda" objects from mda. (#1)
  • Added prediction.polyreg() method for "polyreg" objects from mda. (#1)

prediction 0.2.9

  • Added prediction.speedglm() and prediction.speedlm() methods for "speedglm" and "speedlm" objects from speedglm. (#1)
  • Added prediction.bigLm() method for "bigLm" objects from bigFastlm. (#1)
  • Added prediction.biglm() and prediction.bigglm() methods for "biglm" and "bigglm" objects from biglm, including those based by "ffdf" from ff. (#1)

prediction 0.2.8

  • Changed internal behavior of build_datalist(). The function now returns an an at_specification attribute, which is a data frame representation of the at argument.

prediction 0.2.6

  • Due to a change in gam_1.15, prediction.gam() is now prediction.Gam() for "Gam" objects from gam. (#1)

prediction 0.2.6

  • Added prediction.train() method for "train" objects from caret. (#1)

prediction 0.2.5

  • The at argument in build_datalist() now accepts a data frame of combinations for limiting the set of levels.

prediction 0.2.3

  • Most prediction() methods gain a (experimental) calculate_se argument, which regulates whether to calculate standard errors for predictions. Setting to FALSE can improve performance if they are not needed.

prediction 0.2.3

  • build_datalist() gains an as.data.frame argument, which - if TRUE - returns a stacked data frame rather than a list. This argument is now used internally in most prediction() functions in an effort to improve performance. (#18)

prediction 0.2.2

  • Expanded test suite scope and fixed a few small bugs.
  • Added a summary.prediction() method to interact with the average predicted values that are printed when at != NULL.

prediction 0.2.1

  • Added prediction.knnreg() method for "knnreg" objects from caret. (#1)
  • Added prediction.gausspr() method for "gausspr" objects from kernlab. (#1)
  • Added prediction.ksvm() method for "ksvm" objects from kernlab. (#1)
  • Added prediction.kqr() method for "kqr" objects from kernlab. (#1)
  • Added prediction.earth() method for "earth" objects from earth. (#1)
  • Added prediction.rpart() method for "rpart" objects from rpart. (#1)

prediction 0.2.0

  • CRAN Release.
  • Added mean_or_mode.data.frame() and median_or_mode.data.frame() methods.

prediction 0.1.17

  • Added prediction.zeroinfl() method for "zeroinfl" objects from pscl. (#1)
  • Added prediction.hurdle() method for "hurdle" objects from pscl. (#1)
  • Added prediction.lme() method for "lme" and "nlme" objects from nlme. (#1)
  • Documented prediction.merMod().

prediction 0.1.16

  • Added prediction.plm() method for "plm" objects from plm. (#1)

prediction 0.1.15

  • Expanded test suite considerably and updated CONTRIBUTING.md to reflect expected test-driven development.
  • A few small code tweaks and bug fixes resulting from the updated test suite.

prediction 0.1.14

  • Added prediction.mnp() method for "mnp" objects from MNP. (#1)
  • Added prediction.mnlogit() method for "mnlogit" objects from mnlogit. (#1)
  • Added prediction.gee() method for "gee" objects from gee. (#1)
  • Added prediction.lqs() method for "lqs" objects from MASS. (#1)
  • Added prediction.mca() method for "mca" objects from MASS. (#1)
  • Noted (built-in) support for "brglm" objects from brglm via the prediction.glm() method. (#1)

prediction 0.1.13

  • Added a category argument to prediction() methods for models of multilevel outcomes (e.g., ordered probit, etc.) to be dictate which level is expressed as the "fitted" column. (#14)
  • Added an at argument to prediction() methods. (#13)
  • Made mean_or_mode() and median_or_mode() S3 generics.
  • Fixed a bug in mean_or_mode() and median_or_mode() where incorrect factor levels were being returned.

prediction 0.1.12

  • Added prediction.princomp() method for "princomp" objects from stats. (#1)
  • Added prediction.ppr() method for "ppr" objects from stats. (#1)
  • Added prediction.naiveBayes() method for "naiveBayes" objects from e1071. (#1)
  • Added prediction.rlm() method for "rlm" objects from MASS. (#1)
  • Added prediction.qda() method for "qda" objects from MASS. (#1)
  • Added prediction.lda() method for "lda" objects from MASS. (#1)
  • find_data() now respects the subset argument in an original model call. (#15)
  • find_data() now respects the na.action argument in an original model call. (#15)
  • find_data() now gracefully fails when a model is specified without a formula. (#16)
  • prediction() methods no longer add a "fit" or "se.fit" class to any columns. Fitted values are identifiable by the column name only.

prediction 0.1.11

  • build_datalist() now returns at value combinations as a list.

prediction 0.1.10

  • Added prediction.nnet() method for "nnet" and "multinom" objects from nnet. (#1)

prediction 0.1.9

  • prediction() methods now return the value of data as part of the response data frame. (#8, h/t Ben Whalley)
  • Slight change to find_data() methods for "crch" and "hxlr". (#5)
  • Added prediction.glmx() and prediction.hetglm() methods for "glmx" and "hetglm" objects from glmx. (#1)
  • Added prediction.betareg() method for "betareg" objects from betareg. (#1)
  • Added prediction.rq() method for "rq" objects from quantreg. (#1)
  • Added prediction.gam() method for "gam" objects from gam. (#1)
  • Expanded basic test suite.

prediction 0.1.8

  • Added prediction() and find_data() methods for "crch" "hxlr" objects from crch. (#4, h/t Carl Ganz)

prediction 0.1.7

  • Added prediction() and find_data() methods for "merMod" objects from lme4. (#1)

prediction 0.1.6

  • Moved the seq_range() function from margins to prediction.
  • Moved the build_datalist() function from margins to prediction. This will simplify the ability to calculate arbitrary predictions.

prediction 0.1.5

  • Added prediction.svm() method for objects of class "svm" from e1071. (#1)
  • Fixed a bug in prediction.polr() when attempting to pass a type argument, which is always ignored. A warning is now issued when attempting to override this.

prediction 0.1.4

  • Added mean_or_mode() and median_or_mode() functions, which provide a simple way to aggregate a variable of factor or numeric type. (#3)
  • Added prediction() methods for various time-series model classes: "ar", "arima0", and "Arima".

prediction 0.1.3

  • find_data() is now a generic, methods for "lm", "glm", and "svyglm" classes. (#2, h/t Carl Ganz)

prediction 0.1.2

  • Added support for "svyglm" class from the survey package. (#1)
  • Added tentative support for "clm" class from the ordinal package. (#1)

prediction 0.1.0

  • Initial package released.

Reference manual

It appears you don't have a PDF plugin for this browser. You can click here to download the reference manual.

install.packages("prediction")

0.3.14 by Thomas J. Leeper, 9 days ago


https://github.com/leeper/prediction


Report a bug at https://github.com/leeper/prediction/issues


Browse source code at https://github.com/cran/prediction


Authors: Thomas J. Leeper [aut, cre] , Carl Ganz [ctb] , Vincent Arel-Bundock [ctb]


Documentation:   PDF Manual  


MIT + file LICENSE license


Imports utils, stats, data.table

Suggests datasets, methods, testthat

Enhances AER, aod, betareg, biglm, brglm, caret, crch, e1071, earth, ff, ffbase, gam, gee, glmnet, glmx, kernlab, lme4, MASS, mclogit, mda, mlogit, MNP, nlme, nnet, ordinal, plm, pscl, quantreg, rpart, sampleSelection, speedglm, survey, survival, truncreg, VGAM


Imported by DynNom, iml, margins, predict3d, processR.

Suggested by estimatr, ggeffects.


See at CRAN