Skip to content

Commit 9808196

Browse files
authored
Merge pull request #132 from ModelOriented/additive-shap
Additive shap
2 parents 559ba5c + d1c6958 commit 9808196

23 files changed

+387
-176
lines changed

DESCRIPTION

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: kernelshap
22
Title: Kernel SHAP
3-
Version: 0.4.2
3+
Version: 0.5.0
44
Authors@R: c(
55
person("Michael", "Mayer", , "[email protected]", role = c("aut", "cre")),
66
person("David", "Watson", , "[email protected]", role = "aut"),
@@ -19,7 +19,7 @@ Depends:
1919
R (>= 3.2.0)
2020
Encoding: UTF-8
2121
Roxygen: list(markdown = TRUE)
22-
RoxygenNote: 7.2.3
22+
RoxygenNote: 7.3.1
2323
Imports:
2424
foreach,
2525
stats,

NAMESPACE

+1-3
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,9 @@ S3method(kernelshap,ranger)
55
S3method(permshap,default)
66
S3method(permshap,ranger)
77
S3method(print,kernelshap)
8-
S3method(print,permshap)
98
S3method(summary,kernelshap)
10-
S3method(summary,permshap)
9+
export(additive_shap)
1110
export(is.kernelshap)
12-
export(is.permshap)
1311
export(kernelshap)
1412
export(permshap)
1513
importFrom(foreach,"%dopar%")

NEWS.md

+23
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,26 @@
1+
# kernelshap 0.5.0
2+
3+
## New features
4+
5+
New additive explainer `additive_shap()` that works for models fitted via
6+
7+
- `lm()`,
8+
- `glm()`,
9+
- `mgcv::gam()`,
10+
- `mgcv::bam()`,
11+
- `gam::gam()`,
12+
- `survival::coxph()`,
13+
- `survival::survreg()`.
14+
15+
The explainer uses `predict(..., type = "terms")`, a beautiful trick
16+
used in `fastshap::explain.lm()`. The result will be identical to those returned by `kernelshap()` and `permshap()` but exponentially faster. Thanks David Watson for the great idea discussed in [#130](https://github.com/ModelOriented/kernelshap/issues/130).
17+
18+
## User visible changes
19+
20+
- `permshap()` now returns an object of class "kernelshap" to reduce the number of redundant methods.
21+
- To distinguish which algorithm has generated the "kernelshap" object, the outputs of `kernelshap()`, `permshap()` (and `additive_shap()`) got an element "algorithm".
22+
- `is.permshap()` has been removed.
23+
124
# kernelshap 0.4.2
225

326
## API

R/additive_shap.R

+95
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
#' Additive SHAP
2+
#'
3+
#' Exact additive SHAP assuming feature independence. The implementation
4+
#' works for models fitted via
5+
#' - [lm()],
6+
#' - [glm()],
7+
#' - [mgcv::gam()],
8+
#' - [mgcv::bam()],
9+
#' - [gam::gam()],
10+
#' - [survival::coxph()], and
11+
#' - [survival::survreg()].
12+
#'
13+
#' The SHAP values are extracted via `predict(object, newdata = X, type = "terms")`,
14+
#' a logic heavily inspired by `fastshap:::explain.lm(..., exact = TRUE)`.
15+
#' Models with interactions (specified via `:` or `*`), or with terms of
16+
#' multiple features like `log(x1/x2)` are not supported.
17+
#'
18+
#' @inheritParams kernelshap
19+
#' @param X Dataframe with rows to be explained. Will be used like
20+
#' `predict(object, newdata = X, type = "terms")`.
21+
#' @param ... Currently unused.
22+
#' @returns
23+
#' An object of class "kernelshap" with the following components:
24+
#' - `S`: \eqn{(n \times p)} matrix with SHAP values.
25+
#' - `X`: Same as input argument `X`.
26+
#' - `baseline`: The baseline.
27+
#' - `exact`: `TRUE`.
28+
#' - `txt`: Summary text.
29+
#' - `predictions`: Vector with predictions of `X` on the scale of "terms".
30+
#' - `algorithm`: "additive_shap".
31+
#' @export
32+
#' @examples
33+
#' # MODEL ONE: Linear regression
34+
#' fit <- lm(Sepal.Length ~ ., data = iris)
35+
#' s <- additive_shap(fit, head(iris))
36+
#' s
37+
#'
38+
#' # MODEL TWO: More complicated (but not very clever) formula
39+
#' fit <- lm(
40+
#' Sepal.Length ~ poly(Sepal.Width, 2) + log(Petal.Length) + log(Sepal.Width),
41+
#' data = iris
42+
#' )
43+
#' s <- additive_shap(fit, head(iris))
44+
#' s
45+
additive_shap <- function(object, X, verbose = TRUE, ...) {
46+
stopifnot(
47+
inherits(object, c("lm", "glm", "gam", "bam", "Gam", "coxph", "survreg"))
48+
)
49+
if (any(attr(stats::terms(object), "order") > 1)) {
50+
stop("Additive SHAP not appropriate for models with interactions.")
51+
}
52+
53+
txt <- "Exact additive SHAP via predict(..., type = 'terms')"
54+
if (verbose) {
55+
message(txt)
56+
}
57+
58+
S <- stats::predict(object, newdata = X, type = "terms")
59+
rownames(S) <- NULL
60+
61+
# Baseline value
62+
b <- as.vector(attr(S, "constant"))
63+
if (is.null(b)) {
64+
b <- 0
65+
}
66+
67+
# Which columns of X are used in each column of S?
68+
s_names <- colnames(S)
69+
cols_used <- lapply(s_names, function(z) all.vars(stats::reformulate(z)))
70+
if (any(lengths(cols_used) > 1L)) {
71+
stop("The formula contains terms with multiple features (not supported).")
72+
}
73+
74+
# Collapse all columns in S using the same column in X and rename accordingly
75+
mapping <- split(
76+
s_names, factor(unlist(cols_used), levels = colnames(X)), drop = TRUE
77+
)
78+
S <- do.call(
79+
cbind,
80+
lapply(mapping, function(z) rowSums(S[, z, drop = FALSE], na.rm = TRUE))
81+
)
82+
83+
structure(
84+
list(
85+
S = S,
86+
X = X,
87+
baseline = b,
88+
exact = TRUE,
89+
txt = txt,
90+
predictions = b + rowSums(S),
91+
algorithm = "additive_shap"
92+
),
93+
class = "kernelshap"
94+
)
95+
}

R/kernelshap.R

+3-1
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@
141141
#' - `exact`: Logical flag indicating whether calculations are exact or not.
142142
#' - `txt`: Summary text.
143143
#' - `predictions`: \eqn{(n \times K)} matrix with predictions of `X`.
144+
#' - `algorithm`: "kernelshap".
144145
#' @references
145146
#' 1. Scott M. Lundberg and Su-In Lee. A unified approach to interpreting model
146147
#' predictions. Proceedings of the 31st International Conference on Neural
@@ -318,7 +319,8 @@ kernelshap.default <- function(object, X, bg_X, pred_fun = stats::predict,
318319
prop_exact = prop_exact,
319320
exact = exact || trunc(p / 2) == hybrid_degree,
320321
txt = txt,
321-
predictions = v1
322+
predictions = v1,
323+
algorithm = "kernelshap"
322324
)
323325
class(out) <- "kernelshap"
324326
out

R/methods.R

+4-48
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,6 @@ print.kernelshap <- function(x, n = 2L, ...) {
1616
invisible(x)
1717
}
1818

19-
#' Prints "permshap" Object
20-
#'
21-
#' @param x An object of class "permshap".
22-
#' @inheritParams print.kernelshap
23-
#' @inherit print.kernelshap return
24-
#' @export
25-
#' @examples
26-
#' fit <- lm(Sepal.Length ~ ., data = iris)
27-
#' s <- permshap(fit, iris[1:3, -1], bg_X = iris[, -1])
28-
#' s
29-
#' @seealso [permshap()]
30-
print.permshap <- function(x, n = 2L, ...) {
31-
print.kernelshap(x, n = n, ...)
32-
}
33-
3419
#' Summarizes "kernelshap" Object
3520
#'
3621
#' @param object An object of class "kernelshap".
@@ -67,7 +52,10 @@ summary.kernelshap <- function(object, compact = FALSE, n = 2L, ...) {
6752
"\n - m/iter:", getElement(object, "m")
6853
)
6954
}
70-
cat("\n - m_exact:", getElement(object, "m_exact"))
55+
m_exact <- getElement(object, "m_exact")
56+
if (!is.null(m_exact)) {
57+
cat("\n - m_exact:", m_exact)
58+
}
7159
if (!compact) {
7260
cat("\n\nSHAP values of first observations:\n")
7361
print(head_list(S, n = n))
@@ -79,21 +67,6 @@ summary.kernelshap <- function(object, compact = FALSE, n = 2L, ...) {
7967
invisible(object)
8068
}
8169

82-
#' Summarizes "permshap" Object
83-
#'
84-
#' @param object An object of class "permshap".
85-
#' @inheritParams summary.kernelshap
86-
#' @inherit summary.kernelshap return
87-
#' @export
88-
#' @examples
89-
#' fit <- lm(Sepal.Length ~ ., data = iris)
90-
#' s <- permshap(fit, iris[1:3, -1], bg_X = iris[, -1])
91-
#' summary(s)
92-
#' @seealso [permshap()]
93-
summary.permshap <- function(object, compact = FALSE, n = 2L, ...) {
94-
summary.kernelshap(object, compact = compact, n = n, ...)
95-
}
96-
9770
#' Check for kernelshap
9871
#'
9972
#' Is object of class "kernelshap"?
@@ -110,20 +83,3 @@ summary.permshap <- function(object, compact = FALSE, n = 2L, ...) {
11083
is.kernelshap <- function(object){
11184
inherits(object, "kernelshap")
11285
}
113-
114-
#' Check for permshap
115-
#'
116-
#' Is object of class "permshap"?
117-
#'
118-
#' @param object An R object.
119-
#' @returns `TRUE` if `object` is of class "permshap", and `FALSE` otherwise.
120-
#' @export
121-
#' @examples
122-
#' fit <- lm(Sepal.Length ~ ., data = iris)
123-
#' s <- permshap(fit, iris[1:2, -1], bg_X = iris[, -1])
124-
#' is.permshap(s)
125-
#' is.permshap("a")
126-
#' @seealso [kernelshap()]
127-
is.permshap <- function(object){
128-
inherits(object, "permshap")
129-
}

R/permshap.R

+5-3
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#'
66
#' @inheritParams kernelshap
77
#' @returns
8-
#' An object of class "permshap" with the following components:
8+
#' An object of class "kernelshap" with the following components:
99
#' - `S`: \eqn{(n \times p)} matrix with SHAP values or, if the model output has
1010
#' dimension \eqn{K > 1}, a list of \eqn{K} such matrices.
1111
#' - `X`: Same as input argument `X`.
@@ -16,6 +16,7 @@
1616
#' (currently `TRUE`).
1717
#' - `txt`: Summary text.
1818
#' - `predictions`: \eqn{(n \times K)} matrix with predictions of `X`.
19+
#' - `algorithm`: "permshap".
1920
#' @references
2021
#' 1. Erik Strumbelj and Igor Kononenko. Explaining prediction models and individual
2122
#' predictions with feature contributions. Knowledge and Information Systems 41, 2014.
@@ -141,9 +142,10 @@ permshap.default <- function(object, X, bg_X, pred_fun = stats::predict,
141142
m_exact = m_exact,
142143
exact = TRUE,
143144
txt = txt,
144-
predictions = v1
145+
predictions = v1,
146+
algorithm = "permshap"
145147
)
146-
class(out) <- "permshap"
148+
class(out) <- "kernelshap"
147149
out
148150
}
149151

R/utils.R

+2-1
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,8 @@ case_p1 <- function(n, feature_names, v0, v1, X, verbose) {
226226
prop_exact = 1,
227227
exact = TRUE,
228228
txt = txt,
229-
predictions = v1
229+
predictions = v1,
230+
algorithm = "kernelshap"
230231
)
231232
class(out) <- "kernelshap"
232233
out

README.md

+17-1
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,14 @@
1313

1414
## Overview
1515

16-
The package contains two workhorses to calculate SHAP values for any model:
16+
The package contains two workhorses to calculate SHAP values for *any* model:
1717

1818
- `permshap()`: Exact permutation SHAP algorithm of [1]. Available for up to $p=14$ features.
1919
- `kernelshap()`: Kernel SHAP algorithm of [2] and [3]. By default, exact Kernel SHAP is used for up to $p=8$ features, and an almost exact hybrid algorithm otherwise.
2020

21+
Furthermore, the function `additive_shap()` produces SHAP values for additive models fitted via `lm()`, `glm()`, `mgcv::gam()`, `mgcv::bam()`, `gam::gam()`,
22+
`survival::coxph()`, or `survival::survreg()`. It is exponentially faster than `permshap()` and `kernelshap()`, with identical results.
23+
2124
### Kernel SHAP or permutation SHAP?
2225

2326
Kernel SHAP has been introduced in [2] as an approximation of permutation SHAP [1]. For up to ten features, exact calculations are realistic for both algorithms. Since exact Kernel SHAP is still only an approximation of exact permutation SHAP, the latter should be preferred in this case, even if the results are often very similar.
@@ -38,6 +41,7 @@ If the training data is small, use the full training data. In cases with a natur
3841
- Factor-valued predictions are automatically turned into one-hot-encoded columns.
3942
- Case weights are supported via the argument `bg_w`.
4043
- By changing the defaults in `kernelshap()`, the iterative pure sampling approach in [3] can be enforced.
44+
- The `additive_shap()` explainer is easier to use: Only the model and `X` are required.
4145

4246
## Installation
4347

@@ -215,6 +219,18 @@ sv_dependence(ps, xvars)
215219

216220
![](man/figures/README-nn-dep.svg)
217221

222+
### Additive SHAP
223+
224+
The additive explainer extracts the additive contribution of each feature from a model of suitable class.
225+
226+
```r
227+
fit <- lm(log(price) ~ log(carat) + color + clarity + cut, data = diamonds)
228+
shap_values <- additive_shap(fit, diamonds) |>
229+
shapviz()
230+
sv_importance(shap_values)
231+
sv_dependence(shap_values, v = "carat", color_var = NULL)
232+
```
233+
218234
### Multi-output models
219235

220236
{kernelshap} supports multivariate predictions like:

backlog/test_additive_shap.R

+67
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Some tests that need contributed packages
2+
3+
library(mgcv)
4+
library(gam)
5+
library(survival)
6+
library(splines)
7+
library(testthat)
8+
9+
formulas_ok <- list(
10+
Sepal.Length ~ Sepal.Width + Petal.Width + Species,
11+
Sepal.Length ~ log(Sepal.Width) + poly(Petal.Width, 2) + ns(Petal.Length, 2),
12+
Sepal.Length ~ log(Sepal.Width) + poly(Sepal.Width, 2)
13+
)
14+
15+
formulas_bad <- list(
16+
Sepal.Length ~ Species * Petal.Length,
17+
Sepal.Length ~ Species + Petal.Length + Species:Petal.Length,
18+
Sepal.Length ~ log(Petal.Length / Petal.Width)
19+
)
20+
21+
models <- list(mgcv::gam, mgcv::bam, gam::gam)
22+
23+
X <- head(iris)
24+
for (formula in formulas_ok) {
25+
for (model in models) {
26+
fit <- model(formula, data = iris)
27+
s <- additive_shap(fit, X = X, verbose = FALSE)
28+
expect_equal(s$predictions, as.vector(predict(fit, newdata = X)))
29+
}
30+
}
31+
32+
for (formula in formulas_bad) {
33+
for (model in models) {
34+
fit <- model(formula, data = iris)
35+
expect_error(s <- additive_shap(fit, X = X, verbose = FALSE))
36+
}
37+
}
38+
39+
# Survival
40+
iris$s <- rep(1, nrow(iris))
41+
formulas_ok <- list(
42+
Surv(Sepal.Length, s) ~ Sepal.Width + Petal.Width + Species,
43+
Surv(Sepal.Length, s) ~ log(Sepal.Width) + poly(Petal.Width, 2) + ns(Petal.Length, 2),
44+
Surv(Sepal.Length, s) ~ log(Sepal.Width) + poly(Sepal.Width, 2)
45+
)
46+
47+
formulas_bad <- list(
48+
Surv(Sepal.Length, s) ~ Species * Petal.Length,
49+
Surv(Sepal.Length, s) ~ Species + Petal.Length + Species:Petal.Length,
50+
Surv(Sepal.Length, s) ~ log(Petal.Length / Petal.Width)
51+
)
52+
53+
models <- list(survival::coxph, survival::survreg)
54+
55+
for (formula in formulas_ok) {
56+
for (model in models) {
57+
fit <- model(formula, data = iris)
58+
s <- additive_shap(fit, X = X, verbose = FALSE)
59+
}
60+
}
61+
62+
for (formula in formulas_bad) {
63+
for (model in models) {
64+
fit <- model(formula, data = iris)
65+
expect_error(s <- additive_shap(fit, X = X, verbose = FALSE))
66+
}
67+
}

0 commit comments

Comments
 (0)