Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Slightly more compact README #146

Merged
merged 2 commits into from
Sep 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# kernelshap 0.7.1

## Documentation

- More compact README.

# kernelshap 0.7.0

This release is intended to be the last before stable version 1.0.0.
Expand Down
53 changes: 20 additions & 33 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,34 +13,21 @@

## Overview

The package contains two workhorses to calculate SHAP values for *any* model:
The package contains three functions to crunch SHAP values:

- `permshap()`: Exact permutation SHAP algorithm of [1]. Recommended for up to 8-10 features.
- `kernelshap()`: Kernel SHAP algorithm of [2] and [3]. By default, exact Kernel SHAP is used for up to $p=8$ features, and an almost exact hybrid algorithm otherwise.
- `permshap()`: Exact permutation SHAP algorithm of [1]. Recommended for models with up to 8 features.
- `kernelshap()`: Kernel SHAP algorithm of [2] and [3]. Recommended for models with more than 8 features.
- `additive_shap()`: For *additive models* fitted via `lm()`, `glm()`, `mgcv::gam()`, `mgcv::bam()`, `gam::gam()`, `survival::coxph()`, or `survival::survreg()`. Exponentially faster than the model-agnostic options above, and recommended if possible.

Furthermore, the function `additive_shap()` produces SHAP values for additive models fitted via `lm()`, `glm()`, `mgcv::gam()`, `mgcv::bam()`, `gam::gam()`,
`survival::coxph()`, or `survival::survreg()`. It is exponentially faster than `permshap()` and `kernelshap()`, with identical results when the background dataset of the latter equals the full training data.
To explain your model, select an explanation dataset `X` (up to 1000 rows from the training data) and apply the recommended function. Use {shapviz} to visualize the resulting SHAP values.

### Kernel SHAP or permutation SHAP?
**Remarks for `permshap()` and `kernelshap()`**

Kernel SHAP has been introduced in [2] as an approximation of permutation SHAP [1]. For up to ten features, exact calculations are realistic for both algorithms. Since exact Kernel SHAP is still only an approximation of exact permutation SHAP, the latter should be preferred in this case, even if the results are often very similar.

A situation where the two approaches give different results: The model has interactions of order three or higher.

### Typical workflow to explain any model

1. **Sample rows to explain:** Sample 500 to 2000 rows `X` to be explained. If the training dataset is small, simply use the full training data for this purpose. `X` should only contain feature columns.
2. **Select background data (optional):** Both algorithms require a representative background dataset `bg_X` to calculate marginal means. For this purpose, set aside 50 to 500 rows from the training data.
If the training data is small, use the full training data. In cases with a natural "off" value (like MNIST digits), this can also be a single row with all values set to the off value. If not specified, maximum `bg_n = 200` rows are randomly sampled from `X`.
3. **Crunch:** Use `kernelshap(object, X, bg_X = NULL, ...)` or `permshap(object, X, bg_X = NULL, ...)` to calculate SHAP values. Runtime is proportional to `nrow(X)`, while memory consumption scales linearly in `nrow(bg_X)`.
4. **Analyze:** Use {shapviz} to visualize the results.

**Remarks**

- Multivariate predictions are handled at no additional computational cost.
- Case weights are supported via the argument `bg_w`.
- By changing the defaults in `kernelshap()`, the iterative pure sampling approach in [3] can be enforced.
- The `additive_shap()` explainer is easier to use: Only the model and `X` are required.
- `X` should only contain feature columns.
- Both algorithms need a representative background data `bg_X` to calculate marginal means (up to 500 rows from the training data). In cases with a natural "off" value (like MNIST digits), this can also be a single row with all values set to the off value. If unspecified, 200 rows are randomly sampled from `X`.
- By changing the defaults in `kernelshap()`, the iterative pure sampling approach of [3] can be enforced.
- `permshap()` vs. `kernelshap()`: For models with interactions of order up to two, exact Kernel SHAP agrees with exact permutation SHAP.
- `additive_shap()` vs. the model-agnostic explainers: The results would agree if the full training data would be used as background data.

## Installation

Expand All @@ -64,7 +51,7 @@ library(shapviz)

diamonds <- transform(
diamonds,
log_price = log(price),
log_price = log(price),
log_carat = log(carat)
)

Expand All @@ -82,11 +69,10 @@ fit # OOB R-squared 0.989
set.seed(10)
X <- diamonds[sample(nrow(diamonds), 1000), xvars]

# 2) Optional: Select background data. If not specified, a random sample of 200 rows
# from X is used
# 2) Optional: Select background data. If unspecified, 200 rows from X are used
bg_X <- diamonds[sample(nrow(diamonds), 200), ]

# 3) Crunch SHAP values for all 1000 rows of X (22 seconds)
# 3) Crunch SHAP values (22 seconds)
# Note: Since the number of features is small, we use permshap()
system.time(
ps <- permshap(fit, X, bg_X = bg_X)
Expand All @@ -99,15 +85,15 @@ ps
[2,] -0.4931989 -0.11724773 0.09868921 0.028563613

# Kernel SHAP gives almost the same:
system.time( # 49 s
system.time( # 22 s
ks <- kernelshap(fit, X, bg_X = bg_X)
)
ks
# log_carat clarity color cut
# [1,] 1.1911791 0.0900462 -0.13531648 0.001845958
# [2,] -0.4927482 -0.1168517 0.09815062 0.028255442

# 4) Analyze with our sister package {shapviz}
# 4) Analyze with {shapviz}
ps <- shapviz(ps)
sv_importance(ps)
sv_dependence(ps, xvars)
Expand All @@ -123,9 +109,9 @@ sv_dependence(ps, xvars)

### Parallel computing

Parallel computing is supported via {foreach}. Note that this does not work with all models, and that there is no progress bar.
Parallel computing for `permshap()` and `kernelshap()` is supported via {foreach}. Note that this does not work for all models.

On Windows, sometimes not all packages or global objects are passed to the parallel sessions. Often, this can be fixed via `parallel_args`, see the generalized additive model below.
On Windows, sometimes not all packages or global objects are passed to the parallel sessions. Often, this can be fixed via `parallel_args`, see this example:

```r
library(doFuture)
Expand All @@ -135,6 +121,7 @@ registerDoFuture()
plan(multisession, workers = 4) # Windows
# plan(multicore, workers = 4) # Linux, macOS, Solaris

# GAM with interactions - we cannot use additive_shap()
fit <- gam(log_price ~ s(log_carat) + clarity * color + cut, data = diamonds)

system.time( # 4 seconds in parallel
Expand Down Expand Up @@ -174,7 +161,7 @@ In this {keras} example, we show how to use a tailored `predict()` function that
- uses sufficiently large batches, and
- turns off the Keras progress bar.

The results are not fully reproducible though.
(The results are not fully reproducible.)

```r
library(keras)
Expand Down
Loading