Tidy tools for quantifying how well model fits to a data set such as confusion matrices, class probability curve summaries, and regression metrics (e.g., RMSE).
yardstick
is a package to estimate how well models are working using
tidy data principles.
See the package webpage for
more information.
To install the package:
install.packages("yardstick")# Development version:devtools::install_github("tidymodels/yardstick")
For example, suppose you create a classification model and predict on a new data set. You might have data that looks like this:
library(yardstick)library(dplyr)head(two_class_example)#> truth Class1 Class2 predicted#> 1 Class2 0.00359 0.996411 Class2#> 2 Class1 0.67862 0.321379 Class1#> 3 Class2 0.11089 0.889106 Class2#> 4 Class1 0.73516 0.264838 Class1#> 5 Class2 0.01624 0.983760 Class2#> 6 Class1 0.99928 0.000725 Class1
You can use a dplyr
-like syntax to compute common performance
characteristics of the model and get them back in a data frame:
metrics(two_class_example, truth, predicted)#> # A tibble: 2 x 3#> .metric .estimator .estimate#> <chr> <chr> <dbl>#> 1 accuracy binary 0.838#> 2 kap binary 0.675# ortwo_class_example %>%roc_auc(truth, Class1)#> # A tibble: 1 x 3#> .metric .estimator .estimate#> <chr> <chr> <dbl>#> 1 roc_auc binary 0.939
All classification metrics have at least one multiclass extension, with many of them having multiple ways to calculate multiclass metrics.
data("hpc_cv")hpc_cv <- as_tibble(hpc_cv)hpc_cv#> # A tibble: 3,467 x 7#> obs pred VF F M L Resample#> <fct> <fct> <dbl> <dbl> <dbl> <dbl> <chr>#> 1 VF VF 0.914 0.0779 0.00848 0.0000199 Fold01#> 2 VF VF 0.938 0.0571 0.00482 0.0000101 Fold01#> 3 VF VF 0.947 0.0495 0.00316 0.00000500 Fold01#> 4 VF VF 0.929 0.0653 0.00579 0.0000156 Fold01#> 5 VF VF 0.942 0.0543 0.00381 0.00000729 Fold01#> 6 VF VF 0.951 0.0462 0.00272 0.00000384 Fold01#> 7 VF VF 0.914 0.0782 0.00767 0.0000354 Fold01#> 8 VF VF 0.918 0.0744 0.00726 0.0000157 Fold01#> 9 VF VF 0.843 0.128 0.0296 0.000192 Fold01#> 10 VF VF 0.920 0.0728 0.00703 0.0000147 Fold01#> # … with 3,457 more rows
# Macro averaged multiclass precisionprecision(hpc_cv, obs, pred)#> # A tibble: 1 x 3#> .metric .estimator .estimate#> <chr> <chr> <dbl>#> 1 precision macro 0.631# Micro averaged multiclass precisionprecision(hpc_cv, obs, pred, estimator = "micro")#> # A tibble: 1 x 3#> .metric .estimator .estimate#> <chr> <chr> <dbl>#> 1 precision micro 0.709
If you have multiple resamples of a model, you can use a metric on a grouped data frame to calculate the metric across all resamples at once.
This calculates multiclass ROC AUC using the method described in Hand, Till (2001), and does it across all 10 resamples at once.
hpc_cv %>%group_by(Resample) %>%roc_auc(obs, VF:L)#> # A tibble: 10 x 4#> Resample .metric .estimator .estimate#> <chr> <chr> <chr> <dbl>#> 1 Fold01 roc_auc hand_till 0.831#> 2 Fold02 roc_auc hand_till 0.817#> 3 Fold03 roc_auc hand_till 0.869#> 4 Fold04 roc_auc hand_till 0.849#> 5 Fold05 roc_auc hand_till 0.811#> 6 Fold06 roc_auc hand_till 0.836#> 7 Fold07 roc_auc hand_till 0.825#> 8 Fold08 roc_auc hand_till 0.846#> 9 Fold09 roc_auc hand_till 0.836#> 10 Fold10 roc_auc hand_till 0.820
Curve based methods such as roc_curve()
, pr_curve()
and
gain_curve()
all have ggplot2::autoplot()
methods that allow for
powerful and easy visualization.
library(ggplot2)hpc_cv %>%group_by(Resample) %>%roc_curve(obs, VF:L) %>%autoplot()
Quasiquotation can also be used to supply inputs.
# probability columns:lvl <- levels(two_class_example$truth)two_class_example %>%mn_log_loss(truth, !! lvl[1])#> # A tibble: 1 x 3#> .metric .estimator .estimate#> <chr> <chr> <dbl>#> 1 mn_log_loss binary 0.328
mase()
is a numeric metric for the mean absolute scaled error. It is
generally useful when forecasting with time series (@alexhallam, #68).
huber_loss()
is a numeric metric that is less sensitive to outliers than
rmse()
, but is more sensitive than mae()
for small errors (@blairj09, #71).
huber_loss_pseudo()
is a smoothed form of huber_loss()
(@blairj09, #71).
smape()
is a numeric metric that is based on percentage errors
(@riazhedayati, #67).
conf_mat
objects now have two ggplot2::autoplot()
methods for easy visualization
of the confusion matrix as either a heat map or a mosaic plot (@EmilHvitfeldt, #10).
metric_set()
now returns a classed function. If numeric metrics are used,
a "numeric_metric_set"
function is returned. If class or probability metrics
are used, a "class_prob_metric_set"
is returned.Tests related to the fixed R 3.6 sample()
function have been fixed.
f_meas()
propagates NA
values from precision()
and recall()
correctly (#77).
All "micro"
estimators now propagate NA
values through correctly.
roc_auc(estimator = "hand_till")
now correctly computes the metric when the column names of the probability matrix are not the exact same as the levels of truth
. Note that the computation still assumes that the order of the supplied probability matrix columns still matches the order of levels(truth)
, like other multiclass metrics (#86).
A desire to standardize the yardstick API is what drove these breaking changes. The output of each metric is now in line with tidy principles, returning a tibble rather than a single numeric. Additionally, all metrics now have a standard argument list so you should be able to switch between metrics and combine them together effortlessly.
All metrics now return a tibble rather than a single numeric value. This format
allows metrics to work with grouped data frames (for resamples). It also allows
you to bundle multiple metrics together with a new function, metric_set()
.
For all class probability metrics, now only 1 column can be passed to ...
when a binary implementation is used. Those metrics will no longer select
only the first column when multiple columns are supplied, and will instead
throw an error.
The summary()
method for conf_mat
objects now returns a tibble
to be consistent with the change to the metric functions.
For naming consistency, mnLogLoss()
was renamed to mn_log_loss()
mn_log_loss()
now returns the negative log loss for the
multinomial distribution.
The argument na.rm
has been changed to na_rm
in all metrics to align
with the tidymodels
model implementation principles.
Each metric now has a vector interface to go alongside the data frame interface.
All vector functions end in _vec()
. The vector interface accepts vector/matrix
inputs and returns a single numeric value.
Multiclass support has been added for each classification metric.
The support varies from one metric to the next, but generally macro and micro
averaging is available for all metrics, with some metrics having specialized
multiclass implementations (for example, roc_auc()
supports the
multiclass generalization presented in a paper by Hand and Till).
For more information, see vignette("multiclass", "yardstick")
.
All metrics now work with grouped data frames. This produces a tibble with as many rows as there are groups, and is useful when used alongside resampling techniques.
mape()
calculates the mean absolute percent error.
kap()
is a metric similar to accuracy()
that calculates Cohen's kappa.
detection_prevalence()
calculates the number of predicted positive events
relative to the total number of predictions.
bal_accuracy()
calculates balanced accuracy as the average of sensitivity
and specificity.
roc_curve()
calculates receiver operator curves and returns the results as
a tibble.
pr_curve()
calculates precision recall curves.
gain_curve()
and lift_curve()
calculate the information used
in gain and lift curves.
gain_capture()
is a measure of performance similar in spirit to AUC
but applied to a gain curve.
pr_curve()
, roc_curve()
, gain_curve()
and lift_curve()
all have
ggplot2::autoplot()
methods for easy visualization.
metric_set()
constructs functions that calculate
multiple metrics at once.
The infrastructure for creating metrics has been exposed to allow
users to extend yardstick to work with their own metrics. You might want to
do this if you want your metrics to work with grouped data frames out of the
box, or if you want the standardization and error checking that yardstick
already provides. See vignette("custom-metrics", "yardstick")
for a few
examples.
A vignette describing the three classes of metrics used in yardstick has been
added. It also includes a list of every metric available, grouped by class.
See vignette("metric-types", "yardstick")
.
The error messages in yardstick should now be much more informative, with
better feedback about the types of input that each metric can use and about
what kinds of metrics can be used together (i.e. in metric_set()
).
There is now a grouped_df
method for conf_mat()
that returns a tibble
with a list column of conf_mat
objects.
Each metric now has its own help page. This allows us to better document the nuances of each metric without cluttering the help pages of other metrics.
broom
has been removed from Depends, and is replaced by generics
in Suggests.
tidyr
and ggplot2
have been moved to Suggests.
MLmetrics
has been removed as a dependency.