3.1 Hyperparameter Tuning

Hyperparameters are second-order parameters of machine learning models that, while often not explicitly optimized during the model estimation process, can have an important impact on the outcome and predictive performance of a model. Typically, hyperparameters are fixed before training a model. However, because the output of a model can be sensitive to the specification of hyperparameters, it is often recommended to make an informed decision about which hyperparameter settings may yield better model performance. In many cases, hyperparameter settings may be chosen a priori, but it can be advantageous to try different settings before fitting your model on the training data. This process is often called model ‘tuning.’

Hyperparameter tuning is supported via the mlr3tuning extension package. Below you can find an illustration of the process:

At the heart of mlr3tuning are the R6 classes:

3.1.1 The TuningInstance* Classes

The following sub-section examines the optimization of a simple classification tree on the Pima Indian Diabetes data set.

library("mlr3verse")
task = tsk("pima")
print(task)
## <TaskClassif:pima> (768 x 9)
## * Target: diabetes
## * Properties: twoclass
## * Features (8):
##   - dbl (8): age, glucose, insulin, mass, pedigree, pregnant, pressure,
##     triceps

We use the classification tree from rpart and choose a subset of the hyperparameters we want to tune. This is often referred to as the “tuning space.”

learner = lrn("classif.rpart")
learner$param_set
## <ParamSet>
##                 id    class lower upper nlevels        default value
##  1:       minsplit ParamInt     1   Inf     Inf             20      
##  2:      minbucket ParamInt     1   Inf     Inf <NoDefault[3]>      
##  3:             cp ParamDbl     0     1     Inf           0.01      
##  4:     maxcompete ParamInt     0   Inf     Inf              4      
##  5:   maxsurrogate ParamInt     0   Inf     Inf              5      
##  6:       maxdepth ParamInt     1    30      30             30      
##  7:   usesurrogate ParamInt     0     2       3              2      
##  8: surrogatestyle ParamInt     0     1       2              0      
##  9:           xval ParamInt     0   Inf     Inf             10     0
## 10:     keep_model ParamLgl    NA    NA       2          FALSE

Here, we opt to tune two parameters:

  • The complexity cp
  • The termination criterion minsplit

The tuning space needs to be bounded, therefore one has to set lower and upper bounds:

search_space = ps(
  cp = p_dbl(lower = 0.001, upper = 0.1),
  minsplit = p_int(lower = 1, upper = 10)
)
search_space
## <ParamSet>
##          id    class lower upper nlevels        default value
## 1:       cp ParamDbl 0.001   0.1     Inf <NoDefault[3]>      
## 2: minsplit ParamInt 1.000  10.0      10 <NoDefault[3]>

Next, we need to specify how to evaluate the performance. For this, we need to choose a resampling strategy and a performance measure.

hout = rsmp("holdout")
measure = msr("classif.ce")

Finally, one has to select the budget available, to solve this tuning instance. This is done by selecting one of the available Terminators:

For this short introduction, we specify a budget of 20 evaluations and then put everything together into a TuningInstanceSingleCrit:

library("mlr3tuning")
## Loading required package: paradox
evals20 = trm("evals", n_evals = 20)

instance = TuningInstanceSingleCrit$new(
  task = task,
  learner = learner,
  resampling = hout,
  measure = measure,
  search_space = search_space,
  terminator = evals20
)
instance
## <TuningInstanceSingleCrit>
## * State:  Not optimized
## * Objective: <ObjectiveTuning:classif.rpart_on_pima>
## * Search Space:
## <ParamSet>
##          id    class lower upper nlevels        default value
## 1:       cp ParamDbl 0.001   0.1     Inf <NoDefault[3]>      
## 2: minsplit ParamInt 1.000  10.0      10 <NoDefault[3]>      
## * Terminator: <TerminatorEvals>
## * Terminated: FALSE
## * Archive:
## <ArchiveTuning>
## Null data.table (0 rows and 0 cols)

To start the tuning, we still need to select how the optimization should take place. In other words, we need to choose the optimization algorithm via the Tuner class.

3.1.2 The Tuner Class

The following algorithms are currently implemented in mlr3tuning:

In this example, we will use a simple grid search with a grid resolution of 5.

tuner = tnr("grid_search", resolution = 5)

Since we have only numeric parameters, TunerGridSearch will create an equidistant grid between the respective upper and lower bounds. As we have two hyperparameters with a resolution of 5, the two-dimensional grid consists of \(5^2 = 25\) configurations. Each configuration serves as a hyperparameter setting for the previously defined Learner which is then fitted on the task using the provided Resampling. All configurations will be examined by the tuner (in a random order), until either all configurations are evaluated or the Terminator signals that the budget is exhausted.

3.1.3 Triggering the Tuning

To start the tuning, we simply pass the TuningInstanceSingleCrit to the $optimize() method of the initialized Tuner. The tuner proceeds as follows:

  1. The Tuner proposes at least one hyperparameter configuration (the Tuner may propose multiple points to improve parallelization, which can be controlled via the setting batch_size).
  2. For each configuration, the given Learner is fitted on the Task using the provided Resampling. All evaluations are stored in the archive of the TuningInstanceSingleCrit.
  3. The Terminator is queried if the budget is exhausted. If the budget is not exhausted, restart with 1) until it is.
  4. Determine the configuration with the best observed performance.
  5. Store the best configurations as result in the instance object. The best hyperparameter settings ($result_learner_param_vals) and the corresponding measured performance ($result_y) can be accessed from the instance.
tuner$optimize(instance)
## INFO  [09:35:46.507] [bbotk] Starting to optimize 2 parameter(s) with '<OptimizerGridSearch>' and '<TerminatorEvals> [n_evals=20]' 
## INFO  [09:35:46.538] [bbotk] Evaluating 1 configuration(s) 
## INFO  [09:35:46.808] [bbotk] Result of batch 1: 
## INFO  [09:35:46.810] [bbotk]     cp minsplit classif.ce                                uhash 
## INFO  [09:35:46.810] [bbotk]  0.001       10      0.293 7f29c986-707f-41bf-b8e8-6d5306c14736 
## INFO  [09:35:46.811] [bbotk] Evaluating 1 configuration(s) 
## INFO  [09:35:46.885] [bbotk] Result of batch 2: 
## INFO  [09:35:46.887] [bbotk]   cp minsplit classif.ce                                uhash 
## INFO  [09:35:46.887] [bbotk]  0.1        1     0.2852 a2d2d43d-32a6-4fc3-9ee6-beffe5e3ed65 
## INFO  [09:35:46.888] [bbotk] Evaluating 1 configuration(s) 
## INFO  [09:35:46.962] [bbotk] Result of batch 3: 
## INFO  [09:35:46.963] [bbotk]   cp minsplit classif.ce                                uhash 
## INFO  [09:35:46.963] [bbotk]  0.1        3     0.2852 653492d5-fece-498e-936f-a26074e7dc74 
## INFO  [09:35:46.965] [bbotk] Evaluating 1 configuration(s) 
## INFO  [09:35:47.041] [bbotk] Result of batch 4: 
## INFO  [09:35:47.042] [bbotk]     cp minsplit classif.ce                                uhash 
## INFO  [09:35:47.042] [bbotk]  0.001        1     0.2852 0f0b7b99-3e6a-46cd-9534-8b314a8536d1 
## INFO  [09:35:47.044] [bbotk] Evaluating 1 configuration(s) 
## INFO  [09:35:47.125] [bbotk] Result of batch 5: 
## INFO  [09:35:47.126] [bbotk]   cp minsplit classif.ce                                uhash 
## INFO  [09:35:47.126] [bbotk]  0.1        5     0.2852 d945d2a7-f600-4957-80f3-6070972903f1 
## INFO  [09:35:47.128] [bbotk] Evaluating 1 configuration(s) 
## INFO  [09:35:47.201] [bbotk] Result of batch 6: 
## INFO  [09:35:47.203] [bbotk]     cp minsplit classif.ce                                uhash 
## INFO  [09:35:47.203] [bbotk]  0.001        8      0.293 bcc78278-ab93-4781-98b9-5b6550aea1b3 
## INFO  [09:35:47.204] [bbotk] Evaluating 1 configuration(s) 
## INFO  [09:35:47.279] [bbotk] Result of batch 7: 
## INFO  [09:35:47.281] [bbotk]      cp minsplit classif.ce                                uhash 
## INFO  [09:35:47.281] [bbotk]  0.0505       10     0.2852 95cef3a9-265a-42f5-8041-f9577a48c546 
## INFO  [09:35:47.282] [bbotk] Evaluating 1 configuration(s) 
## INFO  [09:35:47.362] [bbotk] Result of batch 8: 
## INFO  [09:35:47.363] [bbotk]       cp minsplit classif.ce                                uhash 
## INFO  [09:35:47.363] [bbotk]  0.07525        8     0.2852 6af1f6f4-c98c-40f7-a427-115b8e3c7ac8 
## INFO  [09:35:47.365] [bbotk] Evaluating 1 configuration(s) 
## INFO  [09:35:47.438] [bbotk] Result of batch 9: 
## INFO  [09:35:47.440] [bbotk]      cp minsplit classif.ce                                uhash 
## INFO  [09:35:47.440] [bbotk]  0.0505        5     0.2852 6a38cd0f-65f6-4a5e-a29e-1775ffbf7d29 
## INFO  [09:35:47.441] [bbotk] Evaluating 1 configuration(s) 
## INFO  [09:35:47.516] [bbotk] Result of batch 10: 
## INFO  [09:35:47.517] [bbotk]       cp minsplit classif.ce                                uhash 
## INFO  [09:35:47.517] [bbotk]  0.07525        1     0.2852 412ad69b-f6d9-4db9-98a9-21040b8aa316 
## INFO  [09:35:47.518] [bbotk] Evaluating 1 configuration(s) 
## INFO  [09:35:47.598] [bbotk] Result of batch 11: 
## INFO  [09:35:47.599] [bbotk]       cp minsplit classif.ce                                uhash 
## INFO  [09:35:47.599] [bbotk]  0.02575        1     0.2852 ea61d7dc-edd1-410d-a9ff-5e6d25adbc35 
## INFO  [09:35:47.601] [bbotk] Evaluating 1 configuration(s) 
## INFO  [09:35:47.674] [bbotk] Result of batch 12: 
## INFO  [09:35:47.676] [bbotk]       cp minsplit classif.ce                                uhash 
## INFO  [09:35:47.676] [bbotk]  0.07525        5     0.2852 6cd4f418-ded1-42bd-803e-e7f1c663ac5c 
## INFO  [09:35:47.677] [bbotk] Evaluating 1 configuration(s) 
## INFO  [09:35:47.751] [bbotk] Result of batch 13: 
## INFO  [09:35:47.753] [bbotk]   cp minsplit classif.ce                                uhash 
## INFO  [09:35:47.753] [bbotk]  0.1       10     0.2852 556d4bae-dc1e-43b3-9a7e-698035dc1ec8 
## INFO  [09:35:47.754] [bbotk] Evaluating 1 configuration(s) 
## INFO  [09:35:47.833] [bbotk] Result of batch 14: 
## INFO  [09:35:47.834] [bbotk]      cp minsplit classif.ce                                uhash 
## INFO  [09:35:47.834] [bbotk]  0.0505        8     0.2852 1b3a3313-1d65-4fd8-978f-c6f9c93446ff 
## INFO  [09:35:47.836] [bbotk] Evaluating 1 configuration(s) 
## INFO  [09:35:47.909] [bbotk] Result of batch 15: 
## INFO  [09:35:47.911] [bbotk]       cp minsplit classif.ce                                uhash 
## INFO  [09:35:47.911] [bbotk]  0.02575       10     0.2852 2d8a6ce5-8944-48a4-a2cb-02187425736b 
## INFO  [09:35:47.912] [bbotk] Evaluating 1 configuration(s) 
## INFO  [09:35:47.987] [bbotk] Result of batch 16: 
## INFO  [09:35:47.988] [bbotk]       cp minsplit classif.ce                                uhash 
## INFO  [09:35:47.988] [bbotk]  0.07525        3     0.2852 55acb10f-69ef-4575-866c-0a7eeae9d493 
## INFO  [09:35:47.996] [bbotk] Evaluating 1 configuration(s) 
## INFO  [09:35:48.069] [bbotk] Result of batch 17: 
## INFO  [09:35:48.071] [bbotk]   cp minsplit classif.ce                                uhash 
## INFO  [09:35:48.071] [bbotk]  0.1        8     0.2852 48055c19-a4b5-40ca-80af-8363671f7871 
## INFO  [09:35:48.072] [bbotk] Evaluating 1 configuration(s) 
## INFO  [09:35:48.146] [bbotk] Result of batch 18: 
## INFO  [09:35:48.147] [bbotk]      cp minsplit classif.ce                                uhash 
## INFO  [09:35:48.147] [bbotk]  0.0505        3     0.2852 f8c77038-884d-49fc-9073-dbef0713c8e1 
## INFO  [09:35:48.149] [bbotk] Evaluating 1 configuration(s) 
## INFO  [09:35:48.229] [bbotk] Result of batch 19: 
## INFO  [09:35:48.231] [bbotk]       cp minsplit classif.ce                                uhash 
## INFO  [09:35:48.231] [bbotk]  0.02575        8     0.2852 64192cdb-fc37-4f9e-bc50-f738971788a0 
## INFO  [09:35:48.232] [bbotk] Evaluating 1 configuration(s) 
## INFO  [09:35:48.305] [bbotk] Result of batch 20: 
## INFO  [09:35:48.306] [bbotk]       cp minsplit classif.ce                                uhash 
## INFO  [09:35:48.306] [bbotk]  0.02575        5     0.2852 ace22a1d-1843-4ce5-ae86-5ec9828fc533 
## INFO  [09:35:48.312] [bbotk] Finished optimizing after 20 evaluation(s) 
## INFO  [09:35:48.313] [bbotk] Result: 
## INFO  [09:35:48.314] [bbotk]   cp minsplit learner_param_vals  x_domain classif.ce 
## INFO  [09:35:48.314] [bbotk]  0.1        1          <list[3]> <list[2]>     0.2852
##     cp minsplit learner_param_vals  x_domain classif.ce
## 1: 0.1        1          <list[3]> <list[2]>     0.2852
instance$result_learner_param_vals
## $xval
## [1] 0
## 
## $cp
## [1] 0.1
## 
## $minsplit
## [1] 1
instance$result_y
## classif.ce 
##     0.2852

One can investigate all resamplings which were undertaken, as they are stored in the archive of the TuningInstanceSingleCrit and can be accessed by using as.data.table():

as.data.table(instance$archive)
##          cp minsplit classif.ce                                uhash
##  1: 0.00100       10     0.2930 7f29c986-707f-41bf-b8e8-6d5306c14736
##  2: 0.10000        1     0.2852 a2d2d43d-32a6-4fc3-9ee6-beffe5e3ed65
##  3: 0.10000        3     0.2852 653492d5-fece-498e-936f-a26074e7dc74
##  4: 0.00100        1     0.2852 0f0b7b99-3e6a-46cd-9534-8b314a8536d1
##  5: 0.10000        5     0.2852 d945d2a7-f600-4957-80f3-6070972903f1
##  6: 0.00100        8     0.2930 bcc78278-ab93-4781-98b9-5b6550aea1b3
##  7: 0.05050       10     0.2852 95cef3a9-265a-42f5-8041-f9577a48c546
##  8: 0.07525        8     0.2852 6af1f6f4-c98c-40f7-a427-115b8e3c7ac8
##  9: 0.05050        5     0.2852 6a38cd0f-65f6-4a5e-a29e-1775ffbf7d29
## 10: 0.07525        1     0.2852 412ad69b-f6d9-4db9-98a9-21040b8aa316
## 11: 0.02575        1     0.2852 ea61d7dc-edd1-410d-a9ff-5e6d25adbc35
## 12: 0.07525        5     0.2852 6cd4f418-ded1-42bd-803e-e7f1c663ac5c
## 13: 0.10000       10     0.2852 556d4bae-dc1e-43b3-9a7e-698035dc1ec8
## 14: 0.05050        8     0.2852 1b3a3313-1d65-4fd8-978f-c6f9c93446ff
## 15: 0.02575       10     0.2852 2d8a6ce5-8944-48a4-a2cb-02187425736b
## 16: 0.07525        3     0.2852 55acb10f-69ef-4575-866c-0a7eeae9d493
## 17: 0.10000        8     0.2852 48055c19-a4b5-40ca-80af-8363671f7871
## 18: 0.05050        3     0.2852 f8c77038-884d-49fc-9073-dbef0713c8e1
## 19: 0.02575        8     0.2852 64192cdb-fc37-4f9e-bc50-f738971788a0
## 20: 0.02575        5     0.2852 ace22a1d-1843-4ce5-ae86-5ec9828fc533
##               timestamp batch_nr x_domain_cp x_domain_minsplit
##  1: 2021-06-01 09:35:46        1     0.00100                10
##  2: 2021-06-01 09:35:46        2     0.10000                 1
##  3: 2021-06-01 09:35:46        3     0.10000                 3
##  4: 2021-06-01 09:35:47        4     0.00100                 1
##  5: 2021-06-01 09:35:47        5     0.10000                 5
##  6: 2021-06-01 09:35:47        6     0.00100                 8
##  7: 2021-06-01 09:35:47        7     0.05050                10
##  8: 2021-06-01 09:35:47        8     0.07525                 8
##  9: 2021-06-01 09:35:47        9     0.05050                 5
## 10: 2021-06-01 09:35:47       10     0.07525                 1
## 11: 2021-06-01 09:35:47       11     0.02575                 1
## 12: 2021-06-01 09:35:47       12     0.07525                 5
## 13: 2021-06-01 09:35:47       13     0.10000                10
## 14: 2021-06-01 09:35:47       14     0.05050                 8
## 15: 2021-06-01 09:35:47       15     0.02575                10
## 16: 2021-06-01 09:35:47       16     0.07525                 3
## 17: 2021-06-01 09:35:48       17     0.10000                 8
## 18: 2021-06-01 09:35:48       18     0.05050                 3
## 19: 2021-06-01 09:35:48       19     0.02575                 8
## 20: 2021-06-01 09:35:48       20     0.02575                 5

In sum, the grid search evaluated 20/25 different configurations of the grid in a random order before the Terminator stopped the tuning.

The associated resampling iterations can be accessed in the BenchmarkResult:

instance$archive$benchmark_result
## <BenchmarkResult> of 20 rows with 20 resampling runs
##  nr task_id    learner_id resampling_id iters warnings errors
##   1    pima classif.rpart       holdout     1        0      0
##   2    pima classif.rpart       holdout     1        0      0
##   3    pima classif.rpart       holdout     1        0      0
##   4    pima classif.rpart       holdout     1        0      0
##   5    pima classif.rpart       holdout     1        0      0
##   6    pima classif.rpart       holdout     1        0      0
##   7    pima classif.rpart       holdout     1        0      0
##   8    pima classif.rpart       holdout     1        0      0
##   9    pima classif.rpart       holdout     1        0      0
##  10    pima classif.rpart       holdout     1        0      0
##  11    pima classif.rpart       holdout     1        0      0
##  12    pima classif.rpart       holdout     1        0      0
##  13    pima classif.rpart       holdout     1        0      0
##  14    pima classif.rpart       holdout     1        0      0
##  15    pima classif.rpart       holdout     1        0      0
##  16    pima classif.rpart       holdout     1        0      0
##  17    pima classif.rpart       holdout     1        0      0
##  18    pima classif.rpart       holdout     1        0      0
##  19    pima classif.rpart       holdout     1        0      0
##  20    pima classif.rpart       holdout     1        0      0

The uhash column links the resampling iterations to the evaluated configurations stored in instance$archive$data. This allows e.g. to score the included ResampleResults on a different measure.

instance$archive$benchmark_result$score(msr("classif.acc"))
##                                    uhash nr              task task_id
##  1: 7f29c986-707f-41bf-b8e8-6d5306c14736  1 <TaskClassif[46]>    pima
##  2: a2d2d43d-32a6-4fc3-9ee6-beffe5e3ed65  2 <TaskClassif[46]>    pima
##  3: 653492d5-fece-498e-936f-a26074e7dc74  3 <TaskClassif[46]>    pima
##  4: 0f0b7b99-3e6a-46cd-9534-8b314a8536d1  4 <TaskClassif[46]>    pima
##  5: d945d2a7-f600-4957-80f3-6070972903f1  5 <TaskClassif[46]>    pima
##  6: bcc78278-ab93-4781-98b9-5b6550aea1b3  6 <TaskClassif[46]>    pima
##  7: 95cef3a9-265a-42f5-8041-f9577a48c546  7 <TaskClassif[46]>    pima
##  8: 6af1f6f4-c98c-40f7-a427-115b8e3c7ac8  8 <TaskClassif[46]>    pima
##  9: 6a38cd0f-65f6-4a5e-a29e-1775ffbf7d29  9 <TaskClassif[46]>    pima
## 10: 412ad69b-f6d9-4db9-98a9-21040b8aa316 10 <TaskClassif[46]>    pima
## 11: ea61d7dc-edd1-410d-a9ff-5e6d25adbc35 11 <TaskClassif[46]>    pima
## 12: 6cd4f418-ded1-42bd-803e-e7f1c663ac5c 12 <TaskClassif[46]>    pima
## 13: 556d4bae-dc1e-43b3-9a7e-698035dc1ec8 13 <TaskClassif[46]>    pima
## 14: 1b3a3313-1d65-4fd8-978f-c6f9c93446ff 14 <TaskClassif[46]>    pima
## 15: 2d8a6ce5-8944-48a4-a2cb-02187425736b 15 <TaskClassif[46]>    pima
## 16: 55acb10f-69ef-4575-866c-0a7eeae9d493 16 <TaskClassif[46]>    pima
## 17: 48055c19-a4b5-40ca-80af-8363671f7871 17 <TaskClassif[46]>    pima
## 18: f8c77038-884d-49fc-9073-dbef0713c8e1 18 <TaskClassif[46]>    pima
## 19: 64192cdb-fc37-4f9e-bc50-f738971788a0 19 <TaskClassif[46]>    pima
## 20: ace22a1d-1843-4ce5-ae86-5ec9828fc533 20 <TaskClassif[46]>    pima
##                       learner    learner_id              resampling
##  1: <LearnerClassifRpart[34]> classif.rpart <ResamplingHoldout[19]>
##  2: <LearnerClassifRpart[34]> classif.rpart <ResamplingHoldout[19]>
##  3: <LearnerClassifRpart[34]> classif.rpart <ResamplingHoldout[19]>
##  4: <LearnerClassifRpart[34]> classif.rpart <ResamplingHoldout[19]>
##  5: <LearnerClassifRpart[34]> classif.rpart <ResamplingHoldout[19]>
##  6: <LearnerClassifRpart[34]> classif.rpart <ResamplingHoldout[19]>
##  7: <LearnerClassifRpart[34]> classif.rpart <ResamplingHoldout[19]>
##  8: <LearnerClassifRpart[34]> classif.rpart <ResamplingHoldout[19]>
##  9: <LearnerClassifRpart[34]> classif.rpart <ResamplingHoldout[19]>
## 10: <LearnerClassifRpart[34]> classif.rpart <ResamplingHoldout[19]>
## 11: <LearnerClassifRpart[34]> classif.rpart <ResamplingHoldout[19]>
## 12: <LearnerClassifRpart[34]> classif.rpart <ResamplingHoldout[19]>
## 13: <LearnerClassifRpart[34]> classif.rpart <ResamplingHoldout[19]>
## 14: <LearnerClassifRpart[34]> classif.rpart <ResamplingHoldout[19]>
## 15: <LearnerClassifRpart[34]> classif.rpart <ResamplingHoldout[19]>
## 16: <LearnerClassifRpart[34]> classif.rpart <ResamplingHoldout[19]>
## 17: <LearnerClassifRpart[34]> classif.rpart <ResamplingHoldout[19]>
## 18: <LearnerClassifRpart[34]> classif.rpart <ResamplingHoldout[19]>
## 19: <LearnerClassifRpart[34]> classif.rpart <ResamplingHoldout[19]>
## 20: <LearnerClassifRpart[34]> classif.rpart <ResamplingHoldout[19]>
##     resampling_id iteration              prediction classif.acc
##  1:       holdout         1 <PredictionClassif[19]>      0.7070
##  2:       holdout         1 <PredictionClassif[19]>      0.7148
##  3:       holdout         1 <PredictionClassif[19]>      0.7148
##  4:       holdout         1 <PredictionClassif[19]>      0.7148
##  5:       holdout         1 <PredictionClassif[19]>      0.7148
##  6:       holdout         1 <PredictionClassif[19]>      0.7070
##  7:       holdout         1 <PredictionClassif[19]>      0.7148
##  8:       holdout         1 <PredictionClassif[19]>      0.7148
##  9:       holdout         1 <PredictionClassif[19]>      0.7148
## 10:       holdout         1 <PredictionClassif[19]>      0.7148
## 11:       holdout         1 <PredictionClassif[19]>      0.7148
## 12:       holdout         1 <PredictionClassif[19]>      0.7148
## 13:       holdout         1 <PredictionClassif[19]>      0.7148
## 14:       holdout         1 <PredictionClassif[19]>      0.7148
## 15:       holdout         1 <PredictionClassif[19]>      0.7148
## 16:       holdout         1 <PredictionClassif[19]>      0.7148
## 17:       holdout         1 <PredictionClassif[19]>      0.7148
## 18:       holdout         1 <PredictionClassif[19]>      0.7148
## 19:       holdout         1 <PredictionClassif[19]>      0.7148
## 20:       holdout         1 <PredictionClassif[19]>      0.7148

Now the optimized hyperparameters can take the previously created Learner, set the returned hyperparameters and train it on the full dataset.

learner$param_set$values = instance$result_learner_param_vals
learner$train(task)

The trained model can now be used to make a prediction on external data. Note that predicting on observations present in the task, should be avoided. The model has seen these observations already during tuning and therefore results would be statistically biased. Hence, the resulting performance measure would be over-optimistic. Instead, to get statistically unbiased performance estimates for the current task, nested resampling is required.

3.1.4 Automating the Tuning

The AutoTuner wraps a learner and augments it with an automatic tuning for a given set of hyperparameters. Because the AutoTuner itself inherits from the Learner base class, it can be used like any other learner. Analogously to the previous subsection, a new classification tree learner is created. This classification tree learner automatically tunes the parameters cp and minsplit using an inner resampling (holdout). We create a terminator which allows 10 evaluations, and use a simple random search as tuning algorithm:

learner = lrn("classif.rpart")
search_space = ps(
  cp = p_dbl(lower = 0.001, upper = 0.1),
  minsplit = p_int(lower = 1, upper = 10)
)
terminator = trm("evals", n_evals = 10)
tuner = tnr("random_search")

at = AutoTuner$new(
  learner = learner,
  resampling = rsmp("holdout"),
  measure = msr("classif.ce"),
  search_space = search_space,
  terminator = terminator,
  tuner = tuner
)
at
## <AutoTuner:classif.rpart.tuned>
## * Model: -
## * Parameters: list()
## * Packages: rpart
## * Predict Type: response
## * Feature types: logical, integer, numeric, factor, ordered
## * Properties: importance, missings, multiclass, selected_features,
##   twoclass, weights

We can now use the learner like any other learner, calling the $train() and $predict() method.

at$train(task)
## INFO  [09:35:48.706] [bbotk] Starting to optimize 2 parameter(s) with '<OptimizerRandomSearch>' and '<TerminatorEvals> [n_evals=10]' 
## INFO  [09:35:48.721] [bbotk] Evaluating 1 configuration(s) 
## INFO  [09:35:48.792] [bbotk] Result of batch 1: 
## INFO  [09:35:48.793] [bbotk]       cp minsplit classif.ce                                uhash 
## INFO  [09:35:48.793] [bbotk]  0.04267        4     0.2344 210d7e04-d5bb-4d36-a88c-030a2c708d5d 
## INFO  [09:35:48.797] [bbotk] Evaluating 1 configuration(s) 
## INFO  [09:35:48.883] [bbotk] Result of batch 2: 
## INFO  [09:35:48.884] [bbotk]       cp minsplit classif.ce                                uhash 
## INFO  [09:35:48.884] [bbotk]  0.04667        9     0.2344 16096ad4-078e-42dc-bcd9-7d98e52b6a0b 
## INFO  [09:35:48.888] [bbotk] Evaluating 1 configuration(s) 
## INFO  [09:35:48.962] [bbotk] Result of batch 3: 
## INFO  [09:35:48.963] [bbotk]       cp minsplit classif.ce                                uhash 
## INFO  [09:35:48.963] [bbotk]  0.06371        6     0.2344 78a8f6d6-8277-4fb7-b0d5-44ffb34906d7 
## INFO  [09:35:48.967] [bbotk] Evaluating 1 configuration(s) 
## INFO  [09:35:49.040] [bbotk] Result of batch 4: 
## INFO  [09:35:49.041] [bbotk]       cp minsplit classif.ce                                uhash 
## INFO  [09:35:49.041] [bbotk]  0.05232        3     0.2344 1f9d4087-bbc7-4f32-8916-e8ad9a09bba2 
## INFO  [09:35:49.045] [bbotk] Evaluating 1 configuration(s) 
## INFO  [09:35:49.117] [bbotk] Result of batch 5: 
## INFO  [09:35:49.119] [bbotk]       cp minsplit classif.ce                                uhash 
## INFO  [09:35:49.119] [bbotk]  0.04077        2     0.2344 94f9608f-e49d-48bc-b7ee-24f131f9b15b 
## INFO  [09:35:49.122] [bbotk] Evaluating 1 configuration(s) 
## INFO  [09:35:49.204] [bbotk] Result of batch 6: 
## INFO  [09:35:49.206] [bbotk]       cp minsplit classif.ce                                uhash 
## INFO  [09:35:49.206] [bbotk]  0.03277        2     0.2344 bc78ab00-74f7-471c-a06e-3c9be449829e 
## INFO  [09:35:49.209] [bbotk] Evaluating 1 configuration(s) 
## INFO  [09:35:49.282] [bbotk] Result of batch 7: 
## INFO  [09:35:49.284] [bbotk]       cp minsplit classif.ce                                uhash 
## INFO  [09:35:49.284] [bbotk]  0.07266        3     0.2344 61a6e92e-59ba-40bb-94ce-9288fb4899cd 
## INFO  [09:35:49.287] [bbotk] Evaluating 1 configuration(s) 
## INFO  [09:35:49.384] [bbotk] Result of batch 8: 
## INFO  [09:35:49.386] [bbotk]       cp minsplit classif.ce                                uhash 
## INFO  [09:35:49.386] [bbotk]  0.06448        9     0.2344 a269d255-e032-40e7-a4b9-8f0415ed3168 
## INFO  [09:35:49.389] [bbotk] Evaluating 1 configuration(s) 
## INFO  [09:35:49.463] [bbotk] Result of batch 9: 
## INFO  [09:35:49.465] [bbotk]       cp minsplit classif.ce                                uhash 
## INFO  [09:35:49.465] [bbotk]  0.03403       10     0.2344 a4e0f533-7290-41c5-9a82-25f1ccbbf197 
## INFO  [09:35:49.468] [bbotk] Evaluating 1 configuration(s) 
## INFO  [09:35:49.542] [bbotk] Result of batch 10: 
## INFO  [09:35:49.543] [bbotk]       cp minsplit classif.ce                                uhash 
## INFO  [09:35:49.543] [bbotk]  0.07853        9     0.2344 96c44ffb-5904-41e6-8cc7-a697455b59f0 
## INFO  [09:35:49.551] [bbotk] Finished optimizing after 10 evaluation(s) 
## INFO  [09:35:49.552] [bbotk] Result: 
## INFO  [09:35:49.553] [bbotk]       cp minsplit learner_param_vals  x_domain classif.ce 
## INFO  [09:35:49.553] [bbotk]  0.04267        4          <list[3]> <list[2]>     0.2344

We can also pass it to resample() and benchmark(). This is called nested resampling which is discussed in the next chapter.

References

Bergstra, James, and Yoshua Bengio. 2012. “Random Search for Hyper-Parameter Optimization.” J. Mach. Learn. Res. 13: 281–305.