5  Train, Predict, Assess Performance

In this section, we explain how tasks and learners can be used to train a model and predict on a new dataset. Training a learner means fitting a model to a given data set – essentially an optimization problem that determines the best parameters (not hyperparameters!) of the model given the data. We then predict the label for observations that the model has not seen during training. We will then go over comparing the predictions to ground truth values in order to assess the quality of a prediction.

The concept is demonstrated on a supervised classification task using the pima dataset, in which patient data is used to diagnostically predict diabetes, and the rpart learner, which builds a classification tree. As shown in the previous chapters, we load these objects using the short access functions tsk() and lrn().

task = tsk("pima")
learner = lrn("classif.rpart")

5.1 Training the learner

The field $model stores the model that is produced in the training step. Before the $train() method is called on a learner object, this field is NULL:


Now we fit the classification tree using the training set of the task by calling the $train() method of Learner:


This operation modifies the learner in-place by adding the fitted model to the existing object. We can now access the stored model via the field $model:

n= 768 

node), split, n, loss, yval, (yprob)
      * denotes terminal node

  1) root 768 268 neg (0.34895833 0.65104167)  
    2) glucose>=127.5 283 109 pos (0.61484099 0.38515901)  
      4) mass>=29.95 208  58 pos (0.72115385 0.27884615)  
        8) glucose>=157.5 92  12 pos (0.86956522 0.13043478) *
        9) glucose< 157.5 116  46 pos (0.60344828 0.39655172)  
         18) age>=30.5 66  19 pos (0.71212121 0.28787879) *
         19) age< 30.5 50  23 neg (0.46000000 0.54000000)  
           38) pressure< 73 21   8 pos (0.61904762 0.38095238) *
           39) pressure>=73 29  10 neg (0.34482759 0.65517241)  
             78) mass>=41.8 9   3 pos (0.66666667 0.33333333) *
             79) mass< 41.8 20   4 neg (0.20000000 0.80000000) *
      5) mass< 29.95 75  24 neg (0.32000000 0.68000000) *
    3) glucose< 127.5 485  94 neg (0.19381443 0.80618557)  
      6) age>=28.5 214  71 neg (0.33177570 0.66822430)  
       12) insulin>=142.5 56  26 neg (0.46428571 0.53571429)  
         24) age< 56.5 41  16 pos (0.60975610 0.39024390) *
         25) age>=56.5 15   1 neg (0.06666667 0.93333333) *
       13) insulin< 142.5 158  45 neg (0.28481013 0.71518987)  
         26) glucose>=99.5 102  41 neg (0.40196078 0.59803922)  
           52) mass>=26.35 84  41 neg (0.48809524 0.51190476)  
            104) pedigree>=0.2045 65  27 pos (0.58461538 0.41538462)  
              208) pregnant>=5.5 32   8 pos (0.75000000 0.25000000) *
              209) pregnant< 5.5 33  14 neg (0.42424242 0.57575758)  
                418) age>=34.5 19   7 pos (0.63157895 0.36842105) *
                419) age< 34.5 14   2 neg (0.14285714 0.85714286) *
            105) pedigree< 0.2045 19   3 neg (0.15789474 0.84210526) *
           53) mass< 26.35 18   0 neg (0.00000000 1.00000000) *
         27) glucose< 99.5 56   4 neg (0.07142857 0.92857143) *
      7) age< 28.5 271  23 neg (0.08487085 0.91512915) *

Inspecting the output, we see that the learner has identified features in the task that are predictive of the class (diabetes status) and uses them to partition observations in the tree. There are additional details on how the data is partitioned across branches of the tree; the textual representation of the model depends on the type of learner. For more information on this particular type of model, see rpart::print.rpart().

5.2 Predicting

After the model has been fitted to the training data, we can now use it for prediction. A common case is that a model was fitted on all training data that was available, and should now be used to make predictions for new data for which the actual labels are unknown:

pima_new = data.table::fread("
age, glucose, insulin, mass, pedigree, pregnant, pressure, triceps
24,  145,     306,     41.7, 0.5,      3,        52,       36
47,  133,     NA,      23.3, 0.2,      7,        83,       28
   age glucose insulin mass pedigree pregnant pressure triceps
1:  24     145     306 41.7      0.5        3       52      36
2:  47     133      NA 23.3      0.2        7       83      28

The learner does not need to know any more meta-information about this data to make a prediction, such as which columns are features and which are targets, since this was already included in the training task. Instead, this data can directly be used to make a prediction using $predict_newdata():

prediction = learner$predict_newdata(pima_new)
<PredictionClassif> for 2 observations:
 row_ids truth response
       1  <NA>      pos
       2  <NA>      neg

This method returns a Prediction object. More precisely, because the learner is a LearnerClassif, it returns a PredictionClassif object. The easiest way to access information from it is to convert it to a data.table:

   row_ids truth response
1:       1  <NA>      pos
2:       2  <NA>      neg

Here the "truth" column is NA, since it is not known. Should the actual truth values for the new data be known, then one can convert this data to a new Task, create predictions that know both the predicted and the actual label, and use this prediction object for performance evaluation.

Suppose the pima_new data had both been measured on positive ("pos") patients:

pima_new_known = cbind(pima_new, diabetes = factor("pos", levels = c("pos", "neg")))
   age glucose insulin mass pedigree pregnant pressure triceps diabetes
1:  24     145     306 41.7      0.5        3       52      36      pos
2:  47     133      NA 23.3      0.2        7       83      28      pos
task_pima_new = as_task_classif(pima_new_known, target = "diabetes")

This task can then be used to make a prediction using the $predict() method of the Learner class. The result is another PredictionClassif, but with the "truth" column filled out:

prediction = learner$predict(task_pima_new)
<PredictionClassif> for 2 observations:
 row_ids truth response
       1   pos      pos
       2   pos      neg

5.3 Changing the Predict Type

Classification learners default to predicting the class label. However, many classifiers additionally also tell you how sure they are about the predicted label by providing posterior probabilities for the classes. To predict these probabilities, the predict_type field of a LearnerClassif must be changed from "response" (the default) to "prob" before training:

learner$predict_type = "prob"

# re-fit the model

# rebuild prediction object
prediction = learner$predict(task_pima_new)

<PredictionClassif> for 2 observations:
 row_ids truth response  prob.pos  prob.neg
       1   pos      pos 0.6190476 0.3809524
       2   pos      neg 0.3200000 0.6800000

The prediction object now contains probabilities for all class labels in addition to the predicted label (the one with the highest probability):

# directly access the predicted labels:
[1] pos neg
Levels: pos neg
# directly access the matrix of probabilities:
           pos       neg
[1,] 0.6190476 0.3809524
[2,] 0.3200000 0.6800000
# data.table conversion
   row_ids truth response  prob.pos  prob.neg
1:       1   pos      pos 0.6190476 0.3809524
2:       2   pos      neg 0.3200000 0.6800000

Similarly to predicting probabilities for classification, many regression learners support the extraction of standard error estimates for predictions by setting the predict type to "se".

5.4 Thresholding

Models trained on binary classification tasks that predict the probability for the positive class usually use a simple rule to determine the predicted class label: if the probability is more than 50%, predict the positive label, otherwise predict the negative label. In some cases you may want to adjust this threshold, for example if the classes are very unbalanced (i.e. one is much more prevalent than the other).

In the example below, we change the threshold to 0.2, making the model predict "pos" for both example rows:

<PredictionClassif> for 2 observations:
 row_ids truth response  prob.pos  prob.neg
       1   pos      pos 0.6190476 0.3809524
       2   pos      pos 0.3200000 0.6800000

5.5 Predicting on known data and train/test splits

We will usually not want to wait with performance evaluation until new data becomes available and will instead work with all the training data we have available at a given point. However, when evaluating the performance of a Learner, it is also important to score predictions made on data that have not been seen during training, since making predictions on training data is too easy in general – a Learner could just memorize the training data responses and get a perfect score.

mlr3 makes it easy to only train on subsets of given tasks. We first create a vector indicating on what row IDs of the task the Learner should be trained, and another that indicates the remaining rows that should be used for prediction. These vectors indicate the train-test-split we are using. This is done manually here for demonstration purpuses: In Chapter 7, we show how mlr3 can automatically create training and test sets based on resampling strategies that can be more elaborate.

We will use 67% of all available observations to train and predict on the remaining 33%.

train_set = sample(task$row_ids, 0.67 * task$nrow)
test_set = setdiff(task$row_ids, train_set)

Do not use constructs like sample(task$nrow, ...) for the purpose of creating task subsets, since rows are always identified by their $row_ids. These are not guaranteed to range from 1 to task$nrow and could be any positive integer.

Both $train() and $predict() have an optional row_ids-argument that determines which rows are used. Note that it is not a problem to run $train() with a Learner that has already been trained: the old model is automatically discarded, the learner trains from scratch.

# train on the training set
learner$train(task, row_ids = train_set)

# predict on the test set
prediction = learner$predict(task, row_ids = test_set)

# the prediction naturally knows about the "truth" from the task
<PredictionClassif> for 254 observations:
    row_ids truth response   prob.pos  prob.neg
          8   neg      neg 0.37500000 0.6250000
         12   pos      pos 0.84905660 0.1509434
         19   neg      neg 0.37500000 0.6250000
        762   pos      pos 0.84905660 0.1509434
        765   neg      neg 0.09954751 0.9004525
        768   neg      neg 0.09954751 0.9004525

5.6 Performance assessment

The last step of modeling is usually assessing the performance of the trained model. For this, the predictions made by the model are compared with the known ground-truth values that are stored in the Prediction object. The exact nature of this comparison is defined by a measure, which is given by a "Measure" object. If the prediction was made on a dataset without the target column, i.e. without known true labels, then performance can not be calculated.

Available measures can be retrieved using the msr() function, which accesses objects in mlr_measures:

<DictionaryMeasure> with 58 stored values
Keys: aic, bic, classif.acc, classif.auc, classif.bacc, classif.bbrier,
  classif.ce, classif.costs, classif.dor, classif.fbeta, classif.fdr,
  classif.fn, classif.fnr, classif.fomr, classif.fp, classif.fpr,
  classif.logloss, classif.mbrier, classif.mcc, classif.npv,
  classif.ppv, classif.prauc, classif.precision, classif.recall,
  classif.sensitivity, classif.specificity, classif.tn, classif.tnr,
  classif.tp, classif.tpr, debug, oob_error, regr.bias, regr.ktau,
  regr.mae, regr.mape, regr.maxae, regr.medae, regr.medse, regr.mse,
  regr.msle, regr.pbias, regr.rae, regr.rmse, regr.rmsle, regr.rrse,
  regr.rse, regr.rsq, regr.sae, regr.smape, regr.srho, regr.sse,
  selected_features, sim.jaccard, sim.phi, time_both, time_predict,

We choose accuracy (classif.acc) as our specific performance measure here and call the method $score() of the prediction object to quantify the predictive performance of our model.

measure = msr("classif.acc")
<MeasureClassifSimple:classif.acc>: Classification Accuracy
* Packages: mlr3, mlr3measures
* Range: [0, 1]
* Minimize: FALSE
* Average: macro
* Parameters: list()
* Properties: -
* Predict type: response

$score() can called without a given measure. In this case, classification defaults to classification error (classif.ce, which is one minus accuracy) and regression to the mean squared error (regr.mse).

It is possible to calculate multiple measures at the same time by passing a list to $score(). Such a list can easily be constructed using the “plural” msrs() function. If one wanted to have both the “true positive rate” ("classif.tpr") and the “true negative rate” ("classif.tnr"), one would use:

measures = msrs(c("classif.tpr", "classif.tnr"))
classif.tpr classif.tnr 
  0.4639175   0.8853503 

5.6.1 Confusion Matrix

A special case of performance evaluation is the confusion matrix, which shows, for each class, how many observations were predicted to be in that class and how many were actually in it (more information on Wikipedia). The entries along the diagonal denote the correctly classified observations.

response pos neg
     pos  45  18
     neg  52 139

In this case, we can see that our classifier seems to misclassify a relatively large number of positive samples as negative. In fact, a positive case is still more likely to be classified as "neg" than "pos'. Depending on the application being considered, it is possible that it is more important to keep false positives (lower left element of the confusion matrix) low. Lowering the threshold, so that ambiguous samples are more readily classified as positive rather than negative, can help in this case, although it will also lead to negative cases being classified as "pos" more often.

response pos neg
     pos  75  65
     neg  22  92

Thresholds can be tuned automatically with the mlr3pipelines package, i.e. using PipeOpTuneThreshold.

5.7 Plotting Predictions

Similarly to plotting tasks, mlr3viz provides an autoplot() method for Prediction objects. All available types are listed in the manual pages for autoplot.PredictionClassif(), autoplot.PredictionRegr() and the other prediction types (defined by extension packages).

task = tsk("penguins")
learner = lrn("classif.rpart", predict_type = "prob")
prediction = learner$predict(task)