3.4 Train & Predict

In this chapter, we explain how tasks and learners can be used to train a model and predict to a new dataset.

The concept is demonstrated on a supervised classification using the iris dataset and the rpart learner (classification tree).

Additionally, this chapter includes the following use-cases

  • Functional Data Analysis using (WIP)
  • Regression Analysis using (WIP)
  • Survival Analysis using (WIP)
  • Spatial Analysis using (WIP)

3.4.1 Basic concept Creating Task and Learner Objects

The first step is to generate the following mlr3 objects from the task dictionary and the learner dictionary, respectively:

  1. The classification task
  1. A learner for the classification tree Setting up the train/test splits of the data (#split-data)

It is common to train on a majority of the data. Here we use 80% of all available observations and predict on the remaining 20% observations. For this purpose, we create two index vectors: Predicting

After the model was trained, we use the remaining part of the data for prediction. Remember that we initially split the data in train_set and test_set.

The $predict() method of the Learner returns a Prediction object. More precise, as the learner is specialized for classification, a LearnerClassif returns a PredictionClassif object.

A prediction objects holds The row ids of the test data, the respective true label of the target column and the respective predictions. The simplest way to extract this information is by converting to a data.table():

For classification, you can also extract the confusion matrix: Performance assessment

The last step of an modeling is usually the performance assessment where we choose performance measure to quantify the predictions by comparing the predicted labels with the true labels. Available measures are stored in mlr_measures:

We select the accuracy (classif.acc) and call the method $score() of the Prediction object:

Note that, if no measure is specified, classification defaults to classification error (classif.ce) and regression defaults to the mean squared error (regr.mse).