7.1 Extending mlr3

7.1.1 Learners

Here, we show how to create a custom LearnerClassif step-by-step.

Preferably, you checkout our template package for new learners. Alternatively, here is a template snippet for a new classification learner:

In the first line of the template, we create a new R6 class with class "LearnerClassifYourLearner". The next line determines the parent class: As we want to create a classification learner, we obviously want to inherit from LearnerClassif.

A learner consists of three parts:

  1. Meta information about the learners
  2. A train_internal() function which takes a (filtered) TaskClassif and returns a model
  3. A predict_internal() function which operates on the model in self$model (stored during $train()) and a (differently subsetted) TaskClassif to return a named list of predictions. Meta-information

In the constructor function initialize() the constructor of the super class LearnerClassif is called with meta information about the learner we want to construct. This includes:

  • id: The id of the new learner.
  • param_set: A set of hyperparameters and their description, provided as paradox::ParamSet. It is perfectly fine to add no parameters here for a first draft. For each hyperparameter you want to add, you have to select the appropriate class:
  • predict_types: Set of predict types the learner is capable of. These differ depending on the type of the learner.
    • LearnerClassif
      • response: Only predicts a class label for each observation in the test set.
      • prob: Also predicts the posterior probability for each class for each observation in the test set.
    • LearnerRegr
      • response: Only predicts a numeric response for each observation in the test set.
      • se: Also predicts the standard error for each value of response for each observation in the test set.
  • feature_types: Set of feature types the learner can handle. See mlr_reflections$task_feature_types for feature types supported by mlr3.
  • properties: Set of properties of the learner. Possible properties include:
    • "twoclass": The learner works on binary classification problems.
    • "multiclass": The learner works on multi-class classification problems.
    • "missings": The learner can natively handle missing values.
    • "weights": The learner can work on tasks which have observation weights / case weights.
    • "parallel": The learner can be parallelized, e.g. via threading.
    • "importance": The learner supports extracting importance values for features. If this property is set, you must also implement a public method importance() to retrieve the importance values from the model.
    • "selected features": The learner supports extracting the features which where used. If this property is set, you must also implement a public method selected_features() to retrieve the set of used features from the model.
  • Set of required packages to run the learner.

For a simplified rpart::rpart(), the initialization could look like this:

We only have specified a small subset of the available hyperparameters:

  • The complexity "cp" is numeric, its feasible range is [0,1], it defaults to 0.01 and the parameter is used during "train".
  • The complexity "xval" is integer, its lower bound 0, its default is 0 and the parameter is also used during "train". Note that we have changed the default here from 10 to 0 to save some computation time. This is not done by setting a different default in ParamInt$new(), but instead by setting the value implicitly. Train function

We continue the to adept the template for a rpart::rpart() learner, and now tackle the train_internal() function. The train function takes a Task as input and must return an arbitrary model. First, we write something down that works completely without mlr3:

In the next step, we replace the data frame data with a Task:

The target variable "Species" is still hard-coded and specific to the task. This is unnecessary, as the information about the target variable is stored in the task:

We can adapt our code accordingly:

The last thing missing is the handling of hyperparameters. Instead of the hard-coded xval, we query the hyperparameter settings from the Learner itself.

To illustrate this, we quickly construct the tree learner from the mlr3 package, and use the method get_value() from the ParamSet to retrieve all set hyperparameters with tag "train".

To pass all hyperparameters down to the model fitting function, we recommend to use either do.call or the function mlr3misc::invoke().

In the final learner, self will of course reference the learner itself. In the last step, we wrap everything in a function. Predict function

The internal predict function predict_internal also operates on a Task as well as on the model stored during train() in self$model. The return value is a Prediction object. We proceed analogously to the section on the train function: We start with a version without any mlr3 objects and continue to replace objects until we have reached the desired interface:

The rpart::predict.rpart() function predicts class labels if argument type is set to to "class", and class probabilities if set to "prob".

Next, we transition from data to a task again and construct a proper PredictionClassif object to return. Additionally, as we do not want to run the prediction twice, we differentiate what type of prediction is requested by querying the set predict type of the learner. The complete predict_internal function looks like this:

Note that if the learner would need to handle hyperparameters during the predict step, we would proceed analogously to the train() step and use self$params("predict") in combination with mlr3misc::invoke().

Also note that you cannot rely on the column order of the data returned by task$data(), i.e. the order of columns may be different from the order of the columns during $train(). You have to make sure that your learner accesses columns by name, not by position (like some algorithms with a matrix interface do). You may have to restore the order manually here, see “classif.svm” for an example. Final learner

To run some basic tests:

To run a bunch of automatic tests, you may source some auxiliary scripts from the unit tests of mlr3: