library("mlr3extralearners")
create_learner(
pkg = ".",
classname = "Rpart",
algorithm = "decision tree",
type = "classif",
key = "rpart",
package = "rpart",
caller = "rpart",
feature_types = c("logical", "integer", "numeric", "factor", "ordered"),
predict_types = c("response", "prob"),
properties = c("importance", "missings", "multiclass",
"selected_features", "twoclass", "weights"),
references = TRUE,
gh_name = "RaphaelS1"
)
26 Adding new Learners
Here, we show how to create a custom mlr3learner step-by-step using mlr3extralearners::create_learner
.
It is strongly recommended that you first open a learner request issue to discuss the learner you want to implement if you plan on creating a pull request to the mlr-org. This allows us to discuss the purpose and necessity of the learner before you start to put the real work in!
This section gives insights on how a mlr3learner is constructed and how to troubleshoot issues. See the Learner FAQ subsection for help.
Summary of steps for adding a new learner
- Check the learner does not already exist here.
- Fork, clone and load mlr3extralearners.
-
Run
mlr3extralearners::create_learner()
. - Add the learner
param_set
. - Manually add
.train
and.predict
private methods to the learner. - If applicable add
importance
andoob_error
public methods to the learner. - If applicable add references to the learner.
- Check unit tests and paramtests pass (these are automatically created).
- Run cleaning functions
- Open a pull request with the new learner template.
(Do not copy/paste the code shown in this section. Use create_learner()
to start.)
26.1 Setting-up mlr3extralearners
In order to use the mlr3extralearners::create_learner
function you must have a local copy of the mlr3extralearners repository and must specify the correct path to the package. To do so, follow these steps:
- Fork the repository
- Clone a local copy of your forked repository.
Then do one of:
- Open a new R session, call
library("mlr3extralearners")
(install if you haven’t already), and then runmlr3extralearners::create_learner
with thepkg
argument set as the path (the folder location) to the package directory. - Open a new R session, set your working directory as your newly cloned repository, run
devtools::load_all
, and then runmlr3extralearners::create_learner()
, leavingpkg = "."
. - In your newly cloned repository, open the R project, which will automatically set your working directory, run
devtools::load_all
, and then runmlr3extralearners::create_learner()
, leavingpkg = "."
.
We recommend the last option. It is also important that you are familiar with the three devtools
commands:
-
devtools::document()
- Generates roxygen documentation for your new learner. -
devtools::load_all()
- Loads all functions from mlr3extralearners locally, including hidden helper functions. -
devtools::check()
- Checks that the package still passes all tests locally.
26.2 Calling create_learner
The learner classif.rpart
will be used as a running example throughout this section.
The full documentation for the function arguments is in mlr3extralearners::create_learner
, in this example we are doing the following:
-
pkg = "."
- Set the package root to the current directory (assumes mlr3extralearners already set as the working directory) -
classname = "Rpart"
- Set the R6 class name to LearnerClassifRpart (classif is below) -
algorithm = "decision tree"
- Create the title as “Classification Decision Tree Learner”, where “Classification” is determined automatically fromtype
and “Learner” is added for all learners. -
type = "classif"
- Setting the learner as a classification learner, automatically filling the title, class name, id (“classif.rpart”) and task type. -
key = "rpart"
- Used withtype
to create the unique ID of the learner,classif.rpart
. -
package = "rpart"
- Setting the package from which the learner is implemented, this fills in things like the training function (along withcaller
) and theman
field. -
caller = "rpart"
- This tells the.train
function, and the description which function is called to run the algorithm, withpackage
this automatically fillsrpart::rpart
. -
feature_types = c("logical", "integer", "numeric", "factor", "ordered")
- Sets the type of features that can be handled by the learner. See meta information. -
predict_types = c("response", "prob"),
- Sets the possible prediction types as response (deterministic) and prob (probabilistic). See meta information. -
properties = c("importance", "missings", "multiclass", "selected_features", "twoclass", "weights")
- Sets the properties that are handled by the learner, by including"importance"
a public method calledimportance
will be created that must be manually filled. See meta information. -
references = TRUE
- Tells the template to add a “references” tag that must be filled manually. -
gh_name = "RaphaelS1"
- Fills the “author” tag with my GitHub handle, this is required as it identifies the maintainer of the learner.
The sections below demonstrate what happens after the function has been run and the files that are created.
26.3 learner_package_type_key.R
The first script to complete after running create_learner()
is the file with the form learner_package_type_key.R
, in our case this will actually be learner_rpart_classif_rpart.key
.
This name must not be changed as triggering automated tests rely on a strict naming scheme.
For our example, the resulting script looks like this:
#' @title Classification Decision Tree Learner
#' @author RaphaelS1
#' @name mlr_learners_classif.rpart
#'
#' @template class_learner
#' @templateVar id classif.rpart
#' @templateVar caller rpart
#'
#' @references
#' <FIXME - DELETE THIS AND LINE ABOVE IF OMITTED>
#'
#' @template seealso_learner
#' @template example
#' @export
= R6Class("LearnerClassifRpart",
LearnerClassifRpart inherit = LearnerClassif,
public = list(
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function() {
# FIXME - MANUALLY ADD PARAM_SET BELOW AND THEN DELETE THIS LINE
= <param_set>
ps
# FIXME - MANUALLY UPDATE PARAM VALUES BELOW IF APPLICABLE THEN DELETE THIS LINE.
# OTHERWISE DELETE THIS AND LINE BELOW.
$values = list(<param_vals>)
ps
$initialize(
superid = "classif.rpart",
packages = "rpart",
feature_types = c("logical", "integer", "numeric", "factor", "ordered"),
predict_types = c("response", "prob"),
param_set = ps,
properties = c("importance", "missings", "multiclass", "selected_features", "twoclass", "weights"),
man = "mlr3extralearners::mlr_learners_classif.rpart"
)
},
# FIXME - ADD IMPORTANCE METHOD HERE AND DELETE THIS LINE.
# <See LearnerRegrRandomForest for an example>
#' @description
#' The importance scores are extracted from the slot <FIXME>.
#' @return Named `numeric()`.
importance = function() { }
),
private = list(
.train = function(task) {
= self$param_set$get_values(tags = "train")
pars
# set column names to ensure consistency in fit and predict
$state$feature_names = task$feature_names
self
# FIXME - <Create objects for the train call
# <At least "data" and "formula" are required>
= task$formula()
formula = task$data()
data
# FIXME - <here is space for some custom adjustments before proceeding to the
# train call. Check other learners for what can be done here>
# use the mlr3misc::invoke function (it's similar to do.call())
::invoke(rpart::rpart,
mlr3miscformula = formula,
data = data,
.args = pars)
},
.predict = function(task) {
# get parameters with tag "predict"
= self$param_set$get_values(tags = "predict")
pars # get newdata
= task$data(cols = task$feature_names)
newdata
= mlr3misc::invoke(predict, self$model, newdata = newdata,
pred type = type, .args = pars)
# FIXME - ADD PREDICTIONS TO LIST BELOW
list(...)
}
)
)
$add("classif.rpart", LearnerClassifRpart) .extralrns_dict
Now we have to do the following (from top to bottom):
- Fill in the references under “references” and delete the tag that starts “FIXME”
- Replace
<param_set>
with a parameter set - Optionally change default values for parameters in
<param_vals>
- As we included “importance” in
properties
we have to add a function to the public methodimportance
- Fill in the private
.train
method, which takes a (filtered)Task
and returns a model. - Fill in the private
.predict
method, which operates on the model inself$model
(stored during$train()
) and a (differently subsetted)Task
to return a named list of predictions.
26.4 Meta-information
In the constructor (initialize()
) the constructor of the super class (e.g. LearnerClassif
) is called with meta information about the learner which should be constructed. This includes:
-
id
: The ID of the new learner. Usually consists of<type>.<algorithm>
, for example:"classif.rpart"
. -
packages
: The upstream package name of the implemented learner. -
param_set
: A set of hyperparameters and their descriptions provided as aparadox::ParamSet
. For each hyperparameter the appropriate class needs to be chosen. When using theparadox::ps
shortcut, a short constructor of the formp_***
can be used:-
paradox::ParamLgl
/paradox::p_lgl
for scalar logical hyperparameters. -
paradox::ParamInt
/paradox::p_int
for scalar integer hyperparameters. -
paradox::ParamDbl
/paradox::p_dbl
for scalar numeric hyperparameters. -
paradox::ParamFct
/paradox::p_fct
for scalar factor hyperparameters (this includes characters). -
paradox::ParamUty
/paradox::p_uty
for everything else (e.g. vector paramters or list parameters).
-
-
predict_types
: Set of predict types the learner is able to handle. These differ depending on the type of the learner. Seemlr_reflections$learner_predict_types
for the full list of feature types supported by mlr3.-
LearnerClassif
-
response
: Only predicts a class label for each observation in the test set. -
prob
: Also predicts the posterior probability for each class for each observation in the test set.
-
-
LearnerRegr
-
response
: Only predicts a numeric response for each observation in the test set. -
se
: Also predicts the standard error for each value of response for each observation in the test set.
-
-
-
feature_types
: Set of feature types the learner is able to handle. Seemlr_reflections$task_feature_types
for feature types supported by mlr3. -
properties
: Set of properties of the learner. Seemlr_reflections$learner_properties
for the full list of feature types supported by mlr3. Possible properties include:-
"twoclass"
: The learner works on binary classification problems. -
"multiclass"
: The learner works on multi-class classification problems. -
"missings"
: The learner can natively handle missing values. -
"weights"
: The learner can work on tasks which have observation weights / case weights. -
"parallel"
: The learner supports internal parallelization in some way. Currently not used, this is an experimental property. -
"importance"
: The learner supports extracting importance values for features. If this property is set, you must also implement a public methodimportance()
to retrieve the importance values from the model. -
"selected_features"
: The learner supports extracting the features which were used. If this property is set, you must also implement a public methodselected_features()
to retrieve the set of used features from the model.
-
-
man
: The roxygen identifier of the learner. This is used within the$help()
method of the super class to open the help page of the learner.
26.5 ParamSet
The param_set
is the set of hyperparameters used in model training and predicting, this is given as a paradox::ParamSet
. The set consists of a list of hyperparameters, each has a specific class for the hyperparameter type (see above).
For classif.rpart
the following replace <param_set>
above:
ps = ParamSet$new(list(
ParamInt$new(id = "minsplit", default = 20L, lower = 1L, tags = "train"),
ParamInt$new(id = "minbucket", lower = 1L, tags = "train"),
ParamDbl$new(id = "cp", default = 0.01, lower = 0, upper = 1, tags = "train"),
ParamInt$new(id = "maxcompete", default = 4L, lower = 0L, tags = "train"),
ParamInt$new(id = "maxsurrogate", default = 5L, lower = 0L, tags = "train"),
ParamInt$new(id = "maxdepth", default = 30L, lower = 1L, upper = 30L, tags = "train"),
ParamInt$new(id = "usesurrogate", default = 2L, lower = 0L, upper = 2L, tags = "train"),
ParamInt$new(id = "surrogatestyle", default = 0L, lower = 0L, upper = 1L, tags = "train"),
ParamInt$new(id = "xval", default = 0L, lower = 0L, tags = "train"),
ParamLgl$new(id = "keep_model", default = FALSE, tags = "train")
))
ps$values = list(xval = 0L)
Within mlr3 packages we suggest to stick to the lengthly definition for consistency, however the <param_set>
can be written shorter, using the paradox::ps
shortcut:
ps = ps(
minsplit = p_int(lower = 1L, default = 20L, tags = "train"),
minbucket = p_int(lower = 1L, tags = "train"),
cp = p_dbl(lower = 0, upper = 1, default = 0.01, tags = "train"),
maxcompete = p_int(lower = 0L, default = 4L, tags = "train"),
maxsurrogate = p_int(lower = 0L, default = 5L, tags = "train"),
maxdepth = p_int(lower = 1L, upper = 30L, default = 30L, tags = "train"),
usesurrogate = p_int(lower = 0L, upper = 2L, default = 2L, tags = "train"),
surrogatestyle = p_int(lower = 0L, upper = 1L, default = 0L, tags = "train"),
xval = p_int(lower = 0L, default = 0L, tags = "train"),
keep_model = p_lgl(default = FALSE, tags = "train")
)
You should read though the learner documentation to find the full list of available parameters. Just looking at some of these in this example:
-
"cp"
is numeric, has a feasible range of[0,1]
and defaults to0.01
. The parameter is used during"train"
. -
"xval"
is integer has a lower bound of0
, a default of0
and the parameter is used during"train"
. -
"keep_model"
is logical with a default ofFALSE
and is used during"train"
.
In some rare cases you may want to change the default parameter values. You can do this by passing a list to <param_vals>
in the template script above. You can see we have done this for "classif.rpart"
where the default for xval
is changed to 0
. Note that the default in the ParamSet
is recorded as our changed default (0), and not the original (10). It is strongly recommended to only change the defaults if absolutely required, when this is the case add the following to the learner documentation:
#' @section Custom mlr3 defaults:
#' - `<parameter>`:
#' - Actual default: <value>
#' - Adjusted default: <value>
#' - Reason for change: <text>
26.6 Train function
Let’s talk about the .train()
method. The train function takes a Task
as input and must return a model.
Let’s say we want to translate the following call of rpart::rpart()
into code that can be used inside the .train()
method.
First, we write something down that works completely without mlr3:
data = iris
model = rpart::rpart(Species ~ ., data = iris, xval = 0)
We need to pass the formula notation Species ~ .
, the data and the hyperparameters. To get the hyperparameters, we call self$param_set$get_values()
and query all parameters that are using during "train"
.
The dataset is extracted from the Task
.
Last, we call the upstream function rpart::rpart()
with the data and pass all hyperparameters via argument .args
using the mlr3misc::invoke()
function. The latter is simply an optimized version of do.call()
that we use within the mlr3ecosystem.
26.7 Predict function
The internal predict method .predict()
also operates on a Task
as well as on the fitted model that has been created by the train()
call previously and has been stored in self$model
.
The return value is a Prediction
object. We proceed analogously to what we did in the previous section. We start with a version without any mlr3 objects and continue to replace objects until we have reached the desired interface:
The "rpart::predict.rpart()"
function predicts class labels if argument type
is set to to "class"
, and class probabilities if set to "prob"
.
Next, we transition from data
to a Task
again and construct a list with the return type requested by the user, this is stored in the $predict_type
slot of a learner class. Note that the Task
is automatically passed to the prediction object, so all you need to do is return the predictions! Make sure the list names are identical to the task predict types.
The final .predict()
method is below, we could omit the pars
line as there are no parameters with the "predict"
tag but we keep it here to be consistent:
.predict = function(task) {
pars = self$param_set$get_values(tags = "predict")
# get newdata and ensure same ordering in train and predict
newdata = task$data(cols = self$state$feature_names)
if (self$predict_type == "response") {
response = mlr3misc::invoke(predict,
self$model,
newdata = newdata,
type = "class",
.args = pars)
return(list(response = response))
} else {
prob = mlr3misc::invoke(predict,
self$model,
newdata = newdata,
type = "prob",
.args = pars)
return(list(prob = prob))
}
}
Note that you cannot rely on the column order of the data returned by task$data()
as the order of columns may be different from the order of the columns during $.train
. The newdata
line ensures the ordering is the same by calling the saved order set in $.train
, don’t delete either of these lines!
26.8 Control objects/functions of learners
Some learners rely on a “control” object/function such as glmnet::glmnet.control()
. Accounting for such depends on how the underlying package works:
- If the package forwards the control parameters via
...
and makes it possible to just pass control parameters as additional parameters directly to the train call, there is no need to distinguish both"train"
and"control"
parameters. Both can be tagged with “train” in the ParamSet and just be handed over as shown previously. - If the control parameters need to be passed via a separate argument, the parameters should also be tagged accordingly in the ParamSet. Afterwards they can be queried via their tag and passed separately to
mlr3misc::invoke()
. See example below.
= mlr3misc::(<package>::<function>,
control_pars $param_set$get_values(tags = "control"))
self
= self$param_set$get_values(tags = "train"))
train_pars
::invoke([...], .args = train_pars, control = control_pars) mlr3misc
26.9 Testing the learner
Once your learner is created, you are ready to start testing if it works, there are three types of tests: manual, unit and parameter.
26.9.1 Train and Predict
For a bare-bone check you can just try to run a simple train()
call locally.
task = tsk("iris") # assuming a Classif learner
lrn = lrn("classif.rpart")
lrn$train(task)
p = lrn$predict(task)
p$confusion
If it runs without erroring, that’s a very good start!
26.9.2 Autotest
To ensure that your learner is able to handle all kinds of different properties and feature types, we have written an “autotest” that checks the learner for different combinations of such.
The “autotest” setup is generated automatically by create_learner()
and will open after running the function, it will have a name with the form test_package_type_key.R
, in our case this will actually be test_rpart_classif_rpart.key
. This name must not be changed as triggering automated tests rely on a strict naming scheme. In our example this will create the following script, for which no changes are required to pass (assuming the learner was correctly created):
install_learners("classif.rpart")
test_that("autotest", {
learner = LearnerClassifRpart$new()
expect_learner(learner)
result = run_autotest(learner)
expect_true(result, info = result$error)
})
For some learners that have required parameters, it is needed to set some values for required parameters after construction so that the learner can be run in the first place.
You can also exclude some specific test arrangements within the “autotest” via the argument exclude
in the run_autotest()
function. Currently the run_autotest()
function lives in inst/testthat of the mlr_plkg("mlr3")
and still lacks documentation. This should change in the near future.
To finally run the test suite, call devtools::test()
or hit CTRL + Shift + T
if you are using RStudio.
26.9.3 Checking Parameters
Some learners have a high number of parameters and it is easy to miss out on some during the creation of a new learner. In addition, if the maintainer of the upstream package changes something with respect to the arguments of the algorithm, the learner is in danger to break. Also, new arguments could be added upstream and manually checking for new additions all the time is tedious.
Therefore we have written a “Parameter Check” that runs for every learner asynchronously to the R CMD Check of the package itself. This “Parameter Check” compares the parameters of the mlr3 ParamSet against all arguments available in the upstream function that is called during $train()
and $predict()
. Again the file is automatically created and opened by create_learner()
, this will be named like test_paramtest_package_type_key.R
, so in our example test_paramtest_rpart_classif_rpart.R
.
The test comes with an exclude
argument that should be used to exclude and explain why certain arguments of the upstream function are not within the ParamSet of the mlr3learner. This will likely be required for all learners as common arguments like x
, target
or data
are handled by the mlr3 interface and are therefore not included within the ParamSet.
However, there might be more parameters that need to be excluded, for example:
- Type dependent parameters, i.e. parameters that only apply for classification or regression learners.
- Parameters that are actually deprecated by the upstream package and which were therefore not included in the mlr3 ParamSet.
All excluded parameters should have a comment justifying their exclusion.
In our example, the final paramtest script looks like:
library("mlr3extralearners")
install_learners("classif.rpart")
test_that("classif.rpart train", {
learner = lrn("classif.rpart")
fun = rpart::rpart
exclude = c(
"formula", # handled internally
"model", # handled internally
"data", # handled internally
"weights", # handled by task
"subset", # handled by task
"na.action", # handled internally
"method", # handled internally
"x", # handled internally
"y", # handled internally
"parms", # handled internally
"control", # handled internally
"cost" # handled internally
)
ParamTest = run_paramtest(learner, fun, exclude)
expect_true(ParamTest, info = paste0(
"Missing parameters:",
paste0("- '", ParamTest$missing, "'", collapse = "
")))
})
test_that("classif.rpart predict", {
learner = lrn("classif.rpart")
fun = rpart:::predict.rpart
exclude = c(
"object", # handled internally
"newdata", # handled internally
"type", # handled internally
"na.action" # handled internally
)
ParamTest = run_paramtest(learner, fun, exclude)
expect_true(ParamTest, info = paste0(
"Missing parameters:",
paste0("- '", ParamTest$missing, "'", collapse = "
")))
})
26.10 Package Cleaning
Once all tests are passing, run the following functions to ensure that the package remains clean and tidy
devtools::document(roclets = c('rd', 'collate', 'namespace'))
- If you haven’t done this before run:
remotes::install_github('pat-s/styler@mlr-style')
styler::style_pkg(style = styler::mlr_style)
usethis::use_tidy_description()
lintr::lint_package()
Please fix any errors indicated by lintr
before creating a pull request. Finally ensure that all FIXME
are resolved and deleted in the generated files.
You are now ready to add your learner to the mlr3 ecosystem! Simply open a pull request to with the new learner template and complete the checklist in there. Once the pull request is approved and merged, your learner will automatically appear on the package website.
26.11 Thanks and Maintenance
Thank you for contributing to the mlr3 ecosystem!
When you created the learner you would have given your GitHub handle, meaning that you are now listed as the learner author and maintainer. This means that if the learner breaks it is your responsibility to fix the learner - you can view the status of your learner here.
26.12 Learner FAQ
Question 1
How to deal with Parameters which have no default?
Answer
If the learner does not work without providing a value, set a reasonable default in param_set$values
, add tag "required"
to the parameter and document your default properly.
Question 2
Where to add the package of the upstream package in the DESCRIPTION file?
Add it to the “Suggests” section.
Question 3
How to handle arguments from external “control” functions such as glmnet::glmnet_control()
?
Answer
See “Control objects/functions of learners”.
Question 4
How to document if my learner uses a custom default value that differs to the default of the upstream package?
Answer
If you set a custom default for the mlr3learner that does not cope with the one of the upstream package (think twice if this is really needed!), add this information to the help page of the respective learner.
You can use the following skeleton for this:
#' @section Custom[mlr3](https://mlr3.mlr-org.com)defaults:
#' - `<parameter>`:
#' - Actual default: <value>
#' - Adjusted default: <value>
#' - Reason for change: <text>
Question 5
When should the "required"
tag be used when defining Params and what is its purpose?
Answer
The "required"
tag should be used when the following conditions are met:
- The upstream function cannot be run without setting this parameter, i.e. it would throw an error.
- The parameter has no default in the upstream function.
In mlr3we follow the principle that every learner should be constructable without setting custom parameters. Therefore, if a parameter has no default in the upstream function, a custom value is usually set for this parameter in the mlr3learner (remember to document such changes in the help page of the learner).
Even though this practice ensures that no parameter is unset in an mlr3learner and partially removes the usefulness of the "required"
tag, the tag is still useful in the following scenario:
If a user sets custom parameters after construction of the learner
= lrn("<id>")
lrn $param_set$values = list("<param>" = <value>) lrn
Here, all parameters besides the ones set in the list would be unset. See paradox::ParamSet
for more information. If a parameter is tagged as "required"
in the ParamSet, the call above would error and prompt the user that required parameters are missing.
Question 6
What is this error when I run devtools::load_all()
> devtools::load_all(".")
Loading mlr3extralearners:
Warning messagein unloadNamespace() for 'mlr3extralearners', details:
.onUnload failed : vapply(hooks, function(x) environment(x)$pkgname, NA_character_)
call: values must be length 1,
errorFUN(X[[1]]) result is length 0 but
Answer
This is not an error but a warning and you can safely ignore it!