Skip to content

Commit

Permalink
Merge pull request #97 from mayer79/tibble-datatable
Browse files Browse the repository at this point in the history
Turn examples tibble and data.table friendly
  • Loading branch information
mayer79 authored Oct 29, 2023
2 parents 3dd2df2 + c56d59b commit 98a3f15
Show file tree
Hide file tree
Showing 23 changed files with 74 additions and 69 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
## Other changes

- Add unit tests to compare against {iml}.
- Made all examples "tibble" and "data.table" friendly.

# hstats 1.0.0

Expand Down
8 changes: 4 additions & 4 deletions R/H2.R
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,17 @@
#' @examples
#' # MODEL 1: Linear regression
#' fit <- lm(Sepal.Length ~ . + Petal.Width:Species, data = iris)
#' s <- hstats(fit, X = iris[-1])
#' s <- hstats(fit, X = iris[, -1])
#' h2(s)
#'
#' # MODEL 2: Multi-response linear regression
#' fit <- lm(as.matrix(iris[1:2]) ~ Petal.Length + Petal.Width * Species, data = iris)
#' s <- hstats(fit, X = iris[3:5])
#' fit <- lm(as.matrix(iris[, 1:2]) ~ Petal.Length + Petal.Width * Species, data = iris)
#' s <- hstats(fit, X = iris[, 3:5])
#' h2(s)
#'
#' # MODEL 3: No interactions
#' fit <- lm(Sepal.Length ~ ., data = iris)
#' s <- hstats(fit, X = iris[-1], verbose = FALSE)
#' s <- hstats(fit, X = iris[, -1], verbose = FALSE)
#' h2(s)
h2 <- function(object, ...) {
UseMethod("h2")
Expand Down
6 changes: 3 additions & 3 deletions R/H2_overall.R
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,13 @@
#' @examples
#' # MODEL 1: Linear regression
#' fit <- lm(Sepal.Length ~ . + Petal.Width:Species, data = iris)
#' s <- hstats(fit, X = iris[-1])
#' s <- hstats(fit, X = iris[, -1])
#' h2_overall(s)
#' plot(h2_overall(s))
#'
#' # MODEL 2: Multi-response linear regression
#' fit <- lm(as.matrix(iris[1:2]) ~ Petal.Length + Petal.Width * Species, data = iris)
#' s <- hstats(fit, X = iris[3:5], verbose = FALSE)
#' fit <- lm(as.matrix(iris[, 1:2]) ~ Petal.Length + Petal.Width * Species, data = iris)
#' s <- hstats(fit, X = iris[, 3:5], verbose = FALSE)
#' plot(h2_overall(s, zero = FALSE))
h2_overall <- function(object, ...) {
UseMethod("h2_overall")
Expand Down
6 changes: 3 additions & 3 deletions R/H2_pairwise.R
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
#' @examples
#' # MODEL 1: Linear regression
#' fit <- lm(Sepal.Length ~ . + Petal.Width:Species, data = iris)
#' s <- hstats(fit, X = iris[-1])
#' s <- hstats(fit, X = iris[, -1])
#'
#' # Proportion of joint effect coming from pairwise interaction
#' # (for features with strongest overall interactions)
Expand All @@ -65,8 +65,8 @@
#' abs_h$M
#'
#' # MODEL 2: Multi-response linear regression
#' fit <- lm(as.matrix(iris[1:2]) ~ Petal.Length + Petal.Width * Species, data = iris)
#' s <- hstats(fit, X = iris[3:5], verbose = FALSE)
#' fit <- lm(as.matrix(iris[, 1:2]) ~ Petal.Length + Petal.Width * Species, data = iris)
#' s <- hstats(fit, X = iris[, 3:5], verbose = FALSE)
#' x <- h2_pairwise(s)
#' plot(x)
h2_pairwise <- function(object, ...) {
Expand Down
4 changes: 2 additions & 2 deletions R/H2_threeway.R
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,12 @@
#' @examples
#' # MODEL 1: Linear regression
#' fit <- lm(uptake ~ Type * Treatment * conc, data = CO2)
#' s <- hstats(fit, X = CO2[2:4], threeway_m = 5)
#' s <- hstats(fit, X = CO2[, 2:4], threeway_m = 5)
#' h2_threeway(s)
#'
#' #' MODEL 2: Multivariate output (taking just twice the same response as example)
#' fit <- lm(cbind(up = uptake, up2 = 2 * uptake) ~ Type * Treatment * conc, data = CO2)
#' s <- hstats(fit, X = CO2[2:4], threeway_m = 5)
#' s <- hstats(fit, X = CO2[, 2:4], threeway_m = 5)
#' h2_threeway(s)
#' h2_threeway(s, normalize = FALSE, squared = FALSE) # Unnormalized H
#' plot(h2_threeway(s))
Expand Down
6 changes: 3 additions & 3 deletions R/average_loss.R
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,10 @@
#' average_loss(fit, X = iris, y = "Sepal.Length", BY = "Sepal.Width")
#'
#' # MODEL 2: Multi-response linear regression
#' fit <- lm(as.matrix(iris[1:2]) ~ Petal.Length + Petal.Width + Species, data = iris)
#' average_loss(fit, X = iris, y = iris[1:2])
#' fit <- lm(as.matrix(iris[, 1:2]) ~ Petal.Length + Petal.Width + Species, data = iris)
#' average_loss(fit, X = iris, y = iris[, 1:2])
#' L <- average_loss(
#' fit, X = iris, y = iris[1:2], loss = "gamma", BY = "Species"
#' fit, X = iris, y = iris[, 1:2], loss = "gamma", BY = "Species"
#' )
#' L
#' plot(L)
Expand Down
10 changes: 5 additions & 5 deletions R/hstats.R
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@
#' @examples
#' # MODEL 1: Linear regression
#' fit <- lm(Sepal.Length ~ . + Petal.Width:Species, data = iris)
#' s <- hstats(fit, X = iris[-1])
#' s <- hstats(fit, X = iris[, -1])
#' s
#' plot(s)
#' plot(s, zero = FALSE) # Drop 0
Expand All @@ -115,21 +115,21 @@
#' h2_pairwise(s, normalize = FALSE, squared = FALSE, zero = FALSE)
#'
#' # MODEL 2: Multi-response linear regression
#' fit <- lm(as.matrix(iris[1:2]) ~ Petal.Length + Petal.Width * Species, data = iris)
#' s <- hstats(fit, X = iris[3:5], verbose = FALSE)
#' fit <- lm(as.matrix(iris[, 1:2]) ~ Petal.Length + Petal.Width * Species, data = iris)
#' s <- hstats(fit, X = iris[, 3:5], verbose = FALSE)
#' plot(s)
#' summary(s)
#'
#' # MODEL 3: Gamma GLM with log link
#' fit <- glm(Sepal.Length ~ ., data = iris, family = Gamma(link = log))
#'
#' # No interactions for additive features, at least on link scale
#' s <- hstats(fit, X = iris[-1], verbose = FALSE)
#' s <- hstats(fit, X = iris[, -1], verbose = FALSE)
#' summary(s)
#'
#' # On original scale, we have interactions everywhere.
#' # To see three-way interactions, we set threeway_m to a value above 2.
#' s <- hstats(fit, X = iris[-1], type = "response", threeway_m = 5)
#' s <- hstats(fit, X = iris[, -1], type = "response", threeway_m = 5)
#' plot(s, ncol = 1) # All three types use different denominators
#'
#' # All statistics on same scale (of predictions)
Expand Down
2 changes: 1 addition & 1 deletion R/ice.R
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
#' plot(ic, center = TRUE)
#'
#' # MODEL 2: Multi-response linear regression
#' fit <- lm(as.matrix(iris[1:2]) ~ Petal.Length + Petal.Width * Species, data = iris)
#' fit <- lm(as.matrix(iris[, 1:2]) ~ Petal.Length + Petal.Width * Species, data = iris)
#' ic <- ice(fit, v = "Petal.Width", X = iris, BY = iris$Species)
#' plot(ic)
#' plot(ic, center = TRUE)
Expand Down
2 changes: 1 addition & 1 deletion R/partial_dep.R
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@
#' plot(pd, rotate_x = TRUE, d2_geom = "line", swap_dim = TRUE)
#'
#' # MODEL 2: Multi-response linear regression
#' fit <- lm(as.matrix(iris[1:2]) ~ Petal.Length + Petal.Width * Species, data = iris)
#' fit <- lm(as.matrix(iris[, 1:2]) ~ Petal.Length + Petal.Width * Species, data = iris)
#' pd <- partial_dep(fit, v = "Petal.Width", X = iris, BY = "Species")
#' plot(pd, show_points = FALSE)
#' pd <- partial_dep(fit, v = c("Species", "Petal.Width"), X = iris)
Expand Down
6 changes: 3 additions & 3 deletions R/pd_importance.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@
#' @examples
#' # MODEL 1: Linear regression
#' fit <- lm(Sepal.Length ~ . , data = iris)
#' s <- hstats(fit, X = iris[-1])
#' s <- hstats(fit, X = iris[, -1])
#' plot(pd_importance(s))
#'
#' # MODEL 2: Multi-response linear regression
#' fit <- lm(as.matrix(iris[1:2]) ~ Petal.Length + Petal.Width + Species, data = iris)
#' s <- hstats(fit, X = iris[3:5])
#' fit <- lm(as.matrix(iris[, 1:2]) ~ Petal.Length + Petal.Width + Species, data = iris)
#' s <- hstats(fit, X = iris[, 3:5])
#' plot(pd_importance(s))
pd_importance <- function(object, ...) {
UseMethod("pd_importance")
Expand Down
4 changes: 2 additions & 2 deletions R/perm_importance.R
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@
#' plot(s)
#'
#' # MODEL 2: Multi-response linear regression
#' fit <- lm(as.matrix(iris[1:2]) ~ Petal.Length + Petal.Width + Species, data = iris)
#' s <- perm_importance(fit, X = iris[3:5], y = iris[1:2], normalize = TRUE)
#' fit <- lm(as.matrix(iris[, 1:2]) ~ Petal.Length + Petal.Width + Species, data = iris)
#' s <- perm_importance(fit, X = iris[, 3:5], y = iris[, 1:2], normalize = TRUE)
#' s
#' plot(s)
#' plot(s, swap_dim = TRUE, top_m = 2)
Expand Down
18 changes: 9 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ library(ranger)
set.seed(1)

fit <- ranger(Sepal.Length ~ ., data = iris)
ex <- DALEX::explain(fit, data = iris[-1], y = iris[, 1])
ex <- DALEX::explain(fit, data = iris[, -1], y = iris[, 1])

s <- hstats(ex)
s # 0.054
Expand Down Expand Up @@ -246,10 +246,10 @@ ix <- c(1:40, 51:90, 101:140)
train <- iris[ix, ]
valid <- iris[-ix, ]

X_train <- data.matrix(train[-5])
X_valid <- data.matrix(valid[-5])
y_train <- train[, 5]
y_valid <- valid[, 5]
X_train <- data.matrix(train[, -5])
X_valid <- data.matrix(valid[, -5])
y_train <- train[[5]]
y_valid <- valid[[5]]
```

### ranger
Expand All @@ -264,7 +264,7 @@ average_loss(fit, X = valid, y = "Species", loss = "mlogloss") # 0.02

perm_importance(fit, X = iris, y = "Species", loss = "mlogloss")

(s <- hstats(fit, X = iris[-5]))
(s <- hstats(fit, X = iris[, -5]))
plot(s, normalize = FALSE, squared = FALSE)

ice(fit, v = "Petal.Length", X = iris, BY = "Petal.Width") |>
Expand Down Expand Up @@ -400,7 +400,7 @@ iris_wf <- workflow() |>
fit <- iris_wf |>
fit(iris)

s <- hstats(fit, X = iris[-1])
s <- hstats(fit, X = iris[, -1])
s # 0 -> no interactions
plot(partial_dep(fit, v = "Petal.Width", X = iris))

Expand Down Expand Up @@ -429,7 +429,7 @@ fit <- train(
trControl = trainControl(method = "none")
)

h2(hstats(fit, X = iris[-1])) # 0
h2(hstats(fit, X = iris[, -1])) # 0

plot(ice(fit, v = "Petal.Width", X = iris), center = TRUE)
plot(perm_importance(fit, X = iris, y = "Sepal.Length"))
Expand All @@ -448,7 +448,7 @@ set.seed(1)
task_iris <- TaskClassif$new(id = "class", backend = iris, target = "Species")
fit_rf <- lrn("classif.ranger", predict_type = "prob")
fit_rf$train(task_iris)
s <- hstats(fit_rf, X = iris[-5])
s <- hstats(fit_rf, X = iris[, -5])
plot(s)

# Permutation importance
Expand Down
16 changes: 10 additions & 6 deletions backlog/calibration.R
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,11 @@ calibration.default <- function(object, v, X, y = NULL, pred_fun = stats::predic
)

if (!is.null(y)) {
y <- align_pred(prepare_y(y = y, X = X)[["y"]])
y <- prepare_y(y = y, X = X)[["y"]]
if (is.factor(y) || is.character(y)) {
y <- stats::model.matrix(~ as.factor(y) + 0)
}
y <- align_pred(y)
}
if (!is.null(w)) {
w <- prepare_w(w = w, X = X)[["w"]]
Expand All @@ -91,15 +95,15 @@ calibration.default <- function(object, v, X, y = NULL, pred_fun = stats::predic
pred <- pred_fun(object, X, ...)
}
pred <- align_pred(pred)
avg_pred <- gwColMeans(pred, g = g, w = w)
tmp <- gwColMeans(pred, g = g, w = w, mean_only = FALSE)
avg_pred <- tmp[["mean"]]

# Exposure
exposure <- tmp[["denom"]]

# Average observed
avg_obs <- if (!is.null(y)) gwColMeans(y, g = g, w = w)

# Exposure
ww <- if (is.null(w)) rep.int(1, NROW(X)) else w
exposure <- rowsum(ww, group = g)

# Partial dependence
pd <- partial_dep(
object = object,
Expand Down
8 changes: 4 additions & 4 deletions man/H2.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions man/H2_overall.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions man/H2_pairwise.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions man/H2_threeway.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions man/average_loss.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 5 additions & 5 deletions man/hstats.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 98a3f15

Please sign in to comment.