4  Learners

Objects of class Learner provide a unified interface to many popular machine learning algorithms in R. They consist of methods to train and predict a model for a Task and provide meta-information about the learners, such as the hyperparameters (which control the behavior of the learner) you can set.

The base class of each learner is Learner, specialized for regression as LearnerRegr and for classification as LearnerClassif. Other types of learners, provided by extension packages, also inherit from the Learner base class, e.g. mlr3proba::LearnerSurv or mlr3cluster::LearnerClust.

All Learners work in a two-stage procedure:

4.1 Predefined Learners

The mlr3 package ships with the following set of classification and regression learners. We deliberately keep this small to avoid unnecessary dependencies:

This set of baseline learners is usually insufficient for a real data analysis. Thus, we have cherry-picked implementations of the most popular machine learning method and collected them in the mlr3learners package:

More machine learning methods and alternative implementations are collected in the mlr3extralearners repository.

Tip

A full list of available learners across all mlr3 packages is provided in this interactive list and via mlr3extralearners::list_mlr3learners().

To get one of the predefined learners, you need to access the mlr_learners Dictionary which, similar to mlr_tasks, is automatically populated with more learners by extension packages.

library("mlr3learners")       # load recommended learners provided by mlr3learners package
library("mlr3extralearners")  # this loads further less-well-supported learners
library("mlr3proba")          # this loads some survival and density estimation learners
library("mlr3cluster")        # this loads some learners for clustering

mlr_learners
<DictionaryLearner> with 135 stored values
Keys: classif.AdaBoostM1, classif.bart, classif.C50, classif.catboost,
  classif.cforest, classif.ctree, classif.cv_glmnet, classif.debug,
  classif.earth, classif.featureless, classif.fnn, classif.gam,
  classif.gamboost, classif.gausspr, classif.gbm, classif.glmboost,
  classif.glmnet, classif.IBk, classif.J48, classif.JRip, classif.kknn,
  classif.ksvm, classif.lda, classif.liblinear, classif.lightgbm,
  classif.LMT, classif.log_reg, classif.lssvm, classif.mob,
  classif.multinom, classif.naive_bayes, classif.nnet, classif.OneR,
  classif.PART, classif.qda, classif.randomForest, classif.ranger,
  classif.rfsrc, classif.rpart, classif.svm, classif.xgboost,
  clust.agnes, clust.ap, clust.cmeans, clust.cobweb, clust.dbscan,
  clust.diana, clust.em, clust.fanny, clust.featureless, clust.ff,
  clust.hclust, clust.kkmeans, clust.kmeans, clust.MBatchKMeans,
  clust.meanshift, clust.pam, clust.SimpleKMeans, clust.xmeans,
  dens.hist, dens.kde, dens.kde_kd, dens.kde_ks, dens.locfit,
  dens.logspline, dens.mixed, dens.nonpar, dens.pen, dens.plug,
  dens.spline, regr.bart, regr.catboost, regr.cforest, regr.ctree,
  regr.cubist, regr.cv_glmnet, regr.debug, regr.earth,
  regr.featureless, regr.fnn, regr.gam, regr.gamboost, regr.gausspr,
  regr.gbm, regr.glm, regr.glmboost, regr.glmnet, regr.IBk, regr.kknn,
  regr.km, regr.ksvm, regr.liblinear, regr.lightgbm, regr.lm,
  regr.lmer, regr.M5Rules, regr.mars, regr.mob, regr.randomForest,
  regr.ranger, regr.rfsrc, regr.rpart, regr.rvm, regr.svm,
  regr.xgboost, surv.akritas, surv.blackboost, surv.cforest,
  surv.coxboost, surv.coxph, surv.coxtime, surv.ctree,
  surv.cv_coxboost, surv.cv_glmnet, surv.deephit, surv.deepsurv,
  surv.dnnsurv, surv.flexible, surv.gamboost, surv.gbm, surv.glmboost,
  surv.glmnet, surv.kaplan, surv.loghaz, surv.mboost, surv.nelson,
  surv.obliqueRSF, surv.parametric, surv.pchazard, surv.penalized,
  surv.ranger, surv.rfsrc, surv.rpart, surv.svm, surv.xgboost

To obtain an object from the dictionary, use the lrn() function.

learner = lrn("classif.rpart")

Alternatively, the mlr_learners$get() function can be used, for which lrn() is a shortcut.

4.2 Learner API

Each learner provides the following meta-information:

  • $feature_types: the type of features the learner can deal with.
  • $packages: the packages required to train a model with this learner and make predictions.
  • $properties: additional properties and capabilities. For example, a learner has the property “missings” if it is able to handle missing feature values, and “importance” if it computes and allows to extract data on the relative importance of the features. A complete list of these is available in the mlr3 reference.
  • $predict_types: possible prediction types. For example, a classification learner can predict labels (“response”) or probabilities (“prob”). For a complete list of possible predict types see the mlr3 reference.

This information can be queried through these slots, or seen at a glance from the printer:

print(learner)
<LearnerClassifRpart:classif.rpart>: Classification Tree
* Model: -
* Parameters: xval=0
* Packages: mlr3, rpart
* Predict Type: response
* Feature types: logical, integer, numeric, factor, ordered
* Properties: importance, missings, multiclass, selected_features,
  twoclass, weights

Furthermore, each learner has hyperparameters that control its behavior, for example the minimum number of samples in the leaf of a decision tree, or whether to provide verbose output durning training. Setting hyperparameters to values appropriate for a given machine learning task is crucial. The field param_set stores a description of the hyperparameters the learner has, their ranges, defaults, and current values:

learner$param_set
<ParamSet>
                id    class lower upper nlevels        default value
 1:             cp ParamDbl     0     1     Inf           0.01      
 2:     keep_model ParamLgl    NA    NA       2          FALSE      
 3:     maxcompete ParamInt     0   Inf     Inf              4      
 4:       maxdepth ParamInt     1    30      30             30      
 5:   maxsurrogate ParamInt     0   Inf     Inf              5      
 6:      minbucket ParamInt     1   Inf     Inf <NoDefault[3]>      
 7:       minsplit ParamInt     1   Inf     Inf             20      
 8: surrogatestyle ParamInt     0     1       2              0      
 9:   usesurrogate ParamInt     0     2       3              2      
10:           xval ParamInt     0   Inf     Inf             10     0

The set of current hyperparameter values is stored in the values field of the param_set field. You can access and change the current hyperparameter values by accessing this field, it is a named list:

learner$param_set$values
$xval
[1] 0
learner$param_set$values$cp = 0.01
learner$param_set$values
$xval
[1] 0

$cp
[1] 0.01
Tip

It is possible to assign all hyperparameters in one go by assigning a named list to $values: learner$param_set$values = list(cp = 0.01, xval = 0). However, be aware that this operation also removes all previously set hyperparameters.

The lrn() function also accepts additional arguments to update hyperparameters or set fields of the learner in one go:

learner = lrn("classif.rpart", id = "rp", cp = 0.001)
learner$id
[1] "rp"
learner$param_set$values
$xval
[1] 0

$cp
[1] 0.001

More on this is discussed in the section on Hyperparameter Tuning.