From ecaa5a9e2580570687a3cf59f94c82f4b0e601fa Mon Sep 17 00:00:00 2001 From: Joan Maspons Date: Sat, 19 Mar 2022 19:58:03 +0100 Subject: [PATCH 1/8] feature_importance for multiinput models with data as a list of datasets. Datasets can be 2d or 3d arrays --- R/feature_importance.R | 226 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 221 insertions(+), 5 deletions(-) diff --git a/R/feature_importance.R b/R/feature_importance.R index fad1868..72aa8d6 100644 --- a/R/feature_importance.R +++ b/R/feature_importance.R @@ -6,7 +6,7 @@ #' Find more details in the \href{https://ema.drwhy.ai/featureImportance.html}{Feature Importance Chapter}. #' #' @param x an explainer created with function \code{DALEX::explain()}, or a model to be explained. -#' @param data validation dataset, will be extracted from \code{x} if it's an explainer +#' @param data validation dataset, will be extracted from \code{x} if it's an explainer. Can be a list of arrays for multiinput models. #' NOTE: It is best when target variable is not present in the \code{data} #' @param predict_function predict function, will be extracted from \code{x} if it's an explainer #' @param y true labels for \code{data}, will be extracted from \code{x} if it's an explainer @@ -20,10 +20,14 @@ #' If \code{NULL} then variable importance will be calculated on whole dataset (no sampling). #' @param n_sample alias for \code{N} held for backwards compatibility. number of observations that should be sampled for calculation of variable importance. #' @param B integer, number of permutation rounds to perform on each variable. By default it's \code{10}. -#' @param variables vector of variables. If \code{NULL} then variable importance will be tested for each variable from the \code{data} separately. By default \code{NULL} -#' @param variable_groups list of variables names vectors. This is for testing joint variable importance. +#' @param variables vector of variables or a list of vectors for multiinput models. If \code{NULL} then variable importance will be tested for each variable from the \code{data} separately. By default \code{NULL} +#' @param variable_groups list of variables names vectors or a list of vectors for multiinput models. This is for testing joint variable importance. #' If \code{NULL} then variable importance will be tested separately for \code{variables}. -#' By default \code{NULL}. If specified then it will override \code{variables} +#' By default \code{NULL}. If specified then it will override \code{variables}, \code{permDim} and \code{combDims}. +#' @param permDim the dimensions to perform the permutations when \code{data} is a 3d array (e.g. [case, time, variable]). +#' If \code{permDim = 2:3}, it calculates the importance for each variable in the 2nd and 3rd dimensions. +#' For multiinput models, a list of dimensions in the same order than in \code{data}. If \code{NULL}, the default, take all dimensions except the first one (i.e. rows) which correspond to cases. +#' @param combDims if \code{TRUE}, do the permutations for each combination of the levels of the variables from 2nd and 3rd dimensions for input data with 3 dimensions. By default \code{FALSE} #' #' @references Explanatory Model Analysis. Explore, Explain, and Examine Predictive Models. \url{https://ema.drwhy.ai/} #' @@ -163,13 +167,34 @@ feature_importance.default <- function(x, B = 10, variables = NULL, N = n_sample, - variable_groups = NULL) { + variable_groups = NULL, + permDim = NULL, + combDims = FALSE) { # start: checks for arguments ## if (is.null(N) & methods::hasArg("n_sample")) { ## warning("n_sample is deprecated, please update ingredients and DALEX packages to use N instead") ## N <- list(...)[["n_sample"]] ## } + if (inherits(data, "list")){ + res <- feature_importance.multiinput(x = x, + data = data, + y = y, + predict_function = predict_function, + loss_function = loss_function, + ..., + label = label, + type = type, + n_sample = n_sample, + B = B, + variables = variables, + N = n_sample, + variable_groups = variable_groups, + permDim = permDim, + combDims = combDims) + return (res) + } + if (!is.null(variable_groups)) { if (!inherits(variable_groups, "list")) stop("variable_groups should be of class list") @@ -279,3 +304,194 @@ feature_importance.default <- function(x, res } + +feature_importance.multiinput <- function(x, + data, + y, + predict_function = predict, + loss_function = DALEX::loss_root_mean_square, + ..., + label = class(x)[1], + type = c("raw", "ratio", "difference"), + n_sample = NULL, + B = 10, + variables = NULL, + N = n_sample, + variable_groups = NULL, + permDim = NULL, + combDims = FALSE) { + # start: checks for arguments + ## if (is.null(N) & methods::hasArg("n_sample")) { + ## warning("n_sample is deprecated, please update ingredients and DALEX packages to use N instead") + ## N <- list(...)[["n_sample"]] + ## } + + if (is.null(permDim) | !is.null(variable_groups)){ + permDim<- lapply(data, function(d) setNames(2:length(dim(d)), nm=names(dimnames(d))[-1])) # all dims except first (rows) which correspond to cases + } + + # Variables for the dimensions to permute + varsL<- mapply(function(d, dim){ + dimnames(d)[dim] + }, d=data, dim=permDim, SIMPLIFY=FALSE) + + if (!is.null(variable_groups)) { + if (!inherits(variable_groups, "list") | !all(sapply(variable_groups, inherits, "list"))) + stop("variable_groups should be of class list contining lists for each data input") + + wrong_names <- !all(mapply(function(variable_set, vars) { + all(unlist(variable_set) %in% unlist(vars)) + }, variable_set=variable_groups, vars=varsL[names(variable_groups)])) + + if (wrong_names) stop("You have passed wrong variables names in variable_groups argument") + if (!all(unlist(sapply(variable_groups, sapply, sapply, class)) == "character")) + stop("Elements of variable_groups argument should be of class character") + if (any(sapply(sapply(variable_groups, names), is.null))){ + warning("You have passed an unnamed list. The names of variable groupings will be created from variables names.") + # Adding names for variable_groups if not specified + # names(variable_groups) <- + variable_groups<- lapply(variable_groups, function(variable_sets_input) { + if (is.null(names(variable_sets_input))){ + groupNames<- sapply(variable_sets_input, function(v) paste(paste(names(v), v, sep="."), collapse = "; ")) + names(variable_sets_input) <- groupNames + } + variable_sets_input + }) + } + } + type <- match.arg(type) + B <- max(1, round(B)) + + # if `variable_groups` are not specified, then extract from `variables` + if (is.null(variable_groups)) { + # if `variables` are not specified, then extract from data + if (is.null(variables)) { + variables <- lapply(varsL, function(vars){ + if (combDims){ + vars <- expand.grid(vars, stringsAsFactors=FALSE, KEEP.OUT.ATTRS=FALSE) # All combinations for all dimensions in a dataset + rownames(vars) <- apply(vars, 1, function(v) paste(v, collapse="|")) + vars<- split(vars, rownames(vars)) + vars<- lapply(vars, as.list) + } else { + vars <- mapply(function(dimVar, dimNames) { + v<- lapply(dimVar, function(v) setNames(list(v), dimNames)) + setNames(v, nm = dimVar) + }, dimVar=vars, dimNames=names(vars), SIMPLIFY=FALSE) + vars <- do.call(c, vars) + } + vars + }) + } + } else { + variables <- variable_groups + } + + # start: actual calculations + # one permutation round: subsample data, permute variables and compute losses + nCases<- unique(sapply(data, nrow)) + if (length(nCases) > 1){ + stop("Number of cases among inputs in data are different.") + } + sampled_rows <- 1:nCases + + loss_after_permutation <- function() { + if (!is.null(N)) { + if (N < nCases) { + # sample N points + sampled_rows <- sample(1:nCases, N) + } + } + sampled_data <- lapply(data, function(d){ + if (length(dim(d)) == 2) { + sampled_data <- d[sampled_rows, , drop = FALSE] + } else if (length(dim(d)) == 3) { + sampled_data <- d[sampled_rows, , , drop = FALSE] + } + sampled_data + }) + observed <- y[sampled_rows] + # loss on the full model or when outcomes are permuted + loss_full <- loss_function(observed, predict_function(x, sampled_data)) + loss_baseline <- loss_function(sample(observed), predict_function(x, sampled_data)) + # loss upon dropping a single variable (or a single group) + loss_featuresL <- mapply(function(d, vars, inputData){ + loss_features <- sapply(vars, function(variables_set) { + ndf <- d + dimPerm <- names(dimnames(ndf)) %in% names(variables_set) + dims <- list() + for (i in 2:length(dimPerm)){ # First dimension for cases + if (dimPerm[i]){ + dims[[i]] <- variables_set[[names(dimnames(ndf))[i]]] + } else { + dims[[i]] <- 1:dim(ndf)[i] + } + } + names(dims) <- names(dimnames(ndf)) + + if (length(dimPerm) == 2){ + ndf[, dims[[2]]] <- ndf[sample(1:nrow(ndf)), dims[[2]]] + } else if (length(dimPerm) == 3){ + ndf[, dims[[2]], dims[[3]]] <- ndf[sample(1:nrow(ndf)), dims[[2]], dims[[3]]] + } else { + stop("Dimensions for this kind of data is not implemented but should be easy. Contact with the developers.") + } + sampled_data[[inputData]] <- ndf + predicted <- predict_function(x, sampled_data) + loss_function(observed, predicted) + }) + }, d=sampled_data, vars=variables, inputData=seq_along(sampled_data), SIMPLIFY=FALSE) + + unlist(c("_full_model_" = loss_full, loss_featuresL, "_baseline_" = loss_baseline)) + } + # permute B times, collect results into single matrix + raw <- replicate(B, loss_after_permutation()) + + # main result df with dropout_loss averages, with _full_model_ first and _baseline_ last + res <- apply(raw, 1, mean) + res_baseline <- res["_baseline_"] + res_full <- res["_full_model_"] + res <- sort(res[!names(res) %in% c("_full_model_", "_baseline_")]) + res <- data.frame( + variable = c("_full_model_", names(res), "_baseline_"), + permutation = 0, + dropout_loss = c(res_full, res, res_baseline), + label = label, + row.names = NULL + ) + if (type == "ratio") { + res$dropout_loss = res$dropout_loss / res_full + } + if (type == "difference") { + res$dropout_loss = res$dropout_loss - res_full + } + + + # record details of permutations + attr(res, "B") <- B + + if (B > 1) { + res_B <- data.frame( + variable = rep(rownames(raw), ncol(raw)), + permutation = rep(seq_len(B), each = nrow(raw)), + dropout_loss = as.vector(raw), + label = label + ) + + # here mean full model is used (full model for given permutation is an option) + if (type == "ratio") { + res_B$dropout_loss = res_B$dropout_loss / res_full + } + if (type == "difference") { + res_B$dropout_loss = res_B$dropout_loss - res_full + } + + res <- rbind(res, res_B) + } + + class(res) <- c("feature_importance_explainer", "data.frame") + + if(!is.null(attr(loss_function, "loss_name"))) { + attr(res, "loss_name") <- attr(loss_function, "loss_name") + } + res +} From 54f24ff3213876c1eea806b2ac9cabaf71224fea Mon Sep 17 00:00:00 2001 From: Joan Maspons Date: Mon, 21 Mar 2022 11:44:33 +0100 Subject: [PATCH 2/8] Improve variable_groups names + code style --- R/feature_importance.R | 47 +++++++++++++++++++++--------------------- 1 file changed, 23 insertions(+), 24 deletions(-) diff --git a/R/feature_importance.R b/R/feature_importance.R index 72aa8d6..0b11799 100644 --- a/R/feature_importance.R +++ b/R/feature_importance.R @@ -176,7 +176,7 @@ feature_importance.default <- function(x, ## N <- list(...)[["n_sample"]] ## } - if (inherits(data, "list")){ + if (inherits(data, "list")) { res <- feature_importance.multiinput(x = x, data = data, y = y, @@ -326,12 +326,12 @@ feature_importance.multiinput <- function(x, ## N <- list(...)[["n_sample"]] ## } - if (is.null(permDim) | !is.null(variable_groups)){ - permDim<- lapply(data, function(d) setNames(2:length(dim(d)), nm=names(dimnames(d))[-1])) # all dims except first (rows) which correspond to cases + if (is.null(permDim) | !is.null(variable_groups)) { + permDim <- lapply(data, function(d) setNames(2:length(dim(d)), nm=names(dimnames(d))[-1])) # all dims except first (rows) which correspond to cases } # Variables for the dimensions to permute - varsL<- mapply(function(d, dim){ + varsL <- mapply(function(d, dim) { dimnames(d)[dim] }, d=data, dim=permDim, SIMPLIFY=FALSE) @@ -346,13 +346,12 @@ feature_importance.multiinput <- function(x, if (wrong_names) stop("You have passed wrong variables names in variable_groups argument") if (!all(unlist(sapply(variable_groups, sapply, sapply, class)) == "character")) stop("Elements of variable_groups argument should be of class character") - if (any(sapply(sapply(variable_groups, names), is.null))){ + if (any(sapply(sapply(variable_groups, names), is.null))) { warning("You have passed an unnamed list. The names of variable groupings will be created from variables names.") # Adding names for variable_groups if not specified - # names(variable_groups) <- - variable_groups<- lapply(variable_groups, function(variable_sets_input) { - if (is.null(names(variable_sets_input))){ - groupNames<- sapply(variable_sets_input, function(v) paste(paste(names(v), v, sep="."), collapse = "; ")) + variable_groups <- lapply(variable_groups, function(variable_sets_input) { + if (is.null(names(variable_sets_input))) { + groupNames <- sapply(variable_sets_input, function(v) paste(paste(names(v), sapply(v, paste, collapse="; "), sep="."), collapse = " | ")) names(variable_sets_input) <- groupNames } variable_sets_input @@ -366,15 +365,15 @@ feature_importance.multiinput <- function(x, if (is.null(variable_groups)) { # if `variables` are not specified, then extract from data if (is.null(variables)) { - variables <- lapply(varsL, function(vars){ - if (combDims){ + variables <- lapply(varsL, function(vars) { + if (combDims) { vars <- expand.grid(vars, stringsAsFactors=FALSE, KEEP.OUT.ATTRS=FALSE) # All combinations for all dimensions in a dataset rownames(vars) <- apply(vars, 1, function(v) paste(v, collapse="|")) - vars<- split(vars, rownames(vars)) - vars<- lapply(vars, as.list) + vars <- split(vars, rownames(vars)) + vars <- lapply(vars, as.list) } else { vars <- mapply(function(dimVar, dimNames) { - v<- lapply(dimVar, function(v) setNames(list(v), dimNames)) + v <- lapply(dimVar, function(v) setNames(list(v), dimNames)) setNames(v, nm = dimVar) }, dimVar=vars, dimNames=names(vars), SIMPLIFY=FALSE) vars <- do.call(c, vars) @@ -388,8 +387,8 @@ feature_importance.multiinput <- function(x, # start: actual calculations # one permutation round: subsample data, permute variables and compute losses - nCases<- unique(sapply(data, nrow)) - if (length(nCases) > 1){ + nCases <- unique(sapply(data, nrow)) + if (length(nCases) > 1) { stop("Number of cases among inputs in data are different.") } sampled_rows <- 1:nCases @@ -401,7 +400,7 @@ feature_importance.multiinput <- function(x, sampled_rows <- sample(1:nCases, N) } } - sampled_data <- lapply(data, function(d){ + sampled_data <- lapply(data, function(d) { if (length(dim(d)) == 2) { sampled_data <- d[sampled_rows, , drop = FALSE] } else if (length(dim(d)) == 3) { @@ -414,13 +413,13 @@ feature_importance.multiinput <- function(x, loss_full <- loss_function(observed, predict_function(x, sampled_data)) loss_baseline <- loss_function(sample(observed), predict_function(x, sampled_data)) # loss upon dropping a single variable (or a single group) - loss_featuresL <- mapply(function(d, vars, inputData){ + loss_featuresL <- mapply(function(d, vars, inputData) { loss_features <- sapply(vars, function(variables_set) { ndf <- d dimPerm <- names(dimnames(ndf)) %in% names(variables_set) dims <- list() - for (i in 2:length(dimPerm)){ # First dimension for cases - if (dimPerm[i]){ + for (i in 2:length(dimPerm)) { # First dimension for cases + if (dimPerm[i]) { dims[[i]] <- variables_set[[names(dimnames(ndf))[i]]] } else { dims[[i]] <- 1:dim(ndf)[i] @@ -428,9 +427,9 @@ feature_importance.multiinput <- function(x, } names(dims) <- names(dimnames(ndf)) - if (length(dimPerm) == 2){ + if (length(dimPerm) == 2) { ndf[, dims[[2]]] <- ndf[sample(1:nrow(ndf)), dims[[2]]] - } else if (length(dimPerm) == 3){ + } else if (length(dimPerm) == 3) { ndf[, dims[[2]], dims[[3]]] <- ndf[sample(1:nrow(ndf)), dims[[2]], dims[[3]]] } else { stop("Dimensions for this kind of data is not implemented but should be easy. Contact with the developers.") @@ -452,7 +451,7 @@ feature_importance.multiinput <- function(x, res_full <- res["_full_model_"] res <- sort(res[!names(res) %in% c("_full_model_", "_baseline_")]) res <- data.frame( - variable = c("_full_model_", names(res), "_baseline_"), + variable = gsub(paste0("^(", paste(names(data), collapse="|"), ")\\."), "\\1: ", c("_full_model_", names(res), "_baseline_")), permutation = 0, dropout_loss = c(res_full, res, res_baseline), label = label, @@ -471,7 +470,7 @@ feature_importance.multiinput <- function(x, if (B > 1) { res_B <- data.frame( - variable = rep(rownames(raw), ncol(raw)), + variable = gsub(paste0("^(", paste(names(data), collapse="|"), ")\\."), "\\1: ", rep(rownames(raw), ncol(raw))), permutation = rep(seq_len(B), each = nrow(raw)), dropout_loss = as.vector(raw), label = label From 602de078035005935b4c2a61c801e7b712085a43 Mon Sep 17 00:00:00 2001 From: Joan Maspons Date: Mon, 21 Mar 2022 11:46:05 +0100 Subject: [PATCH 3/8] update docs --- man/feature_importance.Rd | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/man/feature_importance.Rd b/man/feature_importance.Rd index b67c8a6..6326d3d 100644 --- a/man/feature_importance.Rd +++ b/man/feature_importance.Rd @@ -34,7 +34,9 @@ feature_importance(x, ...) B = 10, variables = NULL, N = n_sample, - variable_groups = NULL + variable_groups = NULL, + permDim = NULL, + combDims = FALSE ) } \arguments{ @@ -52,23 +54,29 @@ while "difference" returns \code{drop_loss - drop_loss_full_model}} \item{B}{integer, number of permutation rounds to perform on each variable. By default it's \code{10}.} -\item{variables}{vector of variables. If \code{NULL} then variable importance will be tested for each variable from the \code{data} separately. By default \code{NULL}} +\item{variables}{vector of variables or a list of vectors for multiinput models. If \code{NULL} then variable importance will be tested for each variable from the \code{data} separately. By default \code{NULL}} -\item{variable_groups}{list of variables names vectors. This is for testing joint variable importance. +\item{variable_groups}{list of variables names vectors or a list of vectors for multiinput models. This is for testing joint variable importance. If \code{NULL} then variable importance will be tested separately for \code{variables}. -By default \code{NULL}. If specified then it will override \code{variables}} +By default \code{NULL}. If specified then it will override \code{variables}, \code{permDim} and \code{combDims}.} \item{N}{number of observations that should be sampled for calculation of variable importance. If \code{NULL} then variable importance will be calculated on whole dataset (no sampling).} \item{label}{name of the model. By default it's extracted from the \code{class} attribute of the model} -\item{data}{validation dataset, will be extracted from \code{x} if it's an explainer +\item{data}{validation dataset, will be extracted from \code{x} if it's an explainer. Can be a list of arrays for multiinput models. NOTE: It is best when target variable is not present in the \code{data}} \item{y}{true labels for \code{data}, will be extracted from \code{x} if it's an explainer} \item{predict_function}{predict function, will be extracted from \code{x} if it's an explainer} + +\item{permDim}{the dimensions to perform the permutations when \code{data} is a 3d array (e.g. [case, time, variable]). +If \code{permDim = 2:3}, it calculates the importance for each variable in the 2nd and 3rd dimensions. +For multiinput models, a list of dimensions in the same order than in \code{data}. If \code{NULL}, the default, take all dimensions except the first one (i.e. rows) which correspond to cases.} + +\item{combDims}{if \code{TRUE}, do the permutations for each combination of the levels of the variables from 2nd and 3rd dimensions for input data with 3 dimensions. By default \code{FALSE}} } \value{ an object of the class \code{feature_importance} From c521d428b776d425e3fed5b54c46688e41cd2a58 Mon Sep 17 00:00:00 2001 From: Joan Maspons Date: Mon, 21 Mar 2022 15:42:19 +0100 Subject: [PATCH 4/8] CamelCase -> snake_case --- R/feature_importance.R | 66 +++++++++++++++++++++--------------------- 1 file changed, 33 insertions(+), 33 deletions(-) diff --git a/R/feature_importance.R b/R/feature_importance.R index 0b11799..d65b639 100644 --- a/R/feature_importance.R +++ b/R/feature_importance.R @@ -23,11 +23,11 @@ #' @param variables vector of variables or a list of vectors for multiinput models. If \code{NULL} then variable importance will be tested for each variable from the \code{data} separately. By default \code{NULL} #' @param variable_groups list of variables names vectors or a list of vectors for multiinput models. This is for testing joint variable importance. #' If \code{NULL} then variable importance will be tested separately for \code{variables}. -#' By default \code{NULL}. If specified then it will override \code{variables}, \code{permDim} and \code{combDims}. -#' @param permDim the dimensions to perform the permutations when \code{data} is a 3d array (e.g. [case, time, variable]). -#' If \code{permDim = 2:3}, it calculates the importance for each variable in the 2nd and 3rd dimensions. +#' By default \code{NULL}. If specified then it will override \code{variables}, \code{perm_dim} and \code{comb_dims}. +#' @param perm_dim the dimensions to perform the permutations when \code{data} is a 3d array (e.g. [case, time, variable]). +#' If \code{perm_dim = 2:3}, it calculates the importance for each variable in the 2nd and 3rd dimensions. #' For multiinput models, a list of dimensions in the same order than in \code{data}. If \code{NULL}, the default, take all dimensions except the first one (i.e. rows) which correspond to cases. -#' @param combDims if \code{TRUE}, do the permutations for each combination of the levels of the variables from 2nd and 3rd dimensions for input data with 3 dimensions. By default \code{FALSE} +#' @param comb_dims if \code{TRUE}, do the permutations for each combination of the levels of the variables from 2nd and 3rd dimensions for input data with 3 dimensions. By default \code{FALSE} #' #' @references Explanatory Model Analysis. Explore, Explain, and Examine Predictive Models. \url{https://ema.drwhy.ai/} #' @@ -168,8 +168,8 @@ feature_importance.default <- function(x, variables = NULL, N = n_sample, variable_groups = NULL, - permDim = NULL, - combDims = FALSE) { + perm_dim = NULL, + comb_dims = FALSE) { # start: checks for arguments ## if (is.null(N) & methods::hasArg("n_sample")) { ## warning("n_sample is deprecated, please update ingredients and DALEX packages to use N instead") @@ -190,8 +190,8 @@ feature_importance.default <- function(x, variables = variables, N = n_sample, variable_groups = variable_groups, - permDim = permDim, - combDims = combDims) + perm_dim = perm_dim, + comb_dims = comb_dims) return (res) } @@ -318,22 +318,22 @@ feature_importance.multiinput <- function(x, variables = NULL, N = n_sample, variable_groups = NULL, - permDim = NULL, - combDims = FALSE) { + perm_dim = NULL, + comb_dims = FALSE) { # start: checks for arguments ## if (is.null(N) & methods::hasArg("n_sample")) { ## warning("n_sample is deprecated, please update ingredients and DALEX packages to use N instead") ## N <- list(...)[["n_sample"]] ## } - if (is.null(permDim) | !is.null(variable_groups)) { - permDim <- lapply(data, function(d) setNames(2:length(dim(d)), nm=names(dimnames(d))[-1])) # all dims except first (rows) which correspond to cases + if (is.null(perm_dim) | !is.null(variable_groups)) { + perm_dim <- lapply(data, function(d) setNames(2:length(dim(d)), nm=names(dimnames(d))[-1])) # all dims except first (rows) which correspond to cases } # Variables for the dimensions to permute varsL <- mapply(function(d, dim) { dimnames(d)[dim] - }, d=data, dim=permDim, SIMPLIFY=FALSE) + }, d=data, dim=perm_dim, SIMPLIFY=FALSE) if (!is.null(variable_groups)) { if (!inherits(variable_groups, "list") | !all(sapply(variable_groups, inherits, "list"))) @@ -351,8 +351,8 @@ feature_importance.multiinput <- function(x, # Adding names for variable_groups if not specified variable_groups <- lapply(variable_groups, function(variable_sets_input) { if (is.null(names(variable_sets_input))) { - groupNames <- sapply(variable_sets_input, function(v) paste(paste(names(v), sapply(v, paste, collapse="; "), sep="."), collapse = " | ")) - names(variable_sets_input) <- groupNames + group_names <- sapply(variable_sets_input, function(v) paste(paste(names(v), sapply(v, paste, collapse="; "), sep="."), collapse = " | ")) + names(variable_sets_input) <- group_names } variable_sets_input }) @@ -366,16 +366,16 @@ feature_importance.multiinput <- function(x, # if `variables` are not specified, then extract from data if (is.null(variables)) { variables <- lapply(varsL, function(vars) { - if (combDims) { + if (comb_dims) { vars <- expand.grid(vars, stringsAsFactors=FALSE, KEEP.OUT.ATTRS=FALSE) # All combinations for all dimensions in a dataset rownames(vars) <- apply(vars, 1, function(v) paste(v, collapse="|")) vars <- split(vars, rownames(vars)) vars <- lapply(vars, as.list) } else { - vars <- mapply(function(dimVar, dimNames) { - v <- lapply(dimVar, function(v) setNames(list(v), dimNames)) - setNames(v, nm = dimVar) - }, dimVar=vars, dimNames=names(vars), SIMPLIFY=FALSE) + vars <- mapply(function(dim_var, dim_names) { + v <- lapply(dim_var, function(v) setNames(list(v), dim_names)) + setNames(v, nm = dim_var) + }, dim_var=vars, dim_names=names(vars), SIMPLIFY=FALSE) vars <- do.call(c, vars) } vars @@ -387,17 +387,17 @@ feature_importance.multiinput <- function(x, # start: actual calculations # one permutation round: subsample data, permute variables and compute losses - nCases <- unique(sapply(data, nrow)) - if (length(nCases) > 1) { + n_cases <- unique(sapply(data, nrow)) + if (length(n_cases) > 1) { stop("Number of cases among inputs in data are different.") } - sampled_rows <- 1:nCases + sampled_rows <- 1:n_cases loss_after_permutation <- function() { if (!is.null(N)) { - if (N < nCases) { + if (N < n_cases) { # sample N points - sampled_rows <- sample(1:nCases, N) + sampled_rows <- sample(1:n_cases, N) } } sampled_data <- lapply(data, function(d) { @@ -413,13 +413,13 @@ feature_importance.multiinput <- function(x, loss_full <- loss_function(observed, predict_function(x, sampled_data)) loss_baseline <- loss_function(sample(observed), predict_function(x, sampled_data)) # loss upon dropping a single variable (or a single group) - loss_featuresL <- mapply(function(d, vars, inputData) { + loss_featuresL <- mapply(function(d, vars, input_data) { loss_features <- sapply(vars, function(variables_set) { ndf <- d - dimPerm <- names(dimnames(ndf)) %in% names(variables_set) + dim_perm <- names(dimnames(ndf)) %in% names(variables_set) dims <- list() - for (i in 2:length(dimPerm)) { # First dimension for cases - if (dimPerm[i]) { + for (i in 2:length(dim_perm)) { # First dimension for cases + if (dim_perm[i]) { dims[[i]] <- variables_set[[names(dimnames(ndf))[i]]] } else { dims[[i]] <- 1:dim(ndf)[i] @@ -427,18 +427,18 @@ feature_importance.multiinput <- function(x, } names(dims) <- names(dimnames(ndf)) - if (length(dimPerm) == 2) { + if (length(dim_perm) == 2) { ndf[, dims[[2]]] <- ndf[sample(1:nrow(ndf)), dims[[2]]] - } else if (length(dimPerm) == 3) { + } else if (length(dim_perm) == 3) { ndf[, dims[[2]], dims[[3]]] <- ndf[sample(1:nrow(ndf)), dims[[2]], dims[[3]]] } else { stop("Dimensions for this kind of data is not implemented but should be easy. Contact with the developers.") } - sampled_data[[inputData]] <- ndf + sampled_data[[input_data]] <- ndf predicted <- predict_function(x, sampled_data) loss_function(observed, predicted) }) - }, d=sampled_data, vars=variables, inputData=seq_along(sampled_data), SIMPLIFY=FALSE) + }, d=sampled_data, vars=variables, input_data=seq_along(sampled_data), SIMPLIFY=FALSE) unlist(c("_full_model_" = loss_full, loss_featuresL, "_baseline_" = loss_baseline)) } From fd3ef8d8101da16d846113fb2db92d7729567a29 Mon Sep 17 00:00:00 2001 From: Joan Maspons Date: Mon, 21 Mar 2022 16:02:07 +0100 Subject: [PATCH 5/8] Update docs with updated parameter names --- man/feature_importance.Rd | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/man/feature_importance.Rd b/man/feature_importance.Rd index 6326d3d..c6f1b58 100644 --- a/man/feature_importance.Rd +++ b/man/feature_importance.Rd @@ -35,8 +35,8 @@ feature_importance(x, ...) variables = NULL, N = n_sample, variable_groups = NULL, - permDim = NULL, - combDims = FALSE + perm_dim = NULL, + comb_dims = FALSE ) } \arguments{ @@ -58,7 +58,7 @@ while "difference" returns \code{drop_loss - drop_loss_full_model}} \item{variable_groups}{list of variables names vectors or a list of vectors for multiinput models. This is for testing joint variable importance. If \code{NULL} then variable importance will be tested separately for \code{variables}. -By default \code{NULL}. If specified then it will override \code{variables}, \code{permDim} and \code{combDims}.} +By default \code{NULL}. If specified then it will override \code{variables}, \code{perm_dim} and \code{comb_dims}.} \item{N}{number of observations that should be sampled for calculation of variable importance. If \code{NULL} then variable importance will be calculated on whole dataset (no sampling).} @@ -72,11 +72,11 @@ NOTE: It is best when target variable is not present in the \code{data}} \item{predict_function}{predict function, will be extracted from \code{x} if it's an explainer} -\item{permDim}{the dimensions to perform the permutations when \code{data} is a 3d array (e.g. [case, time, variable]). -If \code{permDim = 2:3}, it calculates the importance for each variable in the 2nd and 3rd dimensions. +\item{perm_dim}{the dimensions to perform the permutations when \code{data} is a 3d array (e.g. [case, time, variable]). +If \code{perm_dim = 2:3}, it calculates the importance for each variable in the 2nd and 3rd dimensions. For multiinput models, a list of dimensions in the same order than in \code{data}. If \code{NULL}, the default, take all dimensions except the first one (i.e. rows) which correspond to cases.} -\item{combDims}{if \code{TRUE}, do the permutations for each combination of the levels of the variables from 2nd and 3rd dimensions for input data with 3 dimensions. By default \code{FALSE}} +\item{comb_dims}{if \code{TRUE}, do the permutations for each combination of the levels of the variables from 2nd and 3rd dimensions for input data with 3 dimensions. By default \code{FALSE}} } \value{ an object of the class \code{feature_importance} From a69bbfdcf2479096efd12543d73a45600ff1db65 Mon Sep 17 00:00:00 2001 From: Joan Maspons Date: Tue, 22 Mar 2022 18:04:27 +0100 Subject: [PATCH 6/8] Tests for feature_importance.multiinput --- DESCRIPTION | 5 +- .../test_variable_dropout-multiinput.R | 363 ++++++++++++++++++ 2 files changed, 367 insertions(+), 1 deletion(-) create mode 100644 tests/testthat/test_variable_dropout-multiinput.R diff --git a/DESCRIPTION b/DESCRIPTION index 8397924..7fd1f13 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -37,7 +37,10 @@ Suggests: jsonlite, knitr, rmarkdown, - covr + covr, + abind, + data.table, + keras URL: https://ModelOriented.github.io/ingredients/, https://github.com/ModelOriented/ingredients BugReports: https://github.com/ModelOriented/ingredients/issues VignetteBuilder: knitr diff --git a/tests/testthat/test_variable_dropout-multiinput.R b/tests/testthat/test_variable_dropout-multiinput.R new file mode 100644 index 0000000..92602ee --- /dev/null +++ b/tests/testthat/test_variable_dropout-multiinput.R @@ -0,0 +1,363 @@ +context("Check feature_importance() function") + +library(keras) + +df<- data.frame(id=rep(LETTERS[1:10], each=5), static=rep(1:10, each=5), time=rep(1:5, times=5)) +df.cat<- data.frame(id=LETTERS[1:10], cat1=rep(LETTERS[1:5], times=2), cat2=letters[1:10]) +df<- merge(df, df.cat) +df$x1<- df$time * df$static +df$x2<- rnorm(nrow(df), mean=df$time * df$static + 10, sd=5) +df$x3<- rnorm(nrow(df), mean=df$time * df$static * 3, sd=2) +df$y<- rnorm(nrow(df), mean=(df$x1 + df$x2) / df$x3, sd=2) + +timevar<- "time" +idVars<- "id" +responseVars<- "y" +staticVars<- c("static", "cat1", "cat2") +predTemp<- c("x1", "x2", "x3") +responseTime<- max(df[, timevar], na.rm=TRUE) +regex_time<- "[0-9]+" +hidden_shape.RNN<- 8 +hidden_shape.static<- 8 +hidden_shape.main<- 16 +epochs<- 3 +batch_size<- length(unique(df$id)) +verbose<- 0 + +wideTo3Darray.ts<- function(d, vars, idCols){ + d<- as.data.frame(d) + timevals<- unique(gsub(paste0("^(", paste(vars, collapse="|"), ")_"), "", setdiff(colnames(d), idCols))) + + # Reshape to a 3D array [samples, timesteps, features] Format for RNN layers in NN + a<- lapply(vars, function(x){ + varTS<- d[, grep(paste0("^", x, "_(", paste(timevals, collapse="|"), ")$"), colnames(d))] + a<- array(as.matrix(varTS), dim=c(nrow(varTS), ncol(varTS), 1), dimnames=list(case=NULL, t=gsub(paste0("^", x, "_"), "", colnames(varTS)), var=x)) + }) + names(a)<- vars + a<- abind::abind(a) + names(dimnames(a))<- c("case", "t", "var") + dimnames(a)$case<- do.call(paste, c(d[, idCols, drop=FALSE], list(sep="_"))) + + return(a) +} + +build_modelLTSM<- function(input_shape.ts, input_shape.static=0, output_shape=1, + hidden_shape.RNN=32, hidden_shape.static=16, hidden_shape.main=32){ + inputs.ts<- layer_input(shape=input_shape.ts, name="TS_input") + inputs.static<- layer_input(shape=input_shape.static, name="Static_input") + + predictions.ts<- inputs.ts + for (i in 1:length(hidden_shape.RNN)){ + predictions.ts<- predictions.ts %>% layer_lstm(units=hidden_shape.RNN[i], name=paste0("LSTM_", i)) + } + + if (input_shape.static > 0){ + predictions.static<- inputs.static + for (i in 1:length(hidden_shape.static)){ + predictions.static<- predictions.static %>% layer_dense(units=hidden_shape.static[i], name=paste0("Dense_", i)) + } + output<- layer_concatenate(c(predictions.ts, predictions.static)) + } else { + output<- predictions.ts + } + + for (i in 1:length(hidden_shape.main)){ + output<- output %>% layer_dense(units=hidden_shape.main[i], name=paste0("main_dense_", i)) + } + output<- output %>% layer_dense(units=output_shape, name="main_output") + + if (input_shape.static > 0){ + model<- keras_model(inputs=c(inputs.ts, inputs.static), outputs=output) + } else { + model<- keras_model(inputs=inputs.ts, outputs=output) + } + + compile(model, loss="mse", optimizer=optimizer_rmsprop()) + + model +} + +predVars<- setdiff(colnames(df), c(idVars, timevar)) +predVars.cat<- names(which(!sapply(df[, predVars, drop=FALSE], is.numeric))) +predVars.num<- setdiff(predVars, predVars.cat) + +df.catBin<- stats::model.matrix(stats::as.formula(paste("~ -1 +", paste(predVars.cat, collapse="+"))), data=df) +predVars.catBin<- colnames(df.catBin) +df<- cbind(df[, setdiff(colnames(df), predVars.cat)], df.catBin) +predVars<- c(predVars.num, predVars.catBin) +staticVars.cat<- staticVars[staticVars %in% predVars.cat] +staticVars<- c(setdiff(staticVars, staticVars.cat), predVars.catBin) + +# crossvalidation for timeseries must be done in the wide format data +responseVars.ts<- paste0(responseVars, "_", responseTime) +predVars.tf<- paste0(setdiff(predVars, staticVars), "_", responseTime) + +## df to wide format +dt<- data.table::as.data.table(df) +staticCols<- c(idVars, staticVars) +vars<- setdiff(colnames(dt), c(staticCols, timevar)) +timevals<- unique(data.table:::`[.data.table`(x=dt, , j=timevar, with=FALSE))[[1]] # without importing data.table functions +LHS<- setdiff(staticCols, timevar) +form<- paste0(paste(LHS, collapse=" + "), " ~ ", timevar) +dt<- data.table::dcast(dt, formula=stats::formula(form), value.var=vars) # To wide format (var_time columns) +df.wide<- as.data.frame(dt) +predVars.ts<- setdiff(colnames(df.wide), c(idVars, staticVars)) # WARNING: Includes responseVars.ts +timevals<- unique(df[[timevar]]) + +idxTrain<- sample(1:nrow(df.wide), 6) + +train_y<- df.wide[idxTrain, c(idVars, responseVars.ts), drop=FALSE] +train_data<- df.wide[idxTrain, c(idVars, staticVars, predVars.ts), drop=FALSE] + +test_y<- df.wide[-idxTrain, c(idVars, responseVars.ts), drop=FALSE] +test_data<- df.wide[-idxTrain, c(idVars, staticVars, predVars.ts), drop=FALSE] + + +# Reshape data to 3D arrays [samples, timesteps, features] as expected by LSTM layer +train_data.3d<- wideTo3Darray.ts(d=train_data, vars=setdiff(predVars, staticVars), idCols=idVars) +test_data.3d<- wideTo3Darray.ts(d=test_data, vars=setdiff(predVars, staticVars), idCols=idVars) + +train_data.static<- structure(as.matrix(train_data[, staticVars, drop=FALSE]), + dimnames=list(case=do.call(paste, c(train_data[, idVars, drop=FALSE], list(sep="_"))), var=staticVars)) +test_data.static<- structure(as.matrix(test_data[, staticVars, drop=FALSE]), + dimnames=list(case=do.call(paste, c(test_data[, idVars, drop=FALSE], list(sep="_"))), var=staticVars)) + +train_y<- structure(as.matrix(train_y[, responseVars.ts, drop=FALSE]), + dimnames=list(case=do.call(paste, c(train_y[, idVars, drop=FALSE], list(sep="_"))), var=responseVars.ts)) +test_y<- structure(as.matrix(test_y[, responseVars.ts, drop=FALSE]), + dimnames=list(case=do.call(paste, c(test_y[, idVars, drop=FALSE], list(sep="_"))), var=responseVars.ts)) + +train_data.3d<- train_data.3d[, setdiff(dimnames(train_data.3d)[[2]], responseTime), ] +test_data.3d<- test_data.3d[, setdiff(dimnames(train_data.3d)[[2]], responseTime), ] + +train_data<- list(TS_input=train_data.3d, Static_input=train_data.static) +test_data<- list(TS_input=test_data.3d, Static_input=test_data.static) + +modelNN<- build_modelLTSM(input_shape.ts=dim(train_data.3d)[-1], input_shape.static=length(staticVars), output_shape=length(responseVars), + hidden_shape.RNN=hidden_shape.RNN, hidden_shape.static=hidden_shape.static, hidden_shape.main=hidden_shape.main) +history<- keras::fit(object=modelNN, x=train_data, y=train_y, batch_size=batch_size, + epochs=epochs, verbose=verbose, validation_data=list(test_data, test_y)) + +## 3D data only in a separate PR +# modelNN.LSTM<- build_modelLTSM(input_shape.ts=dim(train_data.3d)[-1], input_shape.static=0, output_shape=length(responseVars), +# hidden_shape.RNN=hidden_shape.RNN, hidden_shape.static=0, hidden_shape.main=hidden_shape.main) +# history<- keras::fit(object=modelNN.LSTM, x=train_data.3d, y=train_y, batch_size=batch_size, +# epochs=epochs, verbose=verbose, validation_data=list(test_data.3d, test_y)) + + +# Basics - tests with improper and proper inputs + +test_that("Output glm",{ + vd_keras <- feature_importance(x = modelNN, data = test_data, y = test_y, + type = "raw", loss_function = loss_root_mean_square) + expect_true("feature_importance_explainer" %in% class(vd_keras)) +}) + + +# Permutations and subsampling + +test_that("feature_importance gives slightly different output on subsequent runs", { + result_1 <- feature_importance(x = modelNN, data = test_data, y = test_y,) + result_2 <- feature_importance(x = modelNN, data = test_data, y = test_y,) + change_12 <- abs(result_1$dropout_loss - result_2$dropout_loss) + expect_gt(sum(change_12), 0) +}) + + +test_that("feature_importance records number of permutations", { + result <- feature_importance(x = modelNN, data = test_data, y = test_y, B = 2) + expect_false(is.null(attr(result, "B"))) + expect_equal(attr(result, "B"), 2) + expect_equal(max(result$permutation), 2) + expect_equal(dim(result[result$permutation != 0,]), c(2*nrow(result)/3, 4)) + # because there is no sub-sampling, all the full-model results should be equal + loss_full <- result[result$variable=="_full_model_",] + expect_equal(length(unique(loss_full$dropout_loss)), 1) +}) + +# +# test_that("feature_importance avoids reporting permutations when only one performed", { +# # by default, one permutation leads to no raw_permutations component +# result_default <- feature_importance(explainer_rf, B = 1) +# expect_true(is.null(attr(result_default, "raw_permutations"))) +# result_keep <- feature_importance(explainer_rf, B = 1, keep_raw_permutations = TRUE) +# expect_false(is.null(attr(result_keep, "raw_permutations"))) +# }) +# +# +# test_that("feature_importance can avoid recording permutation details", { +# result <- feature_importance(explainer_rf, B = 2, keep_raw_permutations = FALSE) +# expect_true(is.null(attr(result, "raw_permutations"))) +# # when keep_raw_permutations is off, output should still signal number of permutations +# expect_false(is.null(attr(result, "B"))) +# expect_equal(attr(result, "B"), 2) +# }) + + +## Too few cases? +# test_that("feature_importance with subsampling gives different full-model results ", { +# result <- feature_importance(x = modelNN, data = test_data, y = test_y, B = 2, N=2) +# # the full model losses should be different in the first and second round +# # because each round is based on different rows in the data... +# # but in principle there is a tiny probability the two rounds are based on the same rows +# loss_full <- result[result$variable=="_full_model_", ] +# expect_equal(length(unique(loss_full$dropout_loss)), 3) +# }) + + +test_that("feature_importance performs at least one permutation", { + result <- feature_importance(x = modelNN, data = test_data, y = test_y, B = 0.1) + expect_false(is.null(attr(result, "B"))) + expect_equal(attr(result, "B"), 1) +}) + + +## Too few cases? +# test_that("feature_importance averaged over many permutations are stable", { +# # this test uses many permutations, so make a very small titanic dataset for speed +# tiny <- titanic_small[titanic_small$age > 50,] +# tiny$country <- tiny$class <- tiny$sibsp <- tiny$embarked <- tiny$gender <- NULL +# +# tiny_rf <- ranger(survived ~ parch + fare + age, data = tiny, probability = TRUE) +# tiny_explainer = explain(tiny_rf, data = tiny, +# y = tiny$survived == "yes", label = "RF") +# # compute single permutations importance values +# result_1 <- feature_importance(tiny_explainer, B = 1) +# result_2 <- feature_importance(tiny_explainer, B = 1) +# # compute feature importance with several permutations +# result_A <- feature_importance(tiny_explainer, B = 40) +# result_B <- feature_importance(tiny_explainer, B = 40) +# # two rounds with many permutation should give results closer together +# # than two rounds with single permutations +# change_12 <- abs(result_1[result_1$permutation == 0, "dropout_loss"] - result_2[result_2$permutation == 0, "dropout_loss"]) +# change_AB <- abs(result_A[result_A$permutation == 0, "dropout_loss"] - result_B[result_B$permutation == 0, "dropout_loss"]) +# # this test should succeed most of the time... but in principle could fail by accident +# expect_lt(sum(change_AB), sum(change_12)) +# }) + + + + +# Variable grouping + +v_groups.ts <- mapply(function(dimVar, dimNames) { + v<- lapply(dimVar, function(v) setNames(list(v), dimNames)) + setNames(v, nm = dimVar) +}, dimVar=dimnames(test_data$TS_input)[-1], dimNames=names(dimnames(test_data$TS_input))[-1], SIMPLIFY=FALSE) +v_groups.ts <- do.call(c, v_groups.ts) + +v_groups.tsCombDim <- expand.grid(dimnames(test_data$TS_input)[-1], stringsAsFactors=FALSE, KEEP.OUT.ATTRS=FALSE) # All combinations for all dimensions in a dataset +rownames(v_groups.tsCombDim) <- apply(v_groups.tsCombDim, 1, function(v) paste(v, collapse="|")) +v_groups.tsCombDim <- split(v_groups.tsCombDim, rownames(v_groups.tsCombDim)) +v_groups.tsCombDim <- lapply(v_groups.tsCombDim, as.list) + +v_groups.static<- list(list("static"), + list(grep("^cat1", dimnames(test_data$Static_input)$var, value=TRUE)), + list(grep("^cat2", dimnames(test_data$Static_input)$var, value=TRUE))) +names(v_groups.static)<- c("static", "cat1", "cat2") +variable_groups<- list(TS_input=v_groups.ts, Static_input=v_groups.static) +variable_groups.combDim<- list(TS_input=v_groups.tsCombDim, Static_input=v_groups.static) + +variable_groups.noGrNames<-lapply(variable_groups, function(input){ + names(input)<- NULL + input +}) + +variable_groups.combDim.noGrNames<-lapply(variable_groups.combDim, function(input){ + names(input)<- NULL + input +}) + +## For visually inspect the construction of variable_groups names +# variable_groups.noVarNames<-lapply(variable_groups, function(input){ +# lapply(input, function(gr){ +# names(gr)<- NULL +# gr +# }) +# }) +# variable_groups.noGrVarNames<-lapply(variable_groups, function(input){ +# names(input)<- NULL +# lapply(input, function(gr){ +# names(gr)<- NULL +# gr +# }) +# }) +# variable_groups.combDim.noVarNames<-lapply(variable_groups.combDim, function(input){ +# lapply(input, function(gr){ +# names(gr)<- NULL +# gr +# }) +# }) +# variable_groups.combDim.noGrVarNames<-lapply(variable_groups.combDim, function(input){ +# names(input)<- NULL +# lapply(input, function(gr){ +# names(gr)<- NULL +# gr +# }) +# }) + +test_that("Variable groupings validation", { + result <- feature_importance(x = modelNN, data = test_data, y = test_y, + loss_function = loss_root_mean_square, + variable_groups = variable_groups) + expect_is(result, "feature_importance_explainer") +}) + + +test_that("Variable groupings validation with combinations of variables from 2 different dimensions", { + result <- feature_importance(x = modelNN, data = test_data, y = test_y, + variable_groups = variable_groups.combDim) + expect_is(result, "feature_importance_explainer") +}) + + +test_that("Variable groupings input validation checks", { + expect_warning(feature_importance(x = modelNN, data = test_data, y = test_y, + loss_function = loss_root_mean_square, + variable_groups = variable_groups.noGrNames), + "You have passed an unnamed list. The names of variable groupings will be created from variables names.") + + expect_error( feature_importance(x = modelNN, data = test_data, y = test_y, + loss_function = loss_root_mean_square, + variable_groups = c("x1", "cat1A") + ), + "variable_groups should be of class list contining lists for each data input") + + variable_groups.wrong<- variable_groups + variable_groups.wrong$TS_input$t.1<- "-1" + variable_groups.wrong$Static_input$cat1[1]<- "wrong" + expect_error(feature_importance(x = modelNN, data = test_data, y = test_y, + loss_function = loss_root_mean_square, + variable_groups = variable_groups.wrong), + "You have passed wrong variables names in variable_groups argument") + variable_groups.wrong<- variable_groups + variable_groups.wrong$Static_input$cat1[[1]]<- as.list(variable_groups.wrong$Static_input$cat1[[1]]) + expect_error(feature_importance(x = modelNN, data = test_data, y = test_y, + loss_function = loss_root_mean_square, + variable_groups = variable_groups.wrong), + "Elements of variable_groups argument should be of class character") +}) + + + + +# Output types + +test_that("feature_importance with type ratio", { + # type "ratio" gives $dropout_loss normalized by _full_model_ + result <- feature_importance(x = modelNN, data = test_data, y = test_y, type="ratio") + expect_equal(result$dropout_loss[result$variable=="_full_model_" & result$permutation == 0], 1) +}) + + +test_that("feature_importance with type difference", { + # type "difference" gives $dropout_loss with _full_model_ subtracted + result <- feature_importance(x = modelNN, data = test_data, y = test_y, type="difference") + expect_equal(result$dropout_loss[result$variable=="_full_model_" & result$permutation == 0], 0) +}) + +test_that("Inverse sorting of bars",{ + result <- feature_importance(x = modelNN, data = test_data, y = test_y, type="difference") + + expect_error(plot(result, desc_sorting = "desc")) +}) From 41922e1d0b3a2f4ff68265fb4e8acc5b5ab44e3c Mon Sep 17 00:00:00 2001 From: Joan Maspons Date: Fri, 1 Apr 2022 17:21:52 +0200 Subject: [PATCH 7/8] Fix call to feature_importance.multiinput by passing N=N --- R/feature_importance.R | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/R/feature_importance.R b/R/feature_importance.R index d65b639..1a291f5 100644 --- a/R/feature_importance.R +++ b/R/feature_importance.R @@ -185,10 +185,9 @@ feature_importance.default <- function(x, ..., label = label, type = type, - n_sample = n_sample, B = B, variables = variables, - N = n_sample, + N = N, variable_groups = variable_groups, perm_dim = perm_dim, comb_dims = comb_dims) From 89b83e40151dd2ab7e9dd668f65b6ebd11819cd7 Mon Sep 17 00:00:00 2001 From: Joan Maspons Date: Fri, 1 Apr 2022 17:24:54 +0200 Subject: [PATCH 8/8] Remove deprecated parameter No need in an internal function --- R/feature_importance.R | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/R/feature_importance.R b/R/feature_importance.R index 1a291f5..de05921 100644 --- a/R/feature_importance.R +++ b/R/feature_importance.R @@ -312,10 +312,9 @@ feature_importance.multiinput <- function(x, ..., label = class(x)[1], type = c("raw", "ratio", "difference"), - n_sample = NULL, B = 10, variables = NULL, - N = n_sample, + N = NULL, variable_groups = NULL, perm_dim = NULL, comb_dims = FALSE) {