2  Basics

This chapter will teach you the essential building blocks of mlr3, as well as its R6 classes and operations used for machine learning.

How these building blocks interoperate is summarized in the following figure.

The data, which mlr3 encapsulates in tasks, is split into non-overlapping training and test sets. As we are interested in models that extrapolate to new data rather than just memorizing the training data, the separate test data allows us to objectively evaluate models with respect to their generalization. The training data is given to a machine learning algorithm, which we call a learner in mlr3. The learner uses the training data to build a model of the relationship of the input features to the output target values. This model is then used to produce predictions on the test data, which are compared to the ground truth values to assess the quality of the model. mlr3 offers a number of different measures to quantify how well a model performs based on the difference between predicted and actual values. Usually, this measure is a numeric score.

Splitting data into training and test sets, building a model, and evaluating it can be repeated several times, resampling different training and test sets from the original data each time. Multiple resampling iterations allow us to get a better and less biased generalizable performance estimate for a particular type of model. As data are usually partitioned randomly into training and test sets, a single split can, for example, produce training and test sets that are very different, hence creating the misleading impression that the particular type of model does not perform well.

Note

Real-world problems usually require preprocessing operations such as normalization or imputation of missing values. In the sketched workflow above, these steps would also be part of the learner. This will be covered in the chapter (pipelines?).

This chapter covers the following topics:

  1. Tasks encapsulate the data with meta-information, such as the name of the prediction target column. We cover how to:

  2. Learners encapsulate machine learning algorithms to train models and make predictions for a task. Other packages provide these. We cover how to:

  3. How to train and predict. In particular, we cover how to:

2.1 Tasks

Tasks are objects that contain the (usually tabular) data and additional meta-data to define a machine learning problem. The meta-data is, for example, the name of the target variable for supervised machine learning problems, or the type of the dataset (e.g. a spatial or survival task). This information is used by specific operations that can be performed on a task.

2.1.1 Task Types

To create a task object, you first need to choose the right task type:

  • Classification Task: The target is a label (stored as character or factor) with only relatively few distinct values → TaskClassif.

  • Regression Task: The target is a numeric quantity (stored as integer or numeric) → TaskRegr.

  • Survival Task: The target is the (right-censored) time to an event. More censoring types are currently in development → mlr3proba::TaskSurv in add-on package mlr3proba.

  • Density Task: An unsupervised task to estimate the density → mlr3proba::TaskDens in add-on package mlr3proba.

  • Cluster Task: An unsupervised task type; there is no target and the aim is to identify similar groups within the feature space → mlr3cluster::TaskClust in add-on package mlr3cluster.

  • Spatial Task: Observations in the task have spatio-temporal information (e.g. coordinates) → mlr3spatiotempcv::TaskRegrST or mlr3spatiotempcv::TaskClassifST in add-on package mlr3spatiotempcv.

  • Ordinal Regression Task: The target is ordinal → TaskOrdinal in add-on package mlr3ordinal (still in development).

2.1.2 Task Creation

As an example, we will create a regression task using the mtcars data set from package datasets (ships with R). It contains characteristics for different types of cars, along with their fuel consumption. We predict the numeric target variable stored in column "mpg" (miles per gallon). Here, we only consider the first two features in the dataset for brevity:

data("mtcars", package = "datasets")
data = mtcars[, 1:3]
str(data)
'data.frame':   32 obs. of  3 variables:
 $ mpg : num  21 21 22.8 21.4 18.7 18.1 14.3 24.4 22.8 19.2 ...
 $ cyl : num  6 6 4 6 8 6 8 4 4 6 ...
 $ disp: num  160 160 108 258 360 ...

Next, we create a regression task, i.e. we construct a new instance of the R6 class TaskRegr. Formally, the intended way to initialize an R6 object is to call the constructor TaskRegr$new(). Here instead, we are calling the converter as_task_regr() to convert our data.frame() stored in the object data to a regression task and provide the following additional information:

  1. x: Object to convert. Works for rectangular data formats such as data.frame(), data.table(), or tibble(). Internally, the data is converted and stored in an abstract DataBackend. This allows connecting to out-of-memory storage systems like SQL servers via the extension package mlr3db.
  2. target: The name of the prediction target column for the regression problem, here miles per gallon ("mpg").
  3. id (optional): An arbitrary identifier for the task, used in plots and summaries. If not provided, the deparsed name of x will be used.
library("mlr3")

task_mtcars = as_task_regr(data, target = "mpg", id = "cars")
print(task_mtcars)
<TaskRegr:cars> (32 x 3)
* Target: mpg
* Properties: -
* Features (2):
  - dbl (2): cyl, disp

The print() method gives a short summary of the task: It has 32 observations and 3 columns, of which 2 are features stored in double-precision floating point format.

We can also plot the task using the mlr3viz package, which gives a graphical summary of its properties:

library("mlr3viz")
autoplot(task_mtcars, type = "pairs")

Tip

Instead of loading multiple extension packages individually, it is often more convenient to load the mlr3verse package instead. mlr3verse imports the Namespace of most mlr3 packages and re-exports functions which are used for common machine learning and data science tasks.

2.1.3 Predefined tasks

mlr3 includes a few predefined machine learning tasks. All tasks are stored in an R6 Dictionary (a key-value store) named mlr_tasks. Printing it gives the keys (the names of the datasets):

mlr_tasks
<DictionaryTask> with 11 stored values
Keys: boston_housing, breast_cancer, german_credit, iris, mtcars,
  penguins, pima, sonar, spam, wine, zoo

We can get a more informative summary of the example tasks by converting the dictionary to a data.table() object:

as.data.table(mlr_tasks)
               key                   label task_type nrow ncol properties lgl
 1: boston_housing   Boston Housing Prices      regr  506   19              0
 2:  breast_cancer Wisconsin Breast Cancer   classif  683   10   twoclass   0
 3:  german_credit           German Credit   classif 1000   21   twoclass   0
 4:           iris            Iris Flowers   classif  150    5 multiclass   0
 5:         mtcars            Motor Trends      regr   32   11              0
 6:       penguins         Palmer Penguins   classif  344    8 multiclass   0
 7:           pima    Pima Indian Diabetes   classif  768    9   twoclass   0
 8:          sonar  Sonar: Mines vs. Rocks   classif  208   61   twoclass   0
 9:           spam       HP Spam Detection   classif 4601   58   twoclass   0
10:           wine            Wine Regions   classif  178   14 multiclass   0
11:            zoo             Zoo Animals   classif  101   17 multiclass  15
6 variables not shown: [int, dbl, chr, fct, ord, pxc]

Above, the columns "lgl" (logical), "int" (integer), "dbl" (double), "chr" (character), "fct" (factor), "ord" (ordered factor) and "pxc" (POSIXct time) show the number of features in the task of the respective type.

To get a task from the dictionary, use the $get() method from the mlr_tasks class and assign the return value to a new variable As getting a task from a dictionary is a very common problem, mlr3 provides the shortcut function tsk(). Here, we retrieve the palmer penguins classification task, which is provided by the imported package palmerpenguins:

task_penguins = tsk("penguins")
print(task_penguins)
<TaskClassif:penguins> (344 x 8): Palmer Penguins
* Target: species
* Properties: multiclass
* Features (7):
  - int (3): body_mass, flipper_length, year
  - dbl (2): bill_depth, bill_length
  - fct (2): island, sex
Note

Loading extension packages can add elements to dictionaries such as mlr_tasks. For example, mlr3data adds some more example and toy tasks for regression and classification, and mlr3proba adds survival and density estimation tasks.

To get more information about a particular task, it is easiest to use the help() method that all mlr3-objects come with:

task_penguins$help()

Alternatively, the corresponding man page can be found under mlr_tasks_[id], e.g. mlr_tasks_penguins:

help("mlr_tasks_penguins")
Tip

Thousands more data sets are readily available via Openml.org ((Vanschoren et al. 2013)) and mlr3oml. E.g., to download the data set “credit-g” with data id 31 and automatically convert it to a classification task:

library("mlr3oml")
tsk("oml", task_id = 31)

2.1.4 Task API

All properties and characteristics of tasks can be queried using the task’s public fields and methods (see Task). Methods can also be used to change the stored data and the behavior of the task.

2.1.4.1 Retrieving Data

The Task object primarily represents a tabular dataset, combined with meta-data about which columns of that data should be used to predict which other columns in what way, as well as some more information about column data types.

Various fields can be used to retrieve meta-data about a task. The dimensions can, for example, be retrieved using $nrow and $ncol:

task_mtcars$nrow
[1] 32
task_mtcars$ncol
[1] 3

The names of the feature and target columns are stored in the $feature_names and $target_names slots, respectively. Here, “target” refers to the variable we want to predict and “feature” to the predictors for the task.

task_mtcars$feature_names
[1] "cyl"  "disp"
task_mtcars$target_names
[1] "mpg"

For the most common tasks, regression and classification, the target will only be the name of a single column. Tasks with other task types, such as for survival estimation, may have more than one target column while clustering tasks have no target at all:

requireNamespace("mlr3proba", quietly = TRUE)
tsk("unemployment")$target_names
[1] "spell"   "censor1"
requireNamespace("mlr3cluster", quietly = TRUE)
tsk("usarrests")$target_names
character(0)

While the columns of a task have unique character-valued names, their rows are identified by unique natural numbers, their row-IDs. They can be accessed through the $row_ids slot:

head(task_mtcars$row_ids)
[1] 1 2 3 4 5 6
Warning

Although the row IDs are typically just the sequence from 1 to nrow(data), they are only guaranteed to be unique natural numbers. It is possible that they do not start at 1, that they are not increasing by 1 each, or that they are not even in increasing order. The reasoning behind this is simple: we allow to transparently operate on real database management systems, and the uniqueness is the only requirement for primary keys in data bases. For more info on connecting to data bases, see backends.

The data contained in a task can be accessed through $data(), which returns a data.table object. It has optional rows and cols arguments to specify subsets of the data to retrieve. When a database backend is used, then this avoids loading unnecessary data into memory, making it more efficient than retrieving the entire data first and then subsetting it using [<rows>, <cols>].

task_mtcars$data()
     mpg cyl  disp
 1: 21.0   6 160.0
 2: 21.0   6 160.0
 3: 22.8   4 108.0
 4: 21.4   6 258.0
 5: 18.7   8 360.0
---               
28: 30.4   4  95.1
29: 15.8   8 351.0
30: 19.7   6 145.0
31: 15.0   8 301.0
32: 21.4   4 121.0
# retrieve data for rows with ids 1, 5, and 10 and select column "mpg"
task_mtcars$data(rows = c(1, 5, 10), cols = "mpg")
    mpg
1: 21.0
2: 18.7
3: 19.2

To extract the complete data from the task, one can also convert it to a data.table:

# show summary of entire data
summary(as.data.table(task_mtcars))
      mpg             cyl             disp      
 Min.   :10.40   Min.   :4.000   Min.   : 71.1  
 1st Qu.:15.43   1st Qu.:4.000   1st Qu.:120.8  
 Median :19.20   Median :6.000   Median :196.3  
 Mean   :20.09   Mean   :6.188   Mean   :230.7  
 3rd Qu.:22.80   3rd Qu.:8.000   3rd Qu.:326.0  
 Max.   :33.90   Max.   :8.000   Max.   :472.0  

2.1.4.2 Task Mutators

It is often necessary to create tasks that encompass subsets of other tasks’ data, for example to manually create train-test-splits, or to fit models on a subset of given features. Restricting tasks to a given set of features can be done by calling $select() with the desired feature names. Restriction to rows is done with $filter() with the row-IDs.

task_penguins_small = tsk("penguins")
task_penguins_small$select(c("body_mass", "flipper_length")) # keep only these features
task_penguins_small$filter(2:4) # keep only these rows
task_penguins_small$data()
   species body_mass flipper_length
1:  Adelie      3800            186
2:  Adelie      3250            195
3:  Adelie        NA             NA

These methods are so-called mutators, they modify the given Task in-place. If you want to have an unmodified version of the task, you need to use the $clone() method to create a copy first.

task_penguins_smaller = task_penguins_small$clone()
task_penguins_smaller$filter(2)
task_penguins_smaller$data()
   species body_mass flipper_length
1:  Adelie      3800            186
task_penguins_small$data()  # this task is unmodified
   species body_mass flipper_length
1:  Adelie      3800            186
2:  Adelie      3250            195
3:  Adelie        NA             NA

Note also how the last call to $filter(2) did not select the second row of the task_penguins_small, but selected the row with ID 2, which is the first row of task_penguins_small.

Tip

If you ever really need to work with row numbers instead of row-IDs, you can work-around by operating on the row ids and pass the result back to the task:

# keep the 2nd row:
keep = task$row_ids[2] # extracts id of 2nd row
task_penguins_smaller$filter(keep)

While the methods above allow us to subset the data, the methods $rbind() and $cbind() allow adding extra rows and columns to a task.

task_penguins_smaller$rbind( # add another row
  data.frame(body_mass = 1e9, flipper_length = 1e9, species = "GigaPeng")
)
task_penguins_smaller$cbind(data.frame(letters = letters[2:3])) # add column with letters
task_penguins_smaller$data()
    species  body_mass flipper_length letters
1:   Adelie       3800            186       b
2: GigaPeng 1000000000     1000000000       c

2.1.4.3 Roles (Rows and Columns)

We have seen that certain columns are designated as “targets” and “features” during task creation, their “roles”: Target refers to the variable(s) we want to predict and features are the predictors (also called co-variates) for the target. Besides these two, there are other possible roles for columns, see the documentation of Task. These roles affect the behavior of the task for different operations.

The previously-constructed task_penguins_small task, for example, has the following column roles:

task_penguins_small$col_roles
$feature
[1] "body_mass"      "flipper_length"

$target
[1] "species"

$name
character(0)

$order
character(0)

$stratum
character(0)

$group
character(0)

$weight
character(0)

Columns can have multiple roles. It is also possible for a column to have no role at all, in which case they are ignored. This is, in fact, how $select() and $filter() operate: They unassign the "feature" (for columns) or "use" (for rows) role without modifying the data which is stored in an immutable backend:

task_penguins_small$backend
<DataBackendDataTable> (344x9)
 species    island bill_length bill_depth flipper_length body_mass    sex year
  Adelie Torgersen        39.1       18.7            181      3750   male 2007
  Adelie Torgersen        39.5       17.4            186      3800 female 2007
  Adelie Torgersen        40.3       18.0            195      3250 female 2007
  Adelie Torgersen          NA         NA             NA        NA   <NA> 2007
  Adelie Torgersen        36.7       19.3            193      3450 female 2007
  Adelie Torgersen        39.3       20.6            190      3650   male 2007
1 variable not shown: [..row_id]
[...] (338 rows omitted)

There are two main ways to manipulate the col roles of a Task:

  1. Use the Task method $set_col_roles() (recommended).
  2. Simply modify the field $col_roles, which is a named list of vectors of column names. Each vector in this list corresponds to a column role, and the column names contained in that vector have that role.

Just as $select()/$filter(), these are in-place operations, so the task object itself is modified. To retain another unmodified version of a task, use $clone().

Changing the column or row roles, whether by $select()/$filter() or directly, does not change the underlying data, it just updates the view on it. Because the underlying data is still there (and accessible through $backend), we can add the "bill_length" column back into the task by setting its col role to "feature".

task_penguins_small$set_col_roles("bill_length", roles = "feature")
task_penguins_small$feature_names  # bill_length is now a feature again
[1] "body_mass"      "flipper_length" "bill_length"   
task_penguins_small$data()
   species body_mass flipper_length bill_length
1:  Adelie      3800            186        39.5
2:  Adelie      3250            195        40.3
3:  Adelie        NA             NA          NA

Supported column roles can be found in the manual of Task, or just by printing the names of the field $col_roles:

# supported column roles, see ?Task
names(task_penguins_small$col_roles)
[1] "feature" "target"  "name"    "order"   "stratum" "group"   "weight" 

Just like columns, it is also possible to assign different roles to rows. Rows can have two different roles:

  1. Role use: Rows that are generally available for model fitting (although they may also be used as test set in resampling). This role is the default role. The $filter() call changes this role, in the same way that $select() changes the "feature" role.
  2. Role validation: Rows that are not used for training. Rows that have missing values in the target column during task creation are automatically set to the validation role.

There are several reasons to hold some observations back or treat them differently:

  1. It is often good practice to validate the final model on an external validation set to identify possible overfitting.
  2. Some observations may be unlabeled, e.g. in competitions like Kaggle.

These observations cannot be used for training a model, but can be used to get predictions.

2.1.5 Task API Extensions

While the previous section described (a subset of) the API all tasks have in common, some tasks come with additional getters or setters.

For example, classification problems with a target variable with only two classes are called binary classification tasks. They are special in the sense that one of these classes is denoted positive and the other one negative. You can specify the positive class within the classification task object during task creation. If not explicitly set during construction, the positive class defaults to the first level of the target variable.

# during construction
data("Sonar", package = "mlbench")
task = as_task_classif(Sonar, target = "Class", positive = "R")

# switch positive class to level 'M'
task$positive = "M"

2.1.6 Plotting Tasks

The mlr3viz package provides plotting facilities for many classes implemented in mlr3. The available plot types depend on the class, but all plots are returned as ggplot2 objects which can be easily customized.

For classification tasks (inheriting from TaskClassif), see the documentation of mlr3viz::autoplot.TaskClassif for the implemented plot types. Here are some examples to get an impression:

library("mlr3viz")

# get the pima indians task
task = tsk("pima")

# subset task to only use the 3 first features
task$select(head(task$feature_names, 3))

# default plot: class frequencies
autoplot(task)

# pairs plot (requires package GGally)
autoplot(task, type = "pairs")

# duo plot (requires package GGally)
autoplot(task, type = "duo")

Of course, you can do the same for regression tasks (inheriting from TaskRegr) as documented in mlr3viz::autoplot.TaskRegr:

library("mlr3viz")

# get the complete mtcars task
task = tsk("mtcars")

# subset task to only use the 3 first features
task$select(head(task$feature_names, 3))

# default plot: boxplot of target variable
autoplot(task)

# pairs plot (requires package GGally)
autoplot(task, type = "pairs")

2.2 Learners

Objects of class Learner provide a unified interface to many popular machine learning algorithms in R. They consist of methods to train and predict a model for a Task and provide meta-information about the learners, such as the hyperparameters (which control the behavior of the learner) you can set.

The base class of each learner is Learner, specialized for regression as LearnerRegr and for classification as LearnerClassif. Other types of learners, provided by extension packages, also inherit from the Learner base class, e.g. mlr3proba::LearnerSurv or mlr3cluster::LearnerClust.

All Learners work in a two-stage procedure:

  • Training stage: The training data (features and target) is passed to the Learner’s $train() function which trains and stores a model, i.e. the relationship of the target and features.
  • Predict stage: The new data, usually a different slice of the original data than used for training, is passed to the $predict() method of the Learner. The model trained in the first step is used to predict the missing target, e.g. labels for classification problems or the numerical value for regression problems.

2.2.1 Predefined Learners

The mlr3 package ships with the following set of classification and regression learners. We deliberately keep this small to avoid unnecessary dependencies:

This set of baseline learners is usually insufficient for a real data analysis. Thus, we have cherry-picked implementations of the most popular machine learning method and collected them in the mlr3learners package:

More machine learning methods and alternative implementations are collected in the mlr3extralearners repository.

Tip

A full list of available learners across all mlr3 packages is hosted on our website: list of learners.

Analogously to mlr_tasks storing the shipped taks, the dictionary mlr_learners stores implemented learners.

library("mlr3learners")       # load recommended learners provided by mlr3learners package
library("mlr3extralearners")  # this loads further less-well-supported learners
library("mlr3proba")          # this loads some survival and density estimation learners
library("mlr3cluster")        # this loads some learners for clustering

mlr_learners
<DictionaryLearner> with 137 stored values
Keys: classif.AdaBoostM1, classif.bart, classif.C50, classif.catboost,
  classif.cforest, classif.ctree, classif.cv_glmnet, classif.debug,
  classif.earth, classif.featureless, classif.fnn, classif.gam,
  classif.gamboost, classif.gausspr, classif.gbm, classif.glmboost,
  classif.glmnet, classif.IBk, classif.J48, classif.JRip, classif.kknn,
  classif.ksvm, classif.lda, classif.liblinear, classif.lightgbm,
  classif.LMT, classif.log_reg, classif.lssvm, classif.mob,
  classif.multinom, classif.naive_bayes, classif.nnet, classif.OneR,
  classif.PART, classif.qda, classif.randomForest, classif.ranger,
  classif.rfsrc, classif.rpart, classif.svm, classif.xgboost,
  clust.agnes, clust.ap, clust.cmeans, clust.cobweb, clust.dbscan,
  clust.diana, clust.em, clust.fanny, clust.featureless, clust.ff,
  clust.hclust, clust.kkmeans, clust.kmeans, clust.MBatchKMeans,
  clust.meanshift, clust.pam, clust.SimpleKMeans, clust.xmeans,
  dens.hist, dens.kde, dens.kde_ks, dens.locfit, dens.logspline,
  dens.mixed, dens.nonpar, dens.pen, dens.plug, dens.spline, regr.bart,
  regr.catboost, regr.cforest, regr.ctree, regr.cubist, regr.cv_glmnet,
  regr.debug, regr.earth, regr.featureless, regr.fnn, regr.gam,
  regr.gamboost, regr.gausspr, regr.gbm, regr.glm, regr.glmboost,
  regr.glmnet, regr.IBk, regr.kknn, regr.km, regr.ksvm, regr.liblinear,
  regr.lightgbm, regr.lm, regr.lmer, regr.M5Rules, regr.mars, regr.mob,
  regr.nnet, regr.randomForest, regr.ranger, regr.rfsrc, regr.rpart,
  regr.rsm, regr.rvm, regr.svm, regr.xgboost, surv.akritas, surv.aorsf,
  surv.blackboost, surv.cforest, surv.coxboost, surv.coxph,
  surv.coxtime, surv.ctree, surv.cv_coxboost, surv.cv_glmnet,
  surv.deephit, surv.deepsurv, surv.dnnsurv, surv.flexible,
  surv.gamboost, surv.gbm, surv.glmboost, surv.glmnet, surv.kaplan,
  surv.loghaz, surv.mboost, surv.nelson, surv.obliqueRSF,
  surv.parametric, surv.pchazard, surv.penalized, surv.ranger,
  surv.rfsrc, surv.rpart, surv.svm, surv.xgboost

To obtain an object from the dictionary, use the syntactic sugar function lrn():

learner = lrn("classif.rpart")

2.2.2 Learner API

Each learner provides the following meta-information:

  • $feature_types: the type of features the learner can deal with.
  • $packages: the packages required to train a model with this learner and make predictions.
  • $properties: additional properties and capabilities. For example, a learner has the property “missings” if it is able to handle missing feature values, and “importance” if it computes and allows to extract data on the relative importance of the features.
  • $predict_types: possible prediction types. For example, a classification learner can predict labels (“response”) or probabilities (“prob”).

This information can be queried through these slots, or seen at a glance from the printer:

print(learner)
<LearnerClassifRpart:classif.rpart>: Classification Tree
* Model: -
* Parameters: xval=0
* Packages: mlr3, rpart
* Predict Types:  [response], prob
* Feature Types: logical, integer, numeric, factor, ordered
* Properties: importance, missings, multiclass, selected_features,
  twoclass, weights

Furthermore, each learner has hyperparameters that control its behavior, for example the minimum number of samples in the leaf of a decision tree, or whether to provide verbose output durning training. Setting hyperparameters to values appropriate for a given machine learning task is crucial. The field param_set stores a description of the hyperparameters the learner has, their ranges, defaults, and current values:

learner$param_set
<ParamSet>
                id    class lower upper nlevels        default value
 1:             cp ParamDbl     0     1     Inf           0.01      
 2:     keep_model ParamLgl    NA    NA       2          FALSE      
 3:     maxcompete ParamInt     0   Inf     Inf              4      
 4:       maxdepth ParamInt     1    30      30             30      
 5:   maxsurrogate ParamInt     0   Inf     Inf              5      
 6:      minbucket ParamInt     1   Inf     Inf <NoDefault[3]>      
 7:       minsplit ParamInt     1   Inf     Inf             20      
 8: surrogatestyle ParamInt     0     1       2              0      
 9:   usesurrogate ParamInt     0     2       3              2      
10:           xval ParamInt     0   Inf     Inf             10     0

The set of current hyperparameter values is stored in the values field of the param_set field. You can access and change the current hyperparameter values by accessing this field, it is a named list:

learner$param_set$values
$xval
[1] 0
learner$param_set$values$cp = 0.01
learner$param_set$values
$xval
[1] 0

$cp
[1] 0.01
Tip

It is possible to assign all hyperparameters in one go by assigning a named list to $values: learner$param_set$values = list(cp = 0.01, xval = 0). However, be aware that this operation also removes all previously set hyperparameters.

The lrn() function also accepts additional arguments to update hyperparameters or set fields of the learner in one go:

learner = lrn("classif.rpart", id = "rp", cp = 0.001)
learner$id
[1] "rp"
learner$param_set$values
$xval
[1] 0

$cp
[1] 0.001

More on this is discussed in the section on Hyperparameter Tuning.

2.3 Train, Predict, Assess Performance

In this section, we explain how tasks and learners can be used to train a model and predict on a new dataset. Training a learner means fitting a model to a given data set – essentially, an optimization problem that determines the best parameters (not hyperparameters!) of the model given the data. We then predict the label for observations that the model has not seen during training. We will then compare the predictions to ground truth values to assess the quality of a prediction.

The concept is demonstrated on a supervised classification task using the pima dataset, in which patient data is used to diagnostically predict diabetes, and the rpart learner, which builds a classification tree. As shown in the previous chapters, we load these objects using the short access functions tsk() and lrn().

task = tsk("pima")
learner = lrn("classif.rpart")

2.3.1 Training the learner

The field $model stores the fitted model in the training step. Before the $train() method is called on a learner object, this field is NULL:

learner$model
NULL

Now we fit the classification tree using the training set of the task by calling the $train() method of Learner:

learner$train(task)

This operation modifies the learner in-place by adding the fitted model to the existing object. We can now access the stored model via the field $model:

learner$model
n= 768 

node), split, n, loss, yval, (yprob)
      * denotes terminal node

  1) root 768 268 neg (0.34895833 0.65104167)  
    2) glucose>=127.5 283 109 pos (0.61484099 0.38515901)  
      4) mass>=29.95 208  58 pos (0.72115385 0.27884615)  
        8) glucose>=157.5 92  12 pos (0.86956522 0.13043478) *
        9) glucose< 157.5 116  46 pos (0.60344828 0.39655172)  
         18) age>=30.5 66  19 pos (0.71212121 0.28787879) *
         19) age< 30.5 50  23 neg (0.46000000 0.54000000)  
           38) pressure< 73 21   8 pos (0.61904762 0.38095238) *
           39) pressure>=73 29  10 neg (0.34482759 0.65517241)  
             78) mass>=41.8 9   3 pos (0.66666667 0.33333333) *
             79) mass< 41.8 20   4 neg (0.20000000 0.80000000) *
      5) mass< 29.95 75  24 neg (0.32000000 0.68000000) *
    3) glucose< 127.5 485  94 neg (0.19381443 0.80618557)  
      6) age>=28.5 214  71 neg (0.33177570 0.66822430)  
       12) insulin>=142.5 56  26 neg (0.46428571 0.53571429)  
         24) age< 56.5 41  16 pos (0.60975610 0.39024390) *
         25) age>=56.5 15   1 neg (0.06666667 0.93333333) *
       13) insulin< 142.5 158  45 neg (0.28481013 0.71518987)  
         26) glucose>=99.5 102  41 neg (0.40196078 0.59803922)  
           52) mass>=26.35 84  41 neg (0.48809524 0.51190476)  
            104) pedigree>=0.2045 65  27 pos (0.58461538 0.41538462)  
              208) pregnant>=5.5 32   8 pos (0.75000000 0.25000000) *
              209) pregnant< 5.5 33  14 neg (0.42424242 0.57575758)  
                418) age>=34.5 19   7 pos (0.63157895 0.36842105) *
                419) age< 34.5 14   2 neg (0.14285714 0.85714286) *
            105) pedigree< 0.2045 19   3 neg (0.15789474 0.84210526) *
           53) mass< 26.35 18   0 neg (0.00000000 1.00000000) *
         27) glucose< 99.5 56   4 neg (0.07142857 0.92857143) *
      7) age< 28.5 271  23 neg (0.08487085 0.91512915) *

Inspecting the output, we see that the learner has identified features in the task that are predictive of the class (diabetes status) and uses them to partition observations in the tree. There are additional details on how the data is partitioned across branches of the tree; the textual representation of the model depends on the type of learner. For more information on this particular type of model and its output, see rpart::print.rpart().

2.3.2 Predicting

After the model has been fitted to the training data, we can now use it for prediction. A common case is that a model was fitted on all training data that was available, and should now be used to make predictions for new data for which the actual labels are unknown:

pima_new = data.table::fread("
age, glucose, insulin, mass, pedigree, pregnant, pressure, triceps
24,  145,     306,     41.7, 0.5,      3,        52,       36
47,  133,     NA,      23.3, 0.2,      7,        83,       28
")
pima_new
   age glucose insulin mass pedigree pregnant pressure triceps
1:  24     145     306 41.7      0.5        3       52      36
2:  47     133      NA 23.3      0.2        7       83      28

The learner does not need to know more meta-information about this data to make a prediction, such as which columns are features and targets, since this was already included in the training task. Instead, this data can directly be used to make a prediction using $predict_newdata():

prediction = learner$predict_newdata(pima_new)
prediction
<PredictionClassif> for 2 observations:
 row_ids truth response
       1  <NA>      pos
       2  <NA>      neg

This method returns a Prediction object. More precisely, because the learner is a LearnerClassif, it returns a PredictionClassif object. The easiest way to access information from it is to convert it to a data.table:

as.data.table(prediction)
   row_ids truth response
1:       1  <NA>      pos
2:       2  <NA>      neg

Here the "truth" column is NA, since the target column was not provided in the pima_new data frame. If we add the column, we will have the true and predicted labels side by side in the prediction object.

pima_new_known = cbind(pima_new, diabetes = factor("pos", levels = c("pos", "neg")))
prediction = learner$predict_newdata(pima_new_known)
prediction
<PredictionClassif> for 2 observations:
 row_ids truth response
       1   pos      pos
       2   pos      neg

Note that it is sometimes helpful first to convert the data to predict on a task. Predicting on the task’s data works analogously, you only need to call the $predict() method instead of $predict_newdata():

task_pima_new = as_task_classif(pima_new_known, target = "diabetes")
prediction = learner$predict(task_pima_new)
prediction
<PredictionClassif> for 2 observations:
 row_ids truth response
       1   pos      pos
       2   pos      neg

2.3.3 Changing the Predict Type

Classification learners default to predicting the class label. However, many classifiers also tell you how sure they are about the predicted label by providing posterior probabilities for the classes. To predict these probabilities, the predict_type field of a LearnerClassif must be changed from "response" (the default) to "prob" before training:

learner$predict_type = "prob"

# re-fit the model
learner$train(task)

# rebuild prediction object
prediction = learner$predict(task_pima_new)
prediction
<PredictionClassif> for 2 observations:
 row_ids truth response  prob.pos  prob.neg
       1   pos      pos 0.6190476 0.3809524
       2   pos      neg 0.3200000 0.6800000

The prediction object now contains probabilities for all class labels in addition to the predicted label (the one with the highest probability):

# directly access the predicted labels:
prediction$response
[1] pos neg
Levels: pos neg
# directly access the matrix of probabilities:
prediction$prob
           pos       neg
[1,] 0.6190476 0.3809524
[2,] 0.3200000 0.6800000
# data.table conversion
as.data.table(prediction)
   row_ids truth response  prob.pos  prob.neg
1:       1   pos      pos 0.6190476 0.3809524
2:       2   pos      neg 0.3200000 0.6800000

Similarly to predicting probabilities for classification, many regression learners support the extraction of standard error estimates for predictions by setting the predict type to "se".

2.3.4 Thresholding

Models trained on binary classification tasks that predict the probability for the positive class usually use a simple rule to determine the predicted class label: if the probability is more than 50%, predict the positive label, otherwise, predict the negative label. In some cases, you may want to adjust this threshold, for example, if the classes are very unbalanced (i.e., one is much more prevalent than the other).

In the example below, we change the threshold to 0.2, making the model predict "pos" for both example rows:

prediction$set_threshold(0.2)
prediction
<PredictionClassif> for 2 observations:
 row_ids truth response  prob.pos  prob.neg
       1   pos      pos 0.6190476 0.3809524
       2   pos      pos 0.3200000 0.6800000

2.3.5 Predicting on known data and train/test splits

We will usually not want to wait with performance evaluation until new data becomes available. Instead, we will work with all the training data available at a given point. However, when evaluating the performance of a Learner, it is also important to score predictions made on data that have not been seen during training, since making predictions on training data is too easy in general – a Learner could just memorize the training data responses and get a perfect score.

mlr3 makes it easy to only train on subsets of given tasks. We first create a vector indicating on what row IDs of the task the Learner should be trained, and another that indicates the remaining rows that should be used for prediction. These vectors indicate the train-test-split we are using. This is done manually here for demonstration purposes: In Section 3.2, we show how mlr3 can automatically create training and test sets based on resampling strategies that can be more elaborate.

We will use 67% of all available observations to train and predict on the remaining 33%.

set.seed(7)
train_set = sample(task$row_ids, 0.67 * task$nrow)
test_set = setdiff(task$row_ids, train_set)
Danger

Do not use constructs like sample(task$nrow, ...) to subset tasks, since rows are always identified by their $row_ids. These are not guaranteed to range from 1 to task$nrow and could be any positive integer.

Both $train() and $predict() have an optional row_ids-argument that determines which rows are used. Note that it is not a problem to run $train() with a Learner that has already been trained: the old model is automatically discarded, and the learner trains from scratch.

# train on the training set
learner$train(task, row_ids = train_set)

# predict on the test set
prediction = learner$predict(task, row_ids = test_set)

# the prediction naturally knows about the "truth" from the task
prediction
<PredictionClassif> for 254 observations:
    row_ids truth response   prob.pos  prob.neg
          8   neg      neg 0.37500000 0.6250000
         12   pos      pos 0.84905660 0.1509434
         19   neg      neg 0.37500000 0.6250000
---                                            
        762   pos      pos 0.84905660 0.1509434
        765   neg      neg 0.09954751 0.9004525
        768   neg      neg 0.09954751 0.9004525

2.3.6 Performance assessment

The last step of modeling is usually assessing the performance of the trained model. For this, the predictions made by the model are compared with the known ground-truth values that are stored in the Prediction object. The exact nature of this comparison is defined by a measure, which is given by a "Measure" object. If the prediction was made on a dataset without the target column, i.e., without known true labels, then performance can not be calculated.

Available measures can be retrieved using the msr() function, which accesses objects in mlr_measures:

mlr_measures
<DictionaryMeasure> with 92 stored values
Keys: aic, bic, classif.acc, classif.auc, classif.bacc, classif.bbrier,
  classif.ce, classif.costs, classif.dor, classif.fbeta, classif.fdr,
  classif.fn, classif.fnr, classif.fomr, classif.fp, classif.fpr,
  classif.logloss, classif.mauc_au1p, classif.mauc_au1u,
  classif.mauc_aunp, classif.mauc_aunu, classif.mbrier, classif.mcc,
  classif.npv, classif.ppv, classif.prauc, classif.precision,
  classif.recall, classif.sensitivity, classif.specificity, classif.tn,
  classif.tnr, classif.tp, classif.tpr, clust.ch, clust.db, clust.dunn,
  clust.silhouette, clust.wss, debug, dens.logloss, oob_error,
  regr.bias, regr.ktau, regr.mae, regr.mape, regr.maxae, regr.medae,
  regr.medse, regr.mse, regr.msle, regr.pbias, regr.rae, regr.rmse,
  regr.rmsle, regr.rrse, regr.rse, regr.rsq, regr.sae, regr.smape,
  regr.srho, regr.sse, selected_features, sim.jaccard, sim.phi,
  surv.brier, surv.calib_alpha, surv.calib_beta, surv.chambless_auc,
  surv.cindex, surv.dcalib, surv.graf, surv.hung_auc, surv.intlogloss,
  surv.logloss, surv.mae, surv.mse, surv.nagelk_r2, surv.oquigley_r2,
  surv.rcll, surv.rmse, surv.schmid, surv.song_auc, surv.song_tnr,
  surv.song_tpr, surv.uno_auc, surv.uno_tnr, surv.uno_tpr, surv.xu_r2,
  time_both, time_predict, time_train

We choose accuracy (classif.acc) as our specific performance measure here and call the method $score() of the prediction object to quantify the predictive performance of our model.

measure = msr("classif.acc")
measure
<MeasureClassifSimple:classif.acc>: Classification Accuracy
* Packages: mlr3, mlr3measures
* Range: [0, 1]
* Minimize: FALSE
* Average: macro
* Parameters: list()
* Properties: -
* Predict type: response
prediction$score(measure)
classif.acc 
  0.7244094 
Note

$score() can called without a given measure. In this case, classification defaults to classification error (classif.ce, which is one minus accuracy) and regression to the mean squared error (regr.mse).

It is possible to calculate multiple measures at the same time by passing a list to $score(). Such a list can easily be constructed using the “plural” msrs() function. If one wanted to have both the “true positive rate” ("classif.tpr") and the “true negative rate” ("classif.tnr"), one would use:

measures = msrs(c("classif.tpr", "classif.tnr"))
prediction$score(measures)
classif.tpr classif.tnr 
  0.4639175   0.8853503 

2.3.6.1 Confusion Matrix

A special case of performance evaluation is the confusion matrix, which shows, for each class, how many observations were predicted to be in that class and how many were actually in it (more information on Wikipedia). The entries along the diagonal denote the correctly classified observations.

prediction$confusion
        truth
response pos neg
     pos  45  18
     neg  52 139

In this case, we can see that our classifier seems to misclassify a relatively large number of positive samples as negative. In fact, a positive case is still more likely to be classified as "neg" than "pos'. Depending on the application being considered, it is possible that it is more important to keep false positives (lower left element of the confusion matrix) low. Lowering the threshold, so that ambiguous samples are more readily classified as positive rather than negative, can help in this case, although it will also lead to negative cases being classified as "pos" more often.

prediction$set_threshold(0.3)
prediction$confusion
        truth
response pos neg
     pos  75  65
     neg  22  92
Tip

Thresholds can be tuned automatically with the mlr3pipelines package, i.e. using PipeOpTuneThreshold.

2.3.7 Plotting Predictions

Similarly to plotting tasks, mlr3viz provides an autoplot() method for Prediction objects. All available types are listed in the manual pages for autoplot.PredictionClassif(), autoplot.PredictionRegr() and the other prediction types (defined by extension packages).

task = tsk("penguins")
learner = lrn("classif.rpart", predict_type = "prob")
learner$train(task)
prediction = learner$predict(task)

library("mlr3viz")
autoplot(prediction)