Spark ML – Tuning

R/ml_tuning.R, R/ml_tuning_cross_validator.R,

ml-tuning

Description

Perform hyper-parameter tuning using either K-fold cross validation or train-validation split.

Usage

 
ml_sub_models(model) 
 
ml_validation_metrics(model) 
 
ml_cross_validator( 
  x, 
  estimator = NULL, 
  estimator_param_maps = NULL, 
  evaluator = NULL, 
  num_folds = 3, 
  collect_sub_models = FALSE, 
  parallelism = 1, 
  seed = NULL, 
  uid = random_string("cross_validator_"), 
  ... 
) 
 
ml_train_validation_split( 
  x, 
  estimator = NULL, 
  estimator_param_maps = NULL, 
  evaluator = NULL, 
  train_ratio = 0.75, 
  collect_sub_models = FALSE, 
  parallelism = 1, 
  seed = NULL, 
  uid = random_string("train_validation_split_"), 
  ... 
) 

Arguments

Arguments Description
model A cross validation or train-validation-split model.
x A spark_connection, ml_pipeline, or a tbl_spark.
estimator A ml_estimator object.
estimator_param_maps A named list of stages and hyper-parameter sets to tune. See details.
evaluator A ml_evaluator object, see ml_evaluator.
num_folds Number of folds for cross validation. Must be >= 2. Default: 3
collect_sub_models Whether to collect a list of sub-models trained during tuning. If set to FALSE, then only the single best sub-model will be available after fitting. If set to true, then all sub-models will be available. Warning: For large models, collecting all sub-models can cause OOMs on the Spark driver.
parallelism The number of threads to use when running parallel algorithms. Default is 1 for serial execution.
seed A random seed. Set this value if you need your results to be reproducible across repeated calls.
uid A character string used to uniquely identify the ML estimator.
Optional arguments; currently unused.
train_ratio Ratio between train and validation data. Must be between 0 and 1. Default: 0.75

Details

ml_cross_validator() performs k-fold cross validation while ml_train_validation_split() performs tuning on one pair of train and validation datasets.

Value

The object returned depends on the class of x.

  • spark_connection: When x is a spark_connection, the function returns an instance of a ml_cross_validator or ml_traing_validation_split object.

  • ml_pipeline: When x is a ml_pipeline, the function returns a ml_pipeline with the tuning estimator appended to the pipeline.

  • tbl_spark: When x is a tbl_spark, a tuning estimator is constructed then immediately fit with the input tbl_spark, returning a ml_cross_validation_model or a ml_train_validation_split_model object.

    For cross validation, ml_sub_models() returns a nested list of models, where the first layer represents fold indices and the second layer represents param maps. For train-validation split, ml_sub_models() returns a list of models, corresponding to the order of the estimator param maps. ml_validation_metrics() returns a data frame of performance metrics and hyperparameter combinations.

Examples

library(sparklyr)
 
sc <- spark_connect(master = "local") 
iris_tbl <- sdf_copy_to(sc, iris, name = "iris_tbl", overwrite = TRUE) 
 
# Create a pipeline 
pipeline <- ml_pipeline(sc) %>% 
  ft_r_formula(Species ~ .) %>% 
  ml_random_forest_classifier() 
 
# Specify hyperparameter grid 
grid <- list( 
  random_forest = list( 
    num_trees = c(5, 10), 
    max_depth = c(5, 10), 
    impurity = c("entropy", "gini") 
  ) 
) 
 
# Create the cross validator object 
cv <- ml_cross_validator( 
  sc, 
  estimator = pipeline, estimator_param_maps = grid, 
  evaluator = ml_multiclass_classification_evaluator(sc), 
  num_folds = 3, 
  parallelism = 4 
) 
 
# Train the models 
cv_model <- ml_fit(cv, iris_tbl) 
 
# Print the metrics 
ml_validation_metrics(cv_model) 
#>          f1 impurity_1 num_trees_1 max_depth_1
#> 1 0.9397887    entropy           5           5
#> 2 0.9056526    entropy          10           5
#> 3 0.9397887    entropy           5          10
#> 4 0.9056526    entropy          10          10
#> 5 0.9397887       gini           5           5
#> 6 0.9127092       gini          10           5
#> 7 0.9397887       gini           5          10
#> 8 0.9127092       gini          10          10