6.1 Adding new Learners

Here, we show how to create a custom mlr3learner step-by-step using mlr3extralearners::create_learner.

This section gives insights on how a mlr3learner is constructed and how to troubleshoot issues. See the Learner FAQ subsection for help.

Summary of steps for adding a new learner

  1. Check the learner does not already exist here.
  2. (Install mlr3extralearners from GitHub: remotes::install_github("mlr-org/mlr3extralearners").)
  3. Run mlr3extralearners::create_learner.
  4. Add the learner param_set.
  5. Manually add .train and .predict private methods to the learner.
  6. If applicable add importance and oob_error public methods to the learner.
  7. If applicable add references to the learner.
  8. Check unit tests and paramtests pass (these are automatically created).
  9. Run cleaning functions
  10. Open a pull request with the new learner template.

(Do not copy/paste the code shown in this section. Use the create_learner to start.)

6.1.1 Calling create_learner

The learner classif.rpart will be used as a running example throughout this section.

library(mlr3extralearners)
create_learner(
  pkg = ".",
  classname = "Rpart",
  algorithm = "decision tree",
  type = "classif",
  key = "rpart",
  package = "rpart",
  caller = "rpart",
  feature_types = c("logical", "integer", "numeric", "factor", "ordered"),
  predict_types = c("response", "prob"),
  properties = c("importance", "missings", "multiclass", "selected_features", "twoclass", "weights"),
  references = TRUE,
  gh_name = "RaphaelS1"
)

The full documentation for the function arguments is in mlr3extralearners::create_learner, in this example we are doing the following:

  1. pkg = "." - Set the package root to the current directory.
  2. classname = "Rpart" - Set the R6 class name to LearnerClassifRpart (classif is below)
  3. algorithm = "decision tree" - Create the title as “Classification Decision Tree Learner”, where “Classification” is determined automatically from type and “Learner” is added for all learners.
  4. type = "classif" - Setting the learner as a classification learner, automatically filling the title, class name, id (“classif.rpart”) and task type.
  5. key = "rpart" - Used with type to create the unique ID of the learner, classif.rpart.
  6. package = "rpart" - Setting the package from which the learner is implemented, this fills in things like the training function (along with caller) and the man field.
  7. caller = "rpart" - This tells the .train function, and the description which function is called to run the algorithm, with package this automatically fills rpart::rpart.
  8. feature_types = c("logical", "integer", "numeric", "factor", "ordered") - Sets the type of features that can be handled by the learner. See meta information.
  9. predict_types = c("response", "prob"), - Sets the possible prediction types as response (deterministic) and prob (probabilistic). See meta information.
  10. properties = c("importance", "missings", "multiclass", "selected_features", "twoclass", "weights") - Sets the properties that are handled by the learner, by including "importance" a public method called importance will be created that must be manually filled. See meta information.
  11. references = TRUE - Tells the template to add a “references” tag that must be filled manually.
  12. gh_name = "RaphaelS1" - Fills the “author” tag with my GitHub handle, this is required as it identifies the maintainer of the learner.

The sections below demonstrate what happens after the function has been run and the files that are created.

6.1.2 learner_package_type_key.R

The first script to complete after running create_learner is the file with the form learner_package_type_key.R, in our case this will actually be learner_rpart_classif_rpart.key. This name must not be changed as triggering automated tests rely on a strict naming scheme. For our example, the resulting script looks like this:

#' @title Classification Decision Tree Learner
#' @author RaphaelS1
#' @name mlr_learners_classif.rpart
#'
#' @template class_learner
#' @templateVar id classif.rpart
#' @templateVar caller rpart
#'
#' @references
#' <FIXME - DELETE THIS AND LINE ABOVE IF OMITTED>
#'
#' @template seealso_learner
#' @template example
#' @export
LearnerClassifRpart = R6Class("LearnerClassifRpart",
  inherit = LearnerClassif,

  public = list(
    #' @description
    #' Creates a new instance of this [R6][R6::R6Class] class.
    initialize = function() {
      # FIXME - MANUALLY ADD PARAM_SET BELOW AND THEN DELETE THIS LINE
      ps = <param_set>

      # FIXME - MANUALLY UPDATE PARAM VALUES BELOW IF APPLICABLE THEN DELETE THIS LINE.
      # OTHERWISE DELETE THIS AND LINE BELOW.
      ps$values = list(<param_vals>)

      super$initialize(
        id = "classif.rpart",
        packages = "rpart",
        feature_types = c("logical", "integer", "numeric", "factor", "ordered"),
        predict_types = c("response", "prob"),
        param_set = ps,
        properties = c("importance", "missings", "multiclass", "selected_features", "twoclass", "weights"),
        man = "mlr3extralearners::mlr_learners_classif.rpart"
      )
    },

    # FIXME - ADD IMPORTANCE METHOD HERE AND DELETE THIS LINE.
    # <See LearnerRegrRandomForest for an example>
    #' @description
    #' The importance scores are extracted from the slot <FIXME>.
    #' @return Named `numeric()`.
    importance = function() { }

  ),

  private = list(

    .train = function(task) {
      pars = self$param_set$get_values(tags = "train")

      # set column names to ensure consistency in fit and predict
      self$state$feature_names = task$feature_names

      # FIXME - <Create objects for the train call
      # <At least "data" and "formula" are required>
      formula = task$formula()
      data = task$data()

      # FIXME - <here is space for some custom adjustments before proceeding to the
      # train call. Check other learners for what can be done here>

      # use the mlr3misc::invoke function (it's similar to do.call())
      mlr3misc::invoke(rpart::rpart,
                       formula = formula,
                       data = data,
                       .args = pars)
    },

    .predict = function(task) {
      # get parameters with tag "predict"
      pars = self$param_set$get_values(tags = "predict")
      # get newdata
      newdata = task$data(cols = task$feature_names)

      pred = mlr3misc::invoke(predict, self$model, newdata = newdata,
                              type = type, .args = pars)

      # FIXME - ADD PREDICTIONS TO LIST BELOW
      list(...)
    }
  )
)

.extralrns_dict$add("classif.rpart", LearnerClassifRpart)

Now we have to do the following (from top to bottom):

  1. Fill in the references under “references” and delete the tag that starts “FIXME”
  2. Replace <param_set> with a parameter set
  3. Optionally change default values for parameters in <param_vals>
  4. As we included “importance” in properties we have to add a function to the public method importance
  5. Fill in the private .train method, which takes a (filtered) Task and returns a model.
  6. Fill in the private .predict method, which operates on the model in self$model (stored during $train()) and a (differently subsetted) Task to return a named list of predictions.

6.1.3 Meta-information

In the constructor (initialize()) the constructor of the super class (e.g. LearnerClassif) is called with meta information about the learner which should be constructed. This includes:

  • id: The ID of the new learner. Usually consists of <type>.<algorithm>, for example: "classif.rpart".
  • packages: The upstream package name of the implemented learner.
  • param_set: A set of hyperparameters and their descriptions provided as a paradox::ParamSet. For each hyperparameter the appropriate class needs to be chosen:
  • predict_types: Set of predict types the learner is able to handle. These differ depending on the type of the learner. See mlr_reflections$learner_predict_types for the full list of feature types supported by mlr3.
    • LearnerClassif
      • response: Only predicts a class label for each observation in the test set.
      • prob: Also predicts the posterior probability for each class for each observation in the test set.
    • LearnerRegr
      • response: Only predicts a numeric response for each observation in the test set.
      • se: Also predicts the standard error for each value of response for each observation in the test set.
  • feature_types: Set of feature types the learner is able to handle. See mlr_reflections$task_feature_types for feature types supported by mlr3.
  • properties: Set of properties of the learner. See mlr_reflections$learner_properties for the full list of feature types supported by mlr3. Possible properties include:
    • "twoclass": The learner works on binary classification problems.
    • "multiclass": The learner works on multi-class classification problems.
    • "missings": The learner can natively handle missing values.
    • "weights": The learner can work on tasks which have observation weights / case weights.
    • "parallel": The learner supports internal parallelization in some way. Currently not used, this is an experimental property.
    • "importance": The learner supports extracting importance values for features. If this property is set, you must also implement a public method importance() to retrieve the importance values from the model.
    • "selected_features": The learner supports extracting the features which where used. If this property is set, you must also implement a public method selected_features() to retrieve the set of used features from the model.
  • man: The roxygen identifier of the learner. This is used within the $help() method of the super class to open the help page of the learner.

6.1.4 ParamSet

The param_set is the set of hyperparameters used in model training and predicting, this is given as a paradox::ParamSet. The set consists of a list of hyperparameters, each has a specific class for the hyperparameter type:

For classif.rpart the following replace <param_set> above:

     ps = ParamSet$new(list(
        ParamInt$new(id = "minsplit", default = 20L, lower = 1L, tags = "train"),
        ParamInt$new(id = "minbucket", lower = 1L, tags = "train"),
        ParamDbl$new(id = "cp", default = 0.01, lower = 0, upper = 1, tags = "train"),
        ParamInt$new(id = "maxcompete", default = 4L, lower = 0L, tags = "train"),
        ParamInt$new(id = "maxsurrogate", default = 5L, lower = 0L, tags = "train"),
        ParamInt$new(id = "maxdepth", default = 30L, lower = 1L, upper = 30L, tags = "train"),
        ParamInt$new(id = "usesurrogate", default = 2L, lower = 0L, upper = 2L, tags = "train"),
        ParamInt$new(id = "surrogatestyle", default = 0L, lower = 0L, upper = 1L, tags = "train"),
        ParamInt$new(id = "xval", default = 0L, lower = 0L, tags = "train"),
        ParamLgl$new(id = "keep_model", default = FALSE, tags = "train")
      ))
      ps$values = list(xval = 0L)

You should read though the learner documentation to find the full list of available parameters. Just looking at some of these in this example:

  • "cp" is numeric, has a feasible range of [0,1] and defaults to 0.01. The parameter is used during "train".
  • "xval" is integer has a lower bound of 0, a default of 0 and the parameter is used during "train".
  • "keep_model" is logical with a default of FALSE and is used during "train".

In some rare cases you may want to change the default parameter values. You can do this by passing a list to <param_vals> in the template script above. You can see we have done this for "classif.rpart" where the default for xval is changed to 0. Note that the default in the ParamSet is recorded as our changed default (0), and not the original (10). It is strongly recommended to only change the defaults if absolutely required, when this is the case add the following to the learner documentation:

#' @section Custom mlr3 defaults:
#' - `<parameter>`:
#'   - Actual default: <value>
#'   - Adjusted default: <value>
#'   - Reason for change: <text>

6.1.5 Train function

Let’s talk about the .train() method. The train function takes a Task as input and must return a model.

Let’s say we want to translate the following call of rpart::rpart() into code that can be used inside the .train() method.

First, we write something down that works completely without mlr3:

data = iris
model = rpart::rpart(Species ~ ., data = iris, xval = 0)

We need to pass the formula notation Species ~ ., the data and the hyperparameters. To get the hyperparameters, we call self$param_set$get_values() and query all parameters that are using during "train".

The dataset is extracted from the Task.

Last, we call the upstream function rpart::rpart() with the data and pass all hyperparameters via argument .args using the mlr3misc::invoke() function. The latter is simply an optimized version of do.call() that we use within the mlr3 ecosystem.

.train = function(task) {
  pars = self$param_set$get_values(tags = "train")
  formula = task$formula()
  data = task$data()
  mlr3misc::invoke(rpart::rpart,
                   formula = formula,
                   data = data,
                   .args = pars)
}

6.1.6 Predict function

The internal predict method .predict() also operates on a Task as well as on the fitted model that has been created by the train() call previously and has been stored in self$model.

The return value is a Prediction object. We proceed analogously to what we did in the previous section. We start with a version without any mlr3 objects and continue to replace objects until we have reached the desired interface:

# inputs:
task = tsk("iris")
self = list(model = rpart::rpart(task$formula(), data = task$data()))

data = iris
response = predict(self$model, newdata = data, type = "class")
prob = predict(self$model, newdata = data, type = "prob")

The rpart::predict.rpart() function predicts class labels if argument type is set to to "class", and class probabilities if set to "prob".

Next, we transition from data to a task again and construct a list with the return type requested by the user, this is stored in the $predict_type slot of a learner class. Note that the task is automatically passed to the prediction object, so all you need to do is return the predictions! Make sure the list names are identical to the task predict types.

The final .predict() method is below, we could omit the pars line as there are no parameters with the "predict" tag but we keep it here to be consistent:

.predict = function(task) {
  pars = self$param_set$get_values(tags = "predict")
  # get newdata and ensure same ordering in train and predict
  newdata = task$data(cols = self$state$feature_names)
  if (self$predict_type == "response") {
    response = mlr3misc::invoke(predict,
                            self$model,
                            newdata = newdata,
                            type = "class",
                            .args = pars)

    return(list(response = response))
  } else {
    prob = mlr3misc::invoke(predict,
                            self$model,
                            newdata = newdata,
                            type = "prob",
                            .args = pars)
    return(list(prob = prob))
  }
}

Note that you cannot rely on the column order of the data returned by task$data() as the order of columns may be different from the order of the columns during $.train. The newdata line ensures the ordering is the same by calling the saved order set in $.train, don’t delete either of these lines!

6.1.7 Control objects/functions of learners

Some learners rely on a “control” object/function such as glmnet::glmnet.control(). Accounting for such depends on how the underlying package works:

  • If the package forwards the control parameters via ... and makes it possible to just pass control parameters as additional parameters directly to the train call, there is no need to distinguish both "train" and "control" parameters. Both can be tagged with “train” in the ParamSet and just be handed over as shown previously.
  • If the control parameters need to be passed via a separate argument, the parameters should also be tagged accordingly in the ParamSet. Afterwards they can be queried via their tag and passed separately to mlr3misc::invoke(). See example below.
control_pars = mlr3misc::(<package>::<function>,
   self$param_set$get_values(tags = "control"))

train_pars = self$param_set$get_values(tags = "train"))

mlr3misc::invoke([...], .args = train_pars, control = control_pars)

6.1.8 Testing the learner

Once your learner is created, you are ready to start testing if it works, there are three types of tests: manual, unit and parameter.

6.1.8.1 Train and Predict

For a bare-bone check you can just try to run a simple train() call locally.

task = tsk("iris") # assuming a Classif learner
lrn = lrn("classif.rpart")
lrn$train(task)
p = lrn$predict(task)
p$confusion

If it runs without erroring, that’s a very good start!

6.1.8.2 Autotest

To ensure that your learner is able to handle all kinds of different properties and feature types, we have written an “autotest” that checks the learner for different combinations of such.

The “autotest” setup is generated automatically by create_learner and will open after running the function, it will have a name with the form test_package_type_key.R, in our case this will actually be test_rpart_classif_rpart.key. This name must not be changed as triggering automated tests rely on a strict naming scheme. In our example this will create the following script, for which no changes are required to pass (assuming the learner was correctly created):

install_learners("classif.rpart")

test_that("autotest", {
  learner = LearnerClassifRpart$new()
  expect_learner(learner)
  result = run_autotest(learner)
  expect_true(result, info = result$error)
})

For some learners that have required parameters, it is needed to set some values for required parameters after construction so that the learner can be run in the first place.

You can also exclude some specific test arrangements within the “autotest” via the argument exclude in the run_autotest() function. Currently the run_autotest() function lives in inst/testthat of the mlr_plkg("mlr3") and still lacks documentation. This should change in the near future.

To finally run the test suite, call devtools::test() or hit CTRL + Shift + T if you are using RStudio.

6.1.8.3 Checking Parameters

Some learners have a high number of parameters and it is easy to miss out on some during the creation of a new learner. In addition, if the maintainer of the upstream package changes something with respect to the arguments of the algorithm, the learner is in danger to break. Also, new arguments could be added upstream and manually checking for new additions all the time is tedious.

Therefore we have written a “Parameter Check” that runs for every learner asynchronously to the R CMD Check of the package itself. This “Parameter Check” compares the parameters of the mlr3 ParamSet against all arguments available in the upstream function that is called during $train() and $predict(). Again the file is automatically created and opened by create_learner, this will be named like test_paramtest_package_type_key.R, so in our example test_paramtest_rpart_classif_rpart.R.

The test comes with an exclude argument that should be used to exclude and explain why certain arguments of the upstream function are not within the ParamSet of the mlr3learner. This will likely be required for all learners as common arguments like x, target or data are handled by the mlr3 interface and are therefore not included within the ParamSet.

However, there might be more parameters that need to be excluded, for example:

  • Type dependent parameters, i.e. parameters that only apply for classification or regression learners.
  • Parameters that are actually deprecated by the upstream package and which were therefore not included in the mlr3 ParamSet.

All excluded parameters should have a comment justifying their exclusion.

In our example, the final paramtest script looks like:

library(mlr3extralearners)
install_learners("classif.rpart")

test_that("classif.rpart train", {
  learner = lrn("classif.rpart")
  fun = rpart::rpart
  exclude = c(
    "formula",# handled internally
    "model", # handled internally
    "data", # handled internally
    "weights", # handled by task
    "subset", # handled by task
    "na.action", # handled internally
    "method", # handled internally
    "x", # handled internally
    "y", # handled internally
    "parms", # handled internally
    "control", # handled internally
    "cost" # handled internally
  )

  ParamTest = run_paramtest(learner, fun, exclude)
  expect_true(ParamTest, info = paste0(
    "Missing parameters:",
    paste0("- '", ParamTest$missing, "'", collapse = "
")))
})

test_that("classif.rpart predict", {
  learner = lrn("classif.rpart")
  fun = rpart:::predict.rpart
    exclude = c(
      "object", # handled internally
      "newdata", # handled internally
      "type", # handled internally
      "na.action" # handled internally
    )

  ParamTest = run_paramtest(learner, fun, exclude)
  expect_true(ParamTest, info = paste0(
    "Missing parameters:",
    paste0("- '", ParamTest$missing, "'", collapse = "
")))
})

6.1.9 Package Cleaning

Once all tests are passing, run the following functions to ensure that the package remains clean and tidy

  1. devtools::document(roclets = c('rd', 'collate', 'namespace'))
  2. If you haven’t done this before run: remotes::install_github('pat-s/styler@mlr-style')
  3. styler::style_pkg(style = styler::mlr_style)
  4. usethis::use_tidy_description()
  5. lintr::lint_package()

Please fix any errors indicated by lintr before creating a pull request. Finally ensure that all FIXME are resolved and deleted in the generated files.

You are now ready to add your learner to the mlr3 ecosystem! Simply open a pull request to with the new learner template and complete the checklist in there. Once the pull request is approved and merged, your learner will automatically appear on the package website.

6.1.10 Thanks and Maintenance

Thank you for contributing to the mlr3 ecosystem!

When you created the learner you would have given your GitHub handle, meaning that you are now listed as the learner author and maintainer. This means that if the learner breaks it is your responsibility to fix the learner - you can view the status of your learner here.

6.1.11 Learner FAQ

Question 1

How to deal with Parameters which have no default?

Answer

If the learner does not work without providing a value, set a reasonable default in param_set$values, add tag "required" to the parameter and document your default properly.

Question 2

Where to add the package of the upstream package in the DESCRIPTION file?

Add it to the “Imports” section. This will install the upstream package during the installation of the mlr3learner if it has not yet been installed by the user.

Question 3

How to handle arguments from external “control” functions such as glmnet::glmnet_control()?

Answer

See “Control objects/functions of learners”.

Question 4

How to document if my learner uses a custom default value that differs to the default of the upstream package?

Answer

If you set a custom default for the mlr3learner that does not cope with the one of the upstream package (think twice if this is really needed!), add this information to the help page of the respective learner.

You can use the following skeleton for this:

#' @section Custom mlr3 defaults:
#' - `<parameter>`:
#'   - Actual default: <value>
#'   - Adjusted default: <value>
#'   - Reason for change: <text>

Question 5

When should the "required" tag be used when defining Params and what is its purpose?

Answer

The "required" tag should be used when the following conditions are met:

  • The upstream function cannot be run without setting this parameter, i.e. it would throw an error.
  • The parameter has no default in the upstream function.

In mlr3 we follow the principle that every learner should be constructable without setting custom parameters. Therefore, if a parameter has no default in the upstream function, a custom value is usually set for this parameter in the mlr3learner (remember to document such changes in the help page of the learner).

Even though this practice ensures that no parameter is unset in an mlr3learner and partially removes the usefulness of the "required" tag, the tag is still useful in the following scenario:

If a user sets custom parameters after construction of the learner

lrn = lrn("<id>")
lrn$param_set$values = list("<param>" = <value>)

Here, all parameters besides the ones set in the list would be unset. See paradox::ParamSet for more information. If a parameter is tagged as "required" in the ParamSet, the call above would error and prompt the user that required parameters are missing.