2.4 Train and Predict

In this section, we explain how tasks and learners can be used to train a model and predict to a new dataset. The concept is demonstrated on a supervised classification using the iris dataset and the rpart learner (classification tree).

Training a learner means fitting a model to a given data set. Subsequently, we want to predict the target value for new observations. These predictions are compared to the ground truth values to assess the quality of the model. In sum, the goal of training and predicting is to evaluate the predictive power of different models.

2.4.1 Creating Task and Learner Objects

The first step is to generate the following mlr3 objects from the task dictionary and the learner dictionary, respectively:

  1. The classification task:
task = tsk("sonar")
  1. A learner for the classification tree:
learner = lrn("classif.rpart")

2.4.2 Setting up the train/test splits of the data

It is common to train on a majority of the data. Here we use 80% of all available observations and predict on the remaining 20% observations. For this purpose, we create two index vectors:

train_set = sample(task$nrow, 0.8 * task$nrow)
test_set = setdiff(seq_len(task$nrow), train_set)

2.4.3 Training the learner

The field model stores the model that is produced in the training step. Before the train method is called on a learner object, this field is NULL:

learner$model
## NULL

Next, the classification tree is trained using the train set of the iris task, applying the $train() method of the Learner:

learner$train(task, row_ids = train_set)

This operation modifies the learner in-place. We can now access the stored model via the field $model:

print(learner$model)
## n= 166 
## 
## node), split, n, loss, yval, (yprob)
##       * denotes terminal node
## 
##  1) root 166 82 M (0.50602 0.49398)  
##    2) V49>=0.04525 81 21 M (0.74074 0.25926)  
##      4) V37< 0.4634 53  5 M (0.90566 0.09434) *
##      5) V37>=0.4634 28 12 R (0.42857 0.57143)  
##       10) V17< 0.3785 17  5 M (0.70588 0.29412) *
##       11) V17>=0.3785 11  0 R (0.00000 1.00000) *
##    3) V49< 0.04525 85 24 R (0.28235 0.71765)  
##      6) V21>=0.6573 39 18 M (0.53846 0.46154)  
##       12) V42>=0.0851 30  9 M (0.70000 0.30000)  
##         24) V51>=0.0108 16  1 M (0.93750 0.06250) *
##         25) V51< 0.0108 14  6 R (0.42857 0.57143) *
##       13) V42< 0.0851 9  0 R (0.00000 1.00000) *
##      7) V21< 0.6573 46  3 R (0.06522 0.93478) *

2.4.4 Predicting

After the model has been trained, we use the remaining part of the data for prediction. Remember that we initially split the data in train_set and test_set.

prediction = learner$predict(task, row_ids = test_set)
print(prediction)
## <PredictionClassif> for 42 observations:
##     row_id truth response
##          9     R        M
##         12     R        M
##         13     R        R
## ---                      
##        200     M        R
##        207     M        M
##        208     M        R

The $predict() method of the Learner returns a Prediction object. More precise, as the learner is specialized for classification, a LearnerClassif returns a PredictionClassif object.

A prediction objects holds The row ids of the test data, the respective true label of the target column and the respective predictions. The simplest way to extract this information is by converting to a data.table():

head(as.data.table(prediction))
##    row_id truth response
## 1:      9     R        M
## 2:     12     R        M
## 3:     13     R        R
## 4:     24     R        M
## 5:     25     R        M
## 6:     28     R        R

For classification, you can also extract the confusion matrix:

prediction$confusion
##         truth
## response  M  R
##        M 17  7
##        R 10  8

2.4.5 Changing the Predict Type

Classification learners default to predicting the class label. However, many classifiers additionally also tell you how sure they are about the predicted label by providing posterior probabilities. To switch to predicting these probabilities, the predict_type field of a LearnerClassif must be changed from "response" to "prob":

learner$predict_type = "prob"

# re-fit the model
learner$train(task, row_ids = train_set)

# rebuild prediction object
prediction = learner$predict(task, row_ids = test_set)

The prediction object now contains probabilities for all class labels:

# data.table conversion
head(as.data.table(prediction))
##    row_id truth response  prob.M  prob.R
## 1:      9     R        M 0.90566 0.09434
## 2:     12     R        M 0.93750 0.06250
## 3:     13     R        R 0.00000 1.00000
## 4:     24     R        M 0.70588 0.29412
## 5:     25     R        M 0.70588 0.29412
## 6:     28     R        R 0.06522 0.93478

# directly access the predicted labels:
head(prediction$response)
## [1] M M R M M R
## Levels: M R

# directly access the matrix of probabilities:
head(prediction$prob)
##            M       R
## [1,] 0.90566 0.09434
## [2,] 0.93750 0.06250
## [3,] 0.00000 1.00000
## [4,] 0.70588 0.29412
## [5,] 0.70588 0.29412
## [6,] 0.06522 0.93478

Analogously to predicting probabilities, many regression learners support the extracting of a standard error estimates by setting the predict type to "se".

2.4.6 Plotting Predictions

Analogously to plotting tasks, mlr3viz provides a autoplot() method. All available types are listed on the manual page of autoplot.PredictionClassif() or autoplot.PredictionClassif(), respectively.

library(mlr3viz)

task = tsk("sonar")
learner = lrn("classif.rpart", predict_type = "prob")
learner$train(task)
prediction = learner$predict(task)
autoplot(prediction)

autoplot(prediction, type = "roc")

library(mlr3viz)
library(mlr3learners)
local({ # we do this locally to not overwrite the objects from previous chunks
    task = tsk("mtcars")
    learner = lrn("regr.lm")
    learner$train(task)
    prediction = learner$predict(task)
    autoplot(prediction)
})

2.4.7 Performance assessment

The last step of an modeling is usually the performance assessment. The quality of the predictions of a model in mlr3 can be assessed with respect to a number of different performance measures. At the performance assessment we choose a specific performance measure to quantify the predictions. This is done by comparing the predicted labels with the true labels. Predefined available measures are stored in mlr_measures (with convenience getter msr()):

mlr_measures
## <DictionaryMeasure> with 51 stored values
## Keys: classif.acc, classif.auc, classif.bacc, classif.ce,
##   classif.costs, classif.dor, classif.fbeta, classif.fdr, classif.fn,
##   classif.fnr, classif.fomr, classif.fp, classif.fpr, classif.logloss,
##   classif.mcc, classif.npv, classif.ppv, classif.precision,
##   classif.recall, classif.sensitivity, classif.specificity, classif.tn,
##   classif.tnr, classif.tp, classif.tpr, debug, oob_error, regr.bias,
##   regr.ktau, regr.mae, regr.mape, regr.maxae, regr.medae, regr.medse,
##   regr.mse, regr.msle, regr.pbias, regr.rae, regr.rmse, regr.rmsle,
##   regr.rrse, regr.rse, regr.rsq, regr.sae, regr.smape, regr.srho,
##   regr.sse, selected_features, time_both, time_predict, time_train

We select the accuracy (classif.acc) and call the method $score() of the Prediction object.

measure = msr("classif.acc")
prediction$score(measure)
## classif.acc 
##       0.875

Note that, if no measure is specified, classification defaults to classification error (classif.ce) and regression defaults to the mean squared error (regr.mse).