Extracts metrics from a fitted table

R/ml_metrics.R

ml_metrics_binary

Description

The function works best when passed a tbl_spark created by ml_predict(). The output tbl_spark will contain the correct variable types and format that the given Spark model “evaluator” expects.

Usage

 
ml_metrics_binary( 
  x, 
  truth = label, 
  estimate = rawPrediction, 
  metrics = c("roc_auc", "pr_auc"), 
  ... 
) 

Arguments

Arguments Description
x A tbl_spark containing the estimate (prediction) and the truth (value of what actually happened)
truth The name of the column from x with an integer field containing the binary response (0 or 1). The ml_predict() function will create a new field named label which contains the expected type and values. truth defaults to label.
estimate The name of the column from x that contains the prediction. Defaults to rawPrediction, since its type and expected values will match truth.
metrics A character vector with the metrics to calculate. For binary models the possible values are: roc_auc (Area under the Receiver Operator curve), pr_auc (Area under the Precesion Recall curve). Defaults to: roc_auc, pr_auc
Optional arguments; currently unused.

Details

The ml_metrics family of functions implement Spark’s evaluate closer to how the yardstick package works. The functions expect a table containing the truth and estimate, and return a tibble with the results. The tibble has the same format and variable names as the output of the yardstick functions.

Examples

library(sparklyr)
 
sc <- spark_connect("local") 
tbl_iris <- copy_to(sc, iris) 
prep_iris <- tbl_iris %>% 
  mutate(is_setosa = ifelse(Species == "setosa", 1, 0)) 
iris_split <- sdf_random_split(prep_iris, training = 0.5, test = 0.5) 
model <- ml_logistic_regression(iris_split$training, "is_setosa ~ Sepal_Length") 
tbl_predictions <- ml_predict(model, iris_split$test) 
ml_metrics_binary(tbl_predictions) 
#> # A tibble: 2 × 3
#>   .metric .estimator .estimate
#>   <chr>   <chr>          <dbl>
#> 1 roc_auc binary         0.988
#> 2 pr_auc  binary         0.978