Skip to content

Commit

Permalink
Add unit tests with iml
Browse files Browse the repository at this point in the history
  • Loading branch information
mayer79 committed Oct 22, 2023
1 parent d7708e1 commit adc8e9f
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 35 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: hstats
Title: Interaction Statistics
Version: 1.0.0
Version: 1.0.1
Authors@R:
person("Michael", "Mayer", , "[email protected]", role = c("aut", "cre"))
Description: Fast, model-agnostic implementation of different H-statistics
Expand Down
6 changes: 6 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# hstats 1.0.1

## Other changes

- Add unit tests to compare against {iml}.

# hstats 1.0.0

## Major changes
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ The core functions `hstats()`, `partial_dep()`, `ice()`, `perm_importance()`, an
{hstats} is not the first R package to explore interactions. Here is an incomplete selection:

- [{gbm}](https://CRAN.R-project.org/package=gbm): Implementation of m-wise interaction statistics of [1] for {gbm} models using the weighted tree-traversal method of [2] to estimate partial dependence functions.
- [{iml}](https://CRAN.R-project.org/package=iml): Variant of pairwise interaction statistics of [1].
- [{iml}](https://CRAN.R-project.org/package=iml): Implementation of overall and pairwise H-statistics.
- [{EIX}](https://CRAN.R-project.org/package=EIX): Interaction statistics extracted from the tree structure of XGBoost and LightGBM.
- [{randomForestExplainer}](https://CRAN.R-project.org/package=randomForestExplainer): Interaction statistics extracted from the tree structure of random forests.
- [{vivid}](https://CRAN.R-project.org/package=vivid): Cool visualization of interaction patterns. Partly based on {flashlight}.
Expand Down
11 changes: 9 additions & 2 deletions backlog/benchmark.R
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,13 @@ X_v500 <- X_valid[1:500, ]
mod500 <- Predictor$new(fit, data = as.data.frame(X_v500), predict.function = predf)
fl500 <- flashlight(fl, data = as.data.frame(valid[1:500, ]))

# iml # 90 s (no pairwise possible)
system.time(
# iml # 225s total, using slow exact calculations
system.time( # 90s
iml_overall <- Interaction$new(mod500, grid.size = 500)
)
system.time( # 135s for all combinations of latitude
iml_pairwise <- Interaction$new(mod500, grid.size = 500, feature = "latitude")
)

# flashlight: 14s total, doing only one pairwise calculation, otherwise would take 63s
system.time( # 12s
Expand Down Expand Up @@ -199,6 +202,10 @@ hstats_overall
# 0.2458269 0.2458269

# Pairwise results match as well
iml_pairwise$results |> filter(.interaction > 1e-6)
# .feature .interaction
# 1: longitude:latitude 0.3942526

fl_pairwise$data |> subset(value > 0, select = c(variable, value))
# latitude:longitude 0.394

Expand Down
2 changes: 1 addition & 1 deletion packaging.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ library(usethis)
use_description(
fields = list(
Title = "Interaction Statistics",
Version = "1.0.0",
Version = "1.0.1",
Description = "Fast, model-agnostic implementation of different H-statistics
introduced by Jerome H. Friedman and Bogdan E. Popescu (2008) <doi:10.1214/07-AOAS148>.
These statistics quantify interaction strength per feature, feature pair,
Expand Down
52 changes: 22 additions & 30 deletions tests/testthat/test_hstats.R
Original file line number Diff line number Diff line change
Expand Up @@ -333,33 +333,25 @@ test_that("hstats() does not give an error with missing", {
expect_equal(rownames(h2_pairwise(r, zero = FALSE)), "x1:x2")
})


# library(gbm)
#
# fit <- gbm(Sepal.Length ~ ., data = iris, interaction.depth = 3, bag.fraction = 1)
# v <- names(iris)[-1]
# combs <- combn(v, 2, simplify = FALSE)
# p <- length(combs)
#
# res <- setNames(numeric(p), sapply(combs, paste, collapse = ":"))
# for (i in 1:p) {
# res[i] <- interact.gbm(fit, iris, i.var = combs[[i]], n.trees = fit$n.trees)
# }
# cbind(res[res > 0.0001])
# # Sepal.Width:Petal.Length 0.10982072
# # Sepal.Width:Petal.Width 0.17932506
# # Sepal.Width:Species 0.21480383
# # Petal.Length:Petal.Width 0.03702921
# # Petal.Length:Species 0.06382609
#
# # Crunching
# system.time( # 0.3 s
# s <- hstats(fit, v = v, X = iris, n.trees = fit$n.trees)
# )
# h2_pairwise(s, squared = FALSE, sort = FALSE)
# # Sepal.Width:Petal.Length 0.10532810
# # Sepal.Width:Petal.Width 0.16697609
# # Sepal.Width:Species 0.17335494
# # Petal.Length:Petal.Width 0.03245863
# # Petal.Length:Species 0.06678683
#
test_that("hstats() matches {iml} 0.11.1 in a specific case", {
fit <- lm(Sepal.Width ~ . + Sepal.Length:Species, data = iris)

# library(iml)
# mod <- Predictor$new(fit, data = iris[-2L])
# iml_overall <- Interaction$new(mod, grid.size = 150)
# Sepal.Length: 0.4634029
# iml_pairwise <- Interaction$new(mod, grid.size = 150, feature = "Species")
# Sepal.Length:Species 0.2624154

H <- hstats(fit, X = iris[-2L], verbose = FALSE)
expect_equal(
c(h2_overall(H, squared = FALSE)["Sepal.Length", ]$M),
0.4634029,
tolerance = 1e-5
)
expect_equal(
c(h2_pairwise(H, squared = FALSE)["Sepal.Length:Species", ]$M),
0.2624154,
tolerance = 1e-5
)
})

0 comments on commit adc8e9f

Please sign in to comment.