Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mlr3 simplification #112

Merged
merged 2 commits into from
Dec 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 1 addition & 7 deletions .github/workflows/test-coverage.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,22 +33,16 @@ jobs:
clean = FALSE,
install_path = file.path(Sys.getenv("RUNNER_TEMP"), "package"),
function_exclusions = c(
"partial_dep\\.Learner",
"partial_dep\\.ranger",
"partial_dep\\.explainer",
"ice\\.Learner",
"ice\\.ranger",
"ice\\.explainer",
"hstats\\.Learner",
"hstats\\.ranger",
"hstats\\.explainer",
"perm_importance\\.Learner",
"perm_importance\\.ranger",
"perm_importance\\.explainer",
"average_loss\\.Learner",
"average_loss\\.ranger",
"average_loss\\.explainer",
"mlr3_pred_fun"
"average_loss\\.explainer"
)
)
shell: Rscript {0}
Expand Down
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: hstats
Title: Interaction Statistics
Version: 1.1.1
Version: 1.1.2
Authors@R:
person("Michael", "Mayer", , "[email protected]", role = c("aut", "cre"))
Description: Fast, model-agnostic implementation of different H-statistics
Expand Down
5 changes: 0 additions & 5 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

S3method("[",hstats_matrix)
S3method("dimnames<-",hstats_matrix)
S3method(average_loss,Learner)
S3method(average_loss,default)
S3method(average_loss,explainer)
S3method(average_loss,ranger)
Expand All @@ -16,21 +15,17 @@ S3method(h2_pairwise,default)
S3method(h2_pairwise,hstats)
S3method(h2_threeway,default)
S3method(h2_threeway,hstats)
S3method(hstats,Learner)
S3method(hstats,default)
S3method(hstats,explainer)
S3method(hstats,ranger)
S3method(ice,Learner)
S3method(ice,default)
S3method(ice,explainer)
S3method(ice,ranger)
S3method(partial_dep,Learner)
S3method(partial_dep,default)
S3method(partial_dep,explainer)
S3method(partial_dep,ranger)
S3method(pd_importance,default)
S3method(pd_importance,hstats)
S3method(perm_importance,Learner)
S3method(perm_importance,default)
S3method(perm_importance,explainer)
S3method(perm_importance,ranger)
Expand Down
7 changes: 7 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
# hstats 1.1.2

## API

- {mlr3}: Non-probabilistic classification now works.
- {mlr3}: For *probabilistic* classification, you now have to pass `predict_type = "prob"`.

# hstats 1.1.1

## Performance improvements
Expand Down
25 changes: 0 additions & 25 deletions R/average_loss.R
Original file line number Diff line number Diff line change
Expand Up @@ -135,31 +135,6 @@ average_loss.ranger <- function(object, X, y,
)
}

#' @describeIn average_loss Method for "mlr3" models.
#' @export
average_loss.Learner <- function(object, X, y,
pred_fun = NULL,
loss = "squared_error",
agg_cols = FALSE,
BY = NULL, by_size = 4L,
w = NULL, ...) {
if (is.null(pred_fun)) {
pred_fun <- mlr3_pred_fun(object, X = X)
}
average_loss.default(
object = object,
X = X,
y = y,
pred_fun = pred_fun,
loss = loss,
agg_cols = agg_cols,
BY = BY,
by_size = by_size,
w = w,
...
)
}

#' @describeIn average_loss Method for DALEX "explainer".
#' @export
average_loss.explainer <- function(object,
Expand Down
28 changes: 0 additions & 28 deletions R/hstats.R
Original file line number Diff line number Diff line change
Expand Up @@ -300,34 +300,6 @@ hstats.ranger <- function(object, X, v = NULL,
)
}

#' @describeIn hstats Method for "mlr3" models.
#' @export
hstats.Learner <- function(object, X, v = NULL,
pred_fun = NULL,
pairwise_m = 5L, threeway_m = 0L,
approx = FALSE, grid_size = 50L,
n_max = 500L, eps = 1e-10,
w = NULL, verbose = TRUE, ...) {
if (is.null(pred_fun)) {
pred_fun <- mlr3_pred_fun(object, X = X)
}
hstats.default(
object = object,
X = X,
v = v,
pred_fun = pred_fun,
pairwise_m = pairwise_m,
threeway_m = threeway_m,
approx = approx,
grid_size = grid_size,
n_max = n_max,
eps = eps,
w = w,
verbose = verbose,
...
)
}

#' @describeIn hstats Method for DALEX "explainer".
#' @export
hstats.explainer <- function(object, X = object[["data"]],
Expand Down
26 changes: 0 additions & 26 deletions R/ice.R
Original file line number Diff line number Diff line change
Expand Up @@ -173,32 +173,6 @@ ice.ranger <- function(object, v, X,
)
}

#' @describeIn ice Method for "mlr3" models.
#' @export
ice.Learner <- function(object, v, X,
pred_fun = NULL,
BY = NULL, grid = NULL, grid_size = 49L, trim = c(0.01, 0.99),
strategy = c("uniform", "quantile"), na.rm = TRUE,
n_max = 100L, ...) {
if (is.null(pred_fun)) {
pred_fun <- mlr3_pred_fun(object, X = X)
}
ice.default(
object = object,
v = v,
X = X,
pred_fun = pred_fun,
BY = BY,
grid = grid,
grid_size = grid_size,
trim = trim,
strategy = strategy,
na.rm = na.rm,
n_max = n_max,
...
)
}

#' @describeIn ice Method for DALEX "explainer".
#' @export
ice.explainer <- function(object, v = v, X = object[["data"]],
Expand Down
29 changes: 0 additions & 29 deletions R/partial_dep.R
Original file line number Diff line number Diff line change
Expand Up @@ -213,35 +213,6 @@ partial_dep.ranger <- function(object, v, X,
)
}

#' @describeIn partial_dep Method for "mlr3" models.
#' @export
partial_dep.Learner <- function(object, v, X,
pred_fun = NULL,
BY = NULL, by_size = 4L, grid = NULL, grid_size = 49L,
trim = c(0.01, 0.99),
strategy = c("uniform", "quantile"), na.rm = TRUE,
n_max = 1000L, w = NULL, ...) {
if (is.null(pred_fun)) {
pred_fun <- mlr3_pred_fun(object, X = X)
}
partial_dep.default(
object = object,
v = v,
X = X,
pred_fun = pred_fun,
BY = BY,
by_size = by_size,
grid = grid,
grid_size = grid_size,
trim = trim,
strategy = strategy,
na.rm = na.rm,
n_max = n_max,
w = w,
...
)
}

#' @describeIn partial_dep Method for DALEX "explainer".
#' @export
partial_dep.explainer <- function(object, v, X = object[["data"]],
Expand Down
28 changes: 0 additions & 28 deletions R/perm_importance.R
Original file line number Diff line number Diff line change
Expand Up @@ -228,34 +228,6 @@ perm_importance.ranger <- function(object, X, y, v = NULL,
)
}

#' @describeIn perm_importance Method for "mlr3" models.
#' @export
perm_importance.Learner <- function(object, X, y, v = NULL,
pred_fun = NULL,
loss = "squared_error", m_rep = 4L,
agg_cols = FALSE,
normalize = FALSE, n_max = 10000L,
w = NULL, verbose = TRUE, ...) {
if (is.null(pred_fun)) {
pred_fun <- mlr3_pred_fun(object, X = X)
}
perm_importance.default(
object = object,
X = X,
y = y,
v = v,
pred_fun = pred_fun,
loss = loss,
m_rep = m_rep,
agg_cols = agg_cols,
normalize = normalize,
n_max = n_max,
w = w,
verbose = verbose,
...
)
}

#' @describeIn perm_importance Method for DALEX "explainer".
#' @export
perm_importance.explainer <- function(object,
Expand Down
23 changes: 0 additions & 23 deletions R/utils_input.R
Original file line number Diff line number Diff line change
Expand Up @@ -114,26 +114,3 @@ prepare_y <- function(y, X, ohe = FALSE) {
list(y = prepare_pred(y, ohe = ohe), y_names = y_names)
}

#' mlr3 Helper
#'
#' Returns the prediction function of a mlr3 Learner.
#'
#' @noRd
#' @keywords internal
#'
#' @param object Learner object.
#' @param X Dataframe like object.
#'
#' @returns A function.
mlr3_pred_fun <- function(object, X) {
if ("classif" %in% object$task_type) {
# Check if probabilities are available
test_pred <- object$predict_newdata(utils::head(X))
if ("prob" %in% test_pred$predict_types) {
return(function(m, X) m$predict_newdata(X)$prob)
} else {
stop("Set lrn(..., predict_type = 'prob') to allow for probabilistic classification.")
}
}
function(m, X) m$predict_newdata(X)$response
}
18 changes: 13 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ Strongest relative interaction shown as ICE plot.

## Multivariate responses

{hstats} works also with multivariate output, see examples with
{hstats} works also with multivariate output, see examples for probabilistic classification with

- ranger,
- LightGBM, and
Expand Down Expand Up @@ -377,7 +377,9 @@ plot(H, normalize = FALSE, squared = FALSE, facet_scales = "free_y", ncol = 1)

![](man/figures/xgboost.svg)

### (Non-probabilistic) classification works as well
### Non-probabilistic classification

When predictions are factor levels, {hstats} uses internal one-hot-encoding.

```r
library(ranger)
Expand All @@ -404,7 +406,7 @@ partial_dep(fit, v = "Petal.Length", X = train, BY = "Petal.Width") |>

## Meta-learning packages

Here, we provide some working examples for "tidymodels", "caret", and "mlr3".
Here, we provide examples for {tidymodels}, {caret}, and {mlr3}.

### tidymodels

Expand Down Expand Up @@ -478,8 +480,14 @@ fit_rf$train(task_iris)
s <- hstats(fit_rf, X = iris[, -5])
plot(s)

# Permutation importance
perm_importance(fit_rf, X = iris, y = "Species", loss = "mlogloss") |>
# Permutation importance (probabilistic using multi-logloss)
p <- perm_importance(
fit_rf, X = iris, y = "Species", loss = "mlogloss", predict_type = "prob"
)
plot(p)

# Non-probabilistic using classification error
perm_importance(fit_rf, X = iris, y = "Species", loss = "classification_error") |>
plot()
```

Expand Down
16 changes: 0 additions & 16 deletions man/average_loss.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 0 additions & 19 deletions man/hstats.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading