Skip to content

Commit

Permalink
Merge pull request #93 from mayer79/tibble
Browse files Browse the repository at this point in the history
Fix tibble problem
  • Loading branch information
mayer79 authored Oct 28, 2023
2 parents 0767d35 + 557b66c commit 9fcab31
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 6 deletions.
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# hstats 1.0.1

## Bug fixes

- Using `BY` and `w` via column names would fail for tibbles. This problem was described in [#92](https://github.com/mayer79/hstats/issues/92) by @RoelVerbelen. Thx!

## Other changes

- Add unit tests to compare against {iml}.
Expand Down
18 changes: 15 additions & 3 deletions R/utils_input.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
prepare_by <- function(BY, X, by_size) {
if (length(BY) == 1L && BY %in% colnames(X)) {
by_name <- BY
BY <- X[, by_name]
if (is.data.frame(X)) {
BY <- X[[by_name]]
} else {
BY <- X[, by_name]
}
} else {
stopifnot(
NCOL(BY) == 1L,
Expand Down Expand Up @@ -44,7 +48,11 @@ prepare_by <- function(BY, X, by_size) {
prepare_w <- function(w, X) {
if (length(w) == 1L && w %in% colnames(X)) {
w_name <- w
w <- X[, w]
if (is.data.frame(X)) {
w <- X[[w]]
} else {
w <- X[, w]
}
} else {
stopifnot(
NCOL(w) == 1L,
Expand All @@ -69,7 +77,11 @@ prepare_w <- function(w, X) {
prepare_y <- function(y, X) {
if (NROW(y) < nrow(X) && all(y %in% colnames(X))) {
y_names <- y
y <- X[, y]
if (is.data.frame(X) && length(y) == 1L) {
y <- X[[y]]
} else {
y <- X[, y]
}
} else {
stopifnot(NROW(y) == nrow(X))
y_names <- NULL
Expand Down
3 changes: 2 additions & 1 deletion tests/testthat/test_perm_importance.R
Original file line number Diff line number Diff line change
Expand Up @@ -400,4 +400,5 @@ test_that("perm_importance() works with missing values", {
r <- perm_importance(fit, X = X, y = y, pred_fun = pf, verbose = FALSE)
)
expect_true(r$M[1L] > 0 && all(r$M[2:3] == 0))
})
})

6 changes: 4 additions & 2 deletions tests/testthat/test_statistics.R
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@ test_that(".zap_small() works for matrix input", {
fit <- lm(cbind(up = uptake, up2 = 2 * uptake) ~ Type * Treatment * conc, data = CO2)
H <- hstats(fit, X = CO2[2:4], verbose = FALSE)
s <- h2_pairwise(H)
imp <- perm_importance(fit, CO2, v = c("Type", "Treatment", "conc"), y = "uptake")
imp <- perm_importance(
fit, CO2, v = c("Type", "Treatment", "conc"), y = "uptake", verbose = FALSE
)

test_that("print() method does not give error", {
capture_output(expect_no_error(print(s)))
Expand Down Expand Up @@ -123,7 +125,7 @@ test_that("subsetting works", {

fit <- lm(uptake ~ Type * Treatment * conc, data = CO2)
set.seed(1L)
s <- perm_importance(fit, X = CO2[2:4], y = CO2$uptake)
s <- perm_importance(fit, X = CO2[2:4], y = CO2$uptake, verbose = FALSE)

test_that("print() method does not give error", {
capture_output(expect_no_error(print(s)))
Expand Down

0 comments on commit 9fcab31

Please sign in to comment.