2 Basics

This chapter will teach you the essential building blocks of mlr3, as well as its R6 classes and operations used for machine learning. A typical machine learning workflow looks like this:

The data, which mlr3 encapsulates in tasks, is split into non-overlapping training and test sets. Since we are interested in models that extrapolate to new data rather than just memorizing the training data, the separate test data allows to objectively evaluate models with respect to generalization. The training data is given to a machine learning algorithm, which we call a learner in mlr3. The learner uses the training data to build a model of the relationship of the input features to the output target values. This model is then used to produce predictions on the test data, which are compared to the ground truth values to assess the quality of the model. mlr3 offers a number of different measures to quantify how well a model performs based on the difference between predicted and actual values. Usually this measure is a numeric score.

The process of splitting up data into training and test sets, building a model, and evaluating it may be repeated several times, resampling different training and test sets from the original data each time. Multiple resampling iterations allow us to get a better, more generalizable performance estimate for a particular type of model as it is tested under different conditions and less likely to get lucky or unlucky because of a particular way the data was resampled.

In many cases, this simple workflow is not sufficient to deal with real-world data, which may require normalization, imputation of missing values, or feature selection. We will cover more complex workflows that allow to do this and even more later in the book.

This chapter covers the following subtopics:

Tasks

Tasks encapsulate the data with meta-information, such as the name of the prediction target column. We cover how to:

Learners

Learners encapsulate machine learning algorithms to train models and make predictions for a task. They are provided by R and other packages. We cover how to:

How to modify and extend learners is covered in a supplemental advanced technical section.

Train and predict

The section on the train and predict methods illustrates how to use tasks and learners to train a model and make predictions on a new data set. In particular, we cover how to:

Resampling

A resampling is a method to create training and test splits. We cover how to

Additional information on resampling can be found in the section about nested resampling and in the chapter on model optimization.

Benchmarking

Benchmarking is used to compare the performance of different models, for example models trained with different learners, on different tasks, or with different resampling methods. We cover how to

Binary classification

Binary classification is a special case of classification where the target variable to predict has only two possible values. In this case, additional considerations apply; in particular:

  • ROC curves and the threshold where to predict one class versus the other, and
  • threshold tuning (WIP).

Before we get into the details of how to use mlr3 for machine learning, we give a brief introduction to R6 as it is a relatively new part of R. mlr3 heavily relies on R6 and all basic building blocks it provides are R6 classes:

2.1 Quick R6 Intro for Beginners

R6 is one of R’s more recent dialects for object-oriented programming (OO). It addresses shortcomings of earlier OO implementations in R, such as S3, which we used in mlr. If you have done any object-oriented programming before, R6 should feel familiar. We focus on the parts of R6 that you need to know to use mlr3 here.

  • Objects are created by calling the constructor of an R6::R6Class() object, specifically the initialization method $new(). For example, foo = Foo$new(bar = 1) creates a new object of class Foo, setting the bar argument of the constructor to the value 1. Most objects in mlr3 are created through special functions (e.g. ‘lrn(“regr.rpart”)’) that are also referred to as sugar functions.
  • Objects have mutable state, which is encapsulated in their fields, which can be accessed through the dollar operator. We can access the bar value in the Foo class through foo$bar and set its value by assigning the field, e.g. foo$bar = 2.
  • In addition to fields, objects expose methods that may allow to inspect the object’s state, retrieve information, or perform an action that may change the internal state of the object. For example, the $train method of a learner changes the internal state of the learner by building and storing a trained model, which can then be used to make predictions given data.
  • Objects can have public and private fields and methods. The public fields and methods define the API to interact with the object. Private methods are only relevant for you if you want to extend mlr3, e.g. with new learners.
  • R6 objects are internally environments, and as such have reference semantics. For example, foo2 = foo does not create a copy of foo in foo2, but another reference to the same actual object. Setting foo$bar = 3 will also change foo2$bar to 3 and vice versa.
  • To copy an object, use the $clone() method and the deep = TRUE argument for nested objects, for example, foo2 = foo$clone(deep = TRUE).

For more details on R6, have a look at the excellent R6 vignettes, especially the introduction.

2.2 Tasks

Tasks are objects that contain the (usually tabular) data and additional meta-data to define a machine learning problem. The meta-data is, for example, the name of the target variable for supervised machine learning problems, or the type of the dataset (e.g. a spatial or survival). This information is used for specific operations that can be performed on a task.

2.2.1 Task Types

To create a task from a data.frame(), data.table() or Matrix(), you first need to select the right task type:

2.2.2 Task Creation

As an example, we will create a regression task using the mtcars data set from the package datasets and predict the numeric target variable "mpg" (miles per gallon). We only consider the first two features in the dataset for brevity.

First, we load and prepare the data.

data("mtcars", package = "datasets")
data = mtcars[, 1:3]
str(data)
## 'data.frame':    32 obs. of  3 variables:
##  $ mpg : num  21 21 22.8 21.4 18.7 18.1 14.3 24.4 22.8 19.2 ...
##  $ cyl : num  6 6 4 6 8 6 8 4 4 6 ...
##  $ disp: num  160 160 108 258 360 ...

Next, we create a regression task, i.e. we construct a new instance of the R6 class TaskRegr. Usually, this is done by calling the constructor TaskRegr$new(). Instead, we are calling the converter as_task_regr() to convert our data.frame() stored as data to a task and provide the following information:

  1. x: Object to convert. Works for data.frame()/data.table()/tibble() abstract data backends implemented in the class DataBackendDataTable. The latter allows to connect to out-of-memory storage systems like SQL servers via the extension package mlr3db.
  2. target: The name of the target column for the regression problem.
  3. id (optional): An arbitrary identifier for the task, used in plots and summaries. If not provided, the deparsed and substituted name of x will be used.
library("mlr3")

task_mtcars = as_task_regr(data, target = "mpg", id = "cars")
print(task_mtcars)
## <TaskRegr:cars> (32 x 3)
## * Target: mpg
## * Properties: -
## * Features (2):
##   - dbl (2): cyl, disp

The print() method gives a short summary of the task: It has 32 observations and 3 columns, of which 2 are features.

We can also plot the task using the mlr3viz package, which gives a graphical summary of its properties:

library("mlr3viz")
autoplot(task_mtcars, type = "pairs")

Note that instead of loading all the extension packages individually, it is often more convenient to load the mlr3verse package instead. mlr3verse imports most mlr3 packages and re-exports functions which are used for regular machine learning and data science tasks.

2.2.3 Predefined tasks

mlr3 ships with a few predefined machine learning tasks. All tasks are stored in an R6 Dictionary (a key-value store) named mlr_tasks. Printing it gives the keys (the names of the datasets):

mlr_tasks
## <DictionaryTask> with 11 stored values
## Keys: boston_housing, breast_cancer, german_credit, iris, mtcars,
##   penguins, pima, sonar, spam, wine, zoo

We can get a more informative summary of the example tasks by converting the dictionary to a data.table() object:

as.data.table(mlr_tasks)
##                key task_type nrow ncol properties lgl int dbl chr fct ord pxc
##  1: boston_housing      regr  506   19              0   3  13   0   2   0   0
##  2:  breast_cancer   classif  683   10   twoclass   0   0   0   0   0   9   0
##  3:  german_credit   classif 1000   21   twoclass   0   3   0   0  14   3   0
##  4:           iris   classif  150    5 multiclass   0   0   4   0   0   0   0
##  5:         mtcars      regr   32   11              0   0  10   0   0   0   0
##  6:       penguins   classif  344    8 multiclass   0   3   2   0   2   0   0
##  7:           pima   classif  768    9   twoclass   0   0   8   0   0   0   0
##  8:          sonar   classif  208   61   twoclass   0   0  60   0   0   0   0
##  9:           spam   classif 4601   58   twoclass   0   0  57   0   0   0   0
## 10:           wine   classif  178   14 multiclass   0   2  11   0   0   0   0
## 11:            zoo   classif  101   17 multiclass  15   1   0   0   0   0   0

In the above display, the columns "lgl" (logical), "int" (integer), "dbl" (double), "chr" (character), "fct" (factor), "ord" (ordered factor) and "pxc" (POSIXct time) display the number of features in the dataset with the corresponding storage type.

To get a task from the dictionary, one can use the $get() method from the mlr_tasks class and assign the return value to a new object. Since mlr3 arranges most of its object instances in dictionaries and extraction is such a common task, there is a shortcut for this: the function tsk(). Here, we retrieve the palmer penguins task originating from the package palmerpenguins:

task_penguins = tsk("penguins")
print(task_penguins)
## <TaskClassif:penguins> (344 x 8)
## * Target: species
## * Properties: multiclass
## * Features (7):
##   - int (3): body_mass, flipper_length, year
##   - dbl (2): bill_depth, bill_length
##   - fct (2): island, sex

Note that dictionaries such as mlr_tasks can get populated by extension packages. E.g., mlr3data comes with some more example and toy tasks for regression and classification, and mlr3proba ships with additional survival and density estimation tasks. Both packages will get loaded once we load the mlr3verse package, so we do it here and have a look at the available tasks again:

library("mlr3verse")
as.data.table(mlr_tasks)[, 1:4]
##                key task_type  nrow ncol
##  1:           actg      surv  1151   13
##  2:   bike_sharing      regr 17379   14
##  3: boston_housing      regr   506   19
##  4:  breast_cancer   classif   683   10
##  5:       faithful      dens   272    1
##  6:           gbcs      surv   686   10
##  7:  german_credit   classif  1000   21
##  8:          grace      surv  1000    8
##  9:           ilpd   classif   583   11
## 10:           iris   classif   150    5
## 11:     kc_housing      regr 21613   20
## 12:           lung      surv   228   10
## 13:      moneyball      regr  1232   15
## 14:         mtcars      regr    32   11
## 15:      optdigits   classif  5620   65
## 16:       penguins   classif   344    8
## 17:           pima   classif   768    9
## 18:         precip      dens    70    1
## 19:           rats      surv   300    5
## 20:          sonar   classif   208   61
## 21:           spam   classif  4601   58
## 22:        titanic   classif  1309   11
## 23:   unemployment      surv  3343    6
## 24:      usarrests     clust    50    4
## 25:           whas      surv   481   11
## 26:           wine   classif   178   14
## 27:            zoo   classif   101   17
##                key task_type  nrow ncol

To get more information about the respective task, the corresponding man page can be found under mlr_tasks_[id], e.g. mlr_tasks_german_credit.

2.2.4 Task API

All task properties and characteristics can be queried using the task’s public fields and methods (see Task). Methods can also be used to change the stored data and the behavior of the task.

2.2.4.1 Retrieving Data

The data stored in a task can be retrieved directly from fields, for example:

task_mtcars
## <TaskRegr:cars> (32 x 3)
## * Target: mpg
## * Properties: -
## * Features (2):
##   - dbl (2): cyl, disp
task_mtcars$nrow
## [1] 32
task_mtcars$ncol
## [1] 3

More information can be obtained through methods of the object, for example:

task_mtcars$data()
##      mpg cyl  disp
##  1: 21.0   6 160.0
##  2: 21.0   6 160.0
##  3: 22.8   4 108.0
##  4: 21.4   6 258.0
##  5: 18.7   8 360.0
##  6: 18.1   6 225.0
##  7: 14.3   8 360.0
##  8: 24.4   4 146.7
##  9: 22.8   4 140.8
## 10: 19.2   6 167.6
## 11: 17.8   6 167.6
## 12: 16.4   8 275.8
## 13: 17.3   8 275.8
## 14: 15.2   8 275.8
## 15: 10.4   8 472.0
## 16: 10.4   8 460.0
## 17: 14.7   8 440.0
## 18: 32.4   4  78.7
## 19: 30.4   4  75.7
## 20: 33.9   4  71.1
## 21: 21.5   4 120.1
## 22: 15.5   8 318.0
## 23: 15.2   8 304.0
## 24: 13.3   8 350.0
## 25: 19.2   8 400.0
## 26: 27.3   4  79.0
## 27: 26.0   4 120.3
## 28: 30.4   4  95.1
## 29: 15.8   8 351.0
## 30: 19.7   6 145.0
## 31: 15.0   8 301.0
## 32: 21.4   4 121.0
##      mpg cyl  disp

In mlr3, each row (observation) has a unique identifier, stored as an integer(). These can be passed as arguments to the $data() method to select specific rows:

head(task_mtcars$row_ids)
## [1] 1 2 3 4 5 6
# retrieve data for rows with ids 1, 5, and 10
task_mtcars$data(rows = c(1, 5, 10))
##     mpg cyl  disp
## 1: 21.0   6 160.0
## 2: 18.7   8 360.0
## 3: 19.2   6 167.6

Note that although the row ids are typically just the sequence from 1 to nrow(data), they are only guaranteed to be unique natural numbers. Keep that in mind, especially if you work with data stored in a real data base management system (see backends).

Similarly to row ids, target and feature columns also have unique identifiers, i.e. names (stored as character()). Their names can be accessed via the public slots $feature_names and $target_names. Here, “target” refers to the variable we want to predict and “feature” to the predictors for the task.

task_mtcars$feature_names
## [1] "cyl"  "disp"
task_mtcars$target_names
## [1] "mpg"

The row_ids and column names can be combined when selecting a subset of the data:

# retrieve data for rows 1, 5, and 10 and only select column "mpg"
task_mtcars$data(rows = c(1, 5, 10), cols = "mpg")
##     mpg
## 1: 21.0
## 2: 18.7
## 3: 19.2

To extract the complete data from the task, one can also simply convert it to a data.table:

summary(as.data.table(task_mtcars))
##       mpg            cyl            disp      
##  Min.   :10.4   Min.   :4.00   Min.   : 71.1  
##  1st Qu.:15.4   1st Qu.:4.00   1st Qu.:120.8  
##  Median :19.2   Median :6.00   Median :196.3  
##  Mean   :20.1   Mean   :6.19   Mean   :230.7  
##  3rd Qu.:22.8   3rd Qu.:8.00   3rd Qu.:326.0  
##  Max.   :33.9   Max.   :8.00   Max.   :472.0

2.2.4.2 Roles (Rows and Columns)

It is possible to assign different roles to rows and columns. These roles affect the behavior of the task for different operations. We already seen this for the target and feature columns which serve a different purpose.

For example, the previously-constructed mtcars task has the following column roles:

print(task_mtcars$col_roles)
## $feature
## [1] "cyl"  "disp"
## 
## $target
## [1] "mpg"
## 
## $name
## character(0)
## 
## $order
## character(0)
## 
## $stratum
## character(0)
## 
## $group
## character(0)
## 
## $weight
## character(0)

Columns can also have no role (they are ignored) or have multiple roles. To add the row names of mtcars as an additional feature, we first add them to the data table as regular column and then recreate the task with the new column.

# with `keep.rownames`, data.table stores the row names in an extra column "rn"
data = as.data.table(datasets::mtcars[, 1:3], keep.rownames = TRUE)
task_mtcars = as_task_regr(data, target = "mpg", id = "cars")

# there is a new feature called "rn"
task_mtcars$feature_names
## [1] "cyl"  "disp" "rn"

The row names are now a feature whose values are stored in the column "rn". We include this column here for educational purposes only. Generally speaking, there is no point in having a feature that uniquely identifies each row. Furthermore, the character data type will cause problems with many types of machine learning algorithms.

On the other hand, the identifier may be useful to label points in plots, for example to identify and label outliers. Therefore we will change the role of the rn column by removing it from the list of features and assign the new role "name". There are two ways to do this:

  1. Use the Task method $set_col_roles() (recommended).
  2. Simply modify the field $col_roles, which is a named list of vectors of column names. Each vector in this list corresponds to a column role, and the column names contained in that vector are designated as having that role.

Supported column roles can be found in the manual of Task, or just by printing the names of the field $col_roles:.

# supported column roles, see ?Task
names(task_mtcars$col_roles)
## [1] "feature" "target"  "name"    "order"   "stratum" "group"   "weight"
# assign column "rn" the role "name", remove from other roles
task_mtcars$set_col_roles("rn", roles = "name")

# note that "rn" not listed as feature anymore
task_mtcars$feature_names
## [1] "cyl"  "disp"
# "rn" also does not appear anymore when we access the data
task_mtcars$data(rows = 1:2)
##    mpg cyl disp
## 1:  21   6  160
## 2:  21   6  160
task_mtcars$head(2)
##    mpg cyl disp
## 1:  21   6  160
## 2:  21   6  160

Changing the role does not change the underlying data, it just updates the view on it. The data is not copied in the code above. The view is changed in-place though, i.e. the task object itself is modified.

Just like columns, it is also possible to assign different roles to rows.

Rows can have two different roles:

  1. Role use: Rows that are generally available for model fitting (although they may also be used as test set in resampling). This role is the default role.

  2. Role validation: Rows that are not used for training. Rows that have missing values in the target column during task creation are automatically set to the validation role.

There are several reasons to hold some observations back or treat them differently:

  1. It is often good practice to validate the final model on an external validation set to identify possible overfitting.
  2. Some observations may be unlabeled, e.g. in competitions like Kaggle.

These observations cannot be used for training a model, but can be used to get predictions.

2.2.4.3 Task Mutators

As shown above, modifying $col_roles or $row_roles (either via set_col_roles()/set_row_roles() or directly by modifying the named list) changes the view on the data. The additional convenience method $filter() subsets the current view based on row ids and $select() subsets the view based on feature names.

task_penguins = tsk("penguins")
task_penguins$select(c("body_mass", "flipper_length")) # keep only these features
task_penguins$filter(1:3) # keep only these rows
task_penguins$head()
##    species body_mass flipper_length
## 1:  Adelie      3750            181
## 2:  Adelie      3800            186
## 3:  Adelie      3250            195

While the methods discussed above allow to subset the data, the methods $rbind() and $cbind() allow to add extra rows and columns to a task. Again, the original data is not changed. The additional rows or columns are only added to the view of the data.

task_penguins$cbind(data.frame(letters = letters[1:3])) # add column foo
task_penguins$head()
##    species body_mass flipper_length letters
## 1:  Adelie      3750            181       a
## 2:  Adelie      3800            186       b
## 3:  Adelie      3250            195       c

2.2.5 Plotting Tasks

The mlr3viz package provides plotting facilities for many classes implemented in mlr3. The available plot types depend on the inherited class, but all plots are returned as ggplot2 objects which can be easily customized.

For classification tasks (inheriting from TaskClassif), see the documentation of mlr3viz::autoplot.TaskClassif for the implemented plot types. Here are some examples to get an impression:

library("mlr3viz")

# get the pima indians task
task = tsk("pima")

# subset task to only use the 3 first features
task$select(head(task$feature_names, 3))

# default plot: class frequencies
autoplot(task)
# pairs plot (requires package GGally)
autoplot(task, type = "pairs")
# duo plot (requires package GGally)
autoplot(task, type = "duo")

Of course, you can do the same for regression tasks (inheriting from TaskRegr) as documented in mlr3viz::autoplot.TaskRegr:

library("mlr3viz")

# get the complete mtcars task
task = tsk("mtcars")

# subset task to only use the 3 first features
task$select(head(task$feature_names, 3))

# default plot: boxplot of target variable
autoplot(task)
# pairs plot (requires package GGally)
autoplot(task, type = "pairs")

2.3 Learners

Objects of class Learner provide a unified interface to many popular machine learning algorithms in R. They consist of methods to train and predict a model for a Task and provide meta-information about the learners, such as the hyperparameters you can set.

The base class of each learner is Learner, specialized for regression as LearnerRegr and for classification as LearnerClassif. Extension packages inherit from the Learner base class, e.g. mlr3proba::LearnerSurv or mlr3cluster::LearnerClust. In contrast to the Task, the creation of a custom Learner is usually not required and a more advanced topic. Hence, we refer the reader to Section 6.1 and proceed with an overview of the interface of already implemented learners.

All Learners work in a two-stage procedure:
  • training step: The training data (features and target) is passed to the Learner’s $train() function which trains and stores a model, i.e. the relationship of target an feature.
  • predict step: A new slice of data, the inference data, is passed to the $predict() method of the Learner. The model trained in the first step is used to predict the missing target feature, e.g. labels for classification problems or the numerical outcome for regression problems.

2.3.1 Predefined Learners

The mlr3 package ships with the following minimal set of classification and regression learners to avoid unnecessary dependencies:

This set of baseline learners is usually insufficient for a real data analysis. Thus, we have cherry-picked one implementation of the most popular machine learning method and connected them in the mlr3learners package:

  • Linear and logistic regression
  • Penalized Generalized Linear Models
  • \(k\)-Nearest Neighbors regression and classification
  • Kriging
  • Linear and Quadratic Discriminant Analysis
  • Naive Bayes
  • Support-Vector machines
  • Gradient Boosting
  • Random Forests for regression, classification and survival

More machine learning methods and alternative implementations are collected in the mlr3extralearners repository.

A full list of implemented learners across all packages is given in this interactive list and also via mlr3extralearners::list_mlr3learners().

head(mlr3extralearners::list_mlr3learners())
##          name   class                 id      mlr3_package
## 1: AdaBoostM1 classif classif.AdaBoostM1 mlr3extralearners
## 2:       bart classif       classif.bart mlr3extralearners
## 3:        C50 classif        classif.C50 mlr3extralearners
## 4:   catboost classif   classif.catboost mlr3extralearners
## 5:    cforest classif    classif.cforest mlr3extralearners
## 6:      ctree classif      classif.ctree mlr3extralearners
##         required_packages                                      properties
## 1:                  RWeka                             multiclass,twoclass
## 2:                 dbarts                                twoclass,weights
## 3:                    C50            missings,multiclass,twoclass,weights
## 4:               catboost importance,missings,multiclass,twoclass,weights
## 5: partykit,sandwich,coin           multiclass,oob_error,twoclass,weights
## 6: partykit,sandwich,coin                     multiclass,twoclass,weights
##                     feature_types predict_types
## 1:         numeric,factor,ordered response,prob
## 2: integer,numeric,factor,ordered response,prob
## 3:         numeric,factor,ordered response,prob
## 4:         numeric,factor,ordered response,prob
## 5: integer,numeric,factor,ordered response,prob
## 6: integer,numeric,factor,ordered response,prob

The latest build status of all learners is listed here.

To create an object for one of the predefined learners, you need to access the mlr_learners Dictionary which, similar to mlr_tasks, gets automatically populated with more learners by extension packages.

# load most mlr3 packages to populate the dictionary
library("mlr3verse")
mlr_learners
## <DictionaryLearner> with 135 stored values
## Keys: classif.AdaBoostM1, classif.bart, classif.C50, classif.catboost,
##   classif.cforest, classif.ctree, classif.cv_glmnet, classif.debug,
##   classif.earth, classif.extratrees, classif.featureless, classif.fnn,
##   classif.gam, classif.gamboost, classif.gausspr, classif.gbm,
##   classif.glmboost, classif.glmnet, classif.IBk, classif.J48,
##   classif.JRip, classif.kknn, classif.ksvm, classif.lda,
##   classif.liblinear, classif.lightgbm, classif.LMT, classif.log_reg,
##   classif.lssvm, classif.mob, classif.multinom, classif.naive_bayes,
##   classif.nnet, classif.OneR, classif.PART, classif.qda,
##   classif.randomForest, classif.ranger, classif.rfsrc, classif.rpart,
##   classif.svm, classif.xgboost, clust.agnes, clust.ap, clust.cmeans,
##   clust.cobweb, clust.dbscan, clust.diana, clust.em, clust.fanny,
##   clust.featureless, clust.ff, clust.hclust, clust.kkmeans,
##   clust.kmeans, clust.MBatchKMeans, clust.meanshift, clust.pam,
##   clust.SimpleKMeans, clust.xmeans, dens.hist, dens.kde, dens.kde_kd,
##   dens.kde_ks, dens.locfit, dens.logspline, dens.mixed, dens.nonpar,
##   dens.pen, dens.plug, dens.spline, regr.bart, regr.catboost,
##   regr.cforest, regr.ctree, regr.cubist, regr.cv_glmnet, regr.earth,
##   regr.extratrees, regr.featureless, regr.fnn, regr.gam, regr.gamboost,
##   regr.gausspr, regr.gbm, regr.glm, regr.glmboost, regr.glmnet,
##   regr.IBk, regr.kknn, regr.km, regr.ksvm, regr.liblinear,
##   regr.lightgbm, regr.lm, regr.M5Rules, regr.mars, regr.mob,
##   regr.randomForest, regr.ranger, regr.rfsrc, regr.rpart, regr.rvm,
##   regr.svm, regr.xgboost, surv.akritas, surv.blackboost, surv.cforest,
##   surv.coxboost, surv.coxph, surv.coxtime, surv.ctree,
##   surv.cv_coxboost, surv.cv_glmnet, surv.deephit, surv.deepsurv,
##   surv.dnnsurv, surv.flexible, surv.gamboost, surv.gbm, surv.glmboost,
##   surv.glmnet, surv.kaplan, surv.loghaz, surv.mboost, surv.nelson,
##   surv.obliqueRSF, surv.parametric, surv.pchazard, surv.penalized,
##   surv.ranger, surv.rfsrc, surv.rpart, surv.svm, surv.xgboost

To obtain an object from the dictionary we can use lrn() or the generic mlr_learners$get() method, e.g. lrn("classif.rpart").

2.3.2 Learner API

Each learner provides the following meta-information:

  • feature_types: the type of features the learner can deal with.
  • packages: the packages required to train a model with this learner and make predictions.
  • properties: additional properties and capabilities. For example, a learner has the property “missings” if it is able to handle missing feature values, and “importance” if it computes and allows to extract data on the relative importance of the features. A complete list of these is available in the mlr3 reference.
  • predict_types: possible prediction types. For example, a classification learner can predict labels (“response”) or probabilities (“prob”). For a complete list of possible predict types see the mlr3 reference.

You can retrieve a specific learner using its id:

learner = lrn("classif.rpart")
print(learner)
## <LearnerClassifRpart:classif.rpart>
## * Model: -
## * Parameters: xval=0
## * Packages: rpart
## * Predict Type: response
## * Feature types: logical, integer, numeric, factor, ordered
## * Properties: importance, missings, multiclass, selected_features,
##   twoclass, weights

The field param_set stores a description of the hyperparameters the learner has, their ranges, defaults, and current values:

learner$param_set
## <ParamSet>
##                 id    class lower upper nlevels        default value
##  1:             cp ParamDbl     0     1     Inf           0.01      
##  2:     keep_model ParamLgl    NA    NA       2          FALSE      
##  3:     maxcompete ParamInt     0   Inf     Inf              4      
##  4:       maxdepth ParamInt     1    30      30             30      
##  5:   maxsurrogate ParamInt     0   Inf     Inf              5      
##  6:      minbucket ParamInt     1   Inf     Inf <NoDefault[3]>      
##  7:       minsplit ParamInt     1   Inf     Inf             20      
##  8: surrogatestyle ParamInt     0     1       2              0      
##  9:   usesurrogate ParamInt     0     2       3              2      
## 10:           xval ParamInt     0   Inf     Inf             10     0

The set of current hyperparameter values is stored in the values field of the param_set field. You can change the current hyperparameter values by assigning a named list to this field:

learner$param_set$values = list(cp = 0.01, xval = 0)
learner
## <LearnerClassifRpart:classif.rpart>
## * Model: -
## * Parameters: cp=0.01, xval=0
## * Packages: rpart
## * Predict Type: response
## * Feature types: logical, integer, numeric, factor, ordered
## * Properties: importance, missings, multiclass, selected_features,
##   twoclass, weights

Note that this operation just overwrites all previously set parameters. If you just want to add a new hyperparameter, retrieve the current set of parameter values, modify the named list and write it back to the learner:

pv = learner$param_set$values
pv$cp = 0.02
learner$param_set$values = pv

This updates cp to 0.02 and keeps the previously set parameter xval.

Note that the lrn() function also accepts additional arguments which are then used to update hyperparameters or set fields of the learner in one go:

learner = lrn("classif.rpart", id = "rp", cp = 0.001)
learner$id
## [1] "rp"
learner$param_set$values
## $xval
## [1] 0
## 
## $cp
## [1] 0.001

2.4 Train, Predict, Score

In this section, we explain how tasks and learners can be used to train a model and predict to a new dataset. The concept is demonstrated on a supervised classification using the penguins dataset and the rpart learner, which builds a singe classification tree.

Training a learner means fitting a model to a given data set. Subsequently, we want to predict the label for new observations. These predictions are compared to the ground truth values in order to assess the predictive performance of the model.

2.4.1 Creating Task and Learner Objects

First of all, we load the mlr3verse package.

Next, we retrieve the task and the learner from mlr_tasks (with shortcut tsk()) and mlr_learners (with shortcut lrn()), respectively:

  1. The classification task:
task = tsk("penguins")
  1. A learner for the classification tree:
learner = lrn("classif.rpart")

2.4.2 Setting up the train/test splits of the data

It is common to train on a majority of the data. Here we use 80% of all available observations and predict on the remaining 20%. For this purpose, we create two index vectors:

train_set = sample(task$nrow, 0.8 * task$nrow)
test_set = setdiff(seq_len(task$nrow), train_set)

In Section 2.5 we will learn how mlr3 can automatically create training and test sets based on different resampling strategies.

2.4.3 Training the learner

The field $model stores the model that is produced in the training step. Before the $train() method is called on a learner object, this field is NULL:

learner$model
## NULL

Next, the classification tree is trained using the train set of the sonar task by calling the $train() method of the Learner:

learner$train(task, row_ids = train_set)

This operation modifies the learner in-place. We can now access the stored model via the field $model:

print(learner$model)
## n= 275 
## 
## node), split, n, loss, yval, (yprob)
##       * denotes terminal node
## 
## 1) root 275 158 Adelie (0.42545 0.20000 0.37455)  
##   2) flipper_length< 207.5 167  52 Adelie (0.68862 0.31138 0.00000)  
##     4) bill_length< 43.35 115   3 Adelie (0.97391 0.02609 0.00000) *
##     5) bill_length>=43.35 52   3 Chinstrap (0.05769 0.94231 0.00000) *
##   3) flipper_length>=207.5 108   5 Gentoo (0.01852 0.02778 0.95370) *

2.4.4 Predicting

After the model has been trained, we use the remaining part of the data for prediction. Remember that we initially split the data in train_set and test_set.

prediction = learner$predict(task, row_ids = test_set)
print(prediction)
## <PredictionClassif> for 69 observations:
##     row_ids     truth  response
##           3    Adelie    Adelie
##           8    Adelie    Adelie
##          15    Adelie    Adelie
## ---                            
##         335 Chinstrap Chinstrap
##         336 Chinstrap Chinstrap
##         344 Chinstrap Chinstrap

The $predict() method of the Learner returns a Prediction object. More precisely, a LearnerClassif returns a PredictionClassif object.

A prediction objects holds the row ids of the test data, the respective true label of the target column and the respective predictions. The simplest way to extract this information is by converting the Prediction object to a data.table():

head(as.data.table(prediction))
##    row_ids  truth  response
## 1:       3 Adelie    Adelie
## 2:       8 Adelie    Adelie
## 3:      15 Adelie    Adelie
## 4:      20 Adelie Chinstrap
## 5:      23 Adelie    Adelie
## 6:      24 Adelie    Adelie

For classification, you can also extract the confusion matrix:

prediction$confusion
##            truth
## response    Adelie Chinstrap Gentoo
##   Adelie        34         2      0
##   Chinstrap      1        10      2
##   Gentoo         0         1     19

2.4.5 Changing the Predict Type

Classification learners default to predicting the class label. However, many classifiers additionally also tell you how sure they are about the predicted label by providing posterior probabilities. To switch to predicting these probabilities, the predict_type field of a LearnerClassif must be changed from "response" to "prob" before training:

learner$predict_type = "prob"

# re-fit the model
learner$train(task, row_ids = train_set)

# rebuild prediction object
prediction = learner$predict(task, row_ids = test_set)

The prediction object now contains probabilities for all class labels:

# data.table conversion
head(as.data.table(prediction))
##    row_ids  truth  response prob.Adelie prob.Chinstrap prob.Gentoo
## 1:       3 Adelie    Adelie     0.97391        0.02609           0
## 2:       8 Adelie    Adelie     0.97391        0.02609           0
## 3:      15 Adelie    Adelie     0.97391        0.02609           0
## 4:      20 Adelie Chinstrap     0.05769        0.94231           0
## 5:      23 Adelie    Adelie     0.97391        0.02609           0
## 6:      24 Adelie    Adelie     0.97391        0.02609           0
# directly access the predicted labels:
head(prediction$response)
## [1] Adelie    Adelie    Adelie    Chinstrap Adelie    Adelie   
## Levels: Adelie Chinstrap Gentoo
# directly access the matrix of probabilities:
head(prediction$prob)
##       Adelie Chinstrap Gentoo
## [1,] 0.97391   0.02609      0
## [2,] 0.97391   0.02609      0
## [3,] 0.97391   0.02609      0
## [4,] 0.05769   0.94231      0
## [5,] 0.97391   0.02609      0
## [6,] 0.97391   0.02609      0

Analogously to predicting probabilities, many regression learners support the extraction of standard error estimates by setting the predict type to "se".

2.4.6 Plotting Predictions

Analogously to plotting tasks, mlr3viz provides a autoplot() method for Prediction objects. All available types are listed on the manual page of autoplot.PredictionClassif() or autoplot.PredictionRegr(), respectively.

task = tsk("penguins")
learner = lrn("classif.rpart", predict_type = "prob")
learner$train(task)
prediction = learner$predict(task)
autoplot(prediction)

2.4.7 Performance assessment

The last step of modeling is usually the performance assessment. To assess the quality of the predictions, the predicted labels are compared with the true labels. How this comparison is calculated is defined by a measure, which is given by a Measure object. Note that if the prediction was made on a dataset without the target column, i.e. without true labels, then no performance can be calculated.

Predefined available measures are stored in mlr_measures (with convenience getter msr()):

mlr_measures
## <DictionaryMeasure> with 85 stored values
## Keys: aic, bic, classif.acc, classif.auc, classif.bacc, classif.bbrier,
##   classif.ce, classif.costs, classif.dor, classif.fbeta, classif.fdr,
##   classif.fn, classif.fnr, classif.fomr, classif.fp, classif.fpr,
##   classif.logloss, classif.mbrier, classif.mcc, classif.npv,
##   classif.ppv, classif.prauc, classif.precision, classif.recall,
##   classif.sensitivity, classif.specificity, classif.tn, classif.tnr,
##   classif.tp, classif.tpr, clust.ch, clust.db, clust.dunn,
##   clust.silhouette, clust.wss, debug, dens.logloss, oob_error,
##   regr.bias, regr.ktau, regr.mae, regr.mape, regr.maxae, regr.medae,
##   regr.medse, regr.mse, regr.msle, regr.pbias, regr.rae, regr.rmse,
##   regr.rmsle, regr.rrse, regr.rse, regr.rsq, regr.sae, regr.smape,
##   regr.srho, regr.sse, selected_features, surv.brier, surv.calib_alpha,
##   surv.calib_beta, surv.chambless_auc, surv.cindex, surv.dcalib,
##   surv.graf, surv.hung_auc, surv.intlogloss, surv.logloss, surv.mae,
##   surv.mse, surv.nagelk_r2, surv.oquigley_r2, surv.rmse, surv.schmid,
##   surv.song_auc, surv.song_tnr, surv.song_tpr, surv.uno_auc,
##   surv.uno_tnr, surv.uno_tpr, surv.xu_r2, time_both, time_predict,
##   time_train

We choose accuracy (classif.acc) as a specific performance measure and call the method $score() of the Prediction object to quantify the predictive performance.

measure = msr("classif.acc")
print(measure)
## <MeasureClassifSimple:classif.acc>
## * Packages: mlr3measures
## * Range: [0, 1]
## * Minimize: FALSE
## * Parameters: list()
## * Properties: -
## * Predict type: response
prediction$score(measure)
## classif.acc 
##      0.9651

Note that, if no measure is specified, classification defaults to classification error (classif.ce) and regression defaults to the mean squared error (regr.mse).

2.5 Resampling

Resampling strategies are usually used to assess the performance of a learning algorithm. mlr3 entails the following predefined resampling strategies:

The following sections provide guidance on how to set and select a resampling strategy and how to subsequently instantiate the resampling process.

Here is a graphical illustration of the resampling process:

2.5.1 Settings

In this example we use the penguins task and a simple classification tree from the rpart package once again.

library("mlr3verse")

task = tsk("penguins")
learner = lrn("classif.rpart")

When performing resampling with a dataset, we first need to define which approach should be used. mlr3 resampling strategies and their parameters can be queried by looking at the data.table output of the mlr_resamplings dictionary:

as.data.table(mlr_resamplings)
##            key        params iters
## 1:   bootstrap ratio,repeats    30
## 2:      custom                  NA
## 3:   custom_cv                  NA
## 4:          cv         folds    10
## 5:     holdout         ratio     1
## 6:    insample                   1
## 7:         loo                  NA
## 8: repeated_cv folds,repeats   100
## 9: subsampling ratio,repeats    30

Additional resampling methods for special use cases will be available via extension packages, such as mlr3spatiotemporal for spatial data.

The model fit conducted in the train/predict/score chapter is equivalent to a “holdout resampling”, so let’s consider this one first. Again, we can retrieve elements from the dictionary mlr_resamplings via $get() or with the convenience functionrsmp():

resampling = rsmp("holdout")
print(resampling)
## <ResamplingHoldout> with 1 iterations
## * Instantiated: FALSE
## * Parameters: ratio=0.6667

Note that the $is_instantiated field is set to FALSE. This means we did not actually apply the strategy on a dataset yet. Applying the strategy on a dataset is done in the next section Instantiation.

By default we get a .66/.33 split of the data. There are two ways in which the ratio can be changed:

  1. Overwriting the slot in $param_set$values using a named list:
resampling$param_set$values = list(ratio = 0.8)
  1. Specifying the resampling parameters directly during construction:
rsmp("holdout", ratio = 0.8)
## <ResamplingHoldout> with 1 iterations
## * Instantiated: FALSE
## * Parameters: ratio=0.8

2.5.2 Instantiation

So far we just set the stage and selected the resampling strategy.

To actually perform the splitting and obtain indices for the training and the test split the resampling needs a Task. By calling the method instantiate(), we split the indices of the data into indices for training and test sets. These resulting indices are stored in the Resampling objects. To better illustrate the following operations, we switch to a 3-fold cross-validation:

resampling = rsmp("cv", folds = 3)
resampling$instantiate(task)
resampling$iters
## [1] 3
str(resampling$train_set(1))
##  int [1:229] 3 9 11 16 19 20 21 25 29 33 ...
str(resampling$test_set(1))
##  int [1:115] 1 2 8 12 15 22 23 24 36 38 ...

Note that if you want to compare multiple Learners in a fair manner, using the same instantiated resampling for each learner is mandatory. A way to greatly simplify the comparison of multiple learners is discussed in the next section on benchmarking.

2.5.3 Execution

With a Task, a Learner and a Resampling object we can call resample(), which repeatedly fits the learner to the task at hand according to the given resampling strategy. This in turn creates a ResampleResult object. We tell resample() to keep the fitted models by setting the store_models option to TRUEand then start the computation:

task = tsk("penguins")
learner = lrn("classif.rpart", maxdepth = 3, predict_type = "prob")
resampling = rsmp("cv", folds = 3)

rr = resample(task, learner, resampling, store_models = TRUE)
print(rr)
## <ResampleResult> of 3 iterations
## * Task: penguins
## * Learner: classif.rpart
## * Warnings: 0 in 0 iterations
## * Errors: 0 in 0 iterations

The returned ResampleResult stored as rr provides various getters to access the stored information:

  • Calculate the average performance across all resampling iterations:

    rr$aggregate(msr("classif.ce"))
    ## classif.ce 
    ##    0.06974
  • Extract the performance for the individual resampling iterations:

    rr$score(msr("classif.ce"))
    ##                 task  task_id                   learner    learner_id
    ## 1: <TaskClassif[47]> penguins <LearnerClassifRpart[36]> classif.rpart
    ## 2: <TaskClassif[47]> penguins <LearnerClassifRpart[36]> classif.rpart
    ## 3: <TaskClassif[47]> penguins <LearnerClassifRpart[36]> classif.rpart
    ##            resampling resampling_id iteration              prediction
    ## 1: <ResamplingCV[19]>            cv         1 <PredictionClassif[19]>
    ## 2: <ResamplingCV[19]>            cv         2 <PredictionClassif[19]>
    ## 3: <ResamplingCV[19]>            cv         3 <PredictionClassif[19]>
    ##    classif.ce
    ## 1:    0.11304
    ## 2:    0.03478
    ## 3:    0.06140
  • Check for warnings or errors:

    rr$warnings
    ## Empty data.table (0 rows and 2 cols): iteration,msg
    rr$errors
    ## Empty data.table (0 rows and 2 cols): iteration,msg
  • Extract and inspect the resampling splits:

    rr$resampling
    ## <ResamplingCV> with 3 iterations
    ## * Instantiated: TRUE
    ## * Parameters: folds=3
    rr$resampling$iters
    ## [1] 3
    str(rr$resampling$test_set(1))
    ##  int [1:115] 1 3 9 10 11 13 19 23 26 27 ...
    str(rr$resampling$train_set(1))
    ##  int [1:229] 2 4 14 15 16 22 25 29 30 36 ...
  • Retrieve the learner of a specific iteration and inspect it:

    lrn = rr$learners[[1]]
    lrn$model
    ## n= 229 
    ## 
    ## node), split, n, loss, yval, (yprob)
    ##       * denotes terminal node
    ## 
    ##  1) root 229 131 Adelie (0.427948 0.196507 0.375546)  
    ##    2) flipper_length< 207.5 142  44 Adelie (0.690141 0.302817 0.007042)  
    ##      4) bill_length< 43.15 94   1 Adelie (0.989362 0.010638 0.000000) *
    ##      5) bill_length>=43.15 48   6 Chinstrap (0.104167 0.875000 0.020833)  
    ##       10) body_mass>=4125 9   4 Adelie (0.555556 0.333333 0.111111) *
    ##       11) body_mass< 4125 39   0 Chinstrap (0.000000 1.000000 0.000000) *
    ##    3) flipper_length>=207.5 87   2 Gentoo (0.000000 0.022989 0.977011) *
  • Extract the predictions:

    rr$prediction() # all predictions merged into a single Prediction object
    ## <PredictionClassif> for 344 observations:
    ##     row_ids     truth  response prob.Adelie prob.Chinstrap prob.Gentoo
    ##           1    Adelie    Adelie      0.9894        0.01064           0
    ##           3    Adelie    Adelie      0.9894        0.01064           0
    ##           9    Adelie    Adelie      0.9894        0.01064           0
    ## ---                                                                   
    ##         337 Chinstrap Chinstrap      0.1000        0.90000           0
    ##         340 Chinstrap Chinstrap      0.1000        0.90000           0
    ##         341 Chinstrap Chinstrap      0.1000        0.90000           0
    rr$predictions()[[1]] # prediction of first resampling iteration
    ## <PredictionClassif> for 115 observations:
    ##     row_ids     truth  response prob.Adelie prob.Chinstrap prob.Gentoo
    ##           1    Adelie    Adelie      0.9894        0.01064           0
    ##           3    Adelie    Adelie      0.9894        0.01064           0
    ##           9    Adelie    Adelie      0.9894        0.01064           0
    ## ---                                                                   
    ##         338 Chinstrap Chinstrap      0.0000        1.00000           0
    ##         339 Chinstrap Chinstrap      0.0000        1.00000           0
    ##         344 Chinstrap Chinstrap      0.0000        1.00000           0
  • Filter to only keep specified iterations:

    rr$filter(c(1, 3))
    print(rr)
    ## <ResampleResult> of 2 iterations
    ## * Task: penguins
    ## * Learner: classif.rpart
    ## * Warnings: 0 in 0 iterations
    ## * Errors: 0 in 0 iterations

2.5.4 Custom resampling

Sometimes it is necessary to perform resampling with custom splits, e.g. to reproduce results reported in a study. A manual resampling instance can be created using the "custom" template.

resampling = rsmp("custom")
resampling$instantiate(task,
  train = list(c(1:10, 51:60, 101:110)),
  test = list(c(11:20, 61:70, 111:120))
)
resampling$iters
## [1] 1
resampling$train_set(1)
##  [1]   1   2   3   4   5   6   7   8   9  10  51  52  53  54  55  56  57  58  59
## [20]  60 101 102 103 104 105 106 107 108 109 110
resampling$test_set(1)
##  [1]  11  12  13  14  15  16  17  18  19  20  61  62  63  64  65  66  67  68  69
## [20]  70 111 112 113 114 115 116 117 118 119 120

2.5.5 Resampling with (predefined) groups

One can denote that certain observations should be grouped during resampling, meaning they will always either belong to the train or test set. This characteristic must be defined via column role "group" during Task creation (see also the help page on this column role). In {mlr} this was previously called “blocking”. See also the mlr3gallery post on this topic for a practical example.

An even more strict approach is to supply a custom grouping structure which solely makes up one fold. {mlr3} supports this via the "custom_cv" resampling method. This method accepts a factor vector with the same length as task$nrow() or a string of an existing variable in the data. Note that this approach does not allow setting the “folds” argument because the number of folds is determined by the number of factor levels. This predefined approach was called “grouping” in mlr2.

The term “blocking” is overloaded in the resampling context. In the field of spatial resampling (see r mlr_pkg(“mlr3spatiotempcv”)), “blocking” is used as a general concept for grouping observations while on the other hand there are also specific resampling methods which have “blocking” in their name (`“spcv_block”) and focus on rectangular partitioning.

2.5.6 Plotting Resample Results

mlr3viz provides a autoplot() method. To showcase some of the plots, we create a binary classification task with two features, perform a resampling with a 10-fold cross validation and visualize the results:

task = tsk("pima")
task$select(c("glucose", "mass"))
learner = lrn("classif.rpart", predict_type = "prob")
rr = resample(task, learner, rsmp("cv"), store_models = TRUE)

# boxplot of AUC values across the 10 folds
autoplot(rr, measure = msr("classif.auc"))
# ROC curve, averaged over 10 folds
autoplot(rr, type = "roc")
# learner predictions for first fold
rr$filter(1)
autoplot(rr, type = "prediction")
## Warning: Removed 1 rows containing missing values (geom_point).

All available plot types are listed on the manual page of autoplot.ResampleResult().

2.5.7 Plotting Resample Partitions

mlr3spatiotempcv provides autoplot() methods to visualize resampling partitions of spatiotemporal datasets. See the function reference and vignette “Spatiotemporal visualization” for more info.

## Warning: Ignoring unknown parameters: crs

## Warning: Ignoring unknown parameters: crs

2.6 Benchmarking

Comparing the performance of different learners on multiple tasks and/or different resampling schemes is a common task. This operation is usually referred to as “benchmarking” in the field of machine-learning. The mlr3 package offers the benchmark() convenience function.

2.6.1 Design Creation

In mlr3 we require you to supply a “design” of your benchmark experiment. Such a design is essentially a table of settings you want to execute. It consists of unique combinations of Task, Learner and Resampling triplets.

We use the benchmark_grid() function to create an exhaustive design and instantiate the resampling properly, so that all learners are executed on the same train/test split for each tasks. We set the learners to predict probabilities and also tell them to predict the observations of the training set (by setting predict_sets to c("train", "test")). Additionally, we use tsks(), lrns(), and rsmps() to retrieve lists of Task, Learner and Resampling in the same fashion as tsk(), lrn() and rsmp().

library("mlr3verse")

design = benchmark_grid(
  tasks = tsks(c("spam", "german_credit", "sonar")),
  learners = lrns(c("classif.ranger", "classif.rpart", "classif.featureless"),
    predict_type = "prob", predict_sets = c("train", "test")),
  resamplings = rsmps("cv", folds = 3)
)
print(design)
##                 task                         learner         resampling
## 1: <TaskClassif[47]>      <LearnerClassifRanger[36]> <ResamplingCV[19]>
## 2: <TaskClassif[47]>       <LearnerClassifRpart[36]> <ResamplingCV[19]>
## 3: <TaskClassif[47]> <LearnerClassifFeatureless[36]> <ResamplingCV[19]>
## 4: <TaskClassif[47]>      <LearnerClassifRanger[36]> <ResamplingCV[19]>
## 5: <TaskClassif[47]>       <LearnerClassifRpart[36]> <ResamplingCV[19]>
## 6: <TaskClassif[47]> <LearnerClassifFeatureless[36]> <ResamplingCV[19]>
## 7: <TaskClassif[47]>      <LearnerClassifRanger[36]> <ResamplingCV[19]>
## 8: <TaskClassif[47]>       <LearnerClassifRpart[36]> <ResamplingCV[19]>
## 9: <TaskClassif[47]> <LearnerClassifFeatureless[36]> <ResamplingCV[19]>

The created design can be passed to benchmark() to start the computation. It is also possible to create a custom design manually. However, if you create a custom task with data.table(), the train/test splits will be different for each row of the design if you do not manually instantiate the resampling before creating the design. See the help page on benchmark_grid() for an example.

2.6.2 Execution and Aggregation of Results

After the benchmark design is ready, we can directly call benchmark():

# execute the benchmark
bmr = benchmark(design)

Note that we did not instantiate the resampling instance manually. benchmark_grid() took care of it for us: Each resampling strategy is instantiated once for each task during the construction of the exhaustive grid.

Once the benchmarking is done, we can aggregate the performance with $aggregate(). We create two measures to calculate the AUC for the training set and for the predict set:

measures = list(
  msr("classif.auc", predict_sets = "train", id = "auc_train"),
  msr("classif.auc", id = "auc_test")
)

tab = bmr$aggregate(measures)
print(tab)
##    nr      resample_result       task_id          learner_id resampling_id
## 1:  1 <ResampleResult[20]>          spam      classif.ranger            cv
## 2:  2 <ResampleResult[20]>          spam       classif.rpart            cv
## 3:  3 <ResampleResult[20]>          spam classif.featureless            cv
## 4:  4 <ResampleResult[20]> german_credit      classif.ranger            cv
## 5:  5 <ResampleResult[20]> german_credit       classif.rpart            cv
## 6:  6 <ResampleResult[20]> german_credit classif.featureless            cv
## 7:  7 <ResampleResult[20]>         sonar      classif.ranger            cv
## 8:  8 <ResampleResult[20]>         sonar       classif.rpart            cv
## 9:  9 <ResampleResult[20]>         sonar classif.featureless            cv
##    iters auc_train auc_test
## 1:     3    0.9994   0.9858
## 2:     3    0.9104   0.9026
## 3:     3    0.5000   0.5000
## 4:     3    0.9986   0.7944
## 5:     3    0.8143   0.7012
## 6:     3    0.5000   0.5000
## 7:     3    1.0000   0.9326
## 8:     3    0.9129   0.7478
## 9:     3    0.5000   0.5000

We can aggregate the results even further. For example, we might be interested to know which learner performed best over all tasks simultaneously. Simply aggregating the performances with the mean is usually not statistically sound. Instead, we calculate the rank statistic for each learner grouped by task. Then the calculated ranks grouped by learner are aggregated with data.table. Since the AUC needs to be maximized, we multiply the values by \(-1\) so that the best learner has a rank of \(1\).

library("data.table")
# group by levels of task_id, return columns:
# - learner_id
# - rank of col '-auc_train' (per level of learner_id)
# - rank of col '-auc_test' (per level of learner_id)
ranks = tab[, .(learner_id, rank_train = rank(-auc_train), rank_test = rank(-auc_test)), by = task_id]
print(ranks)
##          task_id          learner_id rank_train rank_test
## 1:          spam      classif.ranger          1         1
## 2:          spam       classif.rpart          2         2
## 3:          spam classif.featureless          3         3
## 4: german_credit      classif.ranger          1         1
## 5: german_credit       classif.rpart          2         2
## 6: german_credit classif.featureless          3         3
## 7:         sonar      classif.ranger          1         1
## 8:         sonar       classif.rpart          2         2
## 9:         sonar classif.featureless          3         3
# group by levels of learner_id, return columns:
# - mean rank of col 'rank_train' (per level of learner_id)
# - mean rank of col 'rank_test' (per level of learner_id)
ranks = ranks[, .(mrank_train = mean(rank_train), mrank_test = mean(rank_test)), by = learner_id]

# print the final table, ordered by mean rank of AUC test
ranks[order(mrank_test)]
##             learner_id mrank_train mrank_test
## 1:      classif.ranger           1          1
## 2:       classif.rpart           2          2
## 3: classif.featureless           3          3

Unsurprisingly, the featureless learner is outperformed on both training and test set. The classification forest also outperforms a single classification tree.

2.6.3 Plotting Benchmark Results

Analogously to plotting tasks, predictions or resample results, mlr3viz also provides a autoplot() method for benchmark results.

autoplot(bmr) + ggplot2::theme(axis.text.x = ggplot2::element_text(angle = 45, hjust = 1))

We can also plot ROC curves. To do so, we first need to filter the BenchmarkResult to only contain a single Task:

bmr_small = bmr$clone()$filter(task_id = "german_credit")
autoplot(bmr_small, type = "roc")

All available plot types are listed on the manual page of autoplot.BenchmarkResult().

2.6.4 Extracting ResampleResults

A BenchmarkResult object is essentially a collection of multiple ResampleResult objects. As these are stored in a column of the aggregated data.table(), we can easily extract them:

tab = bmr$aggregate(measures)
rr = tab[task_id == "german_credit" & learner_id == "classif.ranger"]$resample_result[[1]]
print(rr)
## <ResampleResult> of 3 iterations
## * Task: german_credit
## * Learner: classif.ranger
## * Warnings: 0 in 0 iterations
## * Errors: 0 in 0 iterations

We can now investigate this resampling and even single resampling iterations using one of the approaches shown in the previous section:

measure = msr("classif.auc")
rr$aggregate(measure)
## classif.auc 
##      0.7944
# get the iteration with worst AUC
perf = rr$score(measure)
i = which.min(perf$classif.auc)

# get the corresponding learner and train set
print(rr$learners[[i]])
## <LearnerClassifRanger:classif.ranger>
## * Model: -
## * Parameters: num.threads=1
## * Packages: ranger
## * Predict Type: prob
## * Feature types: logical, integer, numeric, character, factor, ordered
## * Properties: importance, multiclass, oob_error, twoclass, weights
head(rr$resampling$train_set(i))
## [1]  4  5  6  7  9 16

2.6.5 Converting and Merging

A ResampleResult can be casted to a BenchmarkResult using the converter as_benchmark_result(). Additionally, two BenchmarkResults can be merged into a larger result object.

task = tsk("iris")
resampling = rsmp("holdout")$instantiate(task)

rr1 = resample(task, lrn("classif.rpart"), resampling)
rr2 = resample(task, lrn("classif.featureless"), resampling)

# Cast both ResampleResults to BenchmarkResults
bmr1 = as_benchmark_result(rr1)
bmr2 = as_benchmark_result(rr2)

# Merge 2nd BMR into the first BMR
bmr1$combine(bmr2)

bmr1
## <BenchmarkResult> of 2 rows with 2 resampling runs
##  nr task_id          learner_id resampling_id iters warnings errors
##   1    iris       classif.rpart       holdout     1        0      0
##   2    iris classif.featureless       holdout     1        0      0

2.7 Binary classification

Classification problems with a target variable containing only two classes are called “binary”. For such binary target variables, you can specify the positive class within the classification task object during task creation. If not explicitly set during construction, the positive class defaults to the first level of the target variable.

# during construction
data("Sonar", package = "mlbench")
task = as_task_classif(Sonar, target = "Class", positive = "R")

# switch positive class to level 'M'
task$positive = "M"

2.7.1 ROC Curve and Thresholds

ROC Analysis, which stands for “receiver operating characteristics”, is a subfield of machine learning which studies the evaluation of binary prediction systems. We saw earlier that one can retrieve the confusion matrix of a Prediction by accessing the $confusion field:

learner = lrn("classif.rpart", predict_type = "prob")
pred = learner$train(task)$predict(task)
C = pred$confusion
print(C)
##         truth
## response  M  R
##        M 95 10
##        R 16 87

The confusion matrix contains the counts of correct and incorrect class assignments, grouped by class labels. The columns illustrate the true (observed) labels and the rows display the predicted labels. The positive is always the first row or column in the confusion matrix. Thus, the element in \(C_{11}\) is the number of times our model predicted the positive class and was right about it. Analogously, the element in \(C_{22}\) is the number of times our model predicted the negative class and was also right about it. The elements on the diagonal are called True Positives (TP) and True Negatives (TN). The element \(C_{12}\) is the number of times we falsely predicted a positive label, and is called False Positives (FP). The element \(C_{21}\) is called False Negatives (FN).

We can now normalize in rows and columns of the confusion matrix to derive several informative metrics:

  • True Positive Rate (TPR): How many of the true positives did we predict as positive?
  • True Negative Rate (TNR): How many of the true negatives did we predict as negative?
  • Positive Predictive Value PPV: If we predict positive how likely is it a true positive?
  • Negative Predictive Value NPV: If we predict negative how likely is it a true negative?

Source: Wikipedia

It is difficult to achieve a high TPR and low FPR in conjunction, so one uses them for constructing the ROC Curve. We characterize a classifier by its TPR and FPR values and plot them in a coordinate system. The best classifier lies on the top-left corner. The worst classifier lies at the diagonal. Classifiers lying on the diagonal produce random labels (with different proportions). If each positive \(x\) will be randomly classified with 25% as “positive”, we get a TPR of 0.25. If we assign each negative \(x\) randomly to “positive” we get a FPR of 0.25. In practice, we should never obtain a classifier below the diagonal, as inverting the predicted labels will result in a reflection at the diagonal.

A scoring classifier is a model which produces scores or probabilities, instead of discrete labels. To obtain probabilities from a learner in mlr3, you have to set predict_type = "prob" for a ref("LearnerClassif"). Whether a classifier can predict probabilities is given in its $predict_types field. Thresholding flexibly converts measured probabilities to labels. Predict \(1\) (positive class) if \(\hat{f}(x) > \tau\) else predict \(0\). Normally, one could use \(\tau = 0.5\) to convert probabilities to labels, but for imbalanced or cost-sensitive situations another threshold could be more suitable. After thresholding, any metric defined on labels can be used.

For mlr3 prediction objects, the ROC curve can easily be created with mlr3viz which relies on the precrec to calculate and plot ROC curves:

library("mlr3viz")

# TPR vs FPR / Sensitivity vs (1 - Specificity)
autoplot(pred, type = "roc")
# Precision vs Recall
autoplot(pred, type = "prc")

2.7.2 Threshold Tuning

Learners which can predict the probability for the positive class usually use a simple rule to determine the predicted class label: if the probability exceeds the threshold \(t = 0.5\), pick the positive label, and select the negative label otherwise. If the model is not well calibrated or the class labels are heavily unbalanced, selecting a different threshold can help to improve the predictive performance w.r.t. a chosen performance measure.

Here, we change the threshold to \(t = 0.2\), improving the True Positive Rate (TPR). Note that with the new threshold more observations from the positive class will get correctly classified with the positive label, but at the same time the True Negative Rate (TNR) decreases. Depending on the application, this may be a desired trade-off.

measures = msrs(c("classif.tpr", "classif.tnr"))
pred$confusion
##         truth
## response  M  R
##        M 95 10
##        R 16 87
pred$score(measures)
## classif.tpr classif.tnr 
##      0.8559      0.8969
pred$set_threshold(0.2)
pred$confusion
##         truth
## response   M   R
##        M 104  25
##        R   7  72
pred$score(measures)
## classif.tpr classif.tnr 
##      0.9369      0.7423

Thresholds can also be tuned with the mlr3pipelines package, i.e. using PipeOpTuneThreshold.