The workflowsets Package

Data Science
R
Modeling
Author

Robert Lankford

Published

September 26, 2023

This post covers the workflowsets package, which focuses on quickly and easily fitting a large number of models by creating a set that holds several workflow objects.

Setup

Packages

The following packages are required:

Data

I utilized the penguins data set from the modeldata package for these examples.

Note that the outcome variable is forced to be binary by assigning all the odd-numbered rows as “Chinstrap” and the even numbered rows as “Adelie”. This is done to (1) ensure that the outcome is binary so certain functions and techniques can be used and (2) make the outcome a little more difficult to predict to better demonstrate the difference in the construction of the modeling algorithms used. In addition, rows with missing values are removed.

Code
data("penguins", package = "modeldata")

penguins_tbl <- penguins %>% 
  mutate(
    #> Need a binary outcome
    species = ifelse(row_number() %% 2 == 0, "Adelie", "Chinstrap"),
    species = factor(species, levels = c("Adelie", "Chinstrap"))
  ) %>% 
  filter(!if_any(everything(), \(x) is.na(x)))

penguins_tbl %>% 
  count(species)
# A tibble: 2 × 2
  species       n
  <fct>     <int>
1 Adelie      167
2 Chinstrap   166

Model Preparation Process

As discussed in a previous post in this series, some preparation work is required before constructing a workflow(), and by extension a workflow_set(). As before, these steps follow the tidymodels process.

Data Splitting

Referencing a previous post on the rsample package, a 70/30 train/test initial_split() on both data sets is taken. It is then passed to the training() and testing() functions to extract the training and testing data sets, respectively.

set.seed(1914)
penguins_split_obj <- initial_split(penguins_tbl, prop = 0.7)
penguins_train_tbl <- training(penguins_split_obj)

penguins_train_tbl
# A tibble: 233 × 7
   species   island bill_length_mm bill_depth_mm flipper_length_mm body_mass_g
   <fct>     <fct>           <dbl>         <dbl>             <int>       <int>
 1 Chinstrap Biscoe           36.4          17.1               184        2850
 2 Adelie    Biscoe           52.1          17                 230        5550
 3 Adelie    Dream            50.8          18.5               201        4450
 4 Adelie    Dream            37.2          18.1               178        3900
 5 Chinstrap Biscoe           45.5          14.5               212        4750
 6 Chinstrap Biscoe           39            17.5               186        3550
 7 Chinstrap Biscoe           48.2          15.6               221        5100
 8 Adelie    Dream            52            19                 197        4150
 9 Chinstrap Biscoe           45.4          14.6               211        4800
10 Chinstrap Biscoe           38.1          17                 181        3175
# ℹ 223 more rows
# ℹ 1 more variable: sex <fct>
penguins_test_tbl <- testing(penguins_split_obj)

penguins_test_tbl
# A tibble: 100 × 7
   species   island   bill_length_mm bill_depth_mm flipper_length_mm body_mass_g
   <fct>     <fct>             <dbl>         <dbl>             <int>       <int>
 1 Chinstrap Torgers…           40.3          18                 195        3250
 2 Adelie    Torgers…           39.2          19.6               195        4675
 3 Chinstrap Torgers…           38.7          19                 195        3450
 4 Adelie    Torgers…           46            21.5               194        4200
 5 Chinstrap Biscoe             37.8          18.3               174        3400
 6 Adelie    Biscoe             37.7          18.7               180        3600
 7 Adelie    Biscoe             38.2          18.1               185        3950
 8 Chinstrap Biscoe             40.6          18.6               183        3550
 9 Chinstrap Dream              36.4          17                 195        3325
10 Adelie    Dream              42.2          18.5               180        3550
# ℹ 90 more rows
# ℹ 1 more variable: sex <fct>

Data Processing

Referencing another previous post on the recipes package, pre-processing steps are performed on the training data set. The resulting recipe() object is then passed to the prep() function to prepare it for use. The result is then passed to the juice() function to extract the transformed training data set.

Note that a slightly different method is used here than in the workflows post. Instead of creating a recipe() for each modeling algorithm, three separate recipes are created without regard to the modeling algorithm. More detail will be provided in a later section, but the short explanation is that each combination of recipe() and modeling algorithm will be used, so there is no reason to specify which recipe() goes with which modeling algorithm.

This recipe() contains just the model formula and the training data set. This may seem needless, as passing this recipe() to the juice() function just returns the original data set. This is what is desired, however, and using recipe() objects makes working with a workflow() easier (in my opinion).

recipe_1_obj <- recipe(species ~ ., data = penguins_train_tbl)

recipe_1_obj %>% 
  prep() %>% 
  juice()
# A tibble: 233 × 7
   island bill_length_mm bill_depth_mm flipper_length_mm body_mass_g sex   
   <fct>           <dbl>         <dbl>             <int>       <int> <fct> 
 1 Biscoe           36.4          17.1               184        2850 female
 2 Biscoe           52.1          17                 230        5550 male  
 3 Dream            50.8          18.5               201        4450 male  
 4 Dream            37.2          18.1               178        3900 male  
 5 Biscoe           45.5          14.5               212        4750 female
 6 Biscoe           39            17.5               186        3550 female
 7 Biscoe           48.2          15.6               221        5100 male  
 8 Dream            52            19                 197        4150 male  
 9 Biscoe           45.4          14.6               211        4800 female
10 Biscoe           38.1          17                 181        3175 female
# ℹ 223 more rows
# ℹ 1 more variable: species <fct>

This recipe() adds a step to the previous that one-hot encodes all categorical variables.

recipe_2_obj <- recipe_1_obj %>% 
  step_dummy(all_nominal_predictors(), one_hot = TRUE)

recipe_2_obj %>% 
  prep() %>% 
  juice()
# A tibble: 233 × 10
   bill_length_mm bill_depth_mm flipper_length_mm body_mass_g species  
            <dbl>         <dbl>             <int>       <int> <fct>    
 1           36.4          17.1               184        2850 Chinstrap
 2           52.1          17                 230        5550 Adelie   
 3           50.8          18.5               201        4450 Adelie   
 4           37.2          18.1               178        3900 Adelie   
 5           45.5          14.5               212        4750 Chinstrap
 6           39            17.5               186        3550 Chinstrap
 7           48.2          15.6               221        5100 Chinstrap
 8           52            19                 197        4150 Adelie   
 9           45.4          14.6               211        4800 Chinstrap
10           38.1          17                 181        3175 Chinstrap
# ℹ 223 more rows
# ℹ 5 more variables: island_Biscoe <dbl>, island_Dream <dbl>,
#   island_Torgersen <dbl>, sex_female <dbl>, sex_male <dbl>

This recipe() adds a step to the previous that normalizes all continuous variables.

recipe_3_obj <- recipe_2_obj %>% 
  step_normalize(all_numeric_predictors())

recipe_3_obj %>% 
  prep() %>% 
  juice()
# A tibble: 233 × 10
   bill_length_mm bill_depth_mm flipper_length_mm body_mass_g species  
            <dbl>         <dbl>             <dbl>       <dbl> <fct>    
 1         -1.41       -0.00476           -1.24        -1.72  Chinstrap
 2          1.37       -0.0552             2.02         1.63  Adelie   
 3          1.14        0.701             -0.0360       0.266 Adelie   
 4         -1.26        0.499             -1.67        -0.415 Adelie   
 5          0.204      -1.32               0.745        0.637 Chinstrap
 6         -0.945       0.197             -1.10        -0.849 Chinstrap
 7          0.682      -0.761              1.38         1.07  Chinstrap
 8          1.35        0.953             -0.320       -0.106 Adelie   
 9          0.186      -1.26               0.674        0.699 Chinstrap
10         -1.10       -0.0552            -1.46        -1.31  Chinstrap
# ℹ 223 more rows
# ℹ 5 more variables: island_Biscoe <dbl>, island_Dream <dbl>,
#   island_Torgersen <dbl>, sex_female <dbl>, sex_male <dbl>

Model Specifications

Referencing another previous post on the parsnip package, three model specifications are constructed. The first is a logistic regression, using the logistic_reg() function; the second is a random forest, using the rand_forest() function; the third is a k-nearest neighbors, using the nearest_neighbor() function.

Notice that two of the model specifications, the random forest and the k-nearest neighbors, had their arguments set to tune(). As such, their hyperparameters need to be tuned. This will be handled in the next section.

mod_logreg_spec <- logistic_reg() %>% 
  set_engine("glm")

mod_logreg_spec
Logistic Regression Model Specification (classification)

Computational engine: glm 
mod_rf_spec <- rand_forest() %>%
  set_mode("classification") %>% 
  set_engine("randomForest", importance = TRUE) %>% 
  set_args(
    trees = tune(),
    mtry  = tune(),
    min_n = tune()
  )

mod_rf_spec
Random Forest Model Specification (classification)

Main Arguments:
  mtry = tune()
  trees = tune()
  min_n = tune()

Engine-Specific Arguments:
  importance = TRUE

Computational engine: randomForest 
mod_knn_spec <- nearest_neighbor() %>% 
  set_mode("classification") %>% 
  set_engine("kknn") %>% 
  set_args(neighbors = tune())

mod_knn_spec
K-Nearest Neighbor Model Specification (classification)

Main Arguments:
  neighbors = tune()

Computational engine: kknn 

Model Building Process

Now that three recipe() objects and three model specifications have been created, they can be used to create a workflow_set() and fit some models.

Specify a workflowset

The first step is to specify a workflow_set() object. To do this, a list of recipe() objects and a list of model specifications need to be passed to the preproc and models arguments, respectively. The cross argument can be used to create all combinations of the objects in preproc and models. The default is TRUE. If set to FALSE, then the length of preproc and models should be equal and they are matched pairwise.

models_wset <- workflow_set(
  preproc = list(recipe_1_obj, recipe_2_obj, recipe_3_obj),
  models  = list(mod_logreg_spec, mod_rf_spec, mod_knn_spec),
  cross   = TRUE
)

models_wset
# A workflow set/tibble: 9 × 4
  wflow_id                  info             option    result    
  <chr>                     <list>           <list>    <list>    
1 recipe_1_logistic_reg     <tibble [1 × 4]> <opts[0]> <list [0]>
2 recipe_1_rand_forest      <tibble [1 × 4]> <opts[0]> <list [0]>
3 recipe_1_nearest_neighbor <tibble [1 × 4]> <opts[0]> <list [0]>
4 recipe_2_logistic_reg     <tibble [1 × 4]> <opts[0]> <list [0]>
5 recipe_2_rand_forest      <tibble [1 × 4]> <opts[0]> <list [0]>
6 recipe_2_nearest_neighbor <tibble [1 × 4]> <opts[0]> <list [0]>
7 recipe_3_logistic_reg     <tibble [1 × 4]> <opts[0]> <list [0]>
8 recipe_3_rand_forest      <tibble [1 × 4]> <opts[0]> <list [0]>
9 recipe_3_nearest_neighbor <tibble [1 × 4]> <opts[0]> <list [0]>

Model Hyperparameter Tuning

Every combination of recipe() and model specification has been created, but two of the model specifications had their hyperparameter values set to tune(). The workflowsets package provides a method for tuning all the hyperparameters at once.

During hyperparameter tuning, 10-fold cross-validation will be used. The vfold_cv() function from rsample package and the metric to be optimized during the tuning process is the Area Under the ROC Curve, using the roc_auc() function from the yardstick package. More information on the yardstick package can be found in a previous post.

set.seed(1914)
folds_tbl <- vfold_cv(penguins_train_tbl, v = 10)

To tune all the hyperparameters for each workflow() at the same time, the workflow_map() function can be used on the workflow_set() object. Setting argument fn = "tune_grid" maps the tune_grid() function from the tune package. More information on tuning can be found in a previous post.

The result of vfold_cv() if passed to the resamples argument. The argument grid = 10 means that 10 values of each hyperparameter set to tune() are tried. The metric to optimize is set to roc_auc() (area under the ROC curve) using the metric_set() function from the yardstick package. Finally, verbose = FALSE is used to suppress any messages printed out to the screen as the models are being tuned. In an interactive session, I would recommend setting verbose = TRUE.

The final result has the tuned result for each workflow().

set.seed(1915)

models_tuned_wset <- workflow_map(
  models_wset,
  fn = "tune_grid",
  resamples = folds_tbl,
  grid = 10,
  metrics = metric_set(roc_auc),
  verbose = TRUE,
  seed = 1915
)

models_tuned_wset
# A workflow set/tibble: 9 × 4
  wflow_id                  info             option    result   
  <chr>                     <list>           <list>    <list>   
1 recipe_1_logistic_reg     <tibble [1 × 4]> <opts[3]> <rsmp[+]>
2 recipe_1_rand_forest      <tibble [1 × 4]> <opts[3]> <tune[+]>
3 recipe_1_nearest_neighbor <tibble [1 × 4]> <opts[3]> <tune[+]>
4 recipe_2_logistic_reg     <tibble [1 × 4]> <opts[3]> <rsmp[+]>
5 recipe_2_rand_forest      <tibble [1 × 4]> <opts[3]> <tune[+]>
6 recipe_2_nearest_neighbor <tibble [1 × 4]> <opts[3]> <tune[+]>
7 recipe_3_logistic_reg     <tibble [1 × 4]> <opts[3]> <rsmp[+]>
8 recipe_3_rand_forest      <tibble [1 × 4]> <opts[3]> <tune[+]>
9 recipe_3_nearest_neighbor <tibble [1 × 4]> <opts[3]> <tune[+]>

The autoplot() function can be used on the tuned results to show how the tuning process for each model went. Setting metric = "roc_auc" will print out the AUC for each workflow(), along with its confidence interval, color-coded by the model algorithm. Higher values of AUC are better.

autoplot(models_tuned_wset, metric = "roc_auc")

Adding select_best = TRUE will print out the best for each modeling algorithm. The default is to print the top three for each algorithm.

autoplot(models_tuned_wset, metric = "roc_auc", select_best = TRUE)

Final Model Selection

To select the final model, the tuned workflow_set() object can be passed to the rank_results() function. This function will rank models by the best mean performance metrics over the cross-validation folds. Setting the argument rank_metric = "roc_auc" ranks the models by the AUC. Setting the argument select_best = TRUE will return only the best models, rather than all of them.

models_tuned_wset %>% 
  rank_results(rank_metric = "roc_auc", select_best = TRUE) %>% 
  select(rank, mean, model, wflow_id, .config)
# A tibble: 9 × 5
   rank  mean model            wflow_id                  .config              
  <int> <dbl> <chr>            <chr>                     <chr>                
1     1 0.916 rand_forest      recipe_1_rand_forest      Preprocessor1_Model04
2     2 0.911 rand_forest      recipe_3_rand_forest      Preprocessor1_Model07
3     3 0.911 rand_forest      recipe_2_rand_forest      Preprocessor1_Model07
4     4 0.904 logistic_reg     recipe_1_logistic_reg     Preprocessor1_Model1 
5     5 0.904 logistic_reg     recipe_2_logistic_reg     Preprocessor1_Model1 
6     6 0.904 logistic_reg     recipe_3_logistic_reg     Preprocessor1_Model1 
7     7 0.883 nearest_neighbor recipe_1_nearest_neighbor Preprocessor1_Model02
8     8 0.883 nearest_neighbor recipe_3_nearest_neighbor Preprocessor1_Model02
9     9 0.883 nearest_neighbor recipe_2_nearest_neighbor Preprocessor1_Model02

To extract a specific model, the extract_workflow_set_result() can be used. The exact ID of the model can be specified with the id argument. This will return the cross-validation results for that specific model.

models_tuned_wset %>% 
  extract_workflow_set_result(id = "recipe_1_rand_forest")
# Tuning results
# 10-fold cross-validation 
# A tibble: 10 × 4
   splits           id     .metrics          .notes          
   <list>           <chr>  <list>            <list>          
 1 <split [209/24]> Fold01 <tibble [10 × 7]> <tibble [0 × 3]>
 2 <split [209/24]> Fold02 <tibble [10 × 7]> <tibble [0 × 3]>
 3 <split [209/24]> Fold03 <tibble [10 × 7]> <tibble [0 × 3]>
 4 <split [210/23]> Fold04 <tibble [10 × 7]> <tibble [0 × 3]>
 5 <split [210/23]> Fold05 <tibble [10 × 7]> <tibble [0 × 3]>
 6 <split [210/23]> Fold06 <tibble [10 × 7]> <tibble [0 × 3]>
 7 <split [210/23]> Fold07 <tibble [10 × 7]> <tibble [0 × 3]>
 8 <split [210/23]> Fold08 <tibble [10 × 7]> <tibble [0 × 3]>
 9 <split [210/23]> Fold09 <tibble [10 × 7]> <tibble [0 × 3]>
10 <split [210/23]> Fold10 <tibble [10 × 7]> <tibble [0 × 3]>

To extract the tuned hypearparameters of a specific model, pass the result of the extract_workflow_set_result() function into the select_best() function. Specifying the argument metric = "roc_auc" ensures that the combination of hyperparameters selected are the ones that, on average, maximized the AUC.

best_params_tbl <- models_tuned_wset %>% 
  extract_workflow_set_result(id = "recipe_1_rand_forest") %>% 
  select_best(metric = "roc_auc")

best_params_tbl
# A tibble: 1 × 4
   mtry trees min_n .config              
  <int> <int> <int> <chr>                
1     1   340    20 Preprocessor1_Model04

Finally, the final model can be fit by passing the extracted workflow_set() of tuned models into the extract_workflow() function, specifying the ID of the desired model using the id argument. From there, this can be passed into the finalize_workflow() function along with the tibble of optimal parameters. This process is similar to that of a regular workflow(). Additional information can found in a previous post on the workflows package.

mod_final_spec_wflw <- models_tuned_wset %>% 
  extract_workflow(id = "recipe_1_rand_forest") %>% 
  finalize_workflow(best_params_tbl)

mod_final_spec_wflw
══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: rand_forest()

── Preprocessor ────────────────────────────────────────────────────────────────
0 Recipe Steps

── Model ───────────────────────────────────────────────────────────────────────
Random Forest Model Specification (classification)

Main Arguments:
  mtry = 1
  trees = 340
  min_n = 20

Engine-Specific Arguments:
  importance = TRUE

Computational engine: randomForest 

This final workflow() specification can be passed to the fit() function to fit the model.

mod_final_fit_wflw <- mod_final_spec_wflw %>% 
  fit(penguins_train_tbl)

mod_final_fit_wflw
══ Workflow [trained] ══════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: rand_forest()

── Preprocessor ────────────────────────────────────────────────────────────────
0 Recipe Steps

── Model ───────────────────────────────────────────────────────────────────────

Call:
 randomForest(x = maybe_data_frame(x), y = y, ntree = ~340L, mtry = min_cols(~1L,      x), nodesize = min_rows(~20L, x), importance = ~TRUE) 
               Type of random forest: classification
                     Number of trees: 340
No. of variables tried at each split: 1

        OOB estimate of  error rate: 13.3%
Confusion matrix:
          Adelie Chinstrap class.error
Adelie       104        17   0.1404959
Chinstrap     14        98   0.1250000

Now that there is a final model, it can be used to make predictions. Among other methods, the augment() function can be used to calculate both probability and class predictions from the test data by passing in the final workflow() object and the test data. More information on the augment() function can be found in a previous post on the broom package.

mod_final_preds_tbl <- mod_final_fit_wflw %>% 
  augment(penguins_test_tbl) %>% 
  relocate(starts_with(".pred"), everything())

mod_final_preds_tbl
# A tibble: 100 × 10
   .pred_class .pred_Adelie .pred_Chinstrap species   island    bill_length_mm
   <fct>              <dbl>           <dbl> <fct>     <fct>              <dbl>
 1 Chinstrap         0.15            0.85   Chinstrap Torgersen           40.3
 2 Adelie            0.918           0.0824 Adelie    Torgersen           39.2
 3 Chinstrap         0.229           0.771  Chinstrap Torgersen           38.7
 4 Adelie            0.885           0.115  Adelie    Torgersen           46  
 5 Chinstrap         0.0971          0.903  Chinstrap Biscoe              37.8
 6 Adelie            0.629           0.371  Adelie    Biscoe              37.7
 7 Adelie            0.815           0.185  Adelie    Biscoe              38.2
 8 Adelie            0.788           0.212  Chinstrap Biscoe              40.6
 9 Chinstrap         0.0294          0.971  Chinstrap Dream               36.4
10 Chinstrap         0.291           0.709  Adelie    Dream               42.2
# ℹ 90 more rows
# ℹ 4 more variables: bill_depth_mm <dbl>, flipper_length_mm <int>,
#   body_mass_g <int>, sex <fct>

Out-of-sample performance on the test data can be calculated using functions such as conf_mat(), for the confusion matrix, and roc_auc(), for the AUC, by passing in the test data predictions. Overall, it would seem that this final model does a good job at predicting out-of-sample data due to its large AUC and relatively few incorrectly classified outcome classes.

mod_final_preds_tbl %>% 
  conf_mat(truth = species, estimate = .pred_class)
           Truth
Prediction  Adelie Chinstrap
  Adelie        38        11
  Chinstrap      8        43
mod_final_preds_tbl %>% 
  roc_auc(truth = species, estimate = .pred_Adelie)
# A tibble: 1 × 3
  .metric .estimator .estimate
  <chr>   <chr>          <dbl>
1 roc_auc binary         0.882

Notes

This post is based on a presentation that was given on the date listed. It may be updated from time to time to fix errors, detail new functions, and/or remove deprecated functions so the packages and R version will likely be newer than what was available at the time.

The R session information used for this post:

R version 4.2.1 (2022-06-23)
Platform: aarch64-apple-darwin20 (64-bit)
Running under: macOS 14.1.1

Matrix products: default
BLAS:   /Library/Frameworks/R.framework/Versions/4.2-arm64/Resources/lib/libRblas.0.dylib
LAPACK: /Library/Frameworks/R.framework/Versions/4.2-arm64/Resources/lib/libRlapack.dylib

locale:
[1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8

attached base packages:
[1] stats     graphics  grDevices datasets  utils     methods   base     

other attached packages:
 [1] broom_1.0.5        workflows_1.1.2    yardstick_1.1.0    tune_1.1.2        
 [5] dials_1.1.0        scales_1.2.1       parsnip_1.0.3      recipes_1.0.4     
 [9] dplyr_1.1.3        rsample_1.2.0      workflowsets_1.0.1

loaded via a namespace (and not attached):
 [1] tidyr_1.3.0          jsonlite_1.8.0       splines_4.2.1       
 [4] foreach_1.5.2        prodlim_2019.11.13   GPfit_1.0-8         
 [7] renv_0.16.0          yaml_2.3.5           globals_0.16.2      
[10] ipred_0.9-13         pillar_1.9.0         backports_1.4.1     
[13] lattice_0.20-45      glue_1.6.2           digest_0.6.29       
[16] randomForest_4.7-1.1 hardhat_1.2.0        colorspace_2.0-3    
[19] htmltools_0.5.3      Matrix_1.4-1         timeDate_4022.108   
[22] pkgconfig_2.0.3      lhs_1.1.6            DiceDesign_1.9      
[25] listenv_0.8.0        purrr_1.0.2          gower_1.0.1         
[28] lava_1.7.1           timechange_0.1.1     tibble_3.2.1        
[31] farver_2.1.1         generics_0.1.3       ggplot2_3.4.0       
[34] ellipsis_0.3.2       withr_2.5.1          furrr_0.3.1         
[37] nnet_7.3-17          cli_3.6.1            survival_3.3-1      
[40] magrittr_2.0.3       evaluate_0.16        fansi_1.0.4         
[43] future_1.29.0        parallelly_1.32.1    MASS_7.3-57         
[46] class_7.3-20         tools_4.2.1          lifecycle_1.0.3     
[49] stringr_1.5.0        munsell_0.5.0        compiler_4.2.1      
[52] rlang_1.1.1          grid_4.2.1           iterators_1.0.14    
[55] rstudioapi_0.14      labeling_0.4.2       rmarkdown_2.16      
[58] gtable_0.3.1         codetools_0.2-18     R6_2.5.1            
[61] lubridate_1.9.0      knitr_1.40           fastmap_1.1.0       
[64] future.apply_1.10.0  utf8_1.2.3           stringi_1.7.12      
[67] parallel_4.2.1       Rcpp_1.0.9           vctrs_0.6.3         
[70] rpart_4.1.19         tidyselect_1.2.0     xfun_0.40