2.3 Learners
Objects of class Learner
provide a unified interface to many popular machine learning algorithms in R.
They consist of methods to train and predict a model for a Task
and provide meta-information about the learners, such as the hyperparameters you can set.
The base class of each learner is Learner
, specialized for regression as LearnerRegr
and for classification as LearnerClassif
.
Extension packages inherit from the Learner
base class, e.g. mlr3proba::LearnerSurv
or mlr3cluster::LearnerClust
.
In contrast to the Task
, the creation of a custom Learner is usually not required and a more advanced topic.
Hence, we refer the reader to Section 6.1 and proceed with an overview of the interface of already implemented learners.
All Learners work in a two-stage procedure:
- training step: The training data (features and target) is passed to the Learner’s
$train()
function which trains and stores a model, i.e. the relationship of target an feature. - predict step: A new slice of data, the inference data, is passed to the
$predict()
method of the Learner. The model trained in the first step is used to predict the missing target feature, e.g. labels for classification problems or the numerical outcome for regression problems.
2.3.1 Predefined Learners
The mlr3 package ships with the following minimal set of classification and regression learners to avoid unnecessary dependencies:
mlr_learners_classif.featureless
: Simple baseline classification learner (inheriting fromLearnerClassif
). In the defaults, it constantly predicts the label that is most frequent in the training set.mlr_learners_regr.featureless
: Simple baseline regression learner (inheriting fromLearnerRegr
). In the defaults, it constantly predicts the mean of the outcome in training setmlr_learners_classif.rpart
: Single classification tree from package rpart.mlr_learners_regr.rpart
: Single regression tree from package rpart.
This set of baseline learners is usually insufficient for a real data analysis. Thus, we have cherry-picked one implementation of the most popular machine learning method and connected them in the mlr3learners package:
- Linear and logistic regression
- Penalized Generalized Linear Models
- \(k\)-Nearest Neighbors regression and classification
- Kriging
- Linear and Quadratic Discriminant Analysis
- Naive Bayes
- Support-Vector machines
- Gradient Boosting
- Random Forests for regression, classification and survival
More machine learning methods and alternative implementations are collected in the mlr3extralearners repository.
A full list of implemented learners across all packages is given in this interactive list and also via mlr3extralearners::list_mlr3learners()
.
The latest build status of all learners is listed here.
To create an object for one of the predefined learners, you need to access the mlr_learners
Dictionary
which, similar to mlr_tasks
, gets automatically populated with more learners by extension packages.
# load most mlr3 packages to populate the dictionary
library("mlr3verse")
mlr_learners
## <DictionaryLearner> with 51 stored values
## Keys: classif.cv_glmnet, classif.debug, classif.featureless,
## classif.glmnet, classif.kknn, classif.lda, classif.log_reg,
## classif.multinom, classif.naive_bayes, classif.nnet, classif.qda,
## classif.ranger, classif.rpart, classif.svm, classif.xgboost,
## clust.agnes, clust.ap, clust.cmeans, clust.cobweb, clust.dbscan,
## clust.diana, clust.em, clust.fanny, clust.featureless, clust.ff,
## clust.kkmeans, clust.kmeans, clust.MBatchKMeans, clust.meanshift,
## clust.pam, clust.SimpleKMeans, clust.xmeans, dens.hist, dens.kde,
## regr.cv_glmnet, regr.featureless, regr.glmnet, regr.kknn, regr.km,
## regr.lm, regr.ranger, regr.rpart, regr.svm, regr.xgboost, surv.coxph,
## surv.cv_glmnet, surv.glmnet, surv.kaplan, surv.ranger, surv.rpart,
## surv.xgboost
Again, there is an alternative to writing down the lengthy mlr_learners$get()
part: lrn()
.
2.3.2 Learner API
Each learner provides the following meta-information:
feature_types
: the type of features the learner can deal with.packages
: the packages required to train a model with this learner and make predictions.properties
: additional properties and capabilities. For example, a learner has the property “missings” if it is able to handle missing feature values, and “importance” if it computes and allows to extract data on the relative importance of the features. A complete list of these is available in the mlr3 reference on regression learners and classification learners.predict_types
: possible prediction types. For example, a classification learner can predict labels (“response”) or probabilities (“prob”). For a complete list of possible predict types see the mlr3 reference.
A tabular overview of integrated learners can be found here.
You can get a specific learner using its id
, listed under key
in the dictionary:
= mlr_learners$get("classif.rpart")
learner print(learner)
## <LearnerClassifRpart:classif.rpart>
## * Model: -
## * Parameters: xval=0
## * Packages: rpart
## * Predict Type: response
## * Feature types: logical, integer, numeric, factor, ordered
## * Properties: importance, missings, multiclass, selected_features,
## twoclass, weights
The field param_set
stores a description of the hyperparameters the learner has, their ranges, defaults, and current values:
$param_set learner
## <ParamSet>
## id class lower upper levels default value
## 1: minsplit ParamInt 1 Inf 20
## 2: minbucket ParamInt 1 Inf <NoDefault[3]>
## 3: cp ParamDbl 0 1 0.01
## 4: maxcompete ParamInt 0 Inf 4
## 5: maxsurrogate ParamInt 0 Inf 5
## 6: maxdepth ParamInt 1 30 30
## 7: usesurrogate ParamInt 0 2 2
## 8: surrogatestyle ParamInt 0 1 0
## 9: xval ParamInt 0 Inf 10 0
## 10: keep_model ParamLgl NA NA TRUE,FALSE FALSE
The set of current hyperparameter values is stored in the values
field of the param_set
field.
You can change the current hyperparameter values by assigning a named list to this field:
$param_set$values = list(cp = 0.01, xval = 0)
learner learner
## <LearnerClassifRpart:classif.rpart>
## * Model: -
## * Parameters: cp=0.01, xval=0
## * Packages: rpart
## * Predict Type: response
## * Feature types: logical, integer, numeric, factor, ordered
## * Properties: importance, missings, multiclass, selected_features,
## twoclass, weights
Note that this operation just overwrites all previously set parameters. If you just want to add a new hyperparameter, retrieve the current set of parameter values, modify the named list and write it back to the learner:
= learner$param_set$values
pv $cp = 0.02
pv$param_set$values = pv learner
This updates cp
to 0.02
and keeps the previously set parameter xval
.
Note that if you use the lrn()
function, you can construct learners and simultaneously add hyperparameters or change the identifier in one go:
lrn("classif.rpart", id = "rp", cp = 0.001)
## <LearnerClassifRpart:rp>
## * Model: -
## * Parameters: xval=0, cp=0.001
## * Packages: rpart
## * Predict Type: response
## * Feature types: logical, integer, numeric, factor, ordered
## * Properties: importance, missings, multiclass, selected_features,
## twoclass, weights