2.4 Train, Predict, Score

In this section, we explain how tasks and learners can be used to train a model and predict to a new dataset. The concept is demonstrated on a supervised classification using the penguins dataset and the rpart learner, which builds a singe classification tree.

Training a learner means fitting a model to a given data set. Subsequently, we want to predict the label for new observations. These predictions are compared to the ground truth values in order to assess the predictive performance of the model.

2.4.1 Creating Task and Learner Objects

First of all, we load the mlr3verse package.

library("mlr3verse")

Next, we retrieve the task and the learner from mlr_tasks (with shortcut tsk()) and mlr_learners (with shortcut lrn()), respectively:

  1. The classification task:
task = tsk("penguins")
  1. A learner for the classification tree:
learner = lrn("classif.rpart")

2.4.2 Setting up the train/test splits of the data

It is common to train on a majority of the data. Here we use 80% of all available observations and predict on the remaining 20%. For this purpose, we create two index vectors:

train_set = sample(task$nrow, 0.8 * task$nrow)
test_set = setdiff(seq_len(task$nrow), train_set)

In Section 2.5 we will learn how mlr3 can automatically create training and test sets based on different resampling strategies.

2.4.3 Training the learner

The field $model stores the model that is produced in the training step. Before the $train() method is called on a learner object, this field is NULL:

learner$model
## NULL

Next, the classification tree is trained using the train set of the sonar task by calling the $train() method of the Learner:

learner$train(task, row_ids = train_set)

This operation modifies the learner in-place. We can now access the stored model via the field $model:

print(learner$model)
## n= 275 
## 
## node), split, n, loss, yval, (yprob)
##       * denotes terminal node
## 
## 1) root 275 162 Adelie (0.410909 0.203636 0.385455)  
##   2) flipper_length< 206.5 164  52 Adelie (0.682927 0.310976 0.006098)  
##     4) bill_length< 43.35 117   5 Adelie (0.957265 0.042735 0.000000) *
##     5) bill_length>=43.35 47   1 Chinstrap (0.000000 0.978723 0.021277) *
##   3) flipper_length>=206.5 111   6 Gentoo (0.009009 0.045045 0.945946)  
##     6) bill_depth>=17.2 8   3 Chinstrap (0.125000 0.625000 0.250000) *
##     7) bill_depth< 17.2 103   0 Gentoo (0.000000 0.000000 1.000000) *

2.4.4 Predicting

After the model has been trained, we use the remaining part of the data for prediction. Remember that we initially split the data in train_set and test_set.

prediction = learner$predict(task, row_ids = test_set)
print(prediction)
## <PredictionClassif> for 69 observations:
##     row_ids     truth  response
##           4    Adelie    Adelie
##           9    Adelie    Adelie
##          12    Adelie    Adelie
## ---                            
##         330 Chinstrap Chinstrap
##         339 Chinstrap Chinstrap
##         342 Chinstrap Chinstrap

The $predict() method of the Learner returns a Prediction object. More precisely, a LearnerClassif returns a PredictionClassif object.

A prediction objects holds the row ids of the test data, the respective true label of the target column and the respective predictions. The simplest way to extract this information is by converting the Prediction object to a data.table():

head(as.data.table(prediction))
##    row_ids  truth response
## 1:       4 Adelie   Adelie
## 2:       9 Adelie   Adelie
## 3:      12 Adelie   Adelie
## 4:      14 Adelie   Adelie
## 5:      15 Adelie   Adelie
## 6:      19 Adelie   Adelie

For classification, you can also extract the confusion matrix:

prediction$confusion
##            truth
## response    Adelie Chinstrap Gentoo
##   Adelie        34         0      0
##   Chinstrap      5        12      0
##   Gentoo         0         0     18

2.4.5 Changing the Predict Type

Classification learners default to predicting the class label. However, many classifiers additionally also tell you how sure they are about the predicted label by providing posterior probabilities. To switch to predicting these probabilities, the predict_type field of a LearnerClassif must be changed from "response" to "prob" before training:

learner$predict_type = "prob"

# re-fit the model
learner$train(task, row_ids = train_set)

# rebuild prediction object
prediction = learner$predict(task, row_ids = test_set)

The prediction object now contains probabilities for all class labels:

# data.table conversion
head(as.data.table(prediction))
##    row_ids  truth response prob.Adelie prob.Chinstrap prob.Gentoo
## 1:       4 Adelie   Adelie      0.9573        0.04274           0
## 2:       9 Adelie   Adelie      0.9573        0.04274           0
## 3:      12 Adelie   Adelie      0.9573        0.04274           0
## 4:      14 Adelie   Adelie      0.9573        0.04274           0
## 5:      15 Adelie   Adelie      0.9573        0.04274           0
## 6:      19 Adelie   Adelie      0.9573        0.04274           0
# directly access the predicted labels:
head(prediction$response)
## [1] Adelie Adelie Adelie Adelie Adelie Adelie
## Levels: Adelie Chinstrap Gentoo
# directly access the matrix of probabilities:
head(prediction$prob)
##      Adelie Chinstrap Gentoo
## [1,] 0.9573   0.04274      0
## [2,] 0.9573   0.04274      0
## [3,] 0.9573   0.04274      0
## [4,] 0.9573   0.04274      0
## [5,] 0.9573   0.04274      0
## [6,] 0.9573   0.04274      0

Analogously to predicting probabilities, many regression learners support the extraction of standard error estimates by setting the predict type to "se".

2.4.6 Plotting Predictions

Analogously to plotting tasks, mlr3viz provides a autoplot() method for Prediction objects. All available types are listed on the manual page of autoplot.PredictionClassif() or autoplot.PredictionRegr(), respectively.

task = tsk("penguins")
learner = lrn("classif.rpart", predict_type = "prob")
learner$train(task)
prediction = learner$predict(task)
autoplot(prediction)

2.4.7 Performance assessment

The last step of modeling is usually the performance assessment. To assess the quality of the predictions, the predicted labels are compared with the true labels. How this comparison is calculated is defined by a measure, which is given by a Measure object. Note that if the prediction was made on a dataset without the target column, i.e. without true labels, then no performance can be calculated.

Predefined available measures are stored in mlr_measures (with convenience getter msr()):

mlr_measures
## <DictionaryMeasure> with 82 stored values
## Keys: classif.acc, classif.auc, classif.bacc, classif.bbrier,
##   classif.ce, classif.costs, classif.dor, classif.fbeta, classif.fdr,
##   classif.fn, classif.fnr, classif.fomr, classif.fp, classif.fpr,
##   classif.logloss, classif.mbrier, classif.mcc, classif.npv,
##   classif.ppv, classif.prauc, classif.precision, classif.recall,
##   classif.sensitivity, classif.specificity, classif.tn, classif.tnr,
##   classif.tp, classif.tpr, clust.ch, clust.db, clust.dunn,
##   clust.silhouette, debug, dens.logloss, oob_error, regr.bias,
##   regr.ktau, regr.mae, regr.mape, regr.maxae, regr.medae, regr.medse,
##   regr.mse, regr.msle, regr.pbias, regr.rae, regr.rmse, regr.rmsle,
##   regr.rrse, regr.rse, regr.rsq, regr.sae, regr.smape, regr.srho,
##   regr.sse, selected_features, surv.brier, surv.calib_alpha,
##   surv.calib_beta, surv.chambless_auc, surv.cindex, surv.dcalib,
##   surv.graf, surv.hung_auc, surv.intlogloss, surv.logloss, surv.mae,
##   surv.mse, surv.nagelk_r2, surv.oquigley_r2, surv.rmse, surv.schmid,
##   surv.song_auc, surv.song_tnr, surv.song_tpr, surv.uno_auc,
##   surv.uno_tnr, surv.uno_tpr, surv.xu_r2, time_both, time_predict,
##   time_train

We choose accuracy (classif.acc) as a specific performance measure and call the method $score() of the Prediction object to quantify the predictive performance.

measure = msr("classif.acc")
print(measure)
## <MeasureClassifSimple:classif.acc>
## * Packages: mlr3measures
## * Range: [0, 1]
## * Minimize: FALSE
## * Properties: -
## * Predict type: response
prediction$score(measure)
## classif.acc 
##      0.9651

Note that, if no measure is specified, classification defaults to classification error (classif.ce) and regression defaults to the mean squared error (regr.mse).