Extracts metrics from a fitted table




The function works best when passed a tbl_spark created by ml_predict(). The output tbl_spark will contain the correct variable types and format that the given Spark model “evaluator” expects.


  truth = label, 
  estimate = prediction, 
  metrics = c("accuracy"), 
  beta = NULL, 


Arguments Description
x A tbl_spark containing the estimate (prediction) and the truth (value of what actually happened)
truth The name of the column from x with an integer field containing an the indexed value for each outcome . The ml_predict() function will create a new field named label which contains the expected type and values. truth defaults to label.
estimate The name of the column from x that contains the prediction. Defaults to prediction, since its type and indexed values will match truth.
metrics A character vector with the metrics to calculate. For multiclass models the possible values are: acurracy, f_meas (F-score), recall and precision. This function translates the argument into an acceptable Spark parameter. If no translation is found, then the raw value of the argument is passed to Spark. This makes it possible to request a metric that is not listed here but, depending on version, it is available in Spark. Other metrics form multi-class models are: weightedTruePositiveRate, weightedFalsePositiveRate, weightedFMeasure, truePositiveRateByLabel, falsePositiveRateByLabel, precisionByLabel, recallByLabel, fMeasureByLabel, logLoss, hammingLoss
beta Numerical value used for precision and recall. Defaults to NULL, but if the Spark session’s verion is 3.0 and above, then NULL is changed to 1, unless something different is supplied in this argument.
Optional arguments; currently unused.


The ml_metrics family of functions implement Spark’s evaluate closer to how the yardstick package works. The functions expect a table containing the truth and estimate, and return a tibble with the results. The tibble has the same format and variable names as the output of the yardstick functions.


sc <- spark_connect("local") 
tbl_iris <- copy_to(sc, iris) 
iris_split <- sdf_random_split(tbl_iris, training = 0.5, test = 0.5) 
model <- ml_random_forest(iris_split$training, "Species ~ .") 
tbl_predictions <- ml_predict(model, iris_split$test) 
#> # A tibble: 1 × 3
#>   .metric  .estimator .estimate
#>   <chr>    <chr>          <dbl>
#> 1 accuracy multiclass     0.951
# Request different metrics 
ml_metrics_multiclass(tbl_predictions, metrics = c("recall", "precision")) 
#> # A tibble: 2 × 3
#>   .metric   .estimator .estimate
#>   <chr>     <chr>          <dbl>
#> 1 recall    multiclass     0.951
#> 2 precision multiclass     0.957
# Request metrics not translated by the function, but valid in Spark 
ml_metrics_multiclass(tbl_predictions, metrics = c("logLoss", "hammingLoss")) 
#> # A tibble: 2 × 3
#>   .metric     .estimator .estimate
#>   <chr>       <chr>          <dbl>
#> 1 logLoss     multiclass    0.133 
#> 2 hammingLoss multiclass    0.0494