2 Basics

This chapter will teach you the essential building blocks of mlr3, as well as its R6 classes and operations used for machine learning. A typical machine learning workflow looks like this:

The data, which mlr3 encapsulates in tasks, is split into non-overlapping training and test sets. As we are interested in models that extrapolate to new data rather than just memorizing the training data, the separate test data allows to objectively evaluate models with respect to generalization. The training data is given to a machine learning algorithm, which we call a learner in mlr3. The learner uses the training data to build a model of the relationship of the input features to the output target values. This model is then used to produce predictions on the test data, which are compared to the ground truth values to assess the quality of the model. mlr3 offers a number of different measures to quantify how well a model performs based on the difference between predicted and actual values. Usually, this measure is a numeric score.

The process of splitting up data into training and test sets, building a model, and evaluating it can be repeated several times, resampling different training and test sets from the original data each time. Multiple resampling iterations allow us to get a better and less biased generalizable performance estimate for a particular type of model. As data are usually partitioned randomly into training and test sets, a single split can for example produce trainng and test sets that are very different, hence creating to the misleading impression that the particular type of model does not perform well.

In many cases, this simple workflow is not sufficient to deal with real-world data, which may require normalization, imputation of missing values, or feature selection. We will cover more complex workflows that allow to do this and even more later in the book.

This chapter covers the following topics:

Tasks

Tasks encapsulate the data with meta-information, such as the name of the prediction target column. We cover how to:

Learners

Learners encapsulate machine learning algorithms to train models and make predictions for a task. These are provided other packages. We cover how to:

How to extend learners and implement your own is covered in a supplemental advanced technical section.

Train and predict

The section on the train and predict methods illustrates how to use tasks and learners to train a model and make predictions on a new data set. In particular, we cover how to:

\(~~~~~~~~~~~\)

Before we get into the details of how to use mlr3 for machine learning, we give a brief introduction to R6 as it is a relatively new part of R. mlr3 heavily relies on R6 and all basic building blocks it provides are R6 classes:

2.1 Quick R6 Intro for Beginners

R6 is one of R’s more recent dialects for object-oriented programming (OO). It addresses shortcomings of earlier OO implementations in R, such as S3, which we used in mlr. If you have done any object-oriented programming before, R6 should feel familiar. We focus on the parts of R6 that you need to know to use mlr3 here.

  • Objects are created by calling the constructor of an R6::R6Class() object, specifically the initialization method $new(). For example, foo = Foo$new(bar = 1) creates a new object of class Foo, setting the bar argument of the constructor to the value 1. Most objects in mlr3 are created through special functions (e.g. lrn("regr.rpart")) that encapsulate this and are also referred to as sugar functions.
  • Objects have mutable state that is encapsulated in their fields, which can be accessed through the dollar operator. We can access the bar value in the foo variable from above through foo$bar and set its value by assigning the field, e.g. foo$bar = 2.
  • In addition to fields, objects expose methods that allow to inspect the object’s state, retrieve information, or perform an action that changes the internal state of the object. For example, the $train method of a learner changes the internal state of the learner by building and storing a trained model, which can then be used to make predictions, given data.
  • Objects can have public and private fields and methods. The public fields and methods define the API to interact with the object. Private methods are only relevant for you if you want to extend mlr3, e.g. with new learners.
  • R6 objects are internally environments, and as such have reference semantics. For example, foo2 = foo does not create a copy of foo in foo2, but another reference to the same actual object. Setting foo$bar = 3 will also change foo2$bar to 3 and vice versa.
  • To copy an object, use the $clone() method and the deep = TRUE argument for nested objects, for example, foo2 = foo$clone(deep = TRUE).

For more details on R6, have a look at the excellent R6 vignettes, especially the introduction.

2.2 Tasks

Tasks are objects that contain the (usually tabular) data and additional meta-data to define a machine learning problem. The meta-data is, for example, the name of the target variable for supervised machine learning problems, or the type of the dataset (e.g. a spatial or survival). This information is used by specific operations that can be performed on a task.

2.2.1 Task Types

To create a task from a data.frame(), data.table() or Matrix(), you first need to select the right task type:

2.2.2 Task Creation

As an example, we will create a regression task using the mtcars data set from the package datasets. It contains characteristics for different types of cars, along with their fuel consumption. We predict the numeric target variable "mpg" (miles per gallon). We only consider the first two features in the dataset for brevity.

First, we load and prepare the data, outputting it as a string to get a better idea of what it looks like.

data("mtcars", package = "datasets")
data = mtcars[, 1:3]
str(data)
## 'data.frame':    32 obs. of  3 variables:
##  $ mpg : num  21 21 22.8 21.4 18.7 18.1 14.3 24.4 22.8 19.2 ...
##  $ cyl : num  6 6 4 6 8 6 8 4 4 6 ...
##  $ disp: num  160 160 108 258 360 ...

Next, we create a regression task, i.e. we construct a new instance of the R6 class TaskRegr. Usually, this is done by calling the constructor TaskRegr$new(). Instead, we are calling the converter as_task_regr() to convert our data.frame() stored in the variable data to a task and provide the following information:

  1. x: Object to convert. Works for data.frame()/data.table()/tibble() abstract data backends implemented in the class DataBackendDataTable. The latter allows to connect to out-of-memory storage systems like SQL servers via the extension package mlr3db.
  2. target: The name of the prediction target column for the regression problem.
  3. id (optional): An arbitrary identifier for the task, used in plots and summaries. If not provided, the deparsed name of x will be used.
library("mlr3")

task_mtcars = as_task_regr(data, target = "mpg", id = "cars")
print(task_mtcars)
## <TaskRegr:cars> (32 x 3)
## * Target: mpg
## * Properties: -
## * Features (2):
##   - dbl (2): cyl, disp

The print() method gives a short summary of the task: It has 32 observations and 3 columns, of which 2 are features.

We can also plot the task using the mlr3viz package, which gives a graphical summary of its properties:

library("mlr3viz")
autoplot(task_mtcars, type = "pairs")

Note that instead of loading all the extension packages individually, it is often more convenient to load the mlr3verse package instead. mlr3verse imports most mlr3 packages and re-exports functions which are used for common machine learning and data science tasks.

2.2.3 Predefined tasks

mlr3 includes a few predefined machine learning tasks. All tasks are stored in an R6 Dictionary (a key-value store) named mlr_tasks. Printing it gives the keys (the names of the datasets):

mlr_tasks
## <DictionaryTask> with 11 stored values
## Keys: boston_housing, breast_cancer, german_credit, iris, mtcars,
##   penguins, pima, sonar, spam, wine, zoo

We can get a more informative summary of the example tasks by converting the dictionary to a data.table() object:

as.data.table(mlr_tasks)
##                key task_type nrow ncol properties lgl int dbl chr fct ord pxc
##  1: boston_housing      regr  506   19              0   3  13   0   2   0   0
##  2:  breast_cancer   classif  683   10   twoclass   0   0   0   0   0   9   0
##  3:  german_credit   classif 1000   21   twoclass   0   3   0   0  14   3   0
##  4:           iris   classif  150    5 multiclass   0   0   4   0   0   0   0
##  5:         mtcars      regr   32   11              0   0  10   0   0   0   0
##  6:       penguins   classif  344    8 multiclass   0   3   2   0   2   0   0
##  7:           pima   classif  768    9   twoclass   0   0   8   0   0   0   0
##  8:          sonar   classif  208   61   twoclass   0   0  60   0   0   0   0
##  9:           spam   classif 4601   58   twoclass   0   0  57   0   0   0   0
## 10:           wine   classif  178   14 multiclass   0   2  11   0   0   0   0
## 11:            zoo   classif  101   17 multiclass  15   1   0   0   0   0   0

Above, the columns "lgl" (logical), "int" (integer), "dbl" (double), "chr" (character), "fct" (factor), "ord" (ordered factor) and "pxc" (POSIXct time) show the number of features in the task of the respective type.

To get a task from the dictionary, use the $get() method from the mlr_tasks class and assign the return value to a new variable As getting a task from a dictionary is a very common task, mlr3 provides the shortcut function tsk(). Here, we retrieve the palmer penguins task, which is provided by the package palmerpenguins:

task_penguins = tsk("penguins")
print(task_penguins)
## <TaskClassif:penguins> (344 x 8)
## * Target: species
## * Properties: multiclass
## * Features (7):
##   - int (3): body_mass, flipper_length, year
##   - dbl (2): bill_depth, bill_length
##   - fct (2): island, sex

Note that loading extension packages can add to dictionaries such as mlr_tasks. For example, mlr3data adds some more example and toy tasks for regression and classification, and mlr3proba adds survival and density estimation tasks. Both packages are loaded automatically when the mlr3verse package is loaded:

library("mlr3verse")
as.data.table(mlr_tasks)[, 1:4]
##                key task_type  nrow ncol
##  1:           actg      surv  1151   13
##  2:   bike_sharing      regr 17379   14
##  3: boston_housing      regr   506   19
##  4:  breast_cancer   classif   683   10
##  5:       faithful      dens   272    1
##  6:           gbcs      surv   686   10
##  7:  german_credit   classif  1000   21
##  8:          grace      surv  1000    8
##  9:           ilpd   classif   583   11
## 10:           iris   classif   150    5
## 11:     kc_housing      regr 21613   20
## 12:           lung      surv   228   10
## 13:      moneyball      regr  1232   15
## 14:         mtcars      regr    32   11
## 15:      optdigits   classif  5620   65
## 16:       penguins   classif   344    8
## 17:           pima   classif   768    9
## 18:         precip      dens    70    1
## 19:           rats      surv   300    5
## 20:          sonar   classif   208   61
## 21:           spam   classif  4601   58
## 22:        titanic   classif  1309   11
## 23:   unemployment      surv  3343    6
## 24:      usarrests     clust    50    4
## 25:           whas      surv   481   11
## 26:           wine   classif   178   14
## 27:            zoo   classif   101   17
##                key task_type  nrow ncol

To get more information about a particular task, the corresponding man page can be found under mlr_tasks_[id], e.g. mlr_tasks_german_credit:

help("mlr_tasks_german_credit")

2.2.4 Task API

All properties and characteristics of tasks can be queried using the task’s public fields and methods (see Task). Methods can also be used to change the stored data and the behavior of the task.

2.2.4.1 Retrieving Data

The data stored in a task can be retrieved directly from fields, for example for task_mtcars that we defined above we can get the number of rows and columns:

task_mtcars
## <TaskRegr:cars> (32 x 3)
## * Target: mpg
## * Properties: -
## * Features (2):
##   - dbl (2): cyl, disp
task_mtcars$nrow
## [1] 32
task_mtcars$ncol
## [1] 3

More information can be obtained through methods of the object, for example:

task_mtcars$data()
##      mpg cyl  disp
##  1: 21.0   6 160.0
##  2: 21.0   6 160.0
##  3: 22.8   4 108.0
##  4: 21.4   6 258.0
##  5: 18.7   8 360.0
##  6: 18.1   6 225.0
##  7: 14.3   8 360.0
##  8: 24.4   4 146.7
##  9: 22.8   4 140.8
## 10: 19.2   6 167.6
## 11: 17.8   6 167.6
## 12: 16.4   8 275.8
## 13: 17.3   8 275.8
## 14: 15.2   8 275.8
## 15: 10.4   8 472.0
## 16: 10.4   8 460.0
## 17: 14.7   8 440.0
## 18: 32.4   4  78.7
## 19: 30.4   4  75.7
## 20: 33.9   4  71.1
## 21: 21.5   4 120.1
## 22: 15.5   8 318.0
## 23: 15.2   8 304.0
## 24: 13.3   8 350.0
## 25: 19.2   8 400.0
## 26: 27.3   4  79.0
## 27: 26.0   4 120.3
## 28: 30.4   4  95.1
## 29: 15.8   8 351.0
## 30: 19.7   6 145.0
## 31: 15.0   8 301.0
## 32: 21.4   4 121.0
##      mpg cyl  disp

In mlr3, each row (observation) has a unique identifier, stored as an integer(). These can be passed as arguments to the $data() method to select specific rows:

head(task_mtcars$row_ids)
## [1] 1 2 3 4 5 6
# retrieve data for rows with IDs 1, 5, and 10
task_mtcars$data(rows = c(1, 5, 10))
##     mpg cyl  disp
## 1: 21.0   6 160.0
## 2: 18.7   8 360.0
## 3: 19.2   6 167.6

Note that although the row IDs are typically just the sequence from 1 to nrow(data), they are only guaranteed to be unique natural numbers. Keep that in mind, especially if you work with data stored in a real data base management system (see backends).

Similarly to row IDs, target and feature columns also have unique identifiers, i.e. names (stored as character()). Their names can be accessed via the public slots $feature_names and $target_names. Here, “target” refers to the variable we want to predict and “feature” to the predictors for the task. The target will usually be only a single name.

task_mtcars$feature_names
## [1] "cyl"  "disp"
task_mtcars$target_names
## [1] "mpg"

The row_ids and column names can be combined when selecting a subset of the data:

# retrieve data for rows 1, 5, and 10 and only select column "mpg"
task_mtcars$data(rows = c(1, 5, 10), cols = "mpg")
##     mpg
## 1: 21.0
## 2: 18.7
## 3: 19.2

To extract the complete data from the task, one can also simply convert it to a data.table:

# show summary of entire data
summary(as.data.table(task_mtcars))
##       mpg            cyl            disp      
##  Min.   :10.4   Min.   :4.00   Min.   : 71.1  
##  1st Qu.:15.4   1st Qu.:4.00   1st Qu.:120.8  
##  Median :19.2   Median :6.00   Median :196.3  
##  Mean   :20.1   Mean   :6.19   Mean   :230.7  
##  3rd Qu.:22.8   3rd Qu.:8.00   3rd Qu.:326.0  
##  Max.   :33.9   Max.   :8.00   Max.   :472.0

2.2.4.2 Binary classification

Classification problems with a target variable with only two classes are called binary classification tasks. They are special in the sense that one of these classes is denoted positive and the other one negative. You can specify the positive class within the classification task object during task creation. If not explicitly set during construction, the positive class defaults to the first level of the target variable.

# during construction
data("Sonar", package = "mlbench")
task = as_task_classif(Sonar, target = "Class", positive = "R")

# switch positive class to level 'M'
task$positive = "M"

2.2.4.3 Roles (Rows and Columns)

It is possible to assign different roles to rows and columns. These roles affect the behavior of the task for different operations. We have already seen this for the target and feature roles that are assigned to columns when a task is created. For other possible roles and their meaning, see the documentation of Task.

For example, the previously-constructed task_mtcars task has the following column roles:

print(task_mtcars$col_roles)
## $feature
## [1] "cyl"  "disp"
## 
## $target
## [1] "mpg"
## 
## $name
## character(0)
## 
## $order
## character(0)
## 
## $stratum
## character(0)
## 
## $group
## character(0)
## 
## $weight
## character(0)

Columns can have no role (they are ignored) or have multiple roles. To add the row names of task_mtcars as an additional feature, we first add them to the underlying data as regular column and then recreate the task with the new column.

# with `keep.rownames`, data.table stores the row names in an extra column "rn"
data = as.data.table(datasets::mtcars[, 1:3], keep.rownames = TRUE)
task_mtcars = as_task_regr(data, target = "mpg", id = "cars")

# there is a new feature called "rn"
task_mtcars$feature_names
## [1] "cyl"  "disp" "rn"

The row names are now a feature whose values are stored in the column "rn". We include this column here for educational purposes only. Generally speaking, there is no point in having a feature that uniquely identifies each row. Furthermore, the character data type will cause problems with many types of machine learning algorithms.

The identifier may be useful to label points in plots, for example to identify outliers. To achieve this, we will change the role of the rn column by removing it from the list of features and assign the new role "name". There are two ways to do this:

  1. Use the Task method $set_col_roles() (recommended).
  2. Simply modify the field $col_roles, which is a named list of vectors of column names. Each vector in this list corresponds to a column role, and the column names contained in that vector have that role.

Supported column roles can be found in the manual of Task, or just by printing the names of the field $col_roles:

# supported column roles, see ?Task
names(task_mtcars$col_roles)
## [1] "feature" "target"  "name"    "order"   "stratum" "group"   "weight"
# assign column "rn" the role "name", remove from other roles
task_mtcars$set_col_roles("rn", roles = "name")

# note that "rn" not listed as feature anymore
task_mtcars$feature_names
## [1] "cyl"  "disp"
# "rn" also does not appear anymore when we access the data
task_mtcars$data(rows = 1:2)
##    mpg cyl disp
## 1:  21   6  160
## 2:  21   6  160

Changing the role does not change the underlying data, it just updates the view on it. The data is not copied in the code above. The view is changed in-place though, i.e. the task object itself is modified.

Just like columns, it is also possible to assign different roles to rows.

Rows can have two different roles:

  1. Role use: Rows that are generally available for model fitting (although they may also be used as test set in resampling). This role is the default role.

  2. Role validation: Rows that are not used for training. Rows that have missing values in the target column during task creation are automatically set to the validation role.

There are several reasons to hold some observations back or treat them differently:

  1. It is often good practice to validate the final model on an external validation set to identify possible overfitting.
  2. Some observations may be unlabeled, e.g. in competitions like Kaggle.

These observations cannot be used for training a model, but can be used to get predictions.

2.2.4.4 Task Mutators

As shown above, modifying $col_roles or $row_roles (either via set_col_roles()/set_row_roles() or directly by modifying the named list) changes the view on the data. The additional convenience method $filter() subsets the current view based on row IDs and $select() subsets the view based on feature names.

task_penguins = tsk("penguins")
task_penguins$select(c("body_mass", "flipper_length")) # keep only these features
task_penguins$filter(1:3) # keep only these rows
task_penguins$head()
##    species body_mass flipper_length
## 1:  Adelie      3750            181
## 2:  Adelie      3800            186
## 3:  Adelie      3250            195

The methods above allow to subset the data; the methods $rbind() and $cbind() allow to add extra rows and columns to a task. Again, the original data is not changed. The additional rows or columns are only added to the view of the data.

task_penguins$cbind(data.frame(letters = letters[1:3])) # add column letters
task_penguins$head()
##    species body_mass flipper_length letters
## 1:  Adelie      3750            181       a
## 2:  Adelie      3800            186       b
## 3:  Adelie      3250            195       c

2.2.5 Plotting Tasks

The mlr3viz package provides plotting facilities for many classes implemented in mlr3. The available plot types depend on the class, but all plots are returned as ggplot2 objects which can be easily customized.

For classification tasks (inheriting from TaskClassif), see the documentation of mlr3viz::autoplot.TaskClassif for the implemented plot types. Here are some examples to get an impression:

library("mlr3viz")

# get the pima indians task
task = tsk("pima")

# subset task to only use the 3 first features
task$select(head(task$feature_names, 3))

# default plot: class frequencies
autoplot(task)
# pairs plot (requires package GGally)
autoplot(task, type = "pairs")
# duo plot (requires package GGally)
autoplot(task, type = "duo")

Of course, you can do the same for regression tasks (inheriting from TaskRegr) as documented in mlr3viz::autoplot.TaskRegr:

library("mlr3viz")

# get the complete mtcars task
task = tsk("mtcars")

# subset task to only use the 3 first features
task$select(head(task$feature_names, 3))

# default plot: boxplot of target variable
autoplot(task)
# pairs plot (requires package GGally)
autoplot(task, type = "pairs")

2.3 Learners

Objects of class Learner provide a unified interface to many popular machine learning algorithms in R. They consist of methods to train and predict a model for a Task and provide meta-information about the learners, such as the hyperparameters (which control the behavior of the learner) you can set.

The base class of each learner is Learner, specialized for regression as LearnerRegr and for classification as LearnerClassif. Other types of learners, provided by extension packages, also inherit from the Learner base class, e.g. mlr3proba::LearnerSurv or mlr3cluster::LearnerClust. In contrast to Tasks, the creation of a custom Learner is usually not required and a more advanced topic. Hence, we refer the reader to Section 7.1 and proceed with an overview of the interface of already implemented learners.

All Learners work in a two-stage procedure:
  • training step: The training data (features and target) is passed to the Learner’s $train() function which trains and stores a model, i.e. the relationship of the target and features.
  • predict step: The new data, usually a different slice of the original data than used for training, is passed to the $predict() method of the Learner. The model trained in the first step is used to predict the missing target, e.g. labels for classification problems or the numerical value for regression problems.

2.3.1 Predefined Learners

The mlr3 package ships with the following set of classification and regression learners. We deliberately keep this small to avoid unnecessary dependencies:

This set of baseline learners is usually insufficient for a real data analysis. Thus, we have cherry-picked implementations of the most popular machine learning method and collected them in the mlr3learners package:

  • Linear and logistic regression
  • Penalized Generalized Linear Models
  • \(k\)-Nearest Neighbors regression and classification
  • Kriging
  • Linear and Quadratic Discriminant Analysis
  • Naive Bayes
  • Support-Vector machines
  • Gradient Boosting
  • Random Forests for regression, classification and survival

More machine learning methods and alternative implementations are collected in the mlr3extralearners repository.

A full list of implemented learners across all packages is given in this interactive list and also via mlr3extralearners::list_mlr3learners().

head(mlr3extralearners::list_mlr3learners()) # show first six learners
##          name   class                 id      mlr3_package
## 1: AdaBoostM1 classif classif.AdaBoostM1 mlr3extralearners
## 2:       bart classif       classif.bart mlr3extralearners
## 3:        C50 classif        classif.C50 mlr3extralearners
## 4:   catboost classif   classif.catboost mlr3extralearners
## 5:    cforest classif    classif.cforest mlr3extralearners
## 6:      ctree classif      classif.ctree mlr3extralearners
##         required_packages                                      properties
## 1:                  RWeka                             multiclass,twoclass
## 2:                 dbarts                                twoclass,weights
## 3:                    C50            missings,multiclass,twoclass,weights
## 4:               catboost importance,missings,multiclass,twoclass,weights
## 5: partykit,sandwich,coin           multiclass,oob_error,twoclass,weights
## 6: partykit,sandwich,coin                     multiclass,twoclass,weights
##                     feature_types predict_types
## 1:         numeric,factor,ordered response,prob
## 2: integer,numeric,factor,ordered response,prob
## 3:         numeric,factor,ordered response,prob
## 4:         numeric,factor,ordered response,prob
## 5: integer,numeric,factor,ordered response,prob
## 6: integer,numeric,factor,ordered response,prob

The full list of learners uses a large number of extra packages, which sometimes break. We check the status of each learner’s integration automatically, he latest build status of all learners is shown here.

To get one of the predefined learners, you need to access the mlr_learners Dictionary which, similar to mlr_tasks, is automatically populated with more learners by extension packages.

# load most mlr3 packages to populate the dictionary
library("mlr3verse")
mlr_learners
## <DictionaryLearner> with 135 stored values
## Keys: classif.AdaBoostM1, classif.bart, classif.C50, classif.catboost,
##   classif.cforest, classif.ctree, classif.cv_glmnet, classif.debug,
##   classif.earth, classif.extratrees, classif.featureless, classif.fnn,
##   classif.gam, classif.gamboost, classif.gausspr, classif.gbm,
##   classif.glmboost, classif.glmnet, classif.IBk, classif.J48,
##   classif.JRip, classif.kknn, classif.ksvm, classif.lda,
##   classif.liblinear, classif.lightgbm, classif.LMT, classif.log_reg,
##   classif.lssvm, classif.mob, classif.multinom, classif.naive_bayes,
##   classif.nnet, classif.OneR, classif.PART, classif.qda,
##   classif.randomForest, classif.ranger, classif.rfsrc, classif.rpart,
##   classif.svm, classif.xgboost, clust.agnes, clust.ap, clust.cmeans,
##   clust.cobweb, clust.dbscan, clust.diana, clust.em, clust.fanny,
##   clust.featureless, clust.ff, clust.hclust, clust.kkmeans,
##   clust.kmeans, clust.MBatchKMeans, clust.meanshift, clust.pam,
##   clust.SimpleKMeans, clust.xmeans, dens.hist, dens.kde, dens.kde_kd,
##   dens.kde_ks, dens.locfit, dens.logspline, dens.mixed, dens.nonpar,
##   dens.pen, dens.plug, dens.spline, regr.bart, regr.catboost,
##   regr.cforest, regr.ctree, regr.cubist, regr.cv_glmnet, regr.earth,
##   regr.extratrees, regr.featureless, regr.fnn, regr.gam, regr.gamboost,
##   regr.gausspr, regr.gbm, regr.glm, regr.glmboost, regr.glmnet,
##   regr.IBk, regr.kknn, regr.km, regr.ksvm, regr.liblinear,
##   regr.lightgbm, regr.lm, regr.M5Rules, regr.mars, regr.mob,
##   regr.randomForest, regr.ranger, regr.rfsrc, regr.rpart, regr.rvm,
##   regr.svm, regr.xgboost, surv.akritas, surv.blackboost, surv.cforest,
##   surv.coxboost, surv.coxph, surv.coxtime, surv.ctree,
##   surv.cv_coxboost, surv.cv_glmnet, surv.deephit, surv.deepsurv,
##   surv.dnnsurv, surv.flexible, surv.gamboost, surv.gbm, surv.glmboost,
##   surv.glmnet, surv.kaplan, surv.loghaz, surv.mboost, surv.nelson,
##   surv.obliqueRSF, surv.parametric, surv.pchazard, surv.penalized,
##   surv.ranger, surv.rfsrc, surv.rpart, surv.svm, surv.xgboost

To obtain an object from the dictionary you can also use the shortcut function lrn() or the generic mlr_learners$get() method, e.g. lrn("classif.rpart").

2.3.2 Learner API

Each learner provides the following meta-information:

  • feature_types: the type of features the learner can deal with.
  • packages: the packages required to train a model with this learner and make predictions.
  • properties: additional properties and capabilities. For example, a learner has the property “missings” if it is able to handle missing feature values, and “importance” if it computes and allows to extract data on the relative importance of the features. A complete list of these is available in the mlr3 reference.
  • predict_types: possible prediction types. For example, a classification learner can predict labels (“response”) or probabilities (“prob”). For a complete list of possible predict types see the mlr3 reference.

You can retrieve a specific learner using its ID:

learner = lrn("classif.rpart")
print(learner)
## <LearnerClassifRpart:classif.rpart>
## * Model: -
## * Parameters: xval=0
## * Packages: rpart
## * Predict Type: response
## * Feature types: logical, integer, numeric, factor, ordered
## * Properties: importance, missings, multiclass, selected_features,
##   twoclass, weights

Each learner has hyperparameters that control its behavior, for example the minimum number of samples in the leaf of a decision tree, or whether to provide verbose output durning training. Setting hyperparameters to values appropriate for a given machine learning task is crucial. The field param_set stores a description of the hyperparameters the learner has, their ranges, defaults, and current values:

learner$param_set
## <ParamSet>
##                 id    class lower upper nlevels        default value
##  1:             cp ParamDbl     0     1     Inf           0.01      
##  2:     keep_model ParamLgl    NA    NA       2          FALSE      
##  3:     maxcompete ParamInt     0   Inf     Inf              4      
##  4:       maxdepth ParamInt     1    30      30             30      
##  5:   maxsurrogate ParamInt     0   Inf     Inf              5      
##  6:      minbucket ParamInt     1   Inf     Inf <NoDefault[3]>      
##  7:       minsplit ParamInt     1   Inf     Inf             20      
##  8: surrogatestyle ParamInt     0     1       2              0      
##  9:   usesurrogate ParamInt     0     2       3              2      
## 10:           xval ParamInt     0   Inf     Inf             10     0

The set of current hyperparameter values is stored in the values field of the param_set field. You can change the current hyperparameter values by assigning a named list to this field:

learner$param_set$values = list(cp = 0.01, xval = 0)
learner
## <LearnerClassifRpart:classif.rpart>
## * Model: -
## * Parameters: cp=0.01, xval=0
## * Packages: rpart
## * Predict Type: response
## * Feature types: logical, integer, numeric, factor, ordered
## * Properties: importance, missings, multiclass, selected_features,
##   twoclass, weights

Note that this operation overwrites all previously set parameters. You can also get the current set of hyperparameter values, modify it, and write it back to the learner:

pv = learner$param_set$values
pv$cp = 0.02
learner$param_set$values = pv

This sets cp to 0.02 but keeps any other values that were set previously.

Note that the lrn() function also accepts additional arguments to update hyperparameters or set fields of the learner in one go:

learner = lrn("classif.rpart", id = "rp", cp = 0.001)
learner$id
## [1] "rp"
learner$param_set$values
## $xval
## [1] 0
## 
## $cp
## [1] 0.001

2.3.2.1 Thresholding

Models trained on binary classification tasks that predict the probability for the positive class usually use a simple rule to determine the predicted class label: if the probability is more than 50%, predict the positive label, otherwise predict the negative label. In some cases you may want to adjust this threshold, for example if the classes are very unbalanced (i.e. one is much more prevalent than the other).

In the example below, we change the threshold to 0.2, which improves the True Positive Rate (TPR). Note that while the new threshold classifies more observations from the positive class correctly, the True Negative Rate (TNR) decreases. Depending on the application, this may or may not be desired.

data("Sonar", package = "mlbench")
task = as_task_classif(Sonar, target = "Class", positive = "M")
learner = lrn("classif.rpart", predict_type = "prob")
pred = learner$train(task)$predict(task)

measures = msrs(c("classif.tpr", "classif.tnr")) # use msrs() to get a list of multiple measures
pred$confusion
##         truth
## response  M  R
##        M 95 10
##        R 16 87
pred$score(measures)
## classif.tpr classif.tnr 
##      0.8559      0.8969
pred$set_threshold(0.2)
pred$confusion
##         truth
## response   M   R
##        M 104  25
##        R   7  72
pred$score(measures)
## classif.tpr classif.tnr 
##      0.9369      0.7423

Thresholds can be tuned automatically with respect to a performance measure with the mlr3pipelines package, i.e. using PipeOpTuneThreshold.

2.4 Train, Predict, Assess Performance

In this section, we explain how tasks and learners can be used to train a model and predict on a new dataset. The concept is demonstrated on a supervised classification task using the penguins dataset and the rpart learner, which builds a single classification tree.

Training a learner means fitting a model to a given data set – essentially an optimization problem that determines the best parameters (not hyperparameters!) of the model given the data. We then predict the label for observations that the model has not seen during training. These predictions are compared to the ground truth values in order to assess the predictive performance of the model.

2.4.1 Creating Task and Learner Objects

First of all, we load the mlr3verse package, which will load all other packages we need here.

Now, we retrieve the task and the learner from mlr_tasks (with shortcut tsk()) and mlr_learners (with shortcut lrn()), respectively:

task = tsk("penguins")
learner = lrn("classif.rpart")

2.4.2 Setting up the train/test splits of the data

It is common to train on a majority of the data, to give the learner a better chance of fitting a good model. Here we use 80% of all available observations to train and predict on the remaining 20%. For this purpose, we create two index vectors:

train_set = sample(task$nrow, 0.8 * task$nrow)
test_set = setdiff(seq_len(task$nrow), train_set)

In Section 3.2 we will learn how mlr3 can automatically create training and test sets based on different resampling strategies.

2.4.3 Training the learner

The field $model stores the model that is produced in the training step. Before the $train() method is called on a learner object, this field is NULL:

learner$model
## NULL

Now we fit the classification tree using the training set of the task by calling the $train() method of learner:

learner$train(task, row_ids = train_set)

This operation modifies the learner in-place by adding the fitted model to the existing object. We can now access the stored model via the field $model:

print(learner$model)
## n= 275 
## 
## node), split, n, loss, yval, (yprob)
##       * denotes terminal node
## 
## 1) root 275 158 Adelie (0.425455 0.210909 0.363636)  
##   2) flipper_length< 206.5 171  56 Adelie (0.672515 0.321637 0.005848)  
##     4) bill_length< 43.15 113   3 Adelie (0.973451 0.026549 0.000000) *
##     5) bill_length>=43.15 58   6 Chinstrap (0.086207 0.896552 0.017241) *
##   3) flipper_length>=206.5 104   5 Gentoo (0.019231 0.028846 0.951923) *

Inspecting the output, we see that the learner has identified features in the task that are predictive of the class (the type of penguin) and uses them to partition observations in the tree. There are additional details on how the data is partitioned across branches of the tree; the textual representation of the model depends on the type of learner. For more information on this particular type of model, see rpart::print.rpart().

2.4.4 Predicting

After the model has been fitted to the training data, we use the test set for prediction. Remember that we initially split the data in train_set and test_set.

prediction = learner$predict(task, row_ids = test_set)
print(prediction)
## <PredictionClassif> for 69 observations:
##     row_ids     truth  response
##           2    Adelie    Adelie
##           3    Adelie    Adelie
##           4    Adelie    Adelie
## ---                            
##         319 Chinstrap Chinstrap
##         330 Chinstrap Chinstrap
##         340 Chinstrap    Gentoo

The $predict() method of the Learner returns a Prediction object. More precisely, a LearnerClassif returns a PredictionClassif object.

A prediction objects holds the row IDs of the test data, the respective true label of the target column and the respective predictions. The simplest way to extract this information is by converting the Prediction object to a data.table():

head(as.data.table(prediction)) # show first six predictions
##    row_ids  truth response
## 1:       2 Adelie   Adelie
## 2:       3 Adelie   Adelie
## 3:       4 Adelie   Adelie
## 4:       5 Adelie   Adelie
## 5:       8 Adelie   Adelie
## 6:       9 Adelie   Adelie

For classification, you can also extract the confusion matrix:

prediction$confusion
##            truth
## response    Adelie Chinstrap Gentoo
##   Adelie        34         1      0
##   Chinstrap      1         7      0
##   Gentoo         0         2     24

The confusion matrix shows, for each class, how many observations were predicted to be in that class and how many were actually in it (more information on Wikipedia). The entries along the diagonal denote the correctly classified observations. In this case, we can see that our classifier is really quite good and correctly predicting almost all observations.

2.4.5 Changing the Predict Type

Classification learners default to predicting the class label. However, many classifiers additionally also tell you how sure they are about the predicted label by providing posterior probabilities for the classes. To predict these probabilities, the predict_type field of a LearnerClassif must be changed from "response" (the default) to "prob" before training:

learner$predict_type = "prob"

# re-fit the model
learner$train(task, row_ids = train_set)

# rebuild prediction object
prediction = learner$predict(task, row_ids = test_set)

The prediction object now contains probabilities for all class labels in addition to the predicted label (the one with the highest probability):

# data.table conversion
head(as.data.table(prediction)) # show first six
##    row_ids  truth response prob.Adelie prob.Chinstrap prob.Gentoo
## 1:       2 Adelie   Adelie      0.9735        0.02655           0
## 2:       3 Adelie   Adelie      0.9735        0.02655           0
## 3:       4 Adelie   Adelie      0.9735        0.02655           0
## 4:       5 Adelie   Adelie      0.9735        0.02655           0
## 5:       8 Adelie   Adelie      0.9735        0.02655           0
## 6:       9 Adelie   Adelie      0.9735        0.02655           0
# directly access the predicted labels:
head(prediction$response)
## [1] Adelie Adelie Adelie Adelie Adelie Adelie
## Levels: Adelie Chinstrap Gentoo
# directly access the matrix of probabilities:
head(prediction$prob)
##      Adelie Chinstrap Gentoo
## [1,] 0.9735   0.02655      0
## [2,] 0.9735   0.02655      0
## [3,] 0.9735   0.02655      0
## [4,] 0.9735   0.02655      0
## [5,] 0.9735   0.02655      0
## [6,] 0.9735   0.02655      0

Similarly to predicting probabilities for classification, many regression learners support the extraction of standard error estimates for predictions by setting the predict type to "se".

2.4.6 Plotting Predictions

Similarly to plotting tasks, mlr3viz provides an autoplot() method for Prediction objects. All available types are listed in the manual pages for autoplot.PredictionClassif(), autoplot.PredictionRegr() and the other prediction types (defined by extension packages).

task = tsk("penguins")
learner = lrn("classif.rpart", predict_type = "prob")
learner$train(task)
prediction = learner$predict(task)
autoplot(prediction)

2.4.7 Performance assessment

The last step of modeling is usually assessing the performance of the trained model. We have already had a look at this with the confusion matrix, but it is often convenient to quantify the performance of a model with a single number. The exact nature of this comparison is defined by a measure, which is given by a Measure object. Note that if the prediction was made on a dataset without the target column, i.e. without known true labels, then no performance can be calculated.

Available measures are stored in mlr_measures (with convenience getter function msr()):

mlr_measures
## <DictionaryMeasure> with 85 stored values
## Keys: aic, bic, classif.acc, classif.auc, classif.bacc, classif.bbrier,
##   classif.ce, classif.costs, classif.dor, classif.fbeta, classif.fdr,
##   classif.fn, classif.fnr, classif.fomr, classif.fp, classif.fpr,
##   classif.logloss, classif.mbrier, classif.mcc, classif.npv,
##   classif.ppv, classif.prauc, classif.precision, classif.recall,
##   classif.sensitivity, classif.specificity, classif.tn, classif.tnr,
##   classif.tp, classif.tpr, clust.ch, clust.db, clust.dunn,
##   clust.silhouette, clust.wss, debug, dens.logloss, oob_error,
##   regr.bias, regr.ktau, regr.mae, regr.mape, regr.maxae, regr.medae,
##   regr.medse, regr.mse, regr.msle, regr.pbias, regr.rae, regr.rmse,
##   regr.rmsle, regr.rrse, regr.rse, regr.rsq, regr.sae, regr.smape,
##   regr.srho, regr.sse, selected_features, surv.brier, surv.calib_alpha,
##   surv.calib_beta, surv.chambless_auc, surv.cindex, surv.dcalib,
##   surv.graf, surv.hung_auc, surv.intlogloss, surv.logloss, surv.mae,
##   surv.mse, surv.nagelk_r2, surv.oquigley_r2, surv.rmse, surv.schmid,
##   surv.song_auc, surv.song_tnr, surv.song_tpr, surv.uno_auc,
##   surv.uno_tnr, surv.uno_tpr, surv.xu_r2, time_both, time_predict,
##   time_train

We choose accuracy (classif.acc) as our specific performance measure here and call the method $score() of the prediction object to quantify the predictive performance of our model.

measure = msr("classif.acc")
print(measure)
## <MeasureClassifSimple:classif.acc>
## * Packages: mlr3measures
## * Range: [0, 1]
## * Minimize: FALSE
## * Parameters: list()
## * Properties: -
## * Predict type: response
prediction$score(measure)
## classif.acc 
##      0.9651

Note that if no measure is specified, classification defaults to classification error (classif.ce, the inverse of accuracy) and regression to the mean squared error (regr.mse).