10  Model Interpretation

Authors
Affiliations
Przemysław Biecek

MI2.AI, Warsaw University of Technology

Author 2

Affiliation 2

Abstract
The goal of this chapter is to present key methods that allow in-depth posthoc analysis of an already trained model. The methods presented are model-agnostic, i.e. they can be applied to models of different classes. When using predictive models in practice, it is often the case that high performance on a validation set is not enough. Users more and more often want to know which variables are important and how they influence the model’s predictions. For the end user, such knowledge allows better utilisation of models in the decision-making process, e.g. by analysing different possible decision options. In addition, if the model’s behaviour turns out to be in line with the domain knowledge or the user’s intuition then the user’s confidence in the prediction will increase. For the modeller, an in-depth analysis of the model allows undesirable model behaviour to be detected and corrected.

Predictive models have numerous applications in virtually every area of life. The increasing availability of data and frameworks to create models has allowed the widespread adoption of these solutions. However, this does not always go together with enough testing of the models and the consequences of incorrect predictions can be severe. The bestseller book ,,Weapons of Math Destruction’’ (O’Neil 2016) discusses examples of deployed black-boxes that have led to wrong-headed decisions, sometimes on a massive scale. So what can we do to make our models more thoroughly tested? The answer is methods that allow deeper interpretation of predictive models. In this chapter, we will provide illustrations of how to perform the most popular of these methods (Holzinger et al. 2022).

In principle, all generic frameworks for model interpretation apply to the models fitted with mlr3 by just extracting the fitted models from the Learner objects.

However, two of the most popular frameworks additionally come with some convenience for mlr3, these are

Both these packages offer similar functionality, but they differ in design choices. iml is based on the R6 class system and for this reason working with it is more similar in style to working with the mlr3 package. DALEX is based on the S3 class system and is mainly focused on the ability to compare multiple different models on the same graph for comparison and on the explainable model analysis process.

10.1 Penguin Task

To understand what model interpretation packages can offer, we start with a thorough example. The goal of this example is to figure out the species of penguins given a set of features. The palmerpenguins::penguins (Horst, Hill, and Gorman 2020) data set will be used which is an alternative to the iris data set. The penguins data sets contain 8 variables of 344 penguins:

data("penguins", package = "palmerpenguins")
str(penguins)
tibble [344 × 8] (S3: tbl_df/tbl/data.frame)
 $ species          : Factor w/ 3 levels "Adelie","Chinstrap",..: 1 1 1 1 1 1 1 1 1 1 ...
 $ island           : Factor w/ 3 levels "Biscoe","Dream",..: 3 3 3 3 3 3 3 3 3 3 ...
 $ bill_length_mm   : num [1:344] 39.1 39.5 40.3 NA 36.7 39.3 38.9 39.2 34.1 42 ...
 $ bill_depth_mm    : num [1:344] 18.7 17.4 18 NA 19.3 20.6 17.8 19.6 18.1 20.2 ...
 $ flipper_length_mm: int [1:344] 181 186 195 NA 193 190 181 195 193 190 ...
 $ body_mass_g      : int [1:344] 3750 3800 3250 NA 3450 3650 3625 4675 3475 4250 ...
 $ sex              : Factor w/ 2 levels "female","male": 2 1 1 NA 1 2 1 2 NA NA ...
 $ year             : int [1:344] 2007 2007 2007 2007 2007 2007 2007 2007 2007 2007 ...

To get started run:

library("mlr3")
library("mlr3learners")
set.seed(1)
penguins = na.omit(penguins)
task_peng = as_task_classif(penguins, target = "species")

penguins = na.omit(penguins) is to omit the 11 cases with missing values. If not omitted, there will be an error when running the learner from the data points that have N/A for some features.

learner = lrn("classif.ranger")
learner$predict_type = "prob"
learner$train(task_peng)
learner$model
Ranger result

Call:
 ranger::ranger(dependent.variable.name = task$target_names, data = task$data(),      probability = self$predict_type == "prob", case.weights = task$weights$weight,      num.threads = 1L) 

Type:                             Probability estimation 
Number of trees:                  500 
Sample size:                      333 
Number of independent variables:  7 
Mtry:                             2 
Target node size:                 10 
Variable importance mode:         none 
Splitrule:                        gini 
OOB prediction error (Brier s.):  0.01790106 
x = penguins[which(names(penguins) != "species")]

As explained in Section Learners, specific learners can be queried with mlr_learners. In Section Train/Predict it is recommended for some classifiers to use the predict_type as prob instead of directly predicting a label. This is what is done in this example. penguins[which(names(penguins) != "species")] is the data of all the features and y will be the penguinsspecies. learner$train(task_peng) trains the model and learner$model stores the model from the training command. Predictor holds the machine learning model and the data. All interpretation methods in iml need the machine learning model and the data to be wrapped in the Predictor object.

10.2 iml

Author: Shawn Storm

iml is an R package that interprets the behaviour and explains predictions of machine learning models. The functions provided in the iml package are model-agnostic which gives the flexibility to use any machine learning model.

This chapter provides examples of how to use iml with mlr3. For more information refer to the IML github and the IML book

Next is the core functionality of iml. In this example, three separate interpretation methods will be used: FeatureEffects, FeatureImp and Shapley

10.2.1 FeatureEffects

In addition to the commands above the following two need to be run:

library("iml")

model = Predictor$new(learner, data = x, y = penguins$species)

num_features = c("bill_length_mm", "bill_depth_mm", "flipper_length_mm", "body_mass_g", "year")
effect = FeatureEffects$new(model)
plot(effect, features = num_features)

Plot of the results from FeatureEffects. FeatureEffects computes and plots feature effects of prediction models

effect stores the object from the FeatureEffect computation and the results can then be plotted. In this example, all of the features provided by the penguins data set were used.

All features except for year provide meaningful interpretable information. It should be clear why year doesn’t provide anything of significance. bill_length_mm shows for example that when the bill length is smaller than roughly 40mm, there is a high chance that the penguin is an Adelie.

10.2.2 Shapley

x = penguins[which(names(penguins) != "species")]
model = Predictor$new(learner, data = penguins, y = "species")
x.interest = data.frame(penguins[1, ])
shapley = Shapley$new(model, x.interest = x.interest)
plot(shapley)

Plot of the results from Shapley. \(\phi\) gives the increase or decrease in probability given the values on the vertical axis

The \(\phi\) provides insight into the probability given the values on the vertical axis. For example, a penguin is less likely to be Gentoo if the bill_depth=18.7 is and much more likely to be Adelie than Chinstrap.

10.2.3 FeatureImp

effect = FeatureImp$new(model, loss = "ce")
effect$plot(features = num_features)

Plot of the results from FeatureImp. FeatureImp visualizes the importance of features given the prediction model

FeatureImp shows the level of importance of the features when classifying penguins. It is clear to see that the bill_length_mm is of high importance and one should concentrate on the different boundaries of this feature when attempting to classify the three species.

10.2.4 Independent Test Data

It is also interesting to see how well the model performs on a test data set. For this section, exactly as was recommended in Section Train/Predict, 80% of the penguin data set will be used for the training set and 20% for the test set:

train_set = sample(task_peng$nrow, 0.8 * task_peng$nrow)
test_set = setdiff(seq_len(task_peng$nrow), train_set)
learner$train(task_peng, row_ids = train_set)
prediction = learner$predict(task_peng, row_ids = test_set)

First, we compare the feature importance on training and test set

# plot on training
model = Predictor$new(learner, data = penguins[train_set, ], y = "species")
effect = FeatureImp$new(model, loss = "ce")
plot_train = plot(effect, features = num_features)

# plot on test data
model = Predictor$new(learner, data = penguins[test_set, ], y = "species")
effect = FeatureImp$new(model, loss = "ce")
plot_test = plot(effect, features = num_features)

# combine into single plot
library("patchwork")
plot_train + plot_test

FeatImp on train (left) and test (right)

The results of the train set for FeatureImp are very similar, which is expected. We follow a similar approach to compare the feature effects:

model = Predictor$new(learner, data = penguins[train_set, ], y = "species")
effect = FeatureEffects$new(model)
plot(effect, features = num_features)

FeatEffect train data set

model = Predictor$new(learner, data = penguins[test_set, ], y = "species")
effect = FeatureEffects$new(model)
plot(effect, features = num_features)

FeatEffect test data set

As is the case with FeatureImp, the test data results show either an over- or underestimate of feature importance / feature effects compared to the results where the entire penguin data set was used. This would be a good opportunity for the reader to attempt to resolve the estimation by playing with the amount of features and the amount of data used for both the test and train data sets of FeatureImp and FeatureEffects. Be sure to not change the line train_set = sample(task_peng$nrow, 0.8 * task_peng$nrow) as it will randomly sample the data again.

10.3 DALEX

The DALEX (Biecek 2018) package belongs to DrWhy family of solutions created to support the responsible development of machine learning models. It implements the most common methods for explaining predictive models using posthoc model agnostic techniques. You can use it for any model built with the mlr3 package as well as with other frameworks in R. The counterpart in Python is the library dalex (Baniecki et al. 2021).

The philosophy of working with DALEX package is based on the process of explanatory model analysis described in the EMA book (Biecek and Burzykowski 2021). In this chapter, we present code snippets and a general overview of this package. For illustrative purposes, we reuse the learner model built in the Section 10.1 on palmerpenguins::penguins data.

Once you become familiar with the philosophy of working with the DALEX package, you can also use other packages from this family such as fairmodels (Wiśniewski and Biecek 2022) for detection and mitigation of biases, modelStudio (Baniecki and Biecek 2019) for interactive model exploration, modelDown (Romaszko et al. 2019) for the automatic generation of IML model documentation in the form of a report, survex (Krzyziński et al. 2023) for the explanation of survival models, or treeshap for the analysis of tree-based models.

10.3.1 Explanatory model analysis

The analysis of a model is usually an interactive process starting with a shallow analysis – usually a single-number summary. Then in a series of subsequent steps, one can systematically deepen understanding of the model by exploring the importance of single variables or pairs of variables to an in-depth analysis of the relationship between selected variables to the model outcome. See Bücker et al. (2022) for a broader discussion of what the model exploration process looks like.

This explanatory model analysis (EMA) process can focus on a single observation, in which case we speak of local model analysis, or for a set of observations, in which case we speak of global data analysis. Below, we will present these two scenarios in separate subsections. See Figure 10.1 for an overview of key functions that will be discussed.

Figure 10.1: Taxonomy of methods for model exploration presented in this chapter. Left part overview methods for global level exploration while the right part is related to local level model exploration.

Predictive models in R have different internal structures. To be able to analyse them systematically, an intermediate object – a wrapper – is needed to provide a consistent interface for accessing the model. Working with explanations in the DALEX package always starts with the creation of such a wrapper with the use of the DALEX::explain() function. This function has several arguments that allow the model created by the various frameworks to be parameterised accordingly. For models created in the mlr3 package, it is more convenient to use the DALEXtra::explain_mlr3().

library("DALEX")
library("DALEXtra")

ranger_exp = DALEX::explain(learner,
  data = penguins[test_set, ],
  y = penguins[test_set, "species"],
  label = "Ranger Penguins",
  colorize = FALSE)
Preparation of a new explainer is initiated
  -> model label       :  Ranger Penguins 
  -> data              :  67  rows  8  cols 
  -> data              :  tibble converted into a data.frame 
  -> target variable   :  Argument 'y' was a data frame. Converted to a vector. (  WARNING  )
  -> target variable   :  67  values 
  -> predict function  :  yhat.LearnerClassif  will be used (  default  )
  -> predicted values  :  No value for predict function target column. (  default  )
  -> model_info        :  package mlr3 , ver. 0.14.1 , task multiclass (  default  ) 
  -> predicted values  :  predict function returns multiple columns:  3  (  default  ) 
  -> residual function :  difference between 1 and probability of true class (  default  )
  -> residuals         :  numerical, min =  0 , mean =  0.07756016 , max =  0.5380321  
  A new explainer has been created!  

The DALEX::explain() function performs a series of internal checks so the output is a bit verbose. Turn the verbose = FALSE argument to make it less wordy.

10.3.2 Global level exploration

The global model analysis aims to understand how a model behaves on average on a set of observations, most commonly a test set. In the DALEX package, functions for global analysis have names starting with the prefix model_.

10.3.2.1 Model Performance

As shown in Figure Figure 10.1, it starts by evaluating the performance of a model. This can be done with a variety of tools, in the DALEX package the default is to use the DALEX::model_performance function. Since the explain function checks what type of task is being analysed, it can select the appropriate performance measures for it. In our illustration, we have a multi-label classification, so measures such as micro-aggregated F1, macro-aggregated F1 etc. are calculated in the following snippet. One of the calculated measures is cross entropy and it will be used later in the following sections.

Each explanation can be drawn with the generic plot() function, for multi-label classification the distribution of residuals is drawn by default.

perf_penguin = model_performance(ranger_exp)
perf_penguin
Measures for:  multiclass
micro_F1   : 1 
macro_F1   : 1 
w_macro_F1 : 1 
accuracy   : 1 
w_macro_auc: 1 
cross_entro: 6.034954

Residuals:
          0%          10%          20%          30%          40%          50% 
0.0000000000 0.0005846154 0.0036863492 0.0111489133 0.0315985873 0.0440341048 
         60%          70%          80%          90%         100% 
0.0535907937 0.0683762754 0.0956176783 0.2191798413 0.5380321429 
library("ggplot2")
old_theme = set_theme_dalex("ema") 
plot(perf_penguin)

The task of classifying the penguin species is rather easy, which is why there are so many values of 1 in the performance assessment of this model.

10.3.2.2 Permutational Variable Importance

A popular technique for assessing variable importance in a model-agnostic manner is the permutation variable importance. It is based on the difference (or ratio) in the selected loss function after the selected variable or set of variables has been permuted. Read more about this technique in Variable-importance Measures chapter.

The DALEX::model_parts() function calculates the importance of variables and its results can be visualized with the generic plot() function.

ranger_effect = model_parts(ranger_exp)
head(ranger_effect)
       variable mean_dropout_loss           label
1  _full_model_          6.034954 Ranger Penguins
2          year          5.988560 Ranger Penguins
3       species          6.034954 Ranger Penguins
4           sex          7.002289 Ranger Penguins
5   body_mass_g         12.377824 Ranger Penguins
6 bill_depth_mm         15.617252 Ranger Penguins
plot(ranger_effect, show_boxplots = FALSE) 

The bars start in loss (here cross-entropy loss) for the selected data and end in a loss for the data after the permutation of the selected variable. The more important the variable, the more the model will lose after its permutation.

10.3.2.3 Partial Dependence

Once we know which variables are most important, we can use Partial Dependence Plots to show how the model, on average, changes with changes in selected variables.

The DALEX::model_profile() function calculates the partial dependence profiles. The type argument of this function also allows Marginal profiles and Accumulated Local profiles to be calculated. Again, the result of the explanation can be model_profile with the generic function plot().

ranger_profiles = model_profile(ranger_exp)
ranger_profiles
Top profiles    : 
        _vname_                   _label_    _x_    _yhat_ _ids_
1 bill_depth_mm    Ranger Penguins.Adelie 13.500 0.2839077     0
2 bill_depth_mm Ranger Penguins.Chinstrap 13.500 0.1908264     0
3 bill_depth_mm    Ranger Penguins.Gentoo 13.500 0.5252659     0
4 bill_depth_mm    Ranger Penguins.Adelie 13.566 0.2839077     0
5 bill_depth_mm Ranger Penguins.Chinstrap 13.566 0.1908264     0
6 bill_depth_mm    Ranger Penguins.Gentoo 13.566 0.5252659     0
plot(ranger_profiles) + 
  theme(legend.position = "top") + 
  ggtitle("Partial Dependence for Penguins","")

For the multi-label classification model, profiles are drawn for each class separately by indicating them with different colours. We already know which variable is the most important, so now we can read how the model result changes with the change of this variable. In our example, based on bill_length_mm we can separate Adelie from Chinstrap and based on flipper_length_mm we can separate Adelie from Gentoo.

10.3.3 Local level explanation

The local model analysis aims to understand how a model behaves for a single observation. In the DALEX package, functions for local analysis have names starting with the prefix predict_.

We will carry out the following examples using Steve the penguin of the Adelie species as an example.

steve = penguins[1,]
steve
# A tibble: 1 × 8
  species island    bill_length_mm bill_depth_mm flipper_l…¹ body_…² sex    year
  <fct>   <fct>              <dbl>         <dbl>       <int>   <int> <fct> <int>
1 Adelie  Torgersen           39.1          18.7         181    3750 male   2007
# … with abbreviated variable names ¹​flipper_length_mm, ²​body_mass_g

10.3.3.1 Model Prediction

As shown in Figure Figure 10.1, the local analysis starts with the calculation of a model prediction.

For Steve, the species was correctly predicted as Adelie with high probability.

predict(ranger_exp, steve)
        Adelie   Chinstrap Gentoo
[1,] 0.9900897 0.009910317      0

10.3.3.2 Break Down

A popular technique for assessing the contributions of variables to model prediction is Break Down (see Introduction to Break Down chapter for more information about this method).

The function DALEX::predict_parts() function calculates the attributions of variables and its results can be visualized with the generic plot() function.

ranger_attributions = predict_parts(ranger_exp, new_observation = steve)
plot(ranger_attributions) + ggtitle("Break Down for Steve") 

Looking at the plots above, we can read that the biggest contributors to the final prediction were for Steve the variables bill length and flipper.

10.3.3.3 Shapley Values

By far the most popular technique for local model exploration (Holzinger et al. 2022) is Shapley values and the most popular algorithm for estimating these values is the SHAP algorithm. Find a detailed description of the method and algorithm in the chapter SHapley Additive exPlanations (SHAP).

The function DALEX::predict_parts() calculates SHAP attributions, you just need to set type = "shap". Its results can be visualized with a generic plot() function.

ranger_shap = predict_parts(ranger_exp, new_observation = steve, 
             type = "shap")
plot(ranger_shap, show_boxplots = FALSE) + 
             ggtitle("Shapley values for Steve", "") 

The results for Break Down and SHAP methods are generally similar. Differences will emerge if there are many complex interactions in the model.

10.3.3.4 Ceteris Paribus

In the previous section, we’ve introduced a global explanation – Partial Dependence plots. Ceteris Paribus plots are the local level version of that plot. Read more about this technique in the chapter Ceteris Paribus and note that these profiles are also called Individual Conditional Expectations (ICE). They show the response of a model when only one variable is changed while others stay unchanged.

The function DALEX::predict_profile() calculates Ceteris paribus profiles which can be visualized with the generic plot() function.

ranger_ceteris = predict_profile(ranger_exp, steve)
plot(ranger_ceteris) + ggtitle("Ceteris paribus for Steve", " ") 

Blue dot stands for the prediction for Steve. Only a big change in bill length could convince the model of Steve’s different species.