Skip to content

Commit 5f65a24

Browse files
authored
Merge pull request #145 from ModelOriented/update-readme
Update time benchmarks
2 parents 29f0de1 + 22aee87 commit 5f65a24

File tree

4 files changed

+24
-35
lines changed

4 files changed

+24
-35
lines changed

DESCRIPTION

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: kernelshap
22
Title: Kernel SHAP
3-
Version: 0.7.0
3+
Version: 0.7.1
44
Authors@R: c(
55
person("Michael", "Mayer", , "[email protected]", role = c("aut", "cre"),
66
comment = c(ORCID = "0009-0007-2540-9629")),

README.md

+9-7
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ X <- diamonds[sample(nrow(diamonds), 1000), xvars]
8686
# from X is used
8787
bg_X <- diamonds[sample(nrow(diamonds), 200), ]
8888

89-
# 3) Crunch SHAP values for all 1000 rows of X (54 seconds)
89+
# 3) Crunch SHAP values for all 1000 rows of X (22 seconds)
9090
# Note: Since the number of features is small, we use permshap()
9191
system.time(
9292
ps <- permshap(fit, X, bg_X = bg_X)
@@ -137,8 +137,10 @@ plan(multisession, workers = 4) # Windows
137137

138138
fit <- gam(log_price ~ s(log_carat) + clarity * color + cut, data = diamonds)
139139

140-
system.time( # 9 seconds in parallel
141-
ps <- permshap(fit, X, parallel = TRUE, parallel_args = list(.packages = "mgcv"))
140+
system.time( # 4 seconds in parallel
141+
ps <- permshap(
142+
fit, X, bg_X = bg_X, parallel = TRUE, parallel_args = list(.packages = "mgcv")
143+
)
142144
)
143145
ps
144146

@@ -148,7 +150,7 @@ ps
148150
# [2,] -0.51546 -0.1174766 0.11122775 0.030243973
149151

150152
# Because there are no interactions of order above 2, Kernel SHAP gives the same:
151-
system.time( # 27 s non-parallel
153+
system.time( # 13 s non-parallel
152154
ks <- kernelshap(fit, X, bg_X = bg_X)
153155
)
154156
all.equal(ps$S, ks$S)
@@ -202,9 +204,9 @@ nn |>
202204
)
203205

204206
pred_fun <- function(mod, X)
205-
predict(mod, data.matrix(X), batch_size = 1e4, verbose = FALSE)
207+
predict(mod, data.matrix(X), batch_size = 1e4, verbose = FALSE, workers = 4)
206208

207-
system.time( # 60 s
209+
system.time( # 50 s
208210
ps <- permshap(nn, X, bg_X = bg_X, pred_fun = pred_fun)
209211
)
210212

@@ -284,7 +286,7 @@ iris_wf <- workflow() |>
284286
fit <- iris_wf |>
285287
fit(iris)
286288

287-
system.time( # 4s
289+
system.time( # 3s
288290
ps <- permshap(fit, iris[-5], type = "prob")
289291
)
290292
ps

backlog/compare_with_python.R

+11-24
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@ bg_X <- diamonds[seq(1, nrow(diamonds), 450), ]
1414
# Subset of 1018 diamonds to explain
1515
X_small <- diamonds[seq(1, nrow(diamonds), 53), c("carat", ord)]
1616

17-
# Exact KernelSHAP (5s)
17+
# Exact KernelSHAP (2s)
1818
system.time(
19-
ks <- kernelshap(fit, X_small, bg_X = bg_X)
19+
ks <- kernelshap(fit, X_small, bg_X = bg_X)
2020
)
2121
ks
2222

@@ -25,9 +25,9 @@ ks
2525
# [1,] -2.050074 -0.28048747 0.1281222 0.01587382
2626
# [2,] -2.085838 0.04050415 0.1283010 0.03731644
2727

28-
# Pure sampling version takes a bit longer (12 seconds)
28+
# Pure sampling version takes a bit longer (7 seconds)
2929
system.time(
30-
ks2 <- kernelshap(fit, X_small, bg_X = bg_X, exact = FALSE, hybrid_degree = 0)
30+
ks2 <- kernelshap(fit, X_small, bg_X = bg_X, exact = FALSE, hybrid_degree = 0)
3131
)
3232
ks2
3333

@@ -36,18 +36,6 @@ ks2
3636
# [1,] -2.050074 -0.28048747 0.1281222 0.01587382
3737
# [2,] -2.085838 0.04050415 0.1283010 0.03731644
3838

39-
# Using parallel backend
40-
library("doFuture")
41-
42-
registerDoFuture()
43-
plan(multisession, workers = 2) # Windows
44-
# plan(multicore, workers = 2) # Linux, macOS, Solaris
45-
46-
# 3 seconds
47-
system.time(
48-
ks3 <- kernelshap(fit, X_small, bg_X = bg_X, parallel = TRUE)
49-
)
50-
ks3
5139

5240
library(shapviz)
5341

@@ -58,18 +46,17 @@ sv_dependence(sv, "carat")
5846
# More features (but non-sensical model)
5947
# Fit model
6048
fit <- lm(
61-
log(price) ~ log(carat) * (clarity + color + cut) + x + y + z + table + depth,
49+
log(price) ~ log(carat) * (clarity + color + cut) + x + y + z + table + depth,
6250
data = diamonds
6351
)
6452

6553
# Subset of 1018 diamonds to explain
6654
X_small <- diamonds[seq(1, nrow(diamonds), 53), setdiff(names(diamonds), "price")]
6755

68-
# Exact KernelSHAP on X_small, using X_small as background data
69-
# (58/67(?) seconds for exact, 25/18 for hybrid deg 2, 16/9 for hybrid deg 1,
70-
# 26/17 for pure sampling; second number with 2 parallel sessions on Windows)
56+
# Exact KernelSHAP on X_small, using X_small as background data
57+
# (39s for exact, 15s for hybrid deg 2, 8s for hybrid deg 1, 16s for sampling)
7158
system.time(
72-
ks <- kernelshap(fit, X_small, bg_X = bg_X)
59+
ks <- kernelshap(fit, X_small, bg_X = bg_X)
7360
)
7461
ks
7562

@@ -98,7 +85,7 @@ X = diamonds[x].to_numpy()
9885

9986
# Fit model with interactions and dummy variables
10087
fit = ols(
101-
"np.log(price) ~ np.log(carat) * (C(clarity) + C(cut) + C(color))", # + x + y + z + table + depth",
88+
"np.log(price) ~ np.log(carat) * (C(clarity) + C(cut) + C(color))", # + x + y + z + table + depth",
10289
data=diamonds
10390
).fit()
10491

@@ -110,7 +97,7 @@ X_small = X[0:len(X):53]
11097

11198
# Calculate KernelSHAP values
11299
ks = KernelExplainer(
113-
model=lambda X: fit.predict(pd.DataFrame(X, columns=x)),
100+
model=lambda X: fit.predict(pd.DataFrame(X, columns=x)),
114101
data = bg_X
115102
)
116103
sv = ks.shap_values(X_small) # 11 minutes
@@ -127,4 +114,4 @@ sv[0:2]
127114
# -1.72078182e-01, 1.33027467e-03, -6.44569296e-03],
128115
# [-1.87670887e+00, 3.93291219e-02, 1.26654599e-01,
129116
# 3.85695742e-02, -4.87177593e-04, -4.20263565e-04,
130-
# -1.73988040e-01, 1.39779179e-03, -6.56062359e-03]])
117+
# -1.73988040e-01, 1.39779179e-03, -6.56062359e-03]])

packaging.R

+3-3
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@ library(usethis)
1515
use_description(
1616
fields = list(
1717
Title = "Kernel SHAP",
18-
Version = "0.7.0",
18+
Version = "0.7.1",
1919
Description = "Efficient implementation of Kernel SHAP, see Lundberg and Lee (2017),
2020
and Covert and Lee (2021) <http://proceedings.mlr.press/v130/covert21a>.
2121
Furthermore, for up to 14 features, exact permutation SHAP values can be calculated.
2222
The package plays well together with meta-learning packages like 'tidymodels', 'caret' or 'mlr3'.
2323
Visualizations can be done using the R package 'shapviz'.",
24-
`Authors@R` =
24+
`Authors@R` =
2525
"c(person('Michael', family='Mayer', role=c('aut', 'cre'), email='[email protected]', comment=c(ORCID='0009-0007-2540-9629')),
2626
person('David', family='Watson', role='aut', email='[email protected]', comment=c(ORCID='0000-0001-9632-2159')),
2727
person('Przemyslaw', family='Biecek', email='[email protected]', role='ctb', comment=c(ORCID='0000-0001-8423-1823'))
@@ -98,7 +98,7 @@ install(upgrade = FALSE)
9898
if (FALSE) {
9999
check_win_devel()
100100
check_rhub()
101-
101+
102102
# Takes long
103103
revdepcheck::revdep_check(num_workers = 4L, bioc = FALSE)
104104

0 commit comments

Comments
 (0)