Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature_importance for multiinput models with data as a list of arrays #142

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@ Suggests:
jsonlite,
knitr,
rmarkdown,
covr
covr,
abind,
data.table,
keras
URL: https://ModelOriented.github.io/ingredients/, https://github.com/ModelOriented/ingredients
BugReports: https://github.com/ModelOriented/ingredients/issues
VignetteBuilder: knitr
223 changes: 218 additions & 5 deletions R/feature_importance.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#' Find more details in the \href{https://ema.drwhy.ai/featureImportance.html}{Feature Importance Chapter}.
#'
#' @param x an explainer created with function \code{DALEX::explain()}, or a model to be explained.
#' @param data validation dataset, will be extracted from \code{x} if it's an explainer
#' @param data validation dataset, will be extracted from \code{x} if it's an explainer. Can be a list of arrays for multiinput models.
#' NOTE: It is best when target variable is not present in the \code{data}
#' @param predict_function predict function, will be extracted from \code{x} if it's an explainer
#' @param y true labels for \code{data}, will be extracted from \code{x} if it's an explainer
Expand All @@ -20,10 +20,14 @@
#' If \code{NULL} then variable importance will be calculated on whole dataset (no sampling).
#' @param n_sample alias for \code{N} held for backwards compatibility. number of observations that should be sampled for calculation of variable importance.
#' @param B integer, number of permutation rounds to perform on each variable. By default it's \code{10}.
#' @param variables vector of variables. If \code{NULL} then variable importance will be tested for each variable from the \code{data} separately. By default \code{NULL}
#' @param variable_groups list of variables names vectors. This is for testing joint variable importance.
#' @param variables vector of variables or a list of vectors for multiinput models. If \code{NULL} then variable importance will be tested for each variable from the \code{data} separately. By default \code{NULL}
#' @param variable_groups list of variables names vectors or a list of vectors for multiinput models. This is for testing joint variable importance.
#' If \code{NULL} then variable importance will be tested separately for \code{variables}.
#' By default \code{NULL}. If specified then it will override \code{variables}
#' By default \code{NULL}. If specified then it will override \code{variables}, \code{perm_dim} and \code{comb_dims}.
#' @param perm_dim the dimensions to perform the permutations when \code{data} is a 3d array (e.g. [case, time, variable]).
#' If \code{perm_dim = 2:3}, it calculates the importance for each variable in the 2nd and 3rd dimensions.
#' For multiinput models, a list of dimensions in the same order than in \code{data}. If \code{NULL}, the default, take all dimensions except the first one (i.e. rows) which correspond to cases.
#' @param comb_dims if \code{TRUE}, do the permutations for each combination of the levels of the variables from 2nd and 3rd dimensions for input data with 3 dimensions. By default \code{FALSE}
#'
#' @references Explanatory Model Analysis. Explore, Explain, and Examine Predictive Models. \url{https://ema.drwhy.ai/}
#'
Expand Down Expand Up @@ -163,13 +167,33 @@ feature_importance.default <- function(x,
B = 10,
variables = NULL,
N = n_sample,
variable_groups = NULL) {
variable_groups = NULL,
perm_dim = NULL,
comb_dims = FALSE) {
# start: checks for arguments
## if (is.null(N) & methods::hasArg("n_sample")) {
## warning("n_sample is deprecated, please update ingredients and DALEX packages to use N instead")
## N <- list(...)[["n_sample"]]
## }

if (inherits(data, "list")) {
res <- feature_importance.multiinput(x = x,
data = data,
y = y,
predict_function = predict_function,
loss_function = loss_function,
...,
label = label,
type = type,
B = B,
variables = variables,
N = N,
variable_groups = variable_groups,
perm_dim = perm_dim,
comb_dims = comb_dims)
return (res)
}

if (!is.null(variable_groups)) {
if (!inherits(variable_groups, "list")) stop("variable_groups should be of class list")

Expand Down Expand Up @@ -279,3 +303,192 @@ feature_importance.default <- function(x,
res
}


feature_importance.multiinput <- function(x,
data,
y,
predict_function = predict,
loss_function = DALEX::loss_root_mean_square,
...,
label = class(x)[1],
type = c("raw", "ratio", "difference"),
B = 10,
variables = NULL,
N = NULL,
variable_groups = NULL,
perm_dim = NULL,
comb_dims = FALSE) {
# start: checks for arguments
## if (is.null(N) & methods::hasArg("n_sample")) {
## warning("n_sample is deprecated, please update ingredients and DALEX packages to use N instead")
## N <- list(...)[["n_sample"]]
## }

if (is.null(perm_dim) | !is.null(variable_groups)) {
perm_dim <- lapply(data, function(d) setNames(2:length(dim(d)), nm=names(dimnames(d))[-1])) # all dims except first (rows) which correspond to cases
}

# Variables for the dimensions to permute
varsL <- mapply(function(d, dim) {
dimnames(d)[dim]
}, d=data, dim=perm_dim, SIMPLIFY=FALSE)

if (!is.null(variable_groups)) {
if (!inherits(variable_groups, "list") | !all(sapply(variable_groups, inherits, "list")))
stop("variable_groups should be of class list contining lists for each data input")

wrong_names <- !all(mapply(function(variable_set, vars) {
all(unlist(variable_set) %in% unlist(vars))
}, variable_set=variable_groups, vars=varsL[names(variable_groups)]))

if (wrong_names) stop("You have passed wrong variables names in variable_groups argument")
if (!all(unlist(sapply(variable_groups, sapply, sapply, class)) == "character"))
stop("Elements of variable_groups argument should be of class character")
if (any(sapply(sapply(variable_groups, names), is.null))) {
warning("You have passed an unnamed list. The names of variable groupings will be created from variables names.")
# Adding names for variable_groups if not specified
variable_groups <- lapply(variable_groups, function(variable_sets_input) {
if (is.null(names(variable_sets_input))) {
group_names <- sapply(variable_sets_input, function(v) paste(paste(names(v), sapply(v, paste, collapse="; "), sep="."), collapse = " | "))
names(variable_sets_input) <- group_names
}
variable_sets_input
})
}
}
type <- match.arg(type)
B <- max(1, round(B))

# if `variable_groups` are not specified, then extract from `variables`
if (is.null(variable_groups)) {
# if `variables` are not specified, then extract from data
if (is.null(variables)) {
variables <- lapply(varsL, function(vars) {
if (comb_dims) {
vars <- expand.grid(vars, stringsAsFactors=FALSE, KEEP.OUT.ATTRS=FALSE) # All combinations for all dimensions in a dataset
rownames(vars) <- apply(vars, 1, function(v) paste(v, collapse="|"))
vars <- split(vars, rownames(vars))
vars <- lapply(vars, as.list)
} else {
vars <- mapply(function(dim_var, dim_names) {
v <- lapply(dim_var, function(v) setNames(list(v), dim_names))
setNames(v, nm = dim_var)
}, dim_var=vars, dim_names=names(vars), SIMPLIFY=FALSE)
vars <- do.call(c, vars)
}
vars
})
}
} else {
variables <- variable_groups
}

# start: actual calculations
# one permutation round: subsample data, permute variables and compute losses
n_cases <- unique(sapply(data, nrow))
if (length(n_cases) > 1) {
stop("Number of cases among inputs in data are different.")
}
sampled_rows <- 1:n_cases

loss_after_permutation <- function() {
if (!is.null(N)) {
if (N < n_cases) {
# sample N points
sampled_rows <- sample(1:n_cases, N)
}
}
sampled_data <- lapply(data, function(d) {
if (length(dim(d)) == 2) {
sampled_data <- d[sampled_rows, , drop = FALSE]
} else if (length(dim(d)) == 3) {
sampled_data <- d[sampled_rows, , , drop = FALSE]
}
sampled_data
})
observed <- y[sampled_rows]
# loss on the full model or when outcomes are permuted
loss_full <- loss_function(observed, predict_function(x, sampled_data))
loss_baseline <- loss_function(sample(observed), predict_function(x, sampled_data))
# loss upon dropping a single variable (or a single group)
loss_featuresL <- mapply(function(d, vars, input_data) {
loss_features <- sapply(vars, function(variables_set) {
ndf <- d
dim_perm <- names(dimnames(ndf)) %in% names(variables_set)
dims <- list()
for (i in 2:length(dim_perm)) { # First dimension for cases
if (dim_perm[i]) {
dims[[i]] <- variables_set[[names(dimnames(ndf))[i]]]
} else {
dims[[i]] <- 1:dim(ndf)[i]
}
}
names(dims) <- names(dimnames(ndf))

if (length(dim_perm) == 2) {
ndf[, dims[[2]]] <- ndf[sample(1:nrow(ndf)), dims[[2]]]
} else if (length(dim_perm) == 3) {
ndf[, dims[[2]], dims[[3]]] <- ndf[sample(1:nrow(ndf)), dims[[2]], dims[[3]]]
} else {
stop("Dimensions for this kind of data is not implemented but should be easy. Contact with the developers.")
}
sampled_data[[input_data]] <- ndf
predicted <- predict_function(x, sampled_data)
loss_function(observed, predicted)
})
}, d=sampled_data, vars=variables, input_data=seq_along(sampled_data), SIMPLIFY=FALSE)

unlist(c("_full_model_" = loss_full, loss_featuresL, "_baseline_" = loss_baseline))
}
# permute B times, collect results into single matrix
raw <- replicate(B, loss_after_permutation())

# main result df with dropout_loss averages, with _full_model_ first and _baseline_ last
res <- apply(raw, 1, mean)
res_baseline <- res["_baseline_"]
res_full <- res["_full_model_"]
res <- sort(res[!names(res) %in% c("_full_model_", "_baseline_")])
res <- data.frame(
variable = gsub(paste0("^(", paste(names(data), collapse="|"), ")\\."), "\\1: ", c("_full_model_", names(res), "_baseline_")),
permutation = 0,
dropout_loss = c(res_full, res, res_baseline),
label = label,
row.names = NULL
)
if (type == "ratio") {
res$dropout_loss = res$dropout_loss / res_full
}
if (type == "difference") {
res$dropout_loss = res$dropout_loss - res_full
}


# record details of permutations
attr(res, "B") <- B

if (B > 1) {
res_B <- data.frame(
variable = gsub(paste0("^(", paste(names(data), collapse="|"), ")\\."), "\\1: ", rep(rownames(raw), ncol(raw))),
permutation = rep(seq_len(B), each = nrow(raw)),
dropout_loss = as.vector(raw),
label = label
)

# here mean full model is used (full model for given permutation is an option)
if (type == "ratio") {
res_B$dropout_loss = res_B$dropout_loss / res_full
}
if (type == "difference") {
res_B$dropout_loss = res_B$dropout_loss - res_full
}

res <- rbind(res, res_B)
}

class(res) <- c("feature_importance_explainer", "data.frame")

if(!is.null(attr(loss_function, "loss_name"))) {
attr(res, "loss_name") <- attr(loss_function, "loss_name")
}
res
}
18 changes: 13 additions & 5 deletions man/feature_importance.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading