diff --git a/DESCRIPTION b/DESCRIPTION index de11bccf..192efe8c 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: hstats Title: Interaction Statistics -Version: 1.0.0 +Version: 1.0.1 Authors@R: person("Michael", "Mayer", , "mayermichael79@gmail.com", role = c("aut", "cre")) Description: Fast, model-agnostic implementation of different H-statistics diff --git a/NEWS.md b/NEWS.md index c67981da..748542bc 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,3 +1,9 @@ +# hstats 1.0.1 + +## Other changes + +- Add unit tests to compare against {iml}. + # hstats 1.0.0 ## Major changes diff --git a/README.md b/README.md index cb4f5f5b..e8155d54 100644 --- a/README.md +++ b/README.md @@ -39,7 +39,7 @@ The core functions `hstats()`, `partial_dep()`, `ice()`, `perm_importance()`, an {hstats} is not the first R package to explore interactions. Here is an incomplete selection: - [{gbm}](https://CRAN.R-project.org/package=gbm): Implementation of m-wise interaction statistics of [1] for {gbm} models using the weighted tree-traversal method of [2] to estimate partial dependence functions. -- [{iml}](https://CRAN.R-project.org/package=iml): Variant of pairwise interaction statistics of [1]. +- [{iml}](https://CRAN.R-project.org/package=iml): Implementation of overall and pairwise H-statistics. - [{EIX}](https://CRAN.R-project.org/package=EIX): Interaction statistics extracted from the tree structure of XGBoost and LightGBM. - [{randomForestExplainer}](https://CRAN.R-project.org/package=randomForestExplainer): Interaction statistics extracted from the tree structure of random forests. - [{vivid}](https://CRAN.R-project.org/package=vivid): Cool visualization of interaction patterns. Partly based on {flashlight}. diff --git a/backlog/benchmark.R b/backlog/benchmark.R index 20728c87..ff83827d 100644 --- a/backlog/benchmark.R +++ b/backlog/benchmark.R @@ -160,10 +160,13 @@ X_v500 <- X_valid[1:500, ] mod500 <- Predictor$new(fit, data = as.data.frame(X_v500), predict.function = predf) fl500 <- flashlight(fl, data = as.data.frame(valid[1:500, ])) -# iml # 90 s (no pairwise possible) -system.time( +# iml # 225s total, using slow exact calculations +system.time( # 90s iml_overall <- Interaction$new(mod500, grid.size = 500) ) +system.time( # 135s for all combinations of latitude + iml_pairwise <- Interaction$new(mod500, grid.size = 500, feature = "latitude") +) # flashlight: 14s total, doing only one pairwise calculation, otherwise would take 63s system.time( # 12s @@ -199,6 +202,10 @@ hstats_overall # 0.2458269 0.2458269 # Pairwise results match as well +iml_pairwise$results |> filter(.interaction > 1e-6) +# .feature .interaction +# 1: longitude:latitude 0.3942526 + fl_pairwise$data |> subset(value > 0, select = c(variable, value)) # latitude:longitude 0.394 diff --git a/packaging.R b/packaging.R index 18f66c91..351e8453 100644 --- a/packaging.R +++ b/packaging.R @@ -15,7 +15,7 @@ library(usethis) use_description( fields = list( Title = "Interaction Statistics", - Version = "1.0.0", + Version = "1.0.1", Description = "Fast, model-agnostic implementation of different H-statistics introduced by Jerome H. Friedman and Bogdan E. Popescu (2008) . These statistics quantify interaction strength per feature, feature pair, diff --git a/tests/testthat/test_hstats.R b/tests/testthat/test_hstats.R index bdc35a4b..94f4ef4a 100644 --- a/tests/testthat/test_hstats.R +++ b/tests/testthat/test_hstats.R @@ -333,33 +333,25 @@ test_that("hstats() does not give an error with missing", { expect_equal(rownames(h2_pairwise(r, zero = FALSE)), "x1:x2") }) - -# library(gbm) -# -# fit <- gbm(Sepal.Length ~ ., data = iris, interaction.depth = 3, bag.fraction = 1) -# v <- names(iris)[-1] -# combs <- combn(v, 2, simplify = FALSE) -# p <- length(combs) -# -# res <- setNames(numeric(p), sapply(combs, paste, collapse = ":")) -# for (i in 1:p) { -# res[i] <- interact.gbm(fit, iris, i.var = combs[[i]], n.trees = fit$n.trees) -# } -# cbind(res[res > 0.0001]) -# # Sepal.Width:Petal.Length 0.10982072 -# # Sepal.Width:Petal.Width 0.17932506 -# # Sepal.Width:Species 0.21480383 -# # Petal.Length:Petal.Width 0.03702921 -# # Petal.Length:Species 0.06382609 -# -# # Crunching -# system.time( # 0.3 s -# s <- hstats(fit, v = v, X = iris, n.trees = fit$n.trees) -# ) -# h2_pairwise(s, squared = FALSE, sort = FALSE) -# # Sepal.Width:Petal.Length 0.10532810 -# # Sepal.Width:Petal.Width 0.16697609 -# # Sepal.Width:Species 0.17335494 -# # Petal.Length:Petal.Width 0.03245863 -# # Petal.Length:Species 0.06678683 -# +test_that("hstats() matches {iml} 0.11.1 in a specific case", { + fit <- lm(Sepal.Width ~ . + Sepal.Length:Species, data = iris) + + # library(iml) + # mod <- Predictor$new(fit, data = iris[-2L]) + # iml_overall <- Interaction$new(mod, grid.size = 150) + # Sepal.Length: 0.4634029 + # iml_pairwise <- Interaction$new(mod, grid.size = 150, feature = "Species") + # Sepal.Length:Species 0.2624154 + + H <- hstats(fit, X = iris[-2L], verbose = FALSE) + expect_equal( + c(h2_overall(H, squared = FALSE)["Sepal.Length", ]$M), + 0.4634029, + tolerance = 1e-5 + ) + expect_equal( + c(h2_pairwise(H, squared = FALSE)["Sepal.Length:Species", ]$M), + 0.2624154, + tolerance = 1e-5 + ) +})