No black-box model without XAI. This is where packages like
{flashlight} offers the following XAI methods:
light_performance()
: Performance metrics like RMSE
and/or \(R^2\)light_importance()
: Permutation variable importance
(Fisher, Rudin, and Dominici 2018)light_ice()
: Individual conditional expectation (ICE)
profiles (Goldstein et al. 2015) (centered or
uncentered)light_profile()
: Partial dependence (Friedman 2001), accumulated local
effects (ALE) (Apley and Zhu 2016), average
predicted/observed/residuallight_profile2d()
: Two-dimensional version of
light_profile()
light_effects()
: Combines partial dependence, ALE,
response and prediction profileslight_interaction()
: Different variants of Friedman’s H
statistics (Friedman and Popescu 2008)light_breakdown()
: Variable contribution breakdown
(approximate SHAP) for single observations (Gosiewska and Biecek
2019)light_global_surrogate()
: Global surrogate trees (Molnar
2019)Good to know:
flashlight
(see examples and Section “flashlights”).multiflashlight()
.plot()
visualizes the results via
{ggplot2}.# From CRAN
install.packages("flashlight")
# Development version
::install_github("mayer79/flashlight") devtools
Let’s start with an iris example. For simplicity, we do not split the data into training and testing/validation sets.
library(ggplot2)
library(MetricsWeighted)
library(flashlight)
<- lm(Sepal.Length ~ ., data = iris)
fit_lm
# Make explainer object
<- flashlight(
fl_lm model = fit_lm,
data = iris,
y = "Sepal.Length",
label = "lm",
metrics = list(RMSE = rmse, `R-squared` = r_squared)
)
|>
fl_lm light_performance() |>
plot(fill = "darkred") +
labs(x = element_blank(), title = "Performance on training data")
|>
fl_lm light_performance(by = "Species") |>
plot(fill = "darkred") +
ggtitle("Performance split by Species")
Error bars represent standard errors, i.e., the uncertainty of the estimated importance.
|>
fl_lm light_importance(m_repetitions = 4) |>
plot(fill = "darkred") +
labs(title = "Permutation importance", y = "Increase in RMSE")
Petal.Width
|>
fl_lm light_ice("Sepal.Width", n_max = 200) |>
plot(alpha = 0.3, color = "chartreuse4") +
labs(title = "ICE curves for 'Sepal.Width'", y = "Prediction")
|>
fl_lm light_ice("Sepal.Width", n_max = 200, center = "middle") |>
plot(alpha = 0.3, color = "chartreuse4") +
labs(title = "c-ICE curves for 'Sepal.Width'", y = "Prediction (centered)")
### PDPs
|>
fl_lm light_profile("Sepal.Width", n_bins = 40) |>
plot() +
ggtitle("PDP for 'Sepal.Width'")
|>
fl_lm light_profile("Sepal.Width", n_bins = 40, by = "Species") |>
plot() +
ggtitle("Same grouped by 'Species'")
|>
fl_lm light_profile2d(c("Petal.Width", "Petal.Length")) |>
plot()
|>
fl_lm light_profile("Sepal.Width", type = "ale") |>
plot() +
ggtitle("ALE plot for 'Sepal.Width'")
|>
fl_lm light_effects("Sepal.Width") |>
plot(use = "all") +
ggtitle("Different types of profiles for 'Sepal.Width'")
|>
fl_lm light_breakdown(new_obs = iris[1, ]) |>
plot()
|>
fl_lm light_global_surrogate() |>
plot()
### Multiple models
Multiple flashlights can be combined to a multiflashlight.
library(rpart)
<- rpart(
fit_tree ~ .,
Sepal.Length data = iris,
control = list(cp = 0, xval = 0, maxdepth = 5)
)
# Make explainer object
<- flashlight(
fl_tree model = fit_tree,
data = iris,
y = "Sepal.Length",
label = "tree",
metrics = list(RMSE = rmse, `R-squared` = r_squared)
)
# Combine with other explainer
<- multiflashlight(list(fl_tree, fl_lm))
fls
|>
fls light_performance() |>
plot(fill = "chartreuse4") +
labs(x = "Model", title = "Performance")
|>
fls light_importance() |>
plot(fill = "chartreuse4") +
labs(y = "Increase in RMSE", title = "Permutation importance")
|>
fls light_profile("Petal.Length", n_bins = 40) |>
plot() +
ggtitle("PDP")
|>
fls light_profile("Petal.Length", n_bins = 40, by = "Species") |>
plot() +
ggtitle("PDP by Species")
The “flashlight” explainer expects the following information:
model
: Fitted model. Currently, this argument must be
named.data
: Reference data used to calculate things, often
part of the validation data.y
: Column name in data
corresponding to
the numeric response.predict_function
: function of the same signature as
stats::predict()
. It takes a model
and a
data.frame data
, and provides numeric predictions, see
below for more details.linkinv
: Optional function applied to the output of
predict_function()
. Should actually be called
“trafo”.w
: Optional column name in data
corresponding to case weights.by
: Optional column name in data
used to
group the results. Must be discrete.metrics
: List of metrics, by default
list(rmse = MetricsWeighted::rmse)
. For binary
(probabilistic) classification, good candidate metrics would be
MetricsWeighted::logLoss
.label
: Mandatory name of the model.predict_function
s (a selection)The default stats::predict()
works for models of
class
lm()
,glm()
(for predictions on link scale), andrpart()
.It also works for meta-learner models like
Manual prediction functions are, e.g., required for
function(m, X) predict(m, X)$predictions
for regression, and
function(m, X) predict(m, X)$predictions[, 2]
for
probabilistic binary classificationglm()
: Use
function(m, X) predict(m, X, type = "response")
to get GLM
predictions at the response scaleA bit more complicated are models whose native predict function do not work on data.frames:
Example (XGBoost):
This works when non-numeric features are all factors (not categoricals):
<- vector of features
x = function(m, df) predict(m, data.matrix(df[x])) predict_function