2.5 Resampling
Resampling strategies are usually used to assess the performance of a learning algorithm.
mlr3
entails the following predefined resampling strategies:
cross validation
("cv"
),
leave-one-out cross validation
("loo"
),
repeated cross validation
("repeated_cv"
),
bootstrapping
("bootstrap"
),
subsampling
("subsampling"
),
holdout
("holdout"
),
in-sample resampling
("insample"
), and
custom resampling
("custom"
).
The following sections provide guidance on how to set and select a resampling strategy and how to subsequently instantiate the resampling process.
Below you can find a graphical illustration of the resampling process:
2.5.1 Settings
In this example we use the iris task and a simple classification tree from the rpart package.
= tsk("iris")
task = lrn("classif.rpart") learner
When performing resampling with a dataset, we first need to define which approach should be used.
mlr3 resampling strategies and their parameters can be queried by looking at the data.table output of the mlr_resamplings
dictionary:
as.data.table(mlr_resamplings)
## key params iters
## 1: bootstrap repeats,ratio 30
## 2: custom 0
## 3: cv folds 10
## 4: holdout ratio 1
## 5: insample 1
## 6: loo NA
## 7: repeated_cv repeats,folds 100
## 8: subsampling repeats,ratio 30
Additional resampling methods for special use cases will be available via extension packages, such as mlr3spatiotemporal for spatial data (still in development).
The model fit conducted in the train/predict/score chapter is equivalent to a “holdout resampling,” so let’s consider this one first.
Again, we can retrieve elements from the dictionary mlr_resamplings
via $get()
or with the convenience functionrsmp()
:
= rsmp("holdout")
resampling print(resampling)
## <ResamplingHoldout> with 1 iterations
## * Instantiated: FALSE
## * Parameters: ratio=0.6667
Note that the $is_instantiated
field is set to FALSE
.
This means we did not actually apply the strategy on a dataset yet.
Applying the strategy on a dataset is done in the next section Instantiation.
By default we get a .66/.33 split of the data. There are two ways in which the ratio can be changed:
- Overwriting the slot in
$param_set$values
using a named list:
$param_set$values = list(ratio = 0.8) resampling
- Specifying the resampling parameters directly during construction:
rsmp("holdout", ratio = 0.8)
## <ResamplingHoldout> with 1 iterations
## * Instantiated: FALSE
## * Parameters: ratio=0.8
2.5.2 Instantiation
So far we just set the stage and selected the resampling strategy.
To actually perform the splitting and obtain indices for the training and the test split the resampling needs a Task
.
By calling the method instantiate()
, we split the indices of the data into indices for training and test sets.
These resulting indices are stored in the Resampling
object:
= rsmp("cv", folds = 3L)
resampling $instantiate(task)
resampling$iters resampling
## [1] 3
str(resampling$train_set(1))
## int [1:100] 4 5 7 9 18 20 31 32 33 37 ...
str(resampling$test_set(1))
## int [1:50] 1 6 8 11 12 21 22 23 24 26 ...
2.5.3 Execution
With a Task
, a Learner
and a Resampling
object we can call resample()
, which fits the learner to the task at hand according to the given resampling strategy.
This in turn creates a ResampleResult
object.
Before we go into more detail, let’s change the resampling to a “3-fold cross-validation” to better illustrate what operations are possible with a ResampleResult
.
Additionally, when actually fitting the models, we tell resample()
to keep the fitted models by setting the store_models
option to TRUE
:
= tsk("pima")
task = lrn("classif.rpart", maxdepth = 3, predict_type = "prob")
learner = rsmp("cv", folds = 3L)
resampling
= resample(task, learner, resampling, store_models = TRUE)
rr print(rr)
## <ResampleResult> of 3 iterations
## * Task: pima
## * Learner: classif.rpart
## * Warnings: 0 in 0 iterations
## * Errors: 0 in 0 iterations
The following operations are supported with ResampleResult
objects:
Calculate the average performance across all resampling iterations:
$aggregate(msr("classif.ce")) rr
## classif.ce
## 0.2839
Extract the performance for the individual resampling iterations:
$score(msr("classif.ce")) rr
## task task_id learner learner_id
## 1: <TaskClassif[45]> pima <LearnerClassifRpart[34]> classif.rpart
## 2: <TaskClassif[45]> pima <LearnerClassifRpart[34]> classif.rpart
## 3: <TaskClassif[45]> pima <LearnerClassifRpart[34]> classif.rpart
## resampling resampling_id iteration prediction
## 1: <ResamplingCV[19]> cv 1 <PredictionClassif[19]>
## 2: <ResamplingCV[19]> cv 2 <PredictionClassif[19]>
## 3: <ResamplingCV[19]> cv 3 <PredictionClassif[19]>
## classif.ce
## 1: 0.2656
## 2: 0.2500
## 3: 0.3359
Check for warnings or errors:
$warnings rr
## Empty data.table (0 rows and 2 cols): iteration,msg
$errors rr
## Empty data.table (0 rows and 2 cols): iteration,msg
Extract and inspect the resampling splits:
$resampling rr
## <ResamplingCV> with 3 iterations
## * Instantiated: TRUE
## * Parameters: folds=3
$resampling$iters rr
## [1] 3
str(rr$resampling$test_set(1))
## int [1:256] 2 4 5 7 10 13 15 21 22 24 ...
str(rr$resampling$train_set(1))
## int [1:512] 6 17 18 23 28 29 30 33 35 38 ...
Retrieve the learner of a specific iteration and inspect it:
= rr$learners[[1]]
lrn $model lrn
## n= 512
##
## node), split, n, loss, yval, (yprob)
## * denotes terminal node
##
## 1) root 512 175 neg (0.34180 0.65820)
## 2) glucose>=123.5 208 86 pos (0.58654 0.41346)
## 4) glucose>=154.5 83 18 pos (0.78313 0.21687) *
## 5) glucose< 154.5 125 57 neg (0.45600 0.54400)
## 10) mass>=26.3 107 51 pos (0.52336 0.47664) *
## 11) mass< 26.3 18 1 neg (0.05556 0.94444) *
## 3) glucose< 123.5 304 53 neg (0.17434 0.82566) *
Extract the predictions:
$prediction() # all predictions merged into a single Prediction rr
## <PredictionClassif> for 768 observations:
## row_id truth response prob.pos prob.neg
## 2 neg neg 0.1743 0.8257
## 4 neg neg 0.1743 0.8257
## 5 pos pos 0.5234 0.4766
## ---
## 764 neg neg 0.1823 0.8177
## 767 pos neg 0.1823 0.8177
## 768 neg neg 0.1823 0.8177
$predictions()[[1]] # prediction of first resampling iteration rr
## <PredictionClassif> for 256 observations:
## row_id truth response prob.pos prob.neg
## 2 neg neg 0.1743 0.8257
## 4 neg neg 0.1743 0.8257
## 5 pos pos 0.5234 0.4766
## ---
## 760 pos pos 0.7831 0.2169
## 761 neg neg 0.1743 0.8257
## 766 neg neg 0.1743 0.8257
Note that if you want to compare multiple Learners in a fair manner, it is important to ensure that each learner operates on the same resampling instance. This can be achieved by manually instantiating the instance before fitting model(s) on it.
Hint: If your aim is to compare different Task
, Learner
or Resampling
, you are better off using the benchmark()
function which is covered in the next section on benchmarking.
It is a wrapper around resample()
, simplifying the handling of large comparison grids.
If you discover this only after you’ve run multiple resample()
calls, don’t worry.
You can combine multiple ResampleResult
objects into a BenchmarkResult
(also explained in the section benchmarking).
2.5.4 Custom resampling
Sometimes it is necessary to perform resampling with custom splits. If you want to do that because you are coming from a specific modeling field, first take a look at the mlr3 extension packages, to check whether your resampling method has been implemented already. If this is not the case, feel welcome to extend an existing package or create your own extension package.
A manual resampling instance can be created using the "custom"
template.
= rsmp("custom")
resampling $instantiate(task,
resamplingtrain = list(c(1:10, 51:60, 101:110)),
test = list(c(11:20, 61:70, 111:120))
)$iters resampling
## [1] 1
$train_set(1) resampling
## [1] 1 2 3 4 5 6 7 8 9 10 51 52 53 54 55 56 57 58 59
## [20] 60 101 102 103 104 105 106 107 108 109 110
$test_set(1) resampling
## [1] 11 12 13 14 15 16 17 18 19 20 61 62 63 64 65 66 67 68 69
## [20] 70 111 112 113 114 115 116 117 118 119 120
2.5.5 Plotting Resample Results
Again, mlr3viz provides a autoplot()
method.
library("mlr3viz")
autoplot(rr)
autoplot(rr, type = "roc")
All available plot types are listed on the manual page of autoplot.ResampleResult()
.