Here, we show how to create a custom LearnerClassif step-by-step.

As a starting point please checkout the mlr3learnertemplate repo. Alternatively, here is a template snippet for a new classification learner:

Learner<type><algorithm> = R6Class("Learner<type><algorithm>",
inherit = Learner<type>,

public = list(

#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function() {
ps = ParamSet$new( params = list( <params> ) ) super$initialize(
id = "<type>.<algorithm>",
packages = "<package>",
feature_types = "<feature types>"
predict_types = "<predict types>"
param_set = ps,
properties = "<properties>",
man = "<pkgname>::<help file name>"
)
},

private = list(

.train = function()
.predict = function()

In the first line of the template, we create a new R6 class with class "LearnerClassifYourLearner". The next line determines the parent class: As we want to create a classification learner, we want to inherit from LearnerClassif.

A learner consists of three parts:

1. Meta information about the learners
2. A .train() function which takes a (filtered) TaskClassif and returns a model
3. A .predict() function which operates on the model in self$model (stored during $train()) and a (differently subsetted) TaskClassif to return a named list of predictions.

### 6.1.1 Meta-information

In the constructor (initialize()) the constructor of the super class (e.g. LearnerClassif) is called with meta information about the learner which should be constructed. This includes:

• id: The id of the new learner.
• packages: Set of required packages to run the learner.
• param_set: A set of hyperparameters and their description, provided as paradox::ParamSet. It is perfectly fine to add no parameters here for a first draft. For each hyperparameter you want to add, you have to select the appropriate class:
• predict_types: Set of predict types the learner is capable of. These differ depending on the type of the learner.
• LearnerClassif
• response: Only predicts a class label for each observation in the test set.
• prob: Also predicts the posterior probability for each class for each observation in the test set.
• LearnerRegr
• response: Only predicts a numeric response for each observation in the test set.
• se: Also predicts the standard error for each value of response for each observation in the test set.
• feature_types: Set of feature types the learner can handle. See mlr_reflections$task_feature_types for feature types supported by mlr3. • properties: Set of properties of the learner. Possible properties include: • "twoclass": The learner works on binary classification problems. • "multiclass": The learner works on multi-class classification problems. • "missings": The learner can natively handle missing values. • "weights": The learner can work on tasks which have observation weights / case weights. • "parallel": The learner can be parallelized, e.g. via threading. • "importance": The learner supports extracting importance values for features. If this property is set, you must also implement a public method importance() to retrieve the importance values from the model. • "selected_features": The learner supports extracting the features which where used. If this property is set, you must also implement a public method selected_features() to retrieve the set of used features from the model. • man: The roxygen identifier of the learner. This is used within the $help() method of the super class to open the help page of the learner.

For a simplified rpart::rpart(), the initialization could look like this:

initialize = function(id = "classif.rpart") {
ps = ParamSet$new(list( ParamDbl$new(id = "cp", default = 0.01, lower = 0, upper = 1, tags = "train"),
ParamInt$new(id = "xval", default = 10L, lower = 0L, tags = "train") )) ps$values = list(xval = 0L)

super$initialize( id = id, packages = "rpart", feature_types = c("logical", "integer", "numeric", "factor"), predict_types = c("response", "prob"), param_set = ps, properties = c("twoclass", "multiclass", "weights", "missings") man = "mlr3learners.rpart::mlr_learners_classif.rpart" ) } We only have specified a small subset of the available hyperparameters: • The complexity "cp" is numeric, its feasible range is [0,1], it defaults to 0.01 and the parameter is used during "train". • The complexity "xval" is integer, its lower bound 0, its default is 0 and the parameter is also used during "train". Note that we have changed the default here from 10 to 0 to save some computation time. This is not done by setting a different default in ParamInt$new(), but instead by setting the value explicitly.

### 6.1.2 Train function

We continue the to adept the template for a rpart::rpart() learner, and now tackle the .train() function. The train function takes a Task as input and must return an arbitrary model. First, we write something down that works completely without mlr3:

data = iris
model = rpart::rpart(Species ~ ., data = iris, xval = 0)

In the next step, we replace the data frame data with a Task:

task = tsk("iris")
model = rpart::rpart(Species ~ ., data = task$data(), xval = 0) The target variable "Species" is still hard-coded and specific to the task. This is unnecessary, as the information about the target variable is stored in the task: task$target_names
## [1] "Species"
task$formula() ## Species ~ . ## NULL We can adapt our code accordingly: rpart::rpart(task$formula(), data = task$data(), xval = 0) ## n= 150 ## ## node), split, n, loss, yval, (yprob) ## * denotes terminal node ## ## 1) root 150 100 setosa (0.33333 0.33333 0.33333) ## 2) Petal.Length< 2.45 50 0 setosa (1.00000 0.00000 0.00000) * ## 3) Petal.Length>=2.45 100 50 versicolor (0.00000 0.50000 0.50000) ## 6) Petal.Width< 1.75 54 5 versicolor (0.00000 0.90741 0.09259) * ## 7) Petal.Width>=1.75 46 1 virginica (0.00000 0.02174 0.97826) * The last thing missing is the handling of hyperparameters. Instead of the hard-coded xval, we query the hyperparameter settings from the Learner itself. To illustrate this, we quickly construct the tree learner from the mlr3 package, and use the method get_value() from the ParamSet to retrieve all set hyperparameters with tag "train". self = lrn("classif.rpart") self$param_set$get_values(tags = "train") ##$xval
## [1] 0

To pass all hyperparameters down to the model fitting function, we recommend to use mlr3misc::invoke() which is a better version of do.call().

pars = self$param_set$get_values(tags = "train")
mlr3misc::invoke(rpart::rpart, task$formula(), data = task$data(), .args = pars)
## n= 150
##
## node), split, n, loss, yval, (yprob)
##       * denotes terminal node
##
## 1) root 150 100 setosa (0.33333 0.33333 0.33333)
##   2) Petal.Length< 2.45 50   0 setosa (1.00000 0.00000 0.00000) *
##   3) Petal.Length>=2.45 100  50 versicolor (0.00000 0.50000 0.50000)
##     6) Petal.Width< 1.75 54   5 versicolor (0.00000 0.90741 0.09259) *
##     7) Petal.Width>=1.75 46   1 virginica (0.00000 0.02174 0.97826) *

In the final learner, self will of course reference the learner itself. In the last step, we wrap everything in a function.

.train = function(task) {
pars = self$param_set$get_values(tags = "train")
mlr3misc::invoke(rpart::rpart, task$formula(), data = task$data(), .args = pars)
}

### 6.1.3 Predict function

The internal predict function .predict() also operates on a Task as well as on the model stored during train() in self$model. The return value is a Prediction object. We proceed analogously to the section on the train function. We start with a version without any mlr3 objects and continue to replace objects until we have reached the desired interface: # inputs: task = tsk("iris") self = list(model = rpart::rpart(task$formula(), data = task$data())) data = iris response = predict(self$model, newdata = data, type = "class")
prob = predict(self$model, newdata = data, type = "prob") The rpart::predict.rpart() function predicts class labels if argument type is set to to "class", and class probabilities if set to "prob". Next, we transition from data to a task again and construct a proper PredictionClassif object to return. Additionally, as we do not want to run the prediction twice, we differentiate what type of prediction is requested by querying the set predict type of the learner. The complete .predict() function looks like this: .predict = function(task) { self$predict_type = "response"
response = prob = NULL

if (self$predict_type == "response") { response = predict(self$model, newdata = task$data(), type = "class") } else { prob = predict(self$model, newdata = task$data(), type = "prob") } PredictionClassif$new(task, response = response, prob = prob)
}

Note that if the learner would need to handle hyperparameters during the predict step, we would proceed analogously to the train() step and use self$params("predict") in combination with mlr3misc::invoke(). Also note that you cannot rely on the column order of the data returned by task$data(), i.e. the order of columns may be different from the order of the columns during $train(). You have to make sure that your learner accesses columns by name, not by position (like some algorithms with a matrix interface do). You may have to restore the order manually here, see “classif.svm” for an example. ### 6.1.4 Control objects/functions of learners Some learners rely on a “control” object/function such as glmnet::glmnet.control(). To account for such, add the parameters of the control function as parameters to the ParamSet and tag them with “control”. Then, import them into the .train() method as follows: control = mlr3misc::(<package>::<function>, self$param_set$get_values(tags = "control")) mlr3misc::invoke([...], control = control) ### 6.1.5 Testing the learner To run some basic tests: task = tsk("iris") # assuming a Classif learner lrn$train(task)
p = lrn$predict(task) p$confusion

To run a bunch of automatic tests stored in the mlr3 package call devtools::test() after you updated the scripts in tests/ with the correct names.