Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ice plot facets #114

Merged
merged 3 commits into from
Feb 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions CRAN-SUBMISSION
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
Version: 1.1.1
Date: 2023-12-06 07:21:05 UTC
SHA: 7ee87f42cc26b33f15e8944e7019d22f58c81350
Version: 1.1.2
Date: 2024-02-03 15:31:07 UTC
SHA: 9eebdab2e583bf97d51dff9e7783df2b87751097
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ Depends:
R (>= 3.2.0)
Encoding: UTF-8
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.2.3
RoxygenNote: 7.3.1
Imports:
ggplot2,
stats,
utils
Suggests:
testthat (>= 3.0.0)
Config/testthat/edition: 3
URL: https://github.com/mayer79/hstats
URL: https://github.com/mayer79/hstats, https://mayer79.github.io/hstats/
BugReports: https://github.com/mayer79/hstats/issues
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# hstats 1.1.2

## ICE plots

- The ICE plot of a multioutput model without BY variable will now be using facets (instead of color). Use `swap_dim = TRUE` for the old behavior.

## API

- {mlr3}: Non-probabilistic classification now works.
Expand Down
7 changes: 2 additions & 5 deletions R/ice.R
Original file line number Diff line number Diff line change
Expand Up @@ -258,11 +258,8 @@ plot.ice <- function(x, center = FALSE, alpha = 0.2,
data <- poor_man_stack(data, to_stack = pred_names)

# Distinguish all possible cases
grp <- if (is.null(by_names) && K > 1L) "varying_" else by_names[1L] # can be NULL
wrp <- if (!is.null(by_names) && K > 1L) "varying_"
if (length(by_names) == 2L) {
wrp <- by_names[2L]
}
grp <- if (!is.null(by_names)) by_names[1L]
wrp <- if (K > 1L) "varying_" else if (length(by_names) == 2L) by_names[2L]
if (swap_dim) {
tmp <- grp
grp <- wrp
Expand Down
26 changes: 9 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ ice(fit, v = "Petal.Length", X = iris, BY = "Petal.Width") |>

### LightGBM

Note: Versions from 4.0.0 upwards to not anymore require passing `reshape = TRUE` to the prediction function.
Note: Versions < 4.0.0 require passing `reshape = TRUE` to the prediction function.

```r
library(lightgbm)
Expand All @@ -294,30 +294,22 @@ fit <- lgb.train(
nrounds = 1000
)

# Check that predictions require reshape = TRUE to be a matrix
predict(fit, head(X_train, 2), reshape = TRUE)
# [,1] [,2] [,3]
# [1,] 0.9999997 2.918695e-07 2.858720e-14
# [2,] 0.9999999 1.038470e-07 7.337221e-10

# mlogloss: 9.331699e-05
average_loss(fit, X = X_valid, y = y_valid, loss = "mlogloss", reshape = TRUE)
average_loss(fit, X = X_valid, y = y_valid, loss = "mlogloss")

perm_importance(
fit, X = X_valid, y = y_valid, loss = "mlogloss", reshape = TRUE, m_rep = 100
)
perm_importance(fit, X = X_valid, y = y_valid, loss = "mlogloss", m_rep = 100)
# Permutation importance regarding mlogloss
# Petal.Length Petal.Width Sepal.Width Sepal.Length
# 2.624241332 1.011168660 0.082477177 0.009757393

partial_dep(fit, v = "Petal.Length", X = X_train, reshape = TRUE) |>
partial_dep(fit, v = "Petal.Length", X = X_train) |>
plot(show_points = FALSE)

ice(fit, v = "Petal.Length", X = X_train, reshape = TRUE) |>
plot(swap_dim = TRUE, alpha = 0.05)
ice(fit, v = "Petal.Length", X = X_train) |>
plot(alpha = 0.05)

# Interaction statistics, including three-way stats
(H <- hstats(fit, X = X_train, reshape = TRUE, threeway_m = 4))
(H <- hstats(fit, X = X_train, threeway_m = 4))
# 0.3010446 0.4167927 0.1623982

plot(H, ncol = 1)
Expand All @@ -327,7 +319,7 @@ plot(H, ncol = 1)

### XGBoost

Also here, mind the `reshape = TRUE` sent to the prediction function.
Mind the `reshape = TRUE` sent to the prediction function.

```r
library(xgboost)
Expand Down Expand Up @@ -379,7 +371,7 @@ plot(H, normalize = FALSE, squared = FALSE, facet_scales = "free_y", ncol = 1)

### Non-probabilistic classification

When predictions are factor levels, {hstats} uses internal one-hot-encoding.
When predictions are factor levels, {hstats} uses internal one-hot-encoding. Usually, probabilistic classification makes more sense though.

```r
library(ranger)
Expand Down
58 changes: 29 additions & 29 deletions backlog/benchmark.R
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ fit <- xgb.train(
)

# Interpret via {hstats}
average_loss(fit, X = X_valid, y = y_valid) # 0.0247 MSE -> 0.157 RMSE
average_loss(fit, X = X_valid, y = y_valid) # 0.023 MSE

perm_importance(fit, X = X_valid, y = y_valid) |>
plot()
Expand Down Expand Up @@ -116,11 +116,11 @@ bench::mark(
min_iterations = 3
)

# expression min median `itr/sec` mem_alloc `gc/sec` n_itr n_gc total_time
# iml 1.58s 1.58s 0.631 209.4MB 2.73 3 13 4.76s
# dalex 566.21ms 586.91ms 1.72 34.6MB 0.572 3 1 1.75s
# flashlight 587.03ms 613.15ms 1.63 27.1MB 1.63 3 3 1.84s
# hstats 353.78ms 360.57ms 2.79 27.2MB 0 3 0 1.08s
# expression min median `itr/sec` mem_alloc `gc/sec` n_itr n_gc total_time
# 1 iml 1.72s 1.75s 0.574 210.6MB 1.34 3 7 5.23s <NULL>
# 2 dalex 744.82ms 760.02ms 1.31 35.2MB 0.877 3 2 2.28s <NULL>
# 3 flashlight 1.29s 1.35s 0.742 63MB 0.990 3 4 4.04s <NULL>
# 4 hstats 407.26ms 412.31ms 2.43 26.5MB 0 3 0 1.23s <NULL>

# Partial dependence (cont)
v <- "tot_lvg_area"
Expand All @@ -132,12 +132,12 @@ bench::mark(
check = FALSE,
min_iterations = 3
)
# expression min median `itr/sec` mem_alloc `gc/sec` n_itr n_gc total_time
# iml 1.11s 1.13s 0.887 376.3MB 3.84 3 13 3.38s
# dalex 782.13ms 783.08ms 1.24 192.8MB 2.90 3 7 2.41s
# flashlight 367.73ms 372.5ms 2.68 67.9MB 2.68 3 3 1.12s
# hstats 220.88ms 222.5ms 4.50 14.2MB 0 3 0 666.33ms

# expression min median `itr/sec` mem_alloc `gc/sec` n_itr n_gc total_time
# 1 iml 1.14s 1.16s 0.861 376.7MB 3.73 3 13 3.48s <NULL>
# 2 dalex 653.24ms 654.51ms 1.35 192.8MB 2.24 3 5 2.23s <NULL>
# 3 flashlight 352.34ms 361.79ms 2.72 66.7MB 0.906 3 1 1.1s <NULL>
# 4 hstats 239.03ms 242.79ms 4.04 14.2MB 1.35 3 1 743.43ms <NULL>
# Partial dependence (discrete)
v <- "structure_quality"
bench::mark(
Expand All @@ -148,11 +148,11 @@ bench::mark(
check = FALSE,
min_iterations = 3
)
# expression min median `itr/sec` mem_alloc `gc/sec` n_itr n_gc total_time
# iml 90ms 96ms 10.6 13.29MB 7.06 3 2 283ms
# dalex 170.6ms 174.4ms 5.73 20.55MB 2.87 2 1 349ms
# flashlight 40.8ms 43.8ms 23.1 6.36MB 2.10 11 1 476ms
# hstats 23.5ms 24.4ms 40.6 1.53MB 2.14 19 1 468ms
# expression min median `itr/sec` mem_alloc `gc/sec` n_itr n_gc total_time
# 1 iml 100.6ms 103.6ms 9.46 13.34MB 0 5 0 529ms <NULL>
# 2 dalex 172.4ms 177.9ms 5.62 20.55MB 2.81 2 1 356ms <NULL>
# 3 flashlight 43.5ms 45.5ms 21.9 6.36MB 2.19 10 1 457ms <NULL>
# 4 hstats 25.3ms 25.8ms 37.9 1.54MB 2.10 18 1 475ms <NULL>

# H-Stats -> we use a subset of 500 rows
X_v500 <- X_valid[1:500, ]
Expand All @@ -167,17 +167,17 @@ system.time( # 135s for all combinations of latitude
iml_pairwise <- Interaction$new(mod500, grid.size = 500, feature = "latitude")
)

# flashlight: 14s total, doing only one pairwise calculation, otherwise would take 63s
system.time( # 12s
# flashlight: 13s total, doing only one pairwise calculation, otherwise would take 63s
system.time( # 11.5s
fl_overall <- light_interaction(fl500, v = x, grid_size = Inf, n_max = Inf)
)
system.time( # 2s
system.time( # 2.4s
fl_pairwise <- light_interaction(
fl500, v = coord, grid_size = Inf, n_max = Inf, pairwise = TRUE
)
)

# hstats: 3.4s total
# hstats: 3.5s total
system.time({
H <- hstats(fit, v = x, X = X_v500, n_max = Inf)
hstats_overall <- h2_overall(H, squared = FALSE, zero = FALSE)
Expand All @@ -193,26 +193,26 @@ system.time(
# Overall statistics correspond exactly
iml_overall$results |> filter(.interaction > 1e-6)
# .feature .interaction
# 1: latitude 0.2458269
# 2: longitude 0.2458269
# 1: latitude 0.2791144
# 2: longitude 0.2791144

fl_overall$data |> subset(value > 0, select = c(variable, value))
# variable value
# 1 latitude 0.246
# 2 longitude 0.246
# 1 latitude 0.279
# 2 longitude 0.279

hstats_overall
# longitude latitude
# 0.2458269 0.2458269
# 0.2791144 0.2791144

# Pairwise results match as well
iml_pairwise$results |> filter(.interaction > 1e-6)
# .feature .interaction
# 1: longitude:latitude 0.3942526
# 1: longitude:latitude 0.4339574

fl_pairwise$data |> subset(value > 0, select = c(variable, value))
# latitude:longitude 0.394
# latitude:longitude 0.434

hstats_pairwise
# latitude:longitude
# 0.3942526
# 0.4339574
22 changes: 13 additions & 9 deletions cran-comments.md
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
# hstats 1.1.1
# hstats 1.1.2

Hello CRAN
Hello CRAN team

This is a small release with a bugfix and some speed improvements.
This is a small release with two convenient API improvements.

## Local checks: 1 NOTE
## Local checks: 0 errors, 0 warnings, 0 notes

- Skipping checking math rendering: package 'V8' unavailable

## Rhub: 2 NOTES
## Rhub: 3 NOTES (sounding harmless)

- Skipping checking HTML validation: no command 'tidy' found
- Skipping checking math rendering: package 'V8' unavailable
- checking HTML version of manual ... NOTE
Skipping checking math rendering: package 'V8' unavailable
- checking for non-standard things in the check directory ... NOTE
Found the following files/directories:
''NULL''
- checking for detritus in the temp directory ... NOTE
Found the following files/directories:
'lastMiKTeXException'

## Winbuilder

Expand Down
Loading