3.1 Hyperparameter Tuning
Hyperparameters are second-order parameters of machine learning models that, while often not explicitly optimized during the model estimation process, can have an important impact on the outcome and predictive performance of a model. Typically, hyperparameters are fixed before training a model. However, because the output of a model can be sensitive to the specification of hyperparameters, it is often recommended to make an informed decision about which hyperparameter settings may yield better model performance. In many cases, hyperparameter settings may be chosen a priori, but it can be advantageous to try different settings before fitting your model on the training data. This process is often called model ‘tuning’.
Hyperparameter tuning is supported via the mlr3tuning extension package. Below you can find an illustration of the process:
At the heart of mlr3tuning are the R6 classes:
TuningInstanceSingleCrit,TuningInstanceMultiCrit: These two classes describe the tuning problem and store the results.Tuner: This class is the base class for implementations of tuning algorithms.
3.1.1 The TuningInstance Classes
The following sub-section examines the optimization of a simple classification tree on the Pima Indian Diabetes data set.
## <TaskClassif:pima> (768 x 9)
## * Target: diabetes
## * Properties: twoclass
## * Features (8):
## - dbl (8): age, glucose, insulin, mass, pedigree, pregnant, pressure,
## triceps
We use the classification tree from rpart and choose a subset of the hyperparameters we want to tune. This is often referred to as the “tuning space”.
## <ParamSet>
## id class lower upper levels default value
## 1: minsplit ParamInt 1 Inf 20
## 2: minbucket ParamInt 1 Inf <NoDefault[3]>
## 3: cp ParamDbl 0 1 0.01
## 4: maxcompete ParamInt 0 Inf 4
## 5: maxsurrogate ParamInt 0 Inf 5
## 6: maxdepth ParamInt 1 30 30
## 7: usesurrogate ParamInt 0 2 2
## 8: surrogatestyle ParamInt 0 1 0
## 9: xval ParamInt 0 Inf 10 0
## 10: keep_model ParamLgl NA NA TRUE,FALSE FALSE
Here, we opt to tune two parameters:
- The complexity
cp - The termination criterion
minsplit
The tuning space needs to be bounded, therefore one has to set lower and upper bounds:
library("paradox")
tune_ps = ParamSet$new(list(
ParamDbl$new("cp", lower = 0.001, upper = 0.1),
ParamInt$new("minsplit", lower = 1, upper = 10)
))
tune_ps## <ParamSet>
## id class lower upper levels default value
## 1: cp ParamDbl 0.001 0.1 <NoDefault[3]>
## 2: minsplit ParamInt 1.000 10.0 <NoDefault[3]>
Next, we need to specify how to evaluate the performance.
For this, we need to choose a resampling strategy and a performance measure.
Finally, one has to select the budget available, to solve this tuning instance.
This is done by selecting one of the available Terminators:
- Terminate after a given time (
TerminatorClockTime) - Terminate after a given amount of iterations (
TerminatorEvals) - Terminate after a specific performance is reached (
TerminatorPerfReached) - Terminate when tuning does not improve (
TerminatorStagnation) - A combination of the above in an ALL or ANY fashion (
TerminatorCombo)
For this short introduction, we specify a budget of 20 evaluations and then put everything together into a TuningInstanceSingleCrit:
library("mlr3tuning")
evals20 = trm("evals", n_evals = 20)
instance = TuningInstanceSingleCrit$new(
task = task,
learner = learner,
resampling = hout,
measure = measure,
search_space = tune_ps,
terminator = evals20
)
instance## <TuningInstanceSingleCrit>
## * State: Not optimized
## * Objective: <ObjectiveTuning:classif.rpart_on_pima>
## * Search Space:
## <ParamSet>
## id class lower upper levels default value
## 1: cp ParamDbl 0.001 0.1 <NoDefault[3]>
## 2: minsplit ParamInt 1.000 10.0 <NoDefault[3]>
## * Terminator: <TerminatorEvals>
## * Terminated: FALSE
## * Archive:
## <ArchiveTuning>
## Null data.table (0 rows and 0 cols)
To start the tuning, we still need to select how the optimization should take place.
In other words, we need to choose the optimization algorithm via the Tuner class.
3.1.2 The Tuner Class
The following algorithms are currently implemented in mlr3tuning:
- Grid Search (
TunerGridSearch) - Random Search (
TunerRandomSearch) (Bergstra and Bengio 2012) - Generalized Simulated Annealing (
TunerGenSA) - Non-Linear Optimization (
TunerNLoptr)
In this example, we will use a simple grid search with a grid resolution of 5.
Since we have only numeric parameters, TunerGridSearch will create an equidistant grid between the respective upper and lower bounds.
As we have two hyperparameters with a resolution of 5, the two-dimensional grid consists of \(5^2 = 25\) configurations.
Each configuration serves as a hyperparameter setting for the previously defined Learner and triggers a 3-fold cross validation on the task.
All configurations will be examined by the tuner (in a random order), until either all configurations are evaluated or the Terminator signals that the budget is exhausted.
3.1.3 Triggering the Tuning
To start the tuning, we simply pass the TuningInstanceSingleCrit to the $optimize() method of the initialized Tuner.
The tuner proceeds as follows:
- The
Tunerproposes at least one hyperparameter configuration (theTunermay propose multiple points to improve parallelization, which can be controlled via the settingbatch_size). - For each configuration, the given
Learneris fitted on theTaskusing the providedResampling. All evaluations are stored in the archive of theTuningInstanceSingleCrit. - The
Terminatoris queried if the budget is exhausted. If the budget is not exhausted, restart with 1) until it is. - Determine the configuration with the best observed performance.
- Store the best configurations as result in the instance object.
The best hyperparameter settings (
$result_learner_param_vals) and the corresponding measured performance ($result_y) can be accessed from the instance.
## INFO [13:49:38.512] Starting to optimize 2 parameter(s) with '<OptimizerGridSearch>' and '<TerminatorEvals>'
## INFO [13:49:38.554] Evaluating 1 configuration(s)
## INFO [13:49:38.812] Result of batch 1:
## INFO [13:49:38.815] cp minsplit classif.ce uhash
## INFO [13:49:38.815] 0.1 5 0.2734 4aa261bb-9ece-4eb6-ab3a-c544618d40f2
## INFO [13:49:38.818] Evaluating 1 configuration(s)
## INFO [13:49:39.085] Result of batch 2:
## INFO [13:49:39.088] cp minsplit classif.ce uhash
## INFO [13:49:39.088] 0.001 8 0.2891 03d02715-8995-40c8-acd3-f483cd8a6890
## INFO [13:49:39.091] Evaluating 1 configuration(s)
## INFO [13:49:39.205] Result of batch 3:
## INFO [13:49:39.207] cp minsplit classif.ce uhash
## INFO [13:49:39.207] 0.1 1 0.2734 ce0fb101-9d80-45b9-bb01-0cdfd12ca0a0
## INFO [13:49:39.210] Evaluating 1 configuration(s)
## INFO [13:49:39.308] Result of batch 4:
## INFO [13:49:39.311] cp minsplit classif.ce uhash
## INFO [13:49:39.311] 0.001 3 0.2969 5d3780a8-3b43-42c7-9a5e-590ce0743334
## INFO [13:49:39.314] Evaluating 1 configuration(s)
## INFO [13:49:39.414] Result of batch 5:
## INFO [13:49:39.417] cp minsplit classif.ce uhash
## INFO [13:49:39.417] 0.07525 5 0.2461 1066e59f-c28a-488d-a10b-9257f559d248
## INFO [13:49:39.420] Evaluating 1 configuration(s)
## INFO [13:49:39.516] Result of batch 6:
## INFO [13:49:39.518] cp minsplit classif.ce uhash
## INFO [13:49:39.518] 0.0505 5 0.2461 467e6051-926c-407c-a631-3ea3ca710073
## INFO [13:49:39.521] Evaluating 1 configuration(s)
## INFO [13:49:39.621] Result of batch 7:
## INFO [13:49:39.624] cp minsplit classif.ce uhash
## INFO [13:49:39.624] 0.07525 3 0.2461 6673e5fe-4694-4abc-9f83-8ce9b517fd86
## INFO [13:49:39.626] Evaluating 1 configuration(s)
## INFO [13:49:39.721] Result of batch 8:
## INFO [13:49:39.723] cp minsplit classif.ce uhash
## INFO [13:49:39.723] 0.02575 3 0.2617 2984bede-9eba-49f4-a1a9-5f5c5cf955b9
## INFO [13:49:39.726] Evaluating 1 configuration(s)
## INFO [13:49:39.833] Result of batch 9:
## INFO [13:49:39.835] cp minsplit classif.ce uhash
## INFO [13:49:39.835] 0.0505 10 0.2461 93df4fdb-45f9-4fa6-bd9f-867bacd8a4e0
## INFO [13:49:39.838] Evaluating 1 configuration(s)
## INFO [13:49:39.933] Result of batch 10:
## INFO [13:49:39.935] cp minsplit classif.ce uhash
## INFO [13:49:39.935] 0.1 8 0.2734 67b606f7-2994-4206-8872-b0200b479326
## INFO [13:49:39.938] Evaluating 1 configuration(s)
## INFO [13:49:40.039] Result of batch 11:
## INFO [13:49:40.042] cp minsplit classif.ce uhash
## INFO [13:49:40.042] 0.02575 5 0.2617 1fbe2c4f-6b40-4348-8cf1-2b8b6dfb0f18
## INFO [13:49:40.044] Evaluating 1 configuration(s)
## INFO [13:49:40.140] Result of batch 12:
## INFO [13:49:40.142] cp minsplit classif.ce uhash
## INFO [13:49:40.142] 0.02575 1 0.2617 ac37ae4c-5c58-4d12-9198-0b54e43b3578
## INFO [13:49:40.145] Evaluating 1 configuration(s)
## INFO [13:49:40.246] Result of batch 13:
## INFO [13:49:40.249] cp minsplit classif.ce uhash
## INFO [13:49:40.249] 0.02575 10 0.2617 317ae8c9-efd3-449b-bef7-568c029e9696
## INFO [13:49:40.251] Evaluating 1 configuration(s)
## INFO [13:49:40.347] Result of batch 14:
## INFO [13:49:40.349] cp minsplit classif.ce uhash
## INFO [13:49:40.349] 0.07525 8 0.2461 f7858924-8100-4886-adf8-e88095af40b9
## INFO [13:49:40.352] Evaluating 1 configuration(s)
## INFO [13:49:40.461] Result of batch 15:
## INFO [13:49:40.463] cp minsplit classif.ce uhash
## INFO [13:49:40.463] 0.0505 1 0.2461 f88de77f-e995-4106-884e-e7924a9dff55
## INFO [13:49:40.466] Evaluating 1 configuration(s)
## INFO [13:49:40.561] Result of batch 16:
## INFO [13:49:40.564] cp minsplit classif.ce uhash
## INFO [13:49:40.564] 0.1 3 0.2734 545e30d0-6a93-45e8-a080-cc22553a7b91
## INFO [13:49:40.566] Evaluating 1 configuration(s)
## INFO [13:49:40.668] Result of batch 17:
## INFO [13:49:40.670] cp minsplit classif.ce uhash
## INFO [13:49:40.670] 0.1 10 0.2734 9316df85-b526-4082-8e04-701c880ad675
## INFO [13:49:40.673] Evaluating 1 configuration(s)
## INFO [13:49:40.770] Result of batch 18:
## INFO [13:49:40.772] cp minsplit classif.ce uhash
## INFO [13:49:40.772] 0.001 1 0.3008 a62ad22e-681d-4956-b3ef-c8ee1796c77e
## INFO [13:49:40.775] Evaluating 1 configuration(s)
## INFO [13:49:40.877] Result of batch 19:
## INFO [13:49:40.879] cp minsplit classif.ce uhash
## INFO [13:49:40.879] 0.07525 1 0.2461 4f2bbd3e-977a-4d3c-8868-9785fbbfafe3
## INFO [13:49:40.882] Evaluating 1 configuration(s)
## INFO [13:49:40.977] Result of batch 20:
## INFO [13:49:40.980] cp minsplit classif.ce uhash
## INFO [13:49:40.980] 0.001 10 0.2969 27e08957-095f-4b46-941d-99d96b4388a9
## INFO [13:49:40.987] Finished optimizing after 20 evaluation(s)
## INFO [13:49:40.989] Result:
## INFO [13:49:40.991] cp minsplit learner_param_vals x_domain classif.ce
## INFO [13:49:40.991] 0.07525 5 <list[3]> <list[2]> 0.2461
## cp minsplit learner_param_vals x_domain classif.ce
## 1: 0.07525 5 <list[3]> <list[2]> 0.2461
## $xval
## [1] 0
##
## $cp
## [1] 0.07525
##
## $minsplit
## [1] 5
## classif.ce
## 0.2461
One can investigate all resamplings which were undertaken, as they are stored in the archive of the TuningInstanceSingleCrit and can be accessed through $data() method:
## cp minsplit classif.ce uhash x_domain
## 1: 0.10000 5 0.2734 4aa261bb-9ece-4eb6-ab3a-c544618d40f2 <list[2]>
## 2: 0.00100 8 0.2891 03d02715-8995-40c8-acd3-f483cd8a6890 <list[2]>
## 3: 0.10000 1 0.2734 ce0fb101-9d80-45b9-bb01-0cdfd12ca0a0 <list[2]>
## 4: 0.00100 3 0.2969 5d3780a8-3b43-42c7-9a5e-590ce0743334 <list[2]>
## 5: 0.07525 5 0.2461 1066e59f-c28a-488d-a10b-9257f559d248 <list[2]>
## 6: 0.05050 5 0.2461 467e6051-926c-407c-a631-3ea3ca710073 <list[2]>
## 7: 0.07525 3 0.2461 6673e5fe-4694-4abc-9f83-8ce9b517fd86 <list[2]>
## 8: 0.02575 3 0.2617 2984bede-9eba-49f4-a1a9-5f5c5cf955b9 <list[2]>
## 9: 0.05050 10 0.2461 93df4fdb-45f9-4fa6-bd9f-867bacd8a4e0 <list[2]>
## 10: 0.10000 8 0.2734 67b606f7-2994-4206-8872-b0200b479326 <list[2]>
## 11: 0.02575 5 0.2617 1fbe2c4f-6b40-4348-8cf1-2b8b6dfb0f18 <list[2]>
## 12: 0.02575 1 0.2617 ac37ae4c-5c58-4d12-9198-0b54e43b3578 <list[2]>
## 13: 0.02575 10 0.2617 317ae8c9-efd3-449b-bef7-568c029e9696 <list[2]>
## 14: 0.07525 8 0.2461 f7858924-8100-4886-adf8-e88095af40b9 <list[2]>
## 15: 0.05050 1 0.2461 f88de77f-e995-4106-884e-e7924a9dff55 <list[2]>
## 16: 0.10000 3 0.2734 545e30d0-6a93-45e8-a080-cc22553a7b91 <list[2]>
## 17: 0.10000 10 0.2734 9316df85-b526-4082-8e04-701c880ad675 <list[2]>
## 18: 0.00100 1 0.3008 a62ad22e-681d-4956-b3ef-c8ee1796c77e <list[2]>
## 19: 0.07525 1 0.2461 4f2bbd3e-977a-4d3c-8868-9785fbbfafe3 <list[2]>
## 20: 0.00100 10 0.2969 27e08957-095f-4b46-941d-99d96b4388a9 <list[2]>
## timestamp batch_nr
## 1: 2020-10-12 13:49:38 1
## 2: 2020-10-12 13:49:39 2
## 3: 2020-10-12 13:49:39 3
## 4: 2020-10-12 13:49:39 4
## 5: 2020-10-12 13:49:39 5
## 6: 2020-10-12 13:49:39 6
## 7: 2020-10-12 13:49:39 7
## 8: 2020-10-12 13:49:39 8
## 9: 2020-10-12 13:49:39 9
## 10: 2020-10-12 13:49:39 10
## 11: 2020-10-12 13:49:40 11
## 12: 2020-10-12 13:49:40 12
## 13: 2020-10-12 13:49:40 13
## 14: 2020-10-12 13:49:40 14
## 15: 2020-10-12 13:49:40 15
## 16: 2020-10-12 13:49:40 16
## 17: 2020-10-12 13:49:40 17
## 18: 2020-10-12 13:49:40 18
## 19: 2020-10-12 13:49:40 19
## 20: 2020-10-12 13:49:40 20
In sum, the grid search evaluated 20/25 different configurations of the grid in a random order before the Terminator stopped the tuning.
The associated resampling iterations can be accessed in the BenchmarkResult:
## <ResultData>
## Public:
## as_data_table: function (view = NULL, reassemble_learners = TRUE, convert_predictions = TRUE,
## clone: function (deep = FALSE)
## combine: function (rdata)
## data: list
## initialize: function (data = NULL)
## iterations: function (view = NULL)
## learners: function (view = NULL, states = TRUE, reassemble = TRUE)
## logs: function (view = NULL, condition)
## prediction: function (view = NULL, predict_sets = "test")
## predictions: function (view = NULL, predict_sets = "test")
## resamplings: function (view = NULL)
## sweep: function ()
## task_type: active binding
## tasks: function (view = NULL, reassemble = TRUE)
## uhashes: function (view = NULL)
## Private:
## deep_clone: function (name, value)
## get_view_index: function (view)
The uhash column links the resampling iterations to the evaluated configurations stored in instance$archive$data(). This allows e.g. to score the included ResampleResults on a different measure.
Now the optimized hyperparameters can take the previously created Learner, set the returned hyperparameters and train it on the full dataset.
The trained model can now be used to make a prediction on external data.
Note that predicting on observations present in the task, should be avoided.
The model has seen these observations already during tuning and therefore results would be statistically biased.
Hence, the resulting performance measure would be over-optimistic.
Instead, to get statistically unbiased performance estimates for the current task, nested resampling is required.
3.1.4 Automating the Tuning
The AutoTuner wraps a learner and augments it with an automatic tuning for a given set of hyperparameters.
Because the AutoTuner itself inherits from the Learner base class, it can be used like any other learner.
Analogously to the previous subsection, a new classification tree learner is created.
This classification tree learner automatically tunes the parameters cp and minsplit using an inner resampling (holdout).
We create a terminator which allows 10 evaluations, and use a simple random search as tuning algorithm:
library("paradox")
library("mlr3tuning")
learner = lrn("classif.rpart")
tune_ps = ParamSet$new(list(
ParamDbl$new("cp", lower = 0.001, upper = 0.1),
ParamInt$new("minsplit", lower = 1, upper = 10)
))
terminator = trm("evals", n_evals = 10)
tuner = tnr("random_search")
at = AutoTuner$new(
learner = learner,
resampling = rsmp("holdout"),
measure = msr("classif.ce"),
search_space = tune_ps,
terminator = terminator,
tuner = tuner
)
at## <AutoTuner:classif.rpart.tuned>
## * Model: -
## * Parameters: xval=0
## * Packages: rpart
## * Predict Type: response
## * Feature types: logical, integer, numeric, factor, ordered
## * Properties: importance, missings, multiclass, selected_features,
## twoclass, weights
We can now use the learner like any other learner, calling the $train() and $predict() method.
This time however, we pass it to benchmark() to compare the tuner to a classification tree without tuning.
This way, the AutoTuner will do its resampling for tuning on the training set of the respective split of the outer resampling.
The learner then undertakes predictions using the test set of the outer resampling.
This yields unbiased performance measures, as the observations in the test set have not been used during tuning or fitting of the respective learner.
This is called nested resampling.
To compare the tuned learner with the learner that uses default values, we can use benchmark():
grid = benchmark_grid(
task = tsk("pima"),
learner = list(at, lrn("classif.rpart")),
resampling = rsmp("cv", folds = 3)
)
# avoid console output from mlr3tuning
logger = lgr::get_logger("bbotk")
logger$set_threshold("warn")
bmr = benchmark(grid)
bmr$aggregate(msrs(c("classif.ce", "time_train")))## nr resample_result task_id learner_id resampling_id iters
## 1: 1 <ResampleResult[21]> pima classif.rpart.tuned cv 3
## 2: 2 <ResampleResult[21]> pima classif.rpart cv 3
## classif.ce time_train
## 1: 0.2695 0
## 2: 0.2409 0
Note that we do not expect any differences compared to the non-tuned approach for multiple reasons:
- the task is too easy
- the task is rather small, and thus prone to overfitting
- the tuning budget (10 evaluations) is small
- rpart does not benefit that much from tuning
References
Bergstra, James, and Yoshua Bengio. 2012. “Random Search for Hyper-Parameter Optimization.” J. Mach. Learn. Res. 13: 281–305.