The goal of this package is shed light on black box machine learning models.
The main props of {flashlight}:
Currently, models with numeric or binary response are supported.
# From CRAN
install.packages("flashlight")
# Development version
::install_github("mayer79/flashlight") devtools
Let’s start with an iris example. For simplicity, we do not split the data into training and testing/validation sets.
library(ggplot2)
library(MetricsWeighted)
library(flashlight)
<- lm(Sepal.Length ~ ., data = iris)
fit_lm
# Make explainer object
<- flashlight(
fl_lm model = fit_lm,
data = iris,
y = "Sepal.Length",
label = "lm",
metrics = list(RMSE = rmse, `R-squared` = r_squared)
)
|>
fl_lm light_performance() |>
plot(fill = "darkred") +
labs(x = element_blank(), title = "Performance on training data")
|>
fl_lm light_performance(by = "Species") |>
plot(fill = "darkred") +
ggtitle("Performance split by Species")
Error bars represent standard errors, i.e., the uncertainty of the estimated importance.
|>
fl_lm light_importance(m_repetitions = 4) |>
plot(fill = "darkred") +
labs(title = "Permutation importance", y = "Increase in RMSE")
Petal.Width
|>
fl_lm light_ice("Sepal.Width", n_max = 200) |>
plot(alpha = 0.3, color = "chartreuse4") +
labs(title = "ICE curves for 'Sepal.Width'", y = "Prediction")
|>
fl_lm light_ice("Sepal.Width", n_max = 200, center = "middle") |>
plot(alpha = 0.3, color = "chartreuse4") +
labs(title = "c-ICE curves for 'Sepal.Width'", y = "Prediction (centered)")
|>
fl_lm light_profile("Sepal.Width", n_bins = 40) |>
plot() +
ggtitle("PDP for 'Sepal.Width'")
|>
fl_lm light_profile("Sepal.Width", n_bins = 40, by = "Species") |>
plot() +
ggtitle("Same grouped by 'Species'")
|>
fl_lm light_profile2d(c("Petal.Width", "Petal.Length")) |>
plot()
|>
fl_lm light_profile("Sepal.Width", type = "ale") |>
plot() +
ggtitle("ALE plot for 'Sepal.Width'")
|>
fl_lm light_effects("Sepal.Width") |>
plot(use = "all") +
ggtitle("Different types of profiles for 'Sepal.Width'")
|>
fl_lm light_breakdown(new_obs = iris[1, ]) |>
plot()
|>
fl_lm light_global_surrogate() |>
plot()
Multiple flashlights can be combined to a multiflashlight.
library(rpart)
<- rpart(
fit_tree ~ .,
Sepal.Length data = iris,
control = list(cp = 0, xval = 0, maxdepth = 5)
)
# Make explainer object
<- flashlight(
fl_tree model = fit_tree,
data = iris,
y = "Sepal.Length",
label = "tree",
metrics = list(RMSE = rmse, `R-squared` = r_squared)
)
# Combine with other explainer
<- multiflashlight(list(fl_tree, fl_lm))
fls
|>
fls light_performance() |>
plot(fill = "chartreuse4") +
labs(x = "Model", title = "Performance")
|>
fls light_profile("Petal.Length", n_bins = 40, by = "Species") |>
plot() +
ggtitle("PDP by Species")
Check out the vignette for more information and important references.