Skip to content

Mlr3 simplification #112

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 26, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 1 addition & 7 deletions .github/workflows/test-coverage.yaml
Original file line number Diff line number Diff line change
@@ -33,22 +33,16 @@ jobs:
clean = FALSE,
install_path = file.path(Sys.getenv("RUNNER_TEMP"), "package"),
function_exclusions = c(
"partial_dep\\.Learner",
"partial_dep\\.ranger",
"partial_dep\\.explainer",
"ice\\.Learner",
"ice\\.ranger",
"ice\\.explainer",
"hstats\\.Learner",
"hstats\\.ranger",
"hstats\\.explainer",
"perm_importance\\.Learner",
"perm_importance\\.ranger",
"perm_importance\\.explainer",
"average_loss\\.Learner",
"average_loss\\.ranger",
"average_loss\\.explainer",
"mlr3_pred_fun"
"average_loss\\.explainer"
)
)
shell: Rscript {0}
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: hstats
Title: Interaction Statistics
Version: 1.1.1
Version: 1.1.2
Authors@R:
person("Michael", "Mayer", , "mayermichael79@gmail.com", role = c("aut", "cre"))
Description: Fast, model-agnostic implementation of different H-statistics
5 changes: 0 additions & 5 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -2,7 +2,6 @@

S3method("[",hstats_matrix)
S3method("dimnames<-",hstats_matrix)
S3method(average_loss,Learner)
S3method(average_loss,default)
S3method(average_loss,explainer)
S3method(average_loss,ranger)
@@ -16,21 +15,17 @@ S3method(h2_pairwise,default)
S3method(h2_pairwise,hstats)
S3method(h2_threeway,default)
S3method(h2_threeway,hstats)
S3method(hstats,Learner)
S3method(hstats,default)
S3method(hstats,explainer)
S3method(hstats,ranger)
S3method(ice,Learner)
S3method(ice,default)
S3method(ice,explainer)
S3method(ice,ranger)
S3method(partial_dep,Learner)
S3method(partial_dep,default)
S3method(partial_dep,explainer)
S3method(partial_dep,ranger)
S3method(pd_importance,default)
S3method(pd_importance,hstats)
S3method(perm_importance,Learner)
S3method(perm_importance,default)
S3method(perm_importance,explainer)
S3method(perm_importance,ranger)
7 changes: 7 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
# hstats 1.1.2

## API

- {mlr3}: Non-probabilistic classification now works.
- {mlr3}: For *probabilistic* classification, you now have to pass `predict_type = "prob"`.

# hstats 1.1.1

## Performance improvements
25 changes: 0 additions & 25 deletions R/average_loss.R
Original file line number Diff line number Diff line change
@@ -135,31 +135,6 @@ average_loss.ranger <- function(object, X, y,
)
}

#' @describeIn average_loss Method for "mlr3" models.
#' @export
average_loss.Learner <- function(object, X, y,
pred_fun = NULL,
loss = "squared_error",
agg_cols = FALSE,
BY = NULL, by_size = 4L,
w = NULL, ...) {
if (is.null(pred_fun)) {
pred_fun <- mlr3_pred_fun(object, X = X)
}
average_loss.default(
object = object,
X = X,
y = y,
pred_fun = pred_fun,
loss = loss,
agg_cols = agg_cols,
BY = BY,
by_size = by_size,
w = w,
...
)
}

#' @describeIn average_loss Method for DALEX "explainer".
#' @export
average_loss.explainer <- function(object,
28 changes: 0 additions & 28 deletions R/hstats.R
Original file line number Diff line number Diff line change
@@ -300,34 +300,6 @@ hstats.ranger <- function(object, X, v = NULL,
)
}

#' @describeIn hstats Method for "mlr3" models.
#' @export
hstats.Learner <- function(object, X, v = NULL,
pred_fun = NULL,
pairwise_m = 5L, threeway_m = 0L,
approx = FALSE, grid_size = 50L,
n_max = 500L, eps = 1e-10,
w = NULL, verbose = TRUE, ...) {
if (is.null(pred_fun)) {
pred_fun <- mlr3_pred_fun(object, X = X)
}
hstats.default(
object = object,
X = X,
v = v,
pred_fun = pred_fun,
pairwise_m = pairwise_m,
threeway_m = threeway_m,
approx = approx,
grid_size = grid_size,
n_max = n_max,
eps = eps,
w = w,
verbose = verbose,
...
)
}

#' @describeIn hstats Method for DALEX "explainer".
#' @export
hstats.explainer <- function(object, X = object[["data"]],
26 changes: 0 additions & 26 deletions R/ice.R
Original file line number Diff line number Diff line change
@@ -173,32 +173,6 @@ ice.ranger <- function(object, v, X,
)
}

#' @describeIn ice Method for "mlr3" models.
#' @export
ice.Learner <- function(object, v, X,
pred_fun = NULL,
BY = NULL, grid = NULL, grid_size = 49L, trim = c(0.01, 0.99),
strategy = c("uniform", "quantile"), na.rm = TRUE,
n_max = 100L, ...) {
if (is.null(pred_fun)) {
pred_fun <- mlr3_pred_fun(object, X = X)
}
ice.default(
object = object,
v = v,
X = X,
pred_fun = pred_fun,
BY = BY,
grid = grid,
grid_size = grid_size,
trim = trim,
strategy = strategy,
na.rm = na.rm,
n_max = n_max,
...
)
}

#' @describeIn ice Method for DALEX "explainer".
#' @export
ice.explainer <- function(object, v = v, X = object[["data"]],
29 changes: 0 additions & 29 deletions R/partial_dep.R
Original file line number Diff line number Diff line change
@@ -213,35 +213,6 @@ partial_dep.ranger <- function(object, v, X,
)
}

#' @describeIn partial_dep Method for "mlr3" models.
#' @export
partial_dep.Learner <- function(object, v, X,
pred_fun = NULL,
BY = NULL, by_size = 4L, grid = NULL, grid_size = 49L,
trim = c(0.01, 0.99),
strategy = c("uniform", "quantile"), na.rm = TRUE,
n_max = 1000L, w = NULL, ...) {
if (is.null(pred_fun)) {
pred_fun <- mlr3_pred_fun(object, X = X)
}
partial_dep.default(
object = object,
v = v,
X = X,
pred_fun = pred_fun,
BY = BY,
by_size = by_size,
grid = grid,
grid_size = grid_size,
trim = trim,
strategy = strategy,
na.rm = na.rm,
n_max = n_max,
w = w,
...
)
}

#' @describeIn partial_dep Method for DALEX "explainer".
#' @export
partial_dep.explainer <- function(object, v, X = object[["data"]],
28 changes: 0 additions & 28 deletions R/perm_importance.R
Original file line number Diff line number Diff line change
@@ -228,34 +228,6 @@ perm_importance.ranger <- function(object, X, y, v = NULL,
)
}

#' @describeIn perm_importance Method for "mlr3" models.
#' @export
perm_importance.Learner <- function(object, X, y, v = NULL,
pred_fun = NULL,
loss = "squared_error", m_rep = 4L,
agg_cols = FALSE,
normalize = FALSE, n_max = 10000L,
w = NULL, verbose = TRUE, ...) {
if (is.null(pred_fun)) {
pred_fun <- mlr3_pred_fun(object, X = X)
}
perm_importance.default(
object = object,
X = X,
y = y,
v = v,
pred_fun = pred_fun,
loss = loss,
m_rep = m_rep,
agg_cols = agg_cols,
normalize = normalize,
n_max = n_max,
w = w,
verbose = verbose,
...
)
}

#' @describeIn perm_importance Method for DALEX "explainer".
#' @export
perm_importance.explainer <- function(object,
23 changes: 0 additions & 23 deletions R/utils_input.R
Original file line number Diff line number Diff line change
@@ -114,26 +114,3 @@ prepare_y <- function(y, X, ohe = FALSE) {
list(y = prepare_pred(y, ohe = ohe), y_names = y_names)
}

#' mlr3 Helper
#'
#' Returns the prediction function of a mlr3 Learner.
#'
#' @noRd
#' @keywords internal
#'
#' @param object Learner object.
#' @param X Dataframe like object.
#'
#' @returns A function.
mlr3_pred_fun <- function(object, X) {
if ("classif" %in% object$task_type) {
# Check if probabilities are available
test_pred <- object$predict_newdata(utils::head(X))
if ("prob" %in% test_pred$predict_types) {
return(function(m, X) m$predict_newdata(X)$prob)
} else {
stop("Set lrn(..., predict_type = 'prob') to allow for probabilistic classification.")
}
}
function(m, X) m$predict_newdata(X)$response
}
18 changes: 13 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
@@ -229,7 +229,7 @@ Strongest relative interaction shown as ICE plot.

## Multivariate responses

{hstats} works also with multivariate output, see examples with
{hstats} works also with multivariate output, see examples for probabilistic classification with

- ranger,
- LightGBM, and
@@ -377,7 +377,9 @@ plot(H, normalize = FALSE, squared = FALSE, facet_scales = "free_y", ncol = 1)

![](man/figures/xgboost.svg)

### (Non-probabilistic) classification works as well
### Non-probabilistic classification

When predictions are factor levels, {hstats} uses internal one-hot-encoding.

```r
library(ranger)
@@ -404,7 +406,7 @@ partial_dep(fit, v = "Petal.Length", X = train, BY = "Petal.Width") |>

## Meta-learning packages

Here, we provide some working examples for "tidymodels", "caret", and "mlr3".
Here, we provide examples for {tidymodels}, {caret}, and {mlr3}.

### tidymodels

@@ -478,8 +480,14 @@ fit_rf$train(task_iris)
s <- hstats(fit_rf, X = iris[, -5])
plot(s)

# Permutation importance
perm_importance(fit_rf, X = iris, y = "Species", loss = "mlogloss") |>
# Permutation importance (probabilistic using multi-logloss)
p <- perm_importance(
fit_rf, X = iris, y = "Species", loss = "mlogloss", predict_type = "prob"
)
plot(p)

# Non-probabilistic using classification error
perm_importance(fit_rf, X = iris, y = "Species", loss = "classification_error") |>
plot()
```

16 changes: 0 additions & 16 deletions man/average_loss.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 0 additions & 19 deletions man/hstats.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 0 additions & 18 deletions man/ice.Rd
20 changes: 0 additions & 20 deletions man/partial_dep.Rd
19 changes: 0 additions & 19 deletions man/perm_importance.Rd
2 changes: 1 addition & 1 deletion packaging.R
Original file line number Diff line number Diff line change
@@ -15,7 +15,7 @@ library(usethis)
use_description(
fields = list(
Title = "Interaction Statistics",
Version = "1.1.1",
Version = "1.1.2",
Description = "Fast, model-agnostic implementation of different H-statistics
introduced by Jerome H. Friedman and Bogdan E. Popescu (2008) <doi:10.1214/07-AOAS148>.
These statistics quantify interaction strength per feature, feature pair,