Skip to content

Commit

Permalink
Update news
Browse files Browse the repository at this point in the history
  • Loading branch information
mayer79 committed Nov 7, 2023
1 parent 2d36df7 commit 2062c6f
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 30 deletions.
6 changes: 3 additions & 3 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

## Enhancements

- {hstats} now also work for factor predictions. The levels are represented by one-hot-encoded columns.
- {hstats} now also work for factor predictions. The levels are represented by one-hot-encoded columns ([PR#101](https://github.com/mayer79/hstats/pull/101)).
- The plot method of a two-dimensional PDP has recieved the option `d2_geom = "line"`. Instead of a heatmap of the two features, one of the features is moved to color grouping. Combined with `swap_dim = TRUE`, you can swap the role of the two `v` variables without recalculating anything. The idea was proposed by [Roel Verbelen](https://github.com/RoelVerbelen) in [issue #91](https://github.com/mayer79/hstats/issues/91), see also [issue #94](https://github.com/mayer79/hstats/issues/94).

## Bug fixes
Expand All @@ -11,8 +11,8 @@

## Other changes

- Much faster one-hot-encoding, thanks to Mathias Ambühl.
- Most functions are slightly faster.
- Much faster one-hot-encoding, thanks to Mathias Ambühl ([PR#101](https://github.com/mayer79/hstats/pull/101)).
- Most functions are slightly faster ([PR#101](https://github.com/mayer79/hstats/pull/101)).
- Add unit tests to compare against {iml}.
- Made all examples "tibble" and "data.table" friendly.
- Revised input checks in loss functions (relevant for `perm_importance()` and `average_loss()`).
Expand Down
57 changes: 30 additions & 27 deletions backlog/calibration.R
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,12 @@ calibration <- function(object, ...) {

#' @describeIn calibration Default method.
#' @export
calibration.default <- function(object, v, X, y = NULL, pred_fun = stats::predict,
calibration.default <- function(object, v, X, y = NULL,
pred_fun = stats::predict,
BY = NULL, by_size = 4L,
grid_size = 17L,
breaks = 17L, trim = c(0.01, 0.99),
include.lowest = TRUE,
right = TRUE, na.rm = TRUE,
pred = NULL,
n_max = 1000L, w = NULL, ...) {
stopifnot(
Expand All @@ -70,11 +73,7 @@ calibration.default <- function(object, v, X, y = NULL, pred_fun = stats::predic
)

if (!is.null(y)) {
y <- prepare_y(y = y, X = X)[["y"]]
if (is.factor(y) || is.character(y)) {
y <- stats::model.matrix(~ as.factor(y) + 0)
}
y <- align_pred(y)
y <- prepare_y(y = y, X = X, ohe = TRUE)[["y"]]
}
if (!is.null(w)) {
w <- prepare_w(w = w, X = X)[["w"]]
Expand All @@ -83,49 +82,53 @@ calibration.default <- function(object, v, X, y = NULL, pred_fun = stats::predic
BY2 <- prepare_by(BY = BY, X = X, by_size = by_size)
BY <- BY2[["BY"]]
}
g <- v_grouped <- approx_vector(X[[v]], m = grid_size)
grid <- sort(unique(v_grouped), na.last = TRUE)

h <- hist2(
X[[v]],
breaks = breaks,
trim = trim,
include.lowest = include.lowest,
right = right,
na.rm = TRUE
)

if (!is.null(BY)) {
g <- paste(BY, g, sep = ":")
g <- paste(BY, h$x, sep = ":")
} else {
g <- h$x
}

# Average predicted
if (is.null(pred)) {
pred <- pred_fun(object, X, ...)
}
pred <- align_pred(pred)
tmp <- gwColMeans(pred, g = g, w = w, mean_only = FALSE)
avg_pred <- tmp[["mean"]]

# Exposure
exposure <- tmp[["denom"]]

pred <- prepare_pred(pred, ohe = TRUE)
pr <- gwColMeans(pred, g = g, w = w)

# Average observed
avg_obs <- if (!is.null(y)) gwColMeans(y, g = g, w = w)
avg_obs <- if (!is.null(y)) gwColMeans(y, g = g, w = w)$M

# Partial dependence
pd <- partial_dep(
object = object,
v = v,
X = X,
grid = grid,
grid = h$grid,
pred_fun = pred_fun,
BY = BY,
w = w,
...
)[["data"]]
)

out <- list(
v = v,
K = ncol(pred),
pred_names = colnames(pred),
grid = grid,
K = ncol(pr$M),
pred_names = colnames(pr$M),
grid = h[-1L],
BY,
avg_obs = avg_obs,
avg_pred = avg_pred,
pd = pd,
exposure = exposure
avg_pred = pr$M,
pd = pd[["data"]],
exposure = pr$w
)
return(structure(out, class = "calibration"))
}
Expand Down

0 comments on commit 2062c6f

Please sign in to comment.