12.1 Learners

Here, we show how to create a custom LearnerClassif step-by-step.

Preferably, you start by copying over code from an existing Learner, e.g. from the "classif.rpart learner on GitHub. Alternatively, here is a template for a new classification learner:

LearnerClassifYourLearner = R6::R6Class("LearnerClassifYourLearner",
  inherit = LearnerClassif,
  public = list(
    initialize = function(id = "classif.yourlearner") {
      super$initialize(
        id = id,
        param_set = ParamSet$new(),
        param_vals = list()
        predict_types = ,
        feature_types = ,
        properties = ,
        packages = ,
      )
    },

    train = function(task) {

    },
    predict = function(task) {

    }
  )
)

In the first line of the template, we create a new R6 class with class "LearnerClassifYourLearner". The next line determines the parent class: As we want to create a classification learner, we obviously want to inherit from LearnerClassif.

A learner consists of three parts:

  1. Meta information about the learners
  2. A train() function which takes a TaskClassif, fits a model, stores the model in self$model and returns the learner itself.
  3. A predict() function which operates on the stored model in self$model stored during train() and a (differently subsetted) TaskClassif to calculate predictions for each observation in the provided task.

12.1.1 Meta-information

In the constructor function initialize() the constructor of the super class LearnerClassif is called with meta information about the leaner we want to construct. This includes:

  • id: The id of the new learner.
  • param_set: A set of hyperparameters and their description, provided as paradox::ParamSet.
  • param_vals: Default hyperparameter settings as named list.
  • predict_types: Set of predict types the learner is capable of. For classification, this must be a subset of “response”, “c(”response“,”prob“)”. See mlr_reflection$learner_predict_types for possible predict types of other tasks.
  • feature_types: Set of feature types the learner can handle. See mlr_reflections$task_feature_types for feature types supported by mlr3.
  • properties: Set of properties of the learner. Possible properties include:
    • "twoclass": The learner works on binary classification problems.
    • "multiclass": The learner works on multi-class classification problems.
    • "missings": The learner can natively handle missing values.
    • "weights": The learner can work on tasks which have observation weights / case weights.
    • "parallel": The learner can be parallelized, e.g. via threading.
    • "importance": The learner supports extracting importance values for features. If this property is set, you must also implement a public method importance() to retrieve the importance values from the model.
    • "selected features": The learner supports extracting the features which where used. If this property is set, you must also implement a public method selected_features() to retrieve the set of used features from the model.
  • Set of required packages to run the learner.

For a simplified rpart::rpart(), the initialization could look like this:

initialize = function(id = "classif.rpart") {
    super$initialize(
        id = id,
        packages = "rpart",
        feature_types = c("logical", "integer", "numeric", "factor"),
        predict_types = c("response", "prob"),
        param_set = ParamSet$new(
            params = list(
                ParamDbl$new(id = "cp", default = 0.01, lower = 0, upper = 1, tags = "train"),
                ParamInt$new(id = "xval", default = 0L, lower = 0L, tags = "train")
            )
        ),
        param_vals = list(xval = 0L),
        properties = c("twoclass", "multiclass", "weights", "missings")
    )
}

We only have specified a small subset of the available hyperparameters:

  • The complexity "cp" is numeric, its feasible range is [0,1], it defaults to 0.01 and the parameter is used during "train".
  • The complexity "xval" is integer, its lower bound 0, its default is 0 and the parameter is also used during "train". Note that we have changed the default here from 10 to 0 to save some computation time. This is not done by setting a different default in ParamInt$new(), but instead by setting the value implicitly via param_vals.

12.1.2 Train function

We continue the to adept the template for a rpart::rpart() learner, and now tackle the train() function. The train function takes a Task as input and must return an arbitrary model. First, we write something down that works completely without mlr3:

data = iris
model = rpart::rpart(Species ~ ., data = iris, xval = 0)

In the next step, we replace the data frame data with a Task:

task = mlr_tasks$get("iris")
model = rpart::rpart(Species ~ ., data = task$data(), xval = 0)

The target variable "Species" is still hard-coded and specific to the task. This is unnecessary, as the information about the target variable is stored in the task:

task$target_names
## [1] "Species"
task$formula()
## Species ~ .
## NULL

We can adapt our code accordingly:

model = rpart::rpart(task$formula(), data = task$data(), xval = 0)

The last thing missing is the handling of hyperparameters. Instead of the hard-coded xval, we query the hyperparameter settings from the Learner itself.

To illustrate this, we quickly construct the rpart learner from the mlr3 package, and use the method params() to retrieve all actively set hyperparameters with tag "train".

self = mlr_learners$get("classif.rpart")
self$params("train")
## $xval
## [1] 0

To pass all hyperparameters returned by the params() method to the learner function, we recommend to use either do.call or the function mlr3misc::invoke().

pars = self$params("train")
model = mlr3misc::invoke(rpart::rpart, task$formula(),
    data = task$data(), .args = pars)

In the final learner, self will reference the learner itself. In the last step, we wrap everything in a function. The model is stored in the learner self and self is returned.

train = function(task) {
    pars = self$params("train")
    self$model = mlr3misc::invoke(rpart::rpart, task$formula(),
        data = task$data(), .args = pars)
    self
}

12.1.3 Predict function

The predict function also operates on a Task as well as on the model stored during train(). The return value must be a named list, where each list element corresponds to a predict type.

We proceed analogously to the section on the train function: We start with a version without any mlr3 objects and continue to replace objects until we have reached the desired interface:

# inputs:
task = mlr_tasks$get("iris")
self = list(model = rpart::rpart(task$formula(), data = task$data()))

data = iris
response = predict(self$model, newdata = data, type = "class")
prob = predict(self$model, newdata = data, type = "prob")

The rpart::predict.rpart() function predicts class labels if argument type is set to to "class", and class probabilities if set to "prob".

Next, we transition from data to a task again. Additionally, as we do not want to run the prediction twice, we differentiate what type of prediction is requested by querying the set predict type of the learner.

self$predict_type = "response"

if (self$predict_type == "response") {
  list(response = predict(self$model, newdata = task$data(), type = "class"))
} else {
  list(prob = predict(self$model, newdata = task$data(), type = "prob"))
}

The complete predict() function then looks like this:

predict = function(task) {
  if (self$predict_type == "response") {
    list(response = predict(self$model, newdata = task$data(), type = "class"))
  } else {
    list(prob = predict(self$model, newdata = task$data(), type = "prob"))
  }
}

Note that if the learner would need to handle hyperparameters during the predict step, we would proceed accordingly to the train() step and use self$params("predict") in combination with mlr3misc::invoke().

12.1.4 Final learner

LearnerClassifYourRpart = R6::R6Class("LearnerClassifYourRpart",
  inherit = LearnerClassif,
  public = list(
    initialize = function(id = "classif.rpart") {
      super$initialize(
        id = id,
        packages = "rpart",
        feature_types = c("logical", "integer", "numeric", "factor"),
        predict_types = c("response", "prob"),
        param_set = paradox::ParamSet$new(
          params = list(
            paradox::ParamDbl$new(id = "cp", default = 0.01, lower = 0, upper = 1, tags = "train"),
            paradox::ParamInt$new(id = "xval", default = 0L, lower = 0L, tags = "train")
          )
          ),
        param_vals = list(xval = 0L),
        properties = c("twoclass", "multiclass", "weights", "missings")
      )
    },

    train = function(task) {
      pars = self$params("train")
      self$model = mlr3misc::invoke(rpart::rpart, task$formula(),
        data = task$data(), .args = pars)
      self
    },

    predict = function(task) {
      if (self$predict_type == "response") {
        list(response = predict(self$model, newdata = task$data(), type = "class"))
      } else {
        list(prob = predict(self$model, data = newtask$data(), type = "prob"))
      }
    }
  )
)

lrn = LearnerClassifYourRpart$new()
print(lrn)
## <LearnerClassifYourRpart:classif.rpart>
## Parameters: xval=0
## Packages: rpart
## Predict Type: response
## Feature types: logical, integer, numeric, factor
## Properties: missings, multiclass, twoclass, weights

To run some basic tests:

task = mlr_tasks$get("iris")
e = Experiment$new(task, lrn)$train()$predict()$score()

To run a bunch of automatic tests, you may source some auxiliary scripts from the unit tests of mlr3:

helper = list.files(system.file("testthat", package = "mlr3"), pattern = "^helper.*\\.[rR]", full.names = TRUE)
ok = lapply(helper, source)
stopifnot(run_autotest(lrn))