Skip to content

Commit

Permalink
Merge pull request #114 from mayer79/ice-plot-facets
Browse files Browse the repository at this point in the history
Ice plot facets
  • Loading branch information
mayer79 authored Feb 3, 2024
2 parents b60c28b + 14321f1 commit 3c37e6f
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 65 deletions.
6 changes: 3 additions & 3 deletions CRAN-SUBMISSION
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
Version: 1.1.1
Date: 2023-12-06 07:21:05 UTC
SHA: 7ee87f42cc26b33f15e8944e7019d22f58c81350
Version: 1.1.2
Date: 2024-02-03 15:31:07 UTC
SHA: 9eebdab2e583bf97d51dff9e7783df2b87751097
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ Depends:
R (>= 3.2.0)
Encoding: UTF-8
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.2.3
RoxygenNote: 7.3.1
Imports:
ggplot2,
stats,
utils
Suggests:
testthat (>= 3.0.0)
Config/testthat/edition: 3
URL: https://github.com/mayer79/hstats
URL: https://github.com/mayer79/hstats, https://mayer79.github.io/hstats/
BugReports: https://github.com/mayer79/hstats/issues
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# hstats 1.1.2

## ICE plots

- The ICE plot of a multioutput model without BY variable will now be using facets (instead of color). Use `swap_dim = TRUE` for the old behavior.

## API

- {mlr3}: Non-probabilistic classification now works.
Expand Down
7 changes: 2 additions & 5 deletions R/ice.R
Original file line number Diff line number Diff line change
Expand Up @@ -258,11 +258,8 @@ plot.ice <- function(x, center = FALSE, alpha = 0.2,
data <- poor_man_stack(data, to_stack = pred_names)

# Distinguish all possible cases
grp <- if (is.null(by_names) && K > 1L) "varying_" else by_names[1L] # can be NULL
wrp <- if (!is.null(by_names) && K > 1L) "varying_"
if (length(by_names) == 2L) {
wrp <- by_names[2L]
}
grp <- if (!is.null(by_names)) by_names[1L]
wrp <- if (K > 1L) "varying_" else if (length(by_names) == 2L) by_names[2L]
if (swap_dim) {
tmp <- grp
grp <- wrp
Expand Down
26 changes: 9 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ ice(fit, v = "Petal.Length", X = iris, BY = "Petal.Width") |>

### LightGBM

Note: Versions from 4.0.0 upwards to not anymore require passing `reshape = TRUE` to the prediction function.
Note: Versions < 4.0.0 require passing `reshape = TRUE` to the prediction function.

```r
library(lightgbm)
Expand All @@ -294,30 +294,22 @@ fit <- lgb.train(
nrounds = 1000
)

# Check that predictions require reshape = TRUE to be a matrix
predict(fit, head(X_train, 2), reshape = TRUE)
# [,1] [,2] [,3]
# [1,] 0.9999997 2.918695e-07 2.858720e-14
# [2,] 0.9999999 1.038470e-07 7.337221e-10

# mlogloss: 9.331699e-05
average_loss(fit, X = X_valid, y = y_valid, loss = "mlogloss", reshape = TRUE)
average_loss(fit, X = X_valid, y = y_valid, loss = "mlogloss")

perm_importance(
fit, X = X_valid, y = y_valid, loss = "mlogloss", reshape = TRUE, m_rep = 100
)
perm_importance(fit, X = X_valid, y = y_valid, loss = "mlogloss", m_rep = 100)
# Permutation importance regarding mlogloss
# Petal.Length Petal.Width Sepal.Width Sepal.Length
# 2.624241332 1.011168660 0.082477177 0.009757393

partial_dep(fit, v = "Petal.Length", X = X_train, reshape = TRUE) |>
partial_dep(fit, v = "Petal.Length", X = X_train) |>
plot(show_points = FALSE)

ice(fit, v = "Petal.Length", X = X_train, reshape = TRUE) |>
plot(swap_dim = TRUE, alpha = 0.05)
ice(fit, v = "Petal.Length", X = X_train) |>
plot(alpha = 0.05)

# Interaction statistics, including three-way stats
(H <- hstats(fit, X = X_train, reshape = TRUE, threeway_m = 4))
(H <- hstats(fit, X = X_train, threeway_m = 4))
# 0.3010446 0.4167927 0.1623982

plot(H, ncol = 1)
Expand All @@ -327,7 +319,7 @@ plot(H, ncol = 1)

### XGBoost

Also here, mind the `reshape = TRUE` sent to the prediction function.
Mind the `reshape = TRUE` sent to the prediction function.

```r
library(xgboost)
Expand Down Expand Up @@ -379,7 +371,7 @@ plot(H, normalize = FALSE, squared = FALSE, facet_scales = "free_y", ncol = 1)

### Non-probabilistic classification

When predictions are factor levels, {hstats} uses internal one-hot-encoding.
When predictions are factor levels, {hstats} uses internal one-hot-encoding. Usually, probabilistic classification makes more sense though.

```r
library(ranger)
Expand Down
58 changes: 29 additions & 29 deletions backlog/benchmark.R
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ fit <- xgb.train(
)

# Interpret via {hstats}
average_loss(fit, X = X_valid, y = y_valid) # 0.0247 MSE -> 0.157 RMSE
average_loss(fit, X = X_valid, y = y_valid) # 0.023 MSE

perm_importance(fit, X = X_valid, y = y_valid) |>
plot()
Expand Down Expand Up @@ -116,11 +116,11 @@ bench::mark(
min_iterations = 3
)

# expression min median `itr/sec` mem_alloc `gc/sec` n_itr n_gc total_time
# iml 1.58s 1.58s 0.631 209.4MB 2.73 3 13 4.76s
# dalex 566.21ms 586.91ms 1.72 34.6MB 0.572 3 1 1.75s
# flashlight 587.03ms 613.15ms 1.63 27.1MB 1.63 3 3 1.84s
# hstats 353.78ms 360.57ms 2.79 27.2MB 0 3 0 1.08s
# expression min median `itr/sec` mem_alloc `gc/sec` n_itr n_gc total_time
# 1 iml 1.72s 1.75s 0.574 210.6MB 1.34 3 7 5.23s <NULL>
# 2 dalex 744.82ms 760.02ms 1.31 35.2MB 0.877 3 2 2.28s <NULL>
# 3 flashlight 1.29s 1.35s 0.742 63MB 0.990 3 4 4.04s <NULL>
# 4 hstats 407.26ms 412.31ms 2.43 26.5MB 0 3 0 1.23s <NULL>

# Partial dependence (cont)
v <- "tot_lvg_area"
Expand All @@ -132,12 +132,12 @@ bench::mark(
check = FALSE,
min_iterations = 3
)
# expression min median `itr/sec` mem_alloc `gc/sec` n_itr n_gc total_time
# iml 1.11s 1.13s 0.887 376.3MB 3.84 3 13 3.38s
# dalex 782.13ms 783.08ms 1.24 192.8MB 2.90 3 7 2.41s
# flashlight 367.73ms 372.5ms 2.68 67.9MB 2.68 3 3 1.12s
# hstats 220.88ms 222.5ms 4.50 14.2MB 0 3 0 666.33ms

# expression min median `itr/sec` mem_alloc `gc/sec` n_itr n_gc total_time
# 1 iml 1.14s 1.16s 0.861 376.7MB 3.73 3 13 3.48s <NULL>
# 2 dalex 653.24ms 654.51ms 1.35 192.8MB 2.24 3 5 2.23s <NULL>
# 3 flashlight 352.34ms 361.79ms 2.72 66.7MB 0.906 3 1 1.1s <NULL>
# 4 hstats 239.03ms 242.79ms 4.04 14.2MB 1.35 3 1 743.43ms <NULL>
# Partial dependence (discrete)
v <- "structure_quality"
bench::mark(
Expand All @@ -148,11 +148,11 @@ bench::mark(
check = FALSE,
min_iterations = 3
)
# expression min median `itr/sec` mem_alloc `gc/sec` n_itr n_gc total_time
# iml 90ms 96ms 10.6 13.29MB 7.06 3 2 283ms
# dalex 170.6ms 174.4ms 5.73 20.55MB 2.87 2 1 349ms
# flashlight 40.8ms 43.8ms 23.1 6.36MB 2.10 11 1 476ms
# hstats 23.5ms 24.4ms 40.6 1.53MB 2.14 19 1 468ms
# expression min median `itr/sec` mem_alloc `gc/sec` n_itr n_gc total_time
# 1 iml 100.6ms 103.6ms 9.46 13.34MB 0 5 0 529ms <NULL>
# 2 dalex 172.4ms 177.9ms 5.62 20.55MB 2.81 2 1 356ms <NULL>
# 3 flashlight 43.5ms 45.5ms 21.9 6.36MB 2.19 10 1 457ms <NULL>
# 4 hstats 25.3ms 25.8ms 37.9 1.54MB 2.10 18 1 475ms <NULL>

# H-Stats -> we use a subset of 500 rows
X_v500 <- X_valid[1:500, ]
Expand All @@ -167,17 +167,17 @@ system.time( # 135s for all combinations of latitude
iml_pairwise <- Interaction$new(mod500, grid.size = 500, feature = "latitude")
)

# flashlight: 14s total, doing only one pairwise calculation, otherwise would take 63s
system.time( # 12s
# flashlight: 13s total, doing only one pairwise calculation, otherwise would take 63s
system.time( # 11.5s
fl_overall <- light_interaction(fl500, v = x, grid_size = Inf, n_max = Inf)
)
system.time( # 2s
system.time( # 2.4s
fl_pairwise <- light_interaction(
fl500, v = coord, grid_size = Inf, n_max = Inf, pairwise = TRUE
)
)

# hstats: 3.4s total
# hstats: 3.5s total
system.time({
H <- hstats(fit, v = x, X = X_v500, n_max = Inf)
hstats_overall <- h2_overall(H, squared = FALSE, zero = FALSE)
Expand All @@ -193,26 +193,26 @@ system.time(
# Overall statistics correspond exactly
iml_overall$results |> filter(.interaction > 1e-6)
# .feature .interaction
# 1: latitude 0.2458269
# 2: longitude 0.2458269
# 1: latitude 0.2791144
# 2: longitude 0.2791144

fl_overall$data |> subset(value > 0, select = c(variable, value))
# variable value
# 1 latitude 0.246
# 2 longitude 0.246
# 1 latitude 0.279
# 2 longitude 0.279

hstats_overall
# longitude latitude
# 0.2458269 0.2458269
# 0.2791144 0.2791144

# Pairwise results match as well
iml_pairwise$results |> filter(.interaction > 1e-6)
# .feature .interaction
# 1: longitude:latitude 0.3942526
# 1: longitude:latitude 0.4339574

fl_pairwise$data |> subset(value > 0, select = c(variable, value))
# latitude:longitude 0.394
# latitude:longitude 0.434

hstats_pairwise
# latitude:longitude
# 0.3942526
# 0.4339574
22 changes: 13 additions & 9 deletions cran-comments.md
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
# hstats 1.1.1
# hstats 1.1.2

Hello CRAN
Hello CRAN team

This is a small release with a bugfix and some speed improvements.
This is a small release with two convenient API improvements.

## Local checks: 1 NOTE
## Local checks: 0 errors, 0 warnings, 0 notes

- Skipping checking math rendering: package 'V8' unavailable

## Rhub: 2 NOTES
## Rhub: 3 NOTES (sounding harmless)

- Skipping checking HTML validation: no command 'tidy' found
- Skipping checking math rendering: package 'V8' unavailable
- checking HTML version of manual ... NOTE
Skipping checking math rendering: package 'V8' unavailable
- checking for non-standard things in the check directory ... NOTE
Found the following files/directories:
''NULL''
- checking for detritus in the temp directory ... NOTE
Found the following files/directories:
'lastMiKTeXException'

## Winbuilder

Expand Down

0 comments on commit 3c37e6f

Please sign in to comment.