From 557b66cee258438adf69647ced4f3da2c0af8b3a Mon Sep 17 00:00:00 2001 From: Michael Mayer Date: Sat, 28 Oct 2023 10:44:41 +0200 Subject: [PATCH] Fix tibble problem --- NEWS.md | 4 ++++ R/utils_input.R | 18 +++++++++++++++--- tests/testthat/test_perm_importance.R | 3 ++- tests/testthat/test_statistics.R | 6 ++++-- 4 files changed, 25 insertions(+), 6 deletions(-) diff --git a/NEWS.md b/NEWS.md index 748542bc..a2ee8d84 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,9 @@ # hstats 1.0.1 +## Bug fixes + +- Using `BY` and `w` via column names would fail for tibbles. This problem was described in [#92](https://github.com/mayer79/hstats/issues/92) by @RoelVerbelen. Thx! + ## Other changes - Add unit tests to compare against {iml}. diff --git a/R/utils_input.R b/R/utils_input.R index c59d5027..be353ee0 100644 --- a/R/utils_input.R +++ b/R/utils_input.R @@ -12,7 +12,11 @@ prepare_by <- function(BY, X, by_size) { if (length(BY) == 1L && BY %in% colnames(X)) { by_name <- BY - BY <- X[, by_name] + if (is.data.frame(X)) { + BY <- X[[by_name]] + } else { + BY <- X[, by_name] + } } else { stopifnot( NCOL(BY) == 1L, @@ -44,7 +48,11 @@ prepare_by <- function(BY, X, by_size) { prepare_w <- function(w, X) { if (length(w) == 1L && w %in% colnames(X)) { w_name <- w - w <- X[, w] + if (is.data.frame(X)) { + w <- X[[w]] + } else { + w <- X[, w] + } } else { stopifnot( NCOL(w) == 1L, @@ -69,7 +77,11 @@ prepare_w <- function(w, X) { prepare_y <- function(y, X) { if (NROW(y) < nrow(X) && all(y %in% colnames(X))) { y_names <- y - y <- X[, y] + if (is.data.frame(X) && length(y) == 1L) { + y <- X[[y]] + } else { + y <- X[, y] + } } else { stopifnot(NROW(y) == nrow(X)) y_names <- NULL diff --git a/tests/testthat/test_perm_importance.R b/tests/testthat/test_perm_importance.R index 155d7377..87156ef7 100644 --- a/tests/testthat/test_perm_importance.R +++ b/tests/testthat/test_perm_importance.R @@ -400,4 +400,5 @@ test_that("perm_importance() works with missing values", { r <- perm_importance(fit, X = X, y = y, pred_fun = pf, verbose = FALSE) ) expect_true(r$M[1L] > 0 && all(r$M[2:3] == 0)) -}) \ No newline at end of file +}) + diff --git a/tests/testthat/test_statistics.R b/tests/testthat/test_statistics.R index 06662b8c..6ed3748a 100644 --- a/tests/testthat/test_statistics.R +++ b/tests/testthat/test_statistics.R @@ -85,7 +85,9 @@ test_that(".zap_small() works for matrix input", { fit <- lm(cbind(up = uptake, up2 = 2 * uptake) ~ Type * Treatment * conc, data = CO2) H <- hstats(fit, X = CO2[2:4], verbose = FALSE) s <- h2_pairwise(H) -imp <- perm_importance(fit, CO2, v = c("Type", "Treatment", "conc"), y = "uptake") +imp <- perm_importance( + fit, CO2, v = c("Type", "Treatment", "conc"), y = "uptake", verbose = FALSE +) test_that("print() method does not give error", { capture_output(expect_no_error(print(s))) @@ -123,7 +125,7 @@ test_that("subsetting works", { fit <- lm(uptake ~ Type * Treatment * conc, data = CO2) set.seed(1L) -s <- perm_importance(fit, X = CO2[2:4], y = CO2$uptake) +s <- perm_importance(fit, X = CO2[2:4], y = CO2$uptake, verbose = FALSE) test_that("print() method does not give error", { capture_output(expect_no_error(print(s)))