Skip to content

Commit ae6250b

Browse files
committed
na.rm in grids
1 parent ebf6e38 commit ae6250b

File tree

9 files changed

+70
-36
lines changed

9 files changed

+70
-36
lines changed

NEWS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
- `average_loss()` is more flexible regarding the group `BY` argument. It can also be a variable *name*. Non-discrete `BY` variables are now automatically binned. Like `partial_dep()`, binning is controlled by the `by_size = 4` argument.
2121
- `average_loss()` also returns a "hstats_matrix" object with `print()` and `plot()` method. The values can be extracted via `$M`.
2222
- The default `v` of `hstats()` and `perm_importance()` is now `NULL`. Internally, it is set to `colnames(X)` (minus the column names of `w` and `y` if passed as name).
23+
- Missing grid values: `partial_dep()` and `ice()` have received a `na.rm = TRUE` argument that controls if missing values are dropped during grid creation. The default is compatible with earlier releases.
2324

2425
# hstats 0.3.0
2526

R/ice.R

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ ice <- function(object, ...) {
5959
ice.default <- function(object, v, X, pred_fun = stats::predict,
6060
BY = NULL, grid = NULL, grid_size = 49L,
6161
trim = c(0.01, 0.99),
62-
strategy = c("uniform", "quantile"), n_max = 100L, ...) {
62+
strategy = c("uniform", "quantile"), na.rm = TRUE,
63+
n_max = 100L, ...) {
6364
stopifnot(
6465
is.matrix(X) || is.data.frame(X),
6566
is.function(pred_fun),
@@ -69,7 +70,7 @@ ice.default <- function(object, v, X, pred_fun = stats::predict,
6970
# Prepare grid
7071
if (is.null(grid)) {
7172
grid <- multivariate_grid(
72-
x = X[, v], grid_size = grid_size, trim = trim, strategy = strategy
73+
x = X[, v], grid_size = grid_size, trim = trim, strategy = strategy, na.rm = na.rm
7374
)
7475
} else {
7576
check_grid(g = grid, v = v, X_is_matrix = is.matrix(X))
@@ -142,7 +143,8 @@ ice.ranger <- function(object, v, X,
142143
pred_fun = function(m, X, ...) stats::predict(m, X, ...)$predictions,
143144
BY = NULL, grid = NULL, grid_size = 49L,
144145
trim = c(0.01, 0.99),
145-
strategy = c("uniform", "quantile"), n_max = 100, ...) {
146+
strategy = c("uniform", "quantile"), na.rm = TRUE,
147+
n_max = 100, ...) {
146148
ice.default(
147149
object = object,
148150
v = v,
@@ -153,6 +155,7 @@ ice.ranger <- function(object, v, X,
153155
grid_size = grid_size,
154156
trim = trim,
155157
strategy = strategy,
158+
na.rm = na.rm,
156159
n_max = n_max,
157160
...
158161
)
@@ -163,7 +166,8 @@ ice.ranger <- function(object, v, X,
163166
ice.Learner <- function(object, v, X,
164167
pred_fun = NULL,
165168
BY = NULL, grid = NULL, grid_size = 49L, trim = c(0.01, 0.99),
166-
strategy = c("uniform", "quantile"), n_max = 100L, ...) {
169+
strategy = c("uniform", "quantile"), na.rm = TRUE,
170+
n_max = 100L, ...) {
167171
if (is.null(pred_fun)) {
168172
pred_fun <- mlr3_pred_fun(object, X = X)
169173
}
@@ -177,6 +181,7 @@ ice.Learner <- function(object, v, X,
177181
grid_size = grid_size,
178182
trim = trim,
179183
strategy = strategy,
184+
na.rm = na.rm,
180185
n_max = n_max,
181186
...
182187
)
@@ -188,7 +193,8 @@ ice.explainer <- function(object, v = v, X = object[["data"]],
188193
pred_fun = object[["predict_function"]],
189194
BY = NULL, grid = NULL, grid_size = 49L,
190195
trim = c(0.01, 0.99),
191-
strategy = c("uniform", "quantile"), n_max = 100, ...) {
196+
strategy = c("uniform", "quantile"), na.rm = TRUE,
197+
n_max = 100, ...) {
192198
ice.default(
193199
object = object[["model"]],
194200
v = v,
@@ -199,6 +205,7 @@ ice.explainer <- function(object, v = v, X = object[["data"]],
199205
grid_size = grid_size,
200206
trim = trim,
201207
strategy = strategy,
208+
na.rm = na.rm,
202209
n_max = n_max,
203210
...
204211
)

R/partial_dep.R

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,8 @@ partial_dep <- function(object, ...) {
9999
partial_dep.default <- function(object, v, X, pred_fun = stats::predict,
100100
BY = NULL, by_size = 4L, grid = NULL, grid_size = 49L,
101101
trim = c(0.01, 0.99),
102-
strategy = c("uniform", "quantile"), n_max = 1000L,
103-
w = NULL, ...) {
102+
strategy = c("uniform", "quantile"), na.rm = TRUE,
103+
n_max = 1000L, w = NULL, ...) {
104104
stopifnot(
105105
is.matrix(X) || is.data.frame(X),
106106
is.function(pred_fun),
@@ -110,7 +110,7 @@ partial_dep.default <- function(object, v, X, pred_fun = stats::predict,
110110
# Care about grid
111111
if (is.null(grid)) {
112112
grid <- multivariate_grid(
113-
x = X[, v], grid_size = grid_size, trim = trim, strategy = strategy
113+
x = X[, v], grid_size = grid_size, trim = trim, strategy = strategy, na.rm = na.rm
114114
)
115115
} else {
116116
check_grid(g = grid, v = v, X_is_matrix = is.matrix(X))
@@ -130,7 +130,7 @@ partial_dep.default <- function(object, v, X, pred_fun = stats::predict,
130130
out <- partial_dep.default(
131131
object = object,
132132
v = v,
133-
X = X[BY2$BY %in% b, , drop = FALSE],
133+
X = X[BY2$BY %in% b, , drop = FALSE], # works also when by is NA
134134
pred_fun = pred_fun,
135135
grid = grid,
136136
n_max = n_max,
@@ -185,8 +185,8 @@ partial_dep.ranger <- function(object, v, X,
185185
pred_fun = function(m, X, ...) stats::predict(m, X, ...)$predictions,
186186
BY = NULL, by_size = 4L, grid = NULL, grid_size = 49L,
187187
trim = c(0.01, 0.99),
188-
strategy = c("uniform", "quantile"), n_max = 1000L,
189-
w = NULL, ...) {
188+
strategy = c("uniform", "quantile"), na.rm = TRUE,
189+
n_max = 1000L, w = NULL, ...) {
190190
partial_dep.default(
191191
object = object,
192192
v = v,
@@ -198,6 +198,7 @@ partial_dep.ranger <- function(object, v, X,
198198
grid_size = grid_size,
199199
trim = trim,
200200
strategy = strategy,
201+
na.rm = na.rm,
201202
n_max = n_max,
202203
w = w,
203204
...
@@ -210,8 +211,8 @@ partial_dep.Learner <- function(object, v, X,
210211
pred_fun = NULL,
211212
BY = NULL, by_size = 4L, grid = NULL, grid_size = 49L,
212213
trim = c(0.01, 0.99),
213-
strategy = c("uniform", "quantile"), n_max = 1000L,
214-
w = NULL, ...) {
214+
strategy = c("uniform", "quantile"), na.rm = TRUE,
215+
n_max = 1000L, w = NULL, ...) {
215216
if (is.null(pred_fun)) {
216217
pred_fun <- mlr3_pred_fun(object, X = X)
217218
}
@@ -226,6 +227,7 @@ partial_dep.Learner <- function(object, v, X,
226227
grid_size = grid_size,
227228
trim = trim,
228229
strategy = strategy,
230+
na.rm = na.rm,
229231
n_max = n_max,
230232
w = w,
231233
...
@@ -238,8 +240,8 @@ partial_dep.explainer <- function(object, v, X = object[["data"]],
238240
pred_fun = object[["predict_function"]],
239241
BY = NULL, by_size = 4L, grid = NULL, grid_size = 49L,
240242
trim = c(0.01, 0.99),
241-
strategy = c("uniform", "quantile"), n_max = 1000L,
242-
w = object[["weights"]], ...) {
243+
strategy = c("uniform", "quantile"), na.rm = TRUE,
244+
n_max = 1000L, w = object[["weights"]], ...) {
243245
partial_dep.default(
244246
object = object[["model"]],
245247
v = v,
@@ -251,6 +253,7 @@ partial_dep.explainer <- function(object, v, X = object[["data"]],
251253
grid_size = grid_size,
252254
trim = trim,
253255
strategy = strategy,
256+
na.rm = na.rm,
254257
n_max = n_max,
255258
w = w,
256259
...

R/utils_grid.R

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#' of grid values. Set to `0:1` for no trimming.
2525
#' @param strategy How to find grid values of non-discrete numeric columns?
2626
#' Either "uniform" or "quantile", see description of [univariate_grid()].
27+
#' @param na.rm Should missing values be dropped from grid? Default is `TRUE`.
2728
#' @returns A vector or factor of evaluation points.
2829
#' @seealso [multivariate_grid()]
2930
#' @export
@@ -35,24 +36,29 @@
3536
#' univariate_grid(x, grid_size = 5) # Quantile binning
3637
#' univariate_grid(x, grid_size = 5, strategy = "uniform") # Uniform
3738
univariate_grid <- function(z, grid_size = 49L, trim = c(0.01, 0.99),
38-
strategy = c("uniform", "quantile")) {
39+
strategy = c("uniform", "quantile"), na.rm = TRUE) {
3940
strategy <- match.arg(strategy)
4041
uni <- unique(z)
4142
if (!is.numeric(z) || length(uni) <= grid_size) {
42-
return(sort(uni))
43+
out <- if (na.rm) sort(uni) else sort(uni, na.last = TRUE)
44+
return(out)
4345
}
4446

4547
# Non-discrete numeric
4648
if (strategy == "quantile") {
4749
p <- seq(trim[1L], trim[2L], length.out = grid_size)
4850
g <- stats::quantile(z, probs = p, names = FALSE, type = 1L, na.rm = TRUE)
49-
return(unique(g))
51+
out <- unique(g)
52+
} else {
53+
# strategy = "uniform" (could use range() if trim = 0:1)
54+
r <- stats::quantile(z, probs = trim, names = FALSE, type = 1L, na.rm = TRUE)
55+
# pretty(r, n = grid_size) # Until version 0.2.0
56+
out <- seq(r[1L], r[2L], length.out = grid_size)
5057
}
51-
52-
# strategy = "uniform" (could use range() if trim = 0:1)
53-
r <- stats::quantile(z, probs = trim, names = FALSE, type = 1L, na.rm = TRUE)
54-
# pretty(r, n = grid_size) # Until version 0.2.0
55-
seq(r[1L], r[2L], length.out = grid_size)
58+
if (!na.rm && anyNA(z)) {
59+
out <- c(out, NA)
60+
}
61+
return(out)
5662
}
5763

5864
#' Multivariate Grid
@@ -72,14 +78,17 @@ univariate_grid <- function(z, grid_size = 49L, trim = c(0.01, 0.99),
7278
#' multivariate_grid(iris$Species) # Works also in the univariate case
7379
#' @export
7480
multivariate_grid <- function(x, grid_size = 49L, trim = c(0.01, 0.99),
75-
strategy = c("uniform", "quantile")) {
81+
strategy = c("uniform", "quantile"), na.rm = TRUE) {
7682
strategy <- match.arg(strategy)
7783
p <- NCOL(x)
7884
if (p == 1L) {
7985
if (is.data.frame(x)) {
8086
x <- x[[1L]]
8187
}
82-
return(univariate_grid(x, grid_size = grid_size, trim = trim, strategy = strategy))
88+
out <- univariate_grid(
89+
x, grid_size = grid_size, trim = trim, strategy = strategy, na.rm = na.rm
90+
)
91+
return(out)
8392
}
8493
grid_size <- ceiling(grid_size^(1/p)) # take p's root of grid_size
8594
is_mat <- is.matrix(x)
@@ -89,7 +98,11 @@ multivariate_grid <- function(x, grid_size = 49L, trim = c(0.01, 0.99),
8998
out <- expand.grid(
9099
lapply(
91100
x,
92-
FUN = univariate_grid, grid_size = grid_size, trim = trim, strategy = strategy
101+
FUN = univariate_grid,
102+
grid_size = grid_size,
103+
trim = trim,
104+
strategy = strategy,
105+
na.rm = na.rm
93106
)
94107
)
95108
if (is_mat) as.matrix(out) else out

man/ice.Rd

Lines changed: 6 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/multivariate_grid.Rd

Lines changed: 4 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/partial_dep.Rd

Lines changed: 6 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/univariate_grid.Rd

Lines changed: 4 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test_utils.R

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -215,14 +215,6 @@ test_that("wcenter() works for vectors", {
215215
expect_equal(wcenter(x, w = w), xpected)
216216
})
217217

218-
test_that("basic_checks fire some errors", {
219-
expect_error(basic_check(X = 1:3, v = "a", pred_fun = predict, w = NULL))
220-
expect_error(basic_check(X = iris[0], v = "a", pred_fun = predict, w = NULL))
221-
expect_error(basic_check(X = iris, v = "a", pred_fun = predict, w = NULL))
222-
expect_error(basic_check(X = iris, v = "Species", pred_fun = "mean", w = NULL))
223-
expect_error(basic_check(X = iris, v = "Species", pred_fun = predict, w = 1:3))
224-
})
225-
226218
test_that("poor_man_stack() works (test could be improved", {
227219
y <- c("a", "b", "c")
228220
z <- c("aa", "bb", "cc")

0 commit comments

Comments
 (0)