2 Basics
This chapter will teach you the essential building blocks of mlr3, as well as its R6 classes and operations used for machine learning. A typical machine learning workflow looks like this:
The data, which mlr3 encapsulates in tasks, is split into non-overlapping training and test sets. As we are interested in models that extrapolate to new data rather than just memorizing the training data, the separate test data allows to objectively evaluate models with respect to generalization. The training data is given to a machine learning algorithm, which we call a learner in mlr3. The learner uses the training data to build a model of the relationship of the input features to the output target values. This model is then used to produce predictions on the test data, which are compared to the ground truth values to assess the quality of the model. mlr3 offers a number of different measures to quantify how well a model performs based on the difference between predicted and actual values. Usually, this measure is a numeric score.
The process of splitting up data into training and test sets, building a model, and evaluating it can be repeated several times, resampling different training and test sets from the original data each time. Multiple resampling iterations allow us to get a better and less biased generalizable performance estimate for a particular type of model. As data are usually partitioned randomly into training and test sets, a single split can for example produce training and test sets that are very different, hence creating to the misleading impression that the particular type of model does not perform well.
In many cases, this simple workflow is not sufficient to deal with real-world data, which may require normalization, imputation of missing values, or feature selection. We will cover more complex workflows that allow to do this and even more later in the book.
This chapter covers the following topics:
Tasks
Tasks encapsulate the data with meta-information, such as the name of the prediction target column. We cover how to:
- access predefined tasks,
- specify a task type,
- create a task,
- work with a task’s API,
- assign roles to rows and columns of a task,
- implement task mutators, and
- retrieve the data that is stored in a task.
Learners
Learners encapsulate machine learning algorithms to train models and make predictions for a task. These are provided other packages. We cover how to:
- access the set of classification and regression learners that come with mlr3 and retrieve a specific learner (more types of learners are covered later in the book),
- access the set of hyperparameter values of a learner and modify them.
How to extend learners and implement your own is covered in a supplemental advanced technical section.
Train and predict
The section on the train and predict methods illustrates how to use tasks and learners to train a model and make predictions on a new data set. In particular, we cover how to:
- properly set up tasks and learners for training and prediction,
- set up train and test splits for a task,
- train the learner on the training set to produce a model,
- run the model on the test set to produce predictions, and
- assess the performance of the model by comparing predicted and actual values.
Before we get into the details of how to use mlr3 for machine learning, we give a brief introduction to R6 as it is a relatively new part of R. mlr3 heavily relies on R6 and all basic building blocks it provides are R6 classes:
- tasks,
- learners,
- measures, and
- resamplings.
2.1 Quick R6 Intro for Beginners
R6 is one of R’s more recent dialects for object-oriented programming (OO). It addresses shortcomings of earlier OO implementations in R, such as S3, which we used in mlr. If you have done any object-oriented programming before, R6 should feel familiar. We focus on the parts of R6 that you need to know to use mlr3 here.
- Objects are created by calling the constructor of an
R6::R6Class()
object, specifically the initialization method$new()
. For example,foo = Foo$new(bar = 1)
creates a new object of classFoo
, setting thebar
argument of the constructor to the value1
. Most objects in mlr3 are created through special functions (e.g.lrn("regr.rpart")
) that encapsulate this and are also referred to as sugar functions. - Objects have mutable state that is encapsulated in their fields, which can be accessed through the dollar operator.
We can access the
bar
value in thefoo
variable from above throughfoo$bar
and set its value by assigning the field, e.g.foo$bar = 2
. - In addition to fields, objects expose methods that allow to inspect the object’s state, retrieve information, or perform an action that changes the internal state of the object.
For example, the
$train
method of a learner changes the internal state of the learner by building and storing a trained model, which can then be used to make predictions, given data. - Objects can have public and private fields and methods. The public fields and methods define the API to interact with the object. Private methods are only relevant for you if you want to extend mlr3, e.g. with new learners.
- R6 objects are internally environments, and as such have reference semantics.
For example,
foo2 = foo
does not create a copy offoo
infoo2
, but another reference to the same actual object. Settingfoo$bar = 3
will also changefoo2$bar
to3
and vice versa. - To copy an object, use the
$clone()
method and thedeep = TRUE
argument for nested objects, for example,foo2 = foo$clone(deep = TRUE)
.
For more details on R6, have a look at the excellent R6 vignettes, especially the introduction.
2.2 Tasks
Tasks are objects that contain the (usually tabular) data and additional meta-data to define a machine learning problem. The meta-data is, for example, the name of the target variable for supervised machine learning problems, or the type of the dataset (e.g. a spatial or survival). This information is used by specific operations that can be performed on a task.
2.2.1 Task Types
To create a task from a data.frame()
, data.table()
or Matrix()
, you first need to select the right task type:
Classification Task: The target is a label (stored as
character()
orfactor()
) with only relatively few distinct values →TaskClassif
.Regression Task: The target is a numeric quantity (stored as
integer()
ordouble()
) →TaskRegr
.Survival Task: The target is the (right-censored) time to an event. More censoring types are currently in development →
mlr3proba::TaskSurv
in add-on package mlr3proba.Density Task: An unsupervised task to estimate the density →
mlr3proba::TaskDens
in add-on package mlr3proba.Cluster Task: An unsupervised task type; there is no target and the aim is to identify similar groups within the feature space →
mlr3cluster::TaskClust
in add-on package mlr3cluster.Spatial Task: Observations in the task have spatio-temporal information (e.g. coordinates) →
mlr3spatiotempcv::TaskRegrST
ormlr3spatiotempcv::TaskClassifST
in add-on package mlr3spatiotempcv.Ordinal Regression Task: The target is ordinal →
TaskOrdinal
in add-on package mlr3ordinal (still in development).
2.2.2 Task Creation
As an example, we will create a regression task using the mtcars
data set from the package datasets
.
It contains characteristics for different types of cars, along with their fuel consumption.
We predict the numeric target variable "mpg"
(miles per gallon).
We only consider the first two features in the dataset for brevity.
First, we load and prepare the data, outputting it as a string to get a better idea of what it looks like.
## 'data.frame': 32 obs. of 3 variables:
## $ mpg : num 21 21 22.8 21.4 18.7 18.1 14.3 24.4 22.8 19.2 ...
## $ cyl : num 6 6 4 6 8 6 8 4 4 6 ...
## $ disp: num 160 160 108 258 360 ...
Next, we create a regression task, i.e. we construct a new instance of the R6 class TaskRegr
.
Formally, the intended way to initialize an R6 object is to call the constructor TaskRegr$new()
.
Instead, we are calling the converter as_task_regr()
to convert our data.frame()
stored in the variable data
to a task and provide the following information:
-
x
: Object to convert. Works for rectangular data formats such asdata.frame()
,data.table()
, ortibble()
. Internally, the data is converted and stored in an abstractDataBackend
. This allows connecting to out-of-memory storage systems like SQL servers via the extension package mlr3db. -
target
: The name of the prediction target column for the regression problem, here miles per gallon ("mpg"
). -
id
(optional): An arbitrary identifier for the task, used in plots and summaries. If not provided, the deparsed name ofx
will be used.
library("mlr3")
task_mtcars = as_task_regr(data, target = "mpg", id = "cars")
print(task_mtcars)
## <TaskRegr:cars> (32 x 3)
## * Target: mpg
## * Properties: -
## * Features (2):
## - dbl (2): cyl, disp
The print()
method gives a short summary of the task:
It has 32 observations and 3 columns, of which 2 are features stored in double-precision floating point format.
We can also plot the task using the mlr3viz package, which gives a graphical summary of its properties:
Note that, instead of loading all the extension packages individually, it is often more convenient to load the mlr3verse package instead.
mlr3verse
imports most mlr3
packages and re-exports functions which are used for common machine learning and data science tasks.
2.2.3 Predefined tasks
mlr3 includes a few predefined machine learning tasks.
All tasks are stored in an R6 Dictionary
(a key-value store) named mlr_tasks
.
Printing it gives the keys (the names of the datasets):
mlr_tasks
## <DictionaryTask> with 11 stored values
## Keys: boston_housing, breast_cancer, german_credit, iris, mtcars,
## penguins, pima, sonar, spam, wine, zoo
We can get a more informative summary of the example tasks by converting the dictionary to a data.table()
object:
as.data.table(mlr_tasks)
## key label task_type nrow ncol properties lgl
## 1: boston_housing Boston Housing Prices regr 506 19 0
## 2: breast_cancer Wisconsin Breast Cancer classif 683 10 twoclass 0
## 3: german_credit German Credit classif 1000 21 twoclass 0
## 4: iris Iris Flowers classif 150 5 multiclass 0
## 5: mtcars Motor Trends regr 32 11 0
## 6: penguins Palmer Penguins classif 344 8 multiclass 0
## 7: pima Pima Indian Diabetes classif 768 9 twoclass 0
## 8: sonar Sonar: Mines vs. Rocks classif 208 61 twoclass 0
## 9: spam HP Spam Detection classif 4601 58 twoclass 0
## 10: wine Wine Regions classif 178 14 multiclass 0
## 11: zoo Zoo Animals classif 101 17 multiclass 15
## int dbl chr fct ord pxc
## 1: 3 13 0 2 0 0
## 2: 0 0 0 0 9 0
## 3: 3 0 0 14 3 0
## 4: 0 4 0 0 0 0
## 5: 0 10 0 0 0 0
## 6: 3 2 0 2 0 0
## 7: 0 8 0 0 0 0
## 8: 0 60 0 0 0 0
## 9: 0 57 0 0 0 0
## 10: 2 11 0 0 0 0
## 11: 1 0 0 0 0 0
Above, the columns "lgl"
(logical
), "int"
(integer
), "dbl"
(double
), "chr"
(character
), "fct"
(factor
), "ord"
(ordered factor
) and "pxc"
(POSIXct
time) show the number of features in the task of the respective type.
To get a task from the dictionary, use the $get()
method from the mlr_tasks
class and assign the return value to a new variable
As getting a task from a dictionary is a very common problem, mlr3
provides the shortcut function tsk()
.
Here, we retrieve the palmer penguins classification task
, which is provided by the package palmerpenguins:
## <TaskClassif:penguins> (344 x 8): Palmer Penguins
## * Target: species
## * Properties: multiclass
## * Features (7):
## - int (3): body_mass, flipper_length, year
## - dbl (2): bill_depth, bill_length
## - fct (2): island, sex
Note that loading extension packages can add to dictionaries such as mlr_tasks
.
For example, mlr3data adds some more example and toy tasks for regression and classification, and mlr3proba adds survival and density estimation tasks.
Both packages are loaded automatically when the mlr3verse package is loaded:
library("mlr3verse")
as.data.table(mlr_tasks)[, 1:4]
## key label task_type nrow
## 1: actg ACTG 320 surv 1151
## 2: bike_sharing Bike Sharing Demand regr 17379
## 3: boston_housing Boston Housing Prices regr 506
## 4: breast_cancer Wisconsin Breast Cancer classif 683
## 5: faithful Old Faithful Eruptions dens 272
## 6: gbcs German Breast Cancer surv 686
## 7: german_credit German Credit classif 1000
## 8: grace GRACE 1000 surv 1000
## 9: ilpd Indian Liver Patient Data classif 583
## 10: iris Iris Flowers classif 150
## 11: kc_housing King County House Sales regr 21613
## 12: lung Lung Cancer surv 228
## 13: moneyball Major League Baseball Statistics regr 1232
## 14: mtcars Motor Trends regr 32
## 15: optdigits Optical Recognition of Handwritten Digits classif 5620
## 16: penguins Palmer Penguins classif 344
## 17: penguins_simple Simplified Palmer Penguins classif 333
## 18: pima Pima Indian Diabetes classif 768
## 19: precip Annual Precipitation dens 70
## 20: rats Rats surv 300
## 21: sonar Sonar: Mines vs. Rocks classif 208
## 22: spam HP Spam Detection classif 4601
## 23: titanic Titanic classif 1309
## 24: unemployment Unemployment Duration surv 3343
## 25: usarrests US Arrests clust 50
## 26: whas Worcester Heart Attack surv 481
## 27: wine Wine Regions classif 178
## 28: zoo Zoo Animals classif 101
## key label task_type nrow
To get more information about a particular task, the corresponding man page can be found under mlr_tasks_[id]
, e.g. mlr_tasks_german_credit
:
help("mlr_tasks_german_credit")
As an alternative, all mlr3 objects come with a help()
method which opens the corresponding help page.
To open the help page of the previously-constructed palmer penguins task, you can also call:
task_penguins$help()
2.2.4 Task API
All properties and characteristics of tasks can be queried using the task’s public fields and methods (see Task
).
Methods can also be used to change the stored data and the behavior of the task.
2.2.4.1 Retrieving Data
The data stored in a task can be retrieved directly from fields, for example for task_mtcars
that we defined above we can get the number of rows and columns:
task_mtcars
## <TaskRegr:cars> (32 x 3)
## * Target: mpg
## * Properties: -
## * Features (2):
## - dbl (2): cyl, disp
task_mtcars$nrow
## [1] 32
task_mtcars$ncol
## [1] 3
More information can be obtained through methods of the object, for example:
task_mtcars$data()
## mpg cyl disp
## 1: 21.0 6 160.0
## 2: 21.0 6 160.0
## 3: 22.8 4 108.0
## 4: 21.4 6 258.0
## 5: 18.7 8 360.0
## 6: 18.1 6 225.0
## 7: 14.3 8 360.0
## 8: 24.4 4 146.7
## 9: 22.8 4 140.8
## 10: 19.2 6 167.6
## 11: 17.8 6 167.6
## 12: 16.4 8 275.8
## 13: 17.3 8 275.8
## 14: 15.2 8 275.8
## 15: 10.4 8 472.0
## 16: 10.4 8 460.0
## 17: 14.7 8 440.0
## 18: 32.4 4 78.7
## 19: 30.4 4 75.7
## 20: 33.9 4 71.1
## 21: 21.5 4 120.1
## 22: 15.5 8 318.0
## 23: 15.2 8 304.0
## 24: 13.3 8 350.0
## 25: 19.2 8 400.0
## 26: 27.3 4 79.0
## 27: 26.0 4 120.3
## 28: 30.4 4 95.1
## 29: 15.8 8 351.0
## 30: 19.7 6 145.0
## 31: 15.0 8 301.0
## 32: 21.4 4 121.0
## mpg cyl disp
In mlr3, each row (observation) has a unique identifier, stored as an integer()
.
These can be passed as arguments to the $data()
method to select specific rows:
head(task_mtcars$row_ids)
## [1] 1 2 3 4 5 6
# retrieve data for rows with IDs 1, 5, and 10
task_mtcars$data(rows = c(1, 5, 10))
## mpg cyl disp
## 1: 21.0 6 160.0
## 2: 18.7 8 360.0
## 3: 19.2 6 167.6
Note that although the row IDs are typically just the sequence from 1
to nrow(data)
, they are only guaranteed to be unique natural numbers.
Keep that in mind, especially if you work with data stored in a real database management system (see backends).
Similarly to row IDs, target and feature columns also have unique identifiers, i.e. names (stored as character()
).
Their names can be accessed via the public slots $feature_names
and $target_names
.
Here, “target” refers to the variable we want to predict and “feature” to the predictors for the task.
The target will usually be only a single name.
task_mtcars$feature_names
## [1] "cyl" "disp"
task_mtcars$target_names
## [1] "mpg"
The row_ids
and column names can be combined when selecting a subset of the data:
# retrieve data for rows 1, 5, and 10 and select column "mpg"
task_mtcars$data(rows = c(1, 5, 10), cols = "mpg")
## mpg
## 1: 21.0
## 2: 18.7
## 3: 19.2
To extract the complete data from the task, one can also simply convert it to a data.table
:
# show summary of entire data
summary(as.data.table(task_mtcars))
## mpg cyl disp
## Min. :10.4 Min. :4.00 Min. : 71.1
## 1st Qu.:15.4 1st Qu.:4.00 1st Qu.:120.8
## Median :19.2 Median :6.00 Median :196.3
## Mean :20.1 Mean :6.19 Mean :230.7
## 3rd Qu.:22.8 3rd Qu.:8.00 3rd Qu.:326.0
## Max. :33.9 Max. :8.00 Max. :472.0
2.2.4.2 Binary classification
Classification problems with a target variable with only two classes are called binary classification tasks.
They are special in the sense that one of these classes is denoted positive and the other one negative.
You can specify the positive class within the classification task
object during task creation.
If not explicitly set during construction, the positive class defaults to the first level of the target variable.
# during construction
data("Sonar", package = "mlbench")
task = as_task_classif(Sonar, target = "Class", positive = "R")
# switch positive class to level 'M'
task$positive = "M"
2.2.4.3 Roles (Rows and Columns)
We have seen that during task creation, target and feature roles are assigned to columns.
Target refers to the variable we want to predict and features are the predictors (also called co-variates) for the target.
It is possible to assign different roles to rows and columns.
These roles affect the behavior of the task for different operations.
For other possible roles and their meaning, see the documentation of Task
.
For example, the previously-constructed task_mtcars
task has the following column roles:
print(task_mtcars$col_roles)
## $feature
## [1] "cyl" "disp"
##
## $target
## [1] "mpg"
##
## $name
## character(0)
##
## $order
## character(0)
##
## $stratum
## character(0)
##
## $group
## character(0)
##
## $weight
## character(0)
Columns can have no role (they are ignored) or have multiple roles.
To add the row names of task_mtcars
as an additional feature, we first add them to the underlying data as regular column and then recreate the task with the new column.
# with `keep.rownames`, data.table stores the row names in an extra column "rn"
data = as.data.table(datasets::mtcars[, 1:3], keep.rownames = TRUE)
task_mtcars = as_task_regr(data, target = "mpg", id = "cars")
# there is a new feature called "rn"
task_mtcars$feature_names
## [1] "cyl" "disp" "rn"
The row names are now a feature whose values are stored in the column "rn"
.
We include this column here for educational purposes only.
Generally speaking, there is no point in having a feature that uniquely identifies each row.
Furthermore, the character data type will cause problems with many types of machine learning algorithms.
The identifier may be useful to label points in plots, for example to identify outliers.
To achieve this, we will change the role of the rn
column by removing it from the list of features and assign the new role "name"
.
There are two ways to do this:
- Use the
Task
method$set_col_roles()
(recommended). - Simply modify the field
$col_roles
, which is a named list of vectors of column names. Each vector in this list corresponds to a column role, and the column names contained in that vector have that role.
Supported column roles can be found in the manual of Task
, or just by printing the names of the field $col_roles
:
# supported column roles, see ?Task
names(task_mtcars$col_roles)
## [1] "feature" "target" "name" "order" "stratum" "group" "weight"
# assign column "rn" the role "name", remove from other roles
task_mtcars$set_col_roles("rn", roles = "name")
# note that "rn" not listed as feature anymore
task_mtcars$feature_names
## [1] "cyl" "disp"
# "rn" also does not appear anymore when we access the data
task_mtcars$data(rows = 1:2)
## mpg cyl disp
## 1: 21 6 160
## 2: 21 6 160
Changing the role does not change the underlying data, it just updates the view on it. The data is not copied in the code above. The view is changed in-place though, i.e. the task object itself is modified.
Just like columns, it is also possible to assign different roles to rows.
Rows can have two different roles:
Role
use
: Rows that are generally available for model fitting (although they may also be used as test set in resampling). This role is the default role.Role
validation
: Rows that are not used for training. Rows that have missing values in the target column during task creation are automatically set to the validation role.
There are several reasons to hold some observations back or treat them differently:
- It is often good practice to validate the final model on an external validation set to identify possible overfitting.
- Some observations may be unlabeled, e.g. in competitions like Kaggle.
These observations cannot be used for training a model, but can be used to get predictions.
2.2.4.4 Task Mutators
As shown above, modifying $col_roles
or $row_roles
(either via set_col_roles()
/set_row_roles()
or directly by modifying the named list) changes the view on the data.
The additional convenience method $filter()
subsets the current view based on row IDs and $select()
subsets the view based on feature names.
task_penguins = tsk("penguins")
task_penguins$select(c("body_mass", "flipper_length")) # keep only these features
task_penguins$filter(1:3) # keep only these rows
task_penguins$head()
## species body_mass flipper_length
## 1: Adelie 3750 181
## 2: Adelie 3800 186
## 3: Adelie 3250 195
While the methods above allow us to subset the data, the methods $rbind()
and $cbind()
allow to add extra rows and columns to a task.
Again, the original data is not changed.
The additional rows or columns are only added to the view of the data.
task_penguins$cbind(data.frame(letters = letters[1:3])) # add column letters
task_penguins$head()
## species body_mass flipper_length letters
## 1: Adelie 3750 181 a
## 2: Adelie 3800 186 b
## 3: Adelie 3250 195 c
2.2.5 Plotting Tasks
The mlr3viz package provides plotting facilities for many classes implemented in mlr3. The available plot types depend on the class, but all plots are returned as ggplot2 objects which can be easily customized.
For classification tasks (inheriting from TaskClassif
), see the documentation of mlr3viz::autoplot.TaskClassif
for the implemented plot types.
Here are some examples to get an impression:
library("mlr3viz")
# get the pima indians task
task = tsk("pima")
# subset task to only use the 3 first features
task$select(head(task$feature_names, 3))
# default plot: class frequencies
autoplot(task)
# pairs plot (requires package GGally)
autoplot(task, type = "pairs")
# duo plot (requires package GGally)
autoplot(task, type = "duo")
Of course, you can do the same for regression tasks (inheriting from TaskRegr
) as documented in mlr3viz::autoplot.TaskRegr
:
library("mlr3viz")
# get the complete mtcars task
task = tsk("mtcars")
# subset task to only use the 3 first features
task$select(head(task$feature_names, 3))
# default plot: boxplot of target variable
autoplot(task)
# pairs plot (requires package GGally)
autoplot(task, type = "pairs")
2.3 Learners
Objects of class Learner
provide a unified interface to many popular machine learning algorithms in R.
They consist of methods to train and predict a model for a Task
and provide meta-information about the learners, such as the hyperparameters (which control the behavior of the learner) you can set.
The base class of each learner is Learner
, specialized for regression as LearnerRegr
and for classification as LearnerClassif
.
Other types of learners, provided by extension packages, also inherit from the Learner
base class, e.g. mlr3proba::LearnerSurv
or mlr3cluster::LearnerClust
.
In contrast to Task
s, the creation of a custom Learner is usually not required and a more advanced topic.
Hence, we refer the reader to Section 7.1 and proceed with an overview of the interface of already implemented learners.
-
training step: The training data (features and target) is passed to the Learner’s
$train()
function which trains and stores a model, i.e. the relationship of the target and features. -
predict step: The new data, usually a different slice of the original data than used for training, is passed to the
$predict()
method of the Learner. The model trained in the first step is used to predict the missing target, e.g. labels for classification problems or the numerical value for regression problems.
2.3.1 Predefined Learners
The mlr3
package ships with the following set of classification and regression learners.
We deliberately keep this small to avoid unnecessary dependencies:
-
mlr_learners_classif.featureless
: Simple baseline classification learner (inheriting fromLearnerClassif
). The default is to predict the label that is most frequent in the training set every time. -
mlr_learners_regr.featureless
: Simple baseline regression learner (inheriting fromLearnerRegr
). The default is to predict the mean of the target in training set every time. -
mlr_learners_classif.rpart
: Single classification tree from package rpart. -
mlr_learners_regr.rpart
: Single regression tree from package rpart.
This set of baseline learners is usually insufficient for a real data analysis. Thus, we have cherry-picked implementations of the most popular machine learning method and collected them in the mlr3learners package:
- Linear and logistic regression
- Penalized Generalized Linear Models
- \(k\)-Nearest Neighbors regression and classification
- Kriging
- Linear and Quadratic Discriminant Analysis
- Naive Bayes
- Support-Vector machines
- Gradient Boosting
- Random Forests for regression, classification and survival
More machine learning methods and alternative implementations are collected in the mlr3extralearners repository.
A full list of implemented learners across all packages is given in this interactive list and also via mlr3extralearners::list_mlr3learners()
.
head(mlr3extralearners::list_mlr3learners()) # show first six learners
## name class id mlr3_package
## 1: AdaBoostM1 classif classif.AdaBoostM1 mlr3extralearners
## 2: bart classif classif.bart mlr3extralearners
## 3: C50 classif classif.C50 mlr3extralearners
## 4: catboost classif classif.catboost mlr3extralearners
## 5: cforest classif classif.cforest mlr3extralearners
## 6: ctree classif classif.ctree mlr3extralearners
## required_packages
## 1: mlr3,mlr3extralearners,RWeka
## 2: mlr3,mlr3extralearners,dbarts
## 3: mlr3,mlr3extralearners,C50
## 4: mlr3,mlr3extralearners,catboost
## 5: mlr3,mlr3extralearners,partykit,sandwich,coin
## 6: mlr3,mlr3extralearners,partykit,sandwich,coin
## properties
## 1: multiclass,twoclass
## 2: twoclass,weights
## 3: missings,multiclass,twoclass,weights
## 4: importance,missings,multiclass,twoclass,weights
## 5: multiclass,oob_error,twoclass,weights
## 6: multiclass,twoclass,weights
## feature_types predict_types
## 1: numeric,factor,ordered,integer response,prob
## 2: integer,numeric,factor,ordered response,prob
## 3: numeric,factor,ordered response,prob
## 4: numeric,factor,ordered response,prob
## 5: integer,numeric,factor,ordered response,prob
## 6: integer,numeric,factor,ordered response,prob
The full list of learners uses a large number of extra packages, which sometimes break. We check the status of each learner’s integration automatically, the latest build status of all learners is shown here.
To get one of the predefined learners, you need to access the mlr_learners
Dictionary
which, similar to mlr_tasks
, is automatically populated with more learners by extension packages.
# load most mlr3 packages to populate the dictionary
library("mlr3verse")
mlr_learners
## <DictionaryLearner> with 136 stored values
## Keys: classif.AdaBoostM1, classif.bart, classif.C50, classif.catboost,
## classif.cforest, classif.ctree, classif.cv_glmnet, classif.debug,
## classif.earth, classif.extratrees, classif.featureless, classif.fnn,
## classif.gam, classif.gamboost, classif.gausspr, classif.gbm,
## classif.glmboost, classif.glmnet, classif.IBk, classif.J48,
## classif.JRip, classif.kknn, classif.ksvm, classif.lda,
## classif.liblinear, classif.lightgbm, classif.LMT, classif.log_reg,
## classif.lssvm, classif.mob, classif.multinom, classif.naive_bayes,
## classif.nnet, classif.OneR, classif.PART, classif.qda,
## classif.randomForest, classif.ranger, classif.rfsrc, classif.rpart,
## classif.svm, classif.xgboost, clust.agnes, clust.ap, clust.cmeans,
## clust.cobweb, clust.dbscan, clust.diana, clust.em, clust.fanny,
## clust.featureless, clust.ff, clust.hclust, clust.kkmeans,
## clust.kmeans, clust.MBatchKMeans, clust.meanshift, clust.pam,
## clust.SimpleKMeans, clust.xmeans, dens.hist, dens.kde, dens.kde_kd,
## dens.kde_ks, dens.locfit, dens.logspline, dens.mixed, dens.nonpar,
## dens.pen, dens.plug, dens.spline, regr.bart, regr.catboost,
## regr.cforest, regr.ctree, regr.cubist, regr.cv_glmnet, regr.debug,
## regr.earth, regr.extratrees, regr.featureless, regr.fnn, regr.gam,
## regr.gamboost, regr.gausspr, regr.gbm, regr.glm, regr.glmboost,
## regr.glmnet, regr.IBk, regr.kknn, regr.km, regr.ksvm, regr.liblinear,
## regr.lightgbm, regr.lm, regr.M5Rules, regr.mars, regr.mob,
## regr.randomForest, regr.ranger, regr.rfsrc, regr.rpart, regr.rvm,
## regr.svm, regr.xgboost, surv.akritas, surv.blackboost, surv.cforest,
## surv.coxboost, surv.coxph, surv.coxtime, surv.ctree,
## surv.cv_coxboost, surv.cv_glmnet, surv.deephit, surv.deepsurv,
## surv.dnnsurv, surv.flexible, surv.gamboost, surv.gbm, surv.glmboost,
## surv.glmnet, surv.kaplan, surv.loghaz, surv.mboost, surv.nelson,
## surv.obliqueRSF, surv.parametric, surv.pchazard, surv.penalized,
## surv.ranger, surv.rfsrc, surv.rpart, surv.svm, surv.xgboost
To obtain an object from the dictionary you can also use the shortcut function lrn()
or the generic mlr_learners$get()
method, e.g. lrn("classif.rpart")
.
2.3.2 Learner API
Each learner provides the following meta-information:
-
feature_types
: the type of features the learner can deal with. -
packages
: the packages required to train a model with this learner and make predictions. -
properties
: additional properties and capabilities. For example, a learner has the property “missings” if it is able to handle missing feature values, and “importance” if it computes and allows to extract data on the relative importance of the features. A complete list of these is available in the mlr3 reference. -
predict_types
: possible prediction types. For example, a classification learner can predict labels (“response”) or probabilities (“prob”). For a complete list of possible predict types see the mlr3 reference.
You can retrieve a specific learner using its ID:
## <LearnerClassifRpart:classif.rpart>: Classification Tree
## * Model: -
## * Parameters: xval=0
## * Packages: mlr3, rpart
## * Predict Type: response
## * Feature types: logical, integer, numeric, factor, ordered
## * Properties: importance, missings, multiclass, selected_features,
## twoclass, weights
Each learner has hyperparameters that control its behavior, for example the minimum number of samples in the leaf of a decision tree, or whether to provide verbose output durning training.
Setting hyperparameters to values appropriate for a given machine learning task is crucial.
The field param_set
stores a description of the hyperparameters the learner has, their ranges, defaults, and current values:
learner$param_set
## <ParamSet>
## id class lower upper nlevels default value
## 1: cp ParamDbl 0 1 Inf 0.01
## 2: keep_model ParamLgl NA NA 2 FALSE
## 3: maxcompete ParamInt 0 Inf Inf 4
## 4: maxdepth ParamInt 1 30 30 30
## 5: maxsurrogate ParamInt 0 Inf Inf 5
## 6: minbucket ParamInt 1 Inf Inf <NoDefault[3]>
## 7: minsplit ParamInt 1 Inf Inf 20
## 8: surrogatestyle ParamInt 0 1 2 0
## 9: usesurrogate ParamInt 0 2 3 2
## 10: xval ParamInt 0 Inf Inf 10 0
The set of current hyperparameter values is stored in the values
field of the param_set
field.
You can change the current hyperparameter values by assigning a named list to this field:
learner$param_set$values = list(cp = 0.01, xval = 0)
learner
## <LearnerClassifRpart:classif.rpart>: Classification Tree
## * Model: -
## * Parameters: cp=0.01, xval=0
## * Packages: mlr3, rpart
## * Predict Type: response
## * Feature types: logical, integer, numeric, factor, ordered
## * Properties: importance, missings, multiclass, selected_features,
## twoclass, weights
Note that this operation overwrites all previously set parameters. You can also get the current set of hyperparameter values, modify it, and write it back to the learner:
pv = learner$param_set$values
pv$cp = 0.02
learner$param_set$values = pv
This sets cp
to 0.02
but keeps any other values that were set previously.
Note that the lrn()
function also accepts additional arguments to update hyperparameters or set fields of the learner in one go:
learner = lrn("classif.rpart", id = "rp", cp = 0.001)
learner$id
## [1] "rp"
learner$param_set$values
## $xval
## [1] 0
##
## $cp
## [1] 0.001
More on this is discussed in the section on Hyperparameter Tuning.
2.3.2.1 Thresholding
Models trained on binary classification tasks that predict the probability for the positive class usually use a simple rule to determine the predicted class label: if the probability is more than 50%, predict the positive label, otherwise predict the negative label. In some cases you may want to adjust this threshold, for example if the classes are very unbalanced (i.e. one is much more prevalent than the other).
In the example below, we change the threshold to 0.2, which improves the True Positive Rate (TPR). Note that while the new threshold classifies more observations from the positive class correctly, the True Negative Rate (TNR) decreases. Depending on the application, this may or may not be desired.
data("Sonar", package = "mlbench")
task = as_task_classif(Sonar, target = "Class", positive = "M")
learner = lrn("classif.rpart", predict_type = "prob")
pred = learner$train(task)$predict(task)
measures = msrs(c("classif.tpr", "classif.tnr")) # use msrs() to get a list of multiple measures
pred$confusion
## truth
## response M R
## M 95 10
## R 16 87
pred$score(measures)
## classif.tpr classif.tnr
## 0.8559 0.8969
pred$set_threshold(0.2)
pred$confusion
## truth
## response M R
## M 104 25
## R 7 72
pred$score(measures)
## classif.tpr classif.tnr
## 0.9369 0.7423
Thresholds can be tuned automatically with respect to a performance measure with the mlr3pipelines package, i.e. using PipeOpTuneThreshold
.
2.4 Train, Predict, Assess Performance
In this section, we explain how tasks and learners can be used to train a model and predict on a new dataset.
The concept is demonstrated on a supervised classification task using the penguins
dataset and the rpart
learner, which builds a single classification tree.
Training a learner means fitting a model to a given data set – essentially an optimization problem that determines the best parameters (not hyperparameters!) of the model given the data. We then predict the label for observations that the model has not seen during training. These predictions are compared to the ground truth values in order to assess the predictive performance of the model.
2.4.1 Creating Task and Learner Objects
First of all, we load the mlr3verse package, which will load all other packages we need here.
Now, we retrieve the task and the learner from mlr_tasks
(with shortcut tsk()
) and mlr_learners
(with shortcut lrn()
), respectively:
2.4.2 Setting up the train/test splits of the data
It is common to train on a majority of the data, to give the learner a better chance of fitting a good model. Here we use 80% of all available observations to train and predict on the remaining 20%. For this purpose, we create two index vectors:
In Section 3.2 we will learn how mlr3 can automatically create training and test sets based on different resampling strategies.
2.4.3 Training the learner
The field $model
stores the model that is produced in the training step.
Before the $train()
method is called on a learner object, this field is NULL
:
learner$model
## NULL
Now we fit the classification tree using the training set of the task by calling the $train()
method of learner
:
learner$train(task, row_ids = train_set)
This operation modifies the learner in-place by adding the fitted model to the existing object.
We can now access the stored model via the field $model
:
print(learner$model)
## n= 275
##
## node), split, n, loss, yval, (yprob)
## * denotes terminal node
##
## 1) root 275 156 Adelie (0.432727 0.189091 0.378182)
## 2) flipper_length< 206.5 166 49 Adelie (0.704819 0.289157 0.006024)
## 4) bill_length< 43.35 116 3 Adelie (0.974138 0.025862 0.000000) *
## 5) bill_length>=43.35 50 5 Chinstrap (0.080000 0.900000 0.020000) *
## 3) flipper_length>=206.5 109 6 Gentoo (0.018349 0.036697 0.944954)
## 6) bill_depth>=17.15 8 4 Chinstrap (0.250000 0.500000 0.250000) *
## 7) bill_depth< 17.15 101 0 Gentoo (0.000000 0.000000 1.000000) *
Inspecting the output, we see that the learner has identified features in the task that are predictive of the class (the type of penguin) and uses them to partition observations in the tree.
There are additional details on how the data is partitioned across branches of the tree; the textual representation of the model depends on the type of learner.
For more information on this particular type of model, see rpart::print.rpart()
.
2.4.4 Predicting
After the model has been fitted to the training data, we use the test set for prediction.
Remember that we initially split the data in train_set
and test_set
.
prediction = learner$predict(task, row_ids = test_set)
print(prediction)
## <PredictionClassif> for 69 observations:
## row_ids truth response
## 4 Adelie Adelie
## 7 Adelie Adelie
## 10 Adelie Adelie
## ---
## 324 Chinstrap Chinstrap
## 329 Chinstrap Chinstrap
## 332 Chinstrap Chinstrap
The $predict()
method of the Learner
returns a Prediction
object.
More precisely, a LearnerClassif
returns a PredictionClassif
object.
A prediction objects holds the row IDs of the test data, the respective true label of the target column and the respective predictions.
The simplest way to extract this information is by converting the Prediction
object to a data.table()
:
head(as.data.table(prediction)) # show first six predictions
## row_ids truth response
## 1: 4 Adelie Adelie
## 2: 7 Adelie Adelie
## 3: 10 Adelie Adelie
## 4: 14 Adelie Adelie
## 5: 17 Adelie Adelie
## 6: 29 Adelie Adelie
For classification, you can also extract the confusion matrix:
prediction$confusion
## truth
## response Adelie Chinstrap Gentoo
## Adelie 33 2 0
## Chinstrap 0 14 0
## Gentoo 0 0 20
The confusion matrix shows, for each class, how many observations were predicted to be in that class and how many were actually in it (more information on Wikipedia). The entries along the diagonal denote the correctly classified observations. In this case, we can see that our classifier is really quite good and correctly predicting almost all observations.
2.4.5 Changing the Predict Type
Classification learners default to predicting the class label.
However, many classifiers additionally also tell you how sure they are about the predicted label by providing posterior probabilities for the classes.
To predict these probabilities, the predict_type
field of a LearnerClassif
must be changed from "response"
(the default) to "prob"
before training:
learner$predict_type = "prob"
# re-fit the model
learner$train(task, row_ids = train_set)
# rebuild prediction object
prediction = learner$predict(task, row_ids = test_set)
The prediction object now contains probabilities for all class labels in addition to the predicted label (the one with the highest probability):
# data.table conversion
head(as.data.table(prediction)) # show first six
## row_ids truth response prob.Adelie prob.Chinstrap prob.Gentoo
## 1: 4 Adelie Adelie 0.9741 0.02586 0
## 2: 7 Adelie Adelie 0.9741 0.02586 0
## 3: 10 Adelie Adelie 0.9741 0.02586 0
## 4: 14 Adelie Adelie 0.9741 0.02586 0
## 5: 17 Adelie Adelie 0.9741 0.02586 0
## 6: 29 Adelie Adelie 0.9741 0.02586 0
# directly access the predicted labels:
head(prediction$response)
## [1] Adelie Adelie Adelie Adelie Adelie Adelie
## Levels: Adelie Chinstrap Gentoo
# directly access the matrix of probabilities:
head(prediction$prob)
## Adelie Chinstrap Gentoo
## [1,] 0.9741 0.02586 0
## [2,] 0.9741 0.02586 0
## [3,] 0.9741 0.02586 0
## [4,] 0.9741 0.02586 0
## [5,] 0.9741 0.02586 0
## [6,] 0.9741 0.02586 0
Similarly to predicting probabilities for classification, many regression learners
support the extraction of standard error estimates for predictions by setting the predict type to "se"
.
2.4.6 Plotting Predictions
Similarly to plotting tasks, mlr3viz provides an autoplot()
method for Prediction
objects.
All available types are listed in the manual pages for autoplot.PredictionClassif()
, autoplot.PredictionRegr()
and the other prediction types (defined by extension packages).
task = tsk("penguins")
learner = lrn("classif.rpart", predict_type = "prob")
learner$train(task)
prediction = learner$predict(task)
autoplot(prediction)
2.4.7 Performance assessment
The last step of modeling is usually assessing the performance of the trained model.
We have already had a look at this with the confusion matrix, but it is often convenient to quantify the performance of a model with a single number.
The exact nature of this comparison is defined by a measure, which is given by a Measure
object.
Note that if the prediction was made on a dataset without the target column, i.e. without known true labels, then no performance can be calculated.
Available measures are stored in mlr_measures
(with convenience getter function msr()
):
mlr_measures
## <DictionaryMeasure> with 88 stored values
## Keys: aic, bic, classif.acc, classif.auc, classif.bacc, classif.bbrier,
## classif.ce, classif.costs, classif.dor, classif.fbeta, classif.fdr,
## classif.fn, classif.fnr, classif.fomr, classif.fp, classif.fpr,
## classif.logloss, classif.mbrier, classif.mcc, classif.npv,
## classif.ppv, classif.prauc, classif.precision, classif.recall,
## classif.sensitivity, classif.specificity, classif.tn, classif.tnr,
## classif.tp, classif.tpr, clust.ch, clust.db, clust.dunn,
## clust.silhouette, clust.wss, debug, dens.logloss, oob_error,
## regr.bias, regr.ktau, regr.mae, regr.mape, regr.maxae, regr.medae,
## regr.medse, regr.mse, regr.msle, regr.pbias, regr.rae, regr.rmse,
## regr.rmsle, regr.rrse, regr.rse, regr.rsq, regr.sae, regr.smape,
## regr.srho, regr.sse, selected_features, sim.jaccard, sim.phi,
## surv.brier, surv.calib_alpha, surv.calib_beta, surv.chambless_auc,
## surv.cindex, surv.dcalib, surv.graf, surv.hung_auc, surv.intlogloss,
## surv.logloss, surv.mae, surv.mse, surv.nagelk_r2, surv.oquigley_r2,
## surv.rcll, surv.rmse, surv.schmid, surv.song_auc, surv.song_tnr,
## surv.song_tpr, surv.uno_auc, surv.uno_tnr, surv.uno_tpr, surv.xu_r2,
## time_both, time_predict, time_train
We choose accuracy (classif.acc
) as our specific performance measure here and call the method $score()
of the prediction
object to quantify the predictive performance of our model.
## <MeasureClassifSimple:classif.acc>: Classification Accuracy
## * Packages: mlr3, mlr3measures
## * Range: [0, 1]
## * Minimize: FALSE
## * Average: macro
## * Parameters: list()
## * Properties: -
## * Predict type: response
prediction$score(measure)
## classif.acc
## 0.9651
Note that if no measure is specified, classification defaults to classification error (classif.ce
, the inverse of accuracy) and regression to the mean squared error (regr.mse
).