2.5 Resampling

Resampling strategies are usually used to assess the performance of a learning algorithm. mlr3 entails the following predefined resampling strategies:

The following sections provide guidance on how to set and select a resampling strategy and how to subsequently instantiate the resampling process.

Here is a graphical illustration of the resampling process:

2.5.1 Settings

In this example we use the penguins task and a simple classification tree from the rpart package once again.

library("mlr3verse")

task = tsk("penguins")
learner = lrn("classif.rpart")

When performing resampling with a dataset, we first need to define which approach should be used. mlr3 resampling strategies and their parameters can be queried by looking at the data.table output of the mlr_resamplings dictionary:

as.data.table(mlr_resamplings)
##            key        params iters
## 1:   bootstrap repeats,ratio    30
## 2:      custom                   0
## 3:          cv         folds    10
## 4:     holdout         ratio     1
## 5:    insample                   1
## 6:         loo                  NA
## 7: repeated_cv repeats,folds   100
## 8: subsampling repeats,ratio    30

Additional resampling methods for special use cases will be available via extension packages, such as mlr3spatiotemporal for spatial data.

The model fit conducted in the train/predict/score chapter is equivalent to a “holdout resampling,” so let’s consider this one first. Again, we can retrieve elements from the dictionary mlr_resamplings via $get() or with the convenience functionrsmp():

resampling = rsmp("holdout")
print(resampling)
## <ResamplingHoldout> with 1 iterations
## * Instantiated: FALSE
## * Parameters: ratio=0.6667

Note that the $is_instantiated field is set to FALSE. This means we did not actually apply the strategy on a dataset yet. Applying the strategy on a dataset is done in the next section Instantiation.

By default we get a .66/.33 split of the data. There are two ways in which the ratio can be changed:

  1. Overwriting the slot in $param_set$values using a named list:
resampling$param_set$values = list(ratio = 0.8)
  1. Specifying the resampling parameters directly during construction:
rsmp("holdout", ratio = 0.8)
## <ResamplingHoldout> with 1 iterations
## * Instantiated: FALSE
## * Parameters: ratio=0.8

2.5.2 Instantiation

So far we just set the stage and selected the resampling strategy.

To actually perform the splitting and obtain indices for the training and the test split the resampling needs a Task. By calling the method instantiate(), we split the indices of the data into indices for training and test sets. These resulting indices are stored in the Resampling objects. To better illustrate the following operations, we switch to a 3-fold cross-validation:

resampling = rsmp("cv", folds = 3)
resampling$instantiate(task)
resampling$iters
## [1] 3
str(resampling$train_set(1))
##  int [1:229] 3 6 10 11 17 18 28 31 34 44 ...
str(resampling$test_set(1))
##  int [1:115] 1 2 4 7 13 15 16 20 27 29 ...

Note that if you want to compare multiple Learners in a fair manner, using the same instantiated resampling for each learner is mandatory. A way to greatly simplify the comparison of multiple learners is discussed in the next section on benchmarking.

2.5.3 Execution

With a Task, a Learner and a Resampling object we can call resample(), which repeatedly fits the learner to the task at hand according to the given resampling strategy. This in turn creates a ResampleResult object. We tell resample() to keep the fitted models by setting the store_models option to TRUEand then start the computation:

task = tsk("penguins")
learner = lrn("classif.rpart", maxdepth = 3, predict_type = "prob")
resampling = rsmp("cv", folds = 3)

rr = resample(task, learner, resampling, store_models = TRUE)
print(rr)
## <ResampleResult> of 3 iterations
## * Task: penguins
## * Learner: classif.rpart
## * Warnings: 0 in 0 iterations
## * Errors: 0 in 0 iterations

The returned ResampleResult stored as rr provides various getters to access the stored information:

  • Calculate the average performance across all resampling iterations:

    rr$aggregate(msr("classif.ce"))
    ## classif.ce 
    ##    0.06982
  • Extract the performance for the individual resampling iterations:

    rr$score(msr("classif.ce"))
    ##                 task  task_id                   learner    learner_id
    ## 1: <TaskClassif[46]> penguins <LearnerClassifRpart[34]> classif.rpart
    ## 2: <TaskClassif[46]> penguins <LearnerClassifRpart[34]> classif.rpart
    ## 3: <TaskClassif[46]> penguins <LearnerClassifRpart[34]> classif.rpart
    ##            resampling resampling_id iteration              prediction
    ## 1: <ResamplingCV[19]>            cv         1 <PredictionClassif[19]>
    ## 2: <ResamplingCV[19]>            cv         2 <PredictionClassif[19]>
    ## 3: <ResamplingCV[19]>            cv         3 <PredictionClassif[19]>
    ##    classif.ce
    ## 1:    0.07826
    ## 2:    0.04348
    ## 3:    0.08772
  • Check for warnings or errors:

    rr$warnings
    ## Empty data.table (0 rows and 2 cols): iteration,msg
    rr$errors
    ## Empty data.table (0 rows and 2 cols): iteration,msg
  • Extract and inspect the resampling splits:

    rr$resampling
    ## <ResamplingCV> with 3 iterations
    ## * Instantiated: TRUE
    ## * Parameters: folds=3
    rr$resampling$iters
    ## [1] 3
    str(rr$resampling$test_set(1))
    ##  int [1:115] 4 8 9 12 13 17 18 22 23 24 ...
    str(rr$resampling$train_set(1))
    ##  int [1:229] 1 6 7 10 11 15 16 20 21 26 ...
  • Retrieve the learner of a specific iteration and inspect it:

    lrn = rr$learners[[1]]
    lrn$model
    ## n= 229 
    ## 
    ## node), split, n, loss, yval, (yprob)
    ##       * denotes terminal node
    ## 
    ##  1) root 229 131 Adelie (0.42795 0.21397 0.35808)  
    ##    2) flipper_length< 207.5 141  45 Adelie (0.68085 0.31915 0.00000)  
    ##      4) bill_length< 42.35 91   1 Adelie (0.98901 0.01099 0.00000) *
    ##      5) bill_length>=42.35 50   6 Chinstrap (0.12000 0.88000 0.00000)  
    ##       10) body_mass>=4062 10   4 Adelie (0.60000 0.40000 0.00000) *
    ##       11) body_mass< 4062 40   0 Chinstrap (0.00000 1.00000 0.00000) *
    ##    3) flipper_length>=207.5 88   6 Gentoo (0.02273 0.04545 0.93182)  
    ##      6) bill_depth>=17.05 7   3 Chinstrap (0.28571 0.57143 0.14286) *
    ##      7) bill_depth< 17.05 81   0 Gentoo (0.00000 0.00000 1.00000) *
  • Extract the predictions:

    rr$prediction() # all predictions merged into a single Prediction object
    ## <PredictionClassif> for 344 observations:
    ##     row_ids     truth  response prob.Adelie prob.Chinstrap prob.Gentoo
    ##           4    Adelie    Adelie     0.98901        0.01099     0.00000
    ##           8    Adelie    Adelie     0.98901        0.01099     0.00000
    ##           9    Adelie    Adelie     0.98901        0.01099     0.00000
    ## ---                                                                   
    ##         338 Chinstrap Chinstrap     0.05263        0.92105     0.02632
    ##         340 Chinstrap    Gentoo     0.00000        0.03333     0.96667
    ##         342 Chinstrap Chinstrap     0.05263        0.92105     0.02632
    rr$predictions()[[1]] # prediction of first resampling iteration
    ## <PredictionClassif> for 115 observations:
    ##     row_ids     truth  response prob.Adelie prob.Chinstrap prob.Gentoo
    ##           4    Adelie    Adelie       0.989        0.01099           0
    ##           8    Adelie    Adelie       0.989        0.01099           0
    ##           9    Adelie    Adelie       0.989        0.01099           0
    ## ---                                                                   
    ##         326 Chinstrap Chinstrap       0.000        1.00000           0
    ##         337 Chinstrap Chinstrap       0.000        1.00000           0
    ##         339 Chinstrap Chinstrap       0.000        1.00000           0
  • Filter to only keep specified iterations:

    rr$filter(c(1, 3))
    print(rr)
    ## <ResampleResult> of 2 iterations
    ## * Task: penguins
    ## * Learner: classif.rpart
    ## * Warnings: 0 in 0 iterations
    ## * Errors: 0 in 0 iterations

2.5.4 Custom resampling

Sometimes it is necessary to perform resampling with custom splits, e.g. to reproduce results reported in a study. A manual resampling instance can be created using the "custom" template.

resampling = rsmp("custom")
resampling$instantiate(task,
  train = list(c(1:10, 51:60, 101:110)),
  test = list(c(11:20, 61:70, 111:120))
)
resampling$iters
## [1] 1
resampling$train_set(1)
##  [1]   1   2   3   4   5   6   7   8   9  10  51  52  53  54  55  56  57  58  59
## [20]  60 101 102 103 104 105 106 107 108 109 110
resampling$test_set(1)
##  [1]  11  12  13  14  15  16  17  18  19  20  61  62  63  64  65  66  67  68  69
## [20]  70 111 112 113 114 115 116 117 118 119 120

2.5.5 Resampling with predefined groups

In contrast to defining column role "group", which denotes that specific observations should always appear together in either test or training set, one can also supply a factor variable to pre-define all partitions (Still WIP in {mlr3}).

This means that each factor level of this variable is solely composing the test set. Hence, this method does not allow setting the “folds” argument because the number of folds is determined by the number of factor levels.

This predefined approach was called “blocking” in mlr2. It should not be confused with the term “blocking” in mlr3spatiotempcv which refers to a category of resampling methods making use of squared/rectangular partitioning.

2.5.6 Plotting Resample Results

mlr3viz provides a autoplot() method. To showcase some of the plots, we create a binary classification task with two features, perform a resampling with a 10-fold cross validation and visualize the results:

task = tsk("pima")
task$select(c("glucose", "mass"))
learner = lrn("classif.rpart", predict_type = "prob")
rr = resample(task, learner, rsmp("cv"), store_models = TRUE)

# boxplot of AUC values across the 10 folds
autoplot(rr, measure = msr("classif.auc"))

# ROC curve, averaged over 10 folds
autoplot(rr, type = "roc")

# learner predictions for first fold
rr$filter(1)
autoplot(rr, type = "prediction")
## Warning: Removed 1 rows containing missing values (geom_point).

All available plot types are listed on the manual page of autoplot.ResampleResult().

2.5.7 Plotting Resample Partitions

mlr3spatiotempcv provides autoplot() methods to visualize resampling partitions of spatiotemporal datasets. See the function reference and vignette “Spatiotemporal visualization” for more info.