Skip to content

Commit

Permalink
sped up GLMM mode -- related to #257
Browse files Browse the repository at this point in the history
  • Loading branch information
jr-leary7 committed Oct 19, 2024
1 parent 8d7b7be commit 30e5d92
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 48 deletions.
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,10 @@ importFrom(ggplot2,scale_y_continuous)
importFrom(ggplot2,theme)
importFrom(ggplot2,theme_classic)
importFrom(glmmTMB,glmmTMB)
importFrom(glmmTMB,glmmTMBControl)
importFrom(glmmTMB,nbinom2)
importFrom(glmmTMB,ranef)
importFrom(mpath,glmregNB)
importFrom(parallel,clusterEvalQ)
importFrom(parallel,clusterExport)
importFrom(parallel,clusterSetRNGStream)
Expand Down
103 changes: 56 additions & 47 deletions R/fitGLMM.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,24 @@
#' @name fitGLMM
#' @author Jack R. Leary
#' @description Fits a negative binomial generalized linear mixed model using truncated power basis function splines as input. The basis matrix can be created adaptively using subject-specific estimation of optimal knots using \code{\link{marge2}}, or basis functions can be evenly space across quantiles. The resulting model can output subject-specific and population-level fitted values.
#' @importFrom purrr map_dfr pmap_dfc
#' @importFrom purrr map map_dbl map_dfr pmap_dfc reduce
#' @importFrom dplyr mutate select if_else
#' @importFrom mpath glmregNB
#' @importFrom stats quantile logLik fitted as.formula
#' @importFrom glmmTMB glmmTMB nbinom2
#' @importFrom glmmTMB glmmTMB nbinom2 glmmTMBControl
#' @param X_pred A matrix with one column containing cell ordering values. Defaults to NULL.
#' @param Y A vector of raw single cell counts. Defaults to NULL.
#' @param Y.offset (Optional) An offset to be included in the final model fit. Defaults to NULL.
#' @param id.vec A vector of subject IDs. Defaults to NULL.
#' @param adaptive Should basis functions be chosen adaptively? Defaults to TRUE.
#' @param approx.knot Should knot approximation be used in the calls to \code{\link{marge2}}? This speeds up computation somewhat. Defaults to TRUE.
#' @param M.glm The number of possible basis functions to use in the calls to \code{\link{marge2}} when choosing basis functions adaptively.
#' @param reg.penalty (Optional) String specifying the penalty type to be used when fitting a regularized negative-binomial model to select optimal basis functions. Defaults to "snet".
#' @param return.basis (Optional) Whether the basis model matrix (denoted \code{B_final}) should be returned as part of the \code{marge} model object. Defaults to FALSE.
#' @param return.GCV (Optional) Whether the final GCV value should be returned as part of the \code{marge} model object. Defaults to FALSE.
#' @param verbose (Optional) Should intermediate output be printed to the console? Defaults to FALSE.
#' @return An object of class \code{marge} containing the fitted model & other optional quantities of interest (basis function matrix, GCV, etc.).
#' @seealso \code{\link[mpath]{glmregNB}}
#' @seealso \code{\link[glmmTMB]{glmmTMB}}
#' @seealso \code{\link{testDynamic}}
#' @seealso \code{\link{modelLRT}}
Expand All @@ -38,6 +41,7 @@ fitGLMM <- function(X_pred = NULL,
adaptive = TRUE,
approx.knot = TRUE,
M.glm = 3,
reg.penalty = "snet",
return.basis = FALSE,
return.GCV = FALSE,
verbose = FALSE) {
Expand All @@ -48,40 +52,39 @@ fitGLMM <- function(X_pred = NULL,
# fit NB GLMM
if (adaptive) {
glm_marge_knots <- purrr::map(unique(id.vec),
function(x) {
marge_mod_sub <- marge2(X_pred = X_pred[which(id.vec == x), , drop = FALSE],
Y = Y[which(id.vec == x)],
Y.offset = Y.offset[which(id.vec == x)],
approx.knot = approx.knot,
M = M.glm,
return.basis = TRUE)
})
keepmods <- which(sapply(glm_marge_knots, function(x) length(x$coef_names)) > 1)
function(x) {
marge_mod_sub <- marge2(X_pred = X_pred[which(id.vec == x), , drop = FALSE],
Y = Y[which(id.vec == x)],
Y.offset = Y.offset[which(id.vec == x)],
approx.knot = approx.knot,
M = M.glm,
is.glmm = TRUE,
return.basis = TRUE)
return(marge_mod_sub)
})
keepmods <- which(purrr::map_dbl(glm_marge_knots, \(x) ncol(x$basis_mtx)) > 1)
glm_marge_knots <- glm_marge_knots[keepmods]

allKnots <- lapply(glm_marge_knots, function(x) extractBreakpoints(x)$Breakpoint)
allCoef <- lapply(glm_marge_knots, function(x) names(coef(x$final_mod)[-1]))
allOldCoef <- lapply(glm_marge_knots, function(x) paste0("B_final", x$marge_coef_names[-1]))
tp_fun <- lapply(allCoef, function(x) dplyr::if_else(grepl("h_[0-9]", x), "tp2", "tp1"))

glm_marge_knots <- data.frame(knot = do.call(c, allKnots),
coef = do.call(c, allCoef),
old_coef = do.call(c, allOldCoef),
tp_fun = do.call(c, tp_fun))

allKnots <- purrr::map(glm_marge_knots, \(x) extractBreakpoints(x)$Breakpoint)
allCoef <- purrr::map(glm_marge_knots, \(x) x$coef_names[-1])
allOldCoef <- purrr::map(glm_marge_knots, \(x) paste0("B_final", x$marge_coef_names[-1]))
tp_fun <- purrr::map(allCoef, \(x) dplyr::if_else(grepl("h_[0-9]", x), "tp2", "tp1"))
glm_marge_knots <- data.frame(knot = purrr::reduce(allKnots, c),
coef = purrr::reduce(allCoef, c),
old_coef = purrr::reduce(allOldCoef, c),
tp_fun = purrr::reduce(tp_fun, c))
glmm_basis_df <- purrr::pmap_dfc(list(glm_marge_knots$knot,
glm_marge_knots$tp_fun,
seq_len(nrow(glm_marge_knots))),
function(k, f, i) {
if (f == "tp1") {
basis <- tp1(x = X_pred[, 1], t = k)
} else {
basis <- tp2(x = X_pred[, 1], t = k)
}
basis_df <- data.frame(basis)
colnames(basis_df) <- paste0("X", i)
return(basis_df)
})
glm_marge_knots$tp_fun,
seq_len(nrow(glm_marge_knots))),
function(k, f, i) {
if (f == "tp1") {
basis <- tp1(x = X_pred[, 1], t = k)
} else {
basis <- tp2(x = X_pred[, 1], t = k)
}
basis_df <- data.frame(basis)
colnames(basis_df) <- paste0("X", i)
return(basis_df)
})
if (ncol(glmm_basis_df) == 1) {
mod_formula <- stats::as.formula(paste0("Y ~ ",
paste(colnames(glmm_basis_df), collapse = " + "),
Expand Down Expand Up @@ -109,50 +112,54 @@ fitGLMM <- function(X_pred = NULL,
data = glmm_basis_df,
parallel = FALSE,
nlambda = 50,
penalty = "snet",
penalty = reg.penalty,
alpha = .5,
standardize = TRUE,
trace = FALSE,
maxit.theta = 1,
link = log)
link = log,
maxit = 500)
} else {
pruned_model <- mpath::glmregNB(lasso_formula,
data = glmm_basis_df,
offset = log(1 / Y.offset),
parallel = FALSE,
nlambda = 50,
penalty = "snet",
penalty = reg.penalty,
alpha = .5,
standardize = TRUE,
trace = FALSE,
maxit.theta = 1,
link = log)
link = log,
maxit = 500)
}
# identify nonzero basis functions in minimum AIC model
nonzero_coefs <- which(as.numeric(pruned_model$beta[, which.min(pruned_model$aic)]) != 0)
# identify nonzero basis functions in minimum BIC model
nonzero_coefs <- which(as.numeric(pruned_model$beta[, which.min(pruned_model$bic)]) != 0)
# build formula automatically
mod_formula <- stats::as.formula(paste0("Y ~ ",
paste(colnames(glmm_basis_df)[nonzero_coefs], collapse = " + "),
" + (1 + ", paste(colnames(glmm_basis_df)[nonzero_coefs], collapse = " + "),
" | subject)"))
glmm_basis_df_new <- dplyr::mutate(glmm_basis_df,
Y = Y,
subject = id.vec,
.before = 1)
Y = Y,
subject = id.vec,
.before = 1)
}
if (is.null(Y.offset)) {
glmm_mod <- glmmTMB::glmmTMB(mod_formula,
data = glmm_basis_df_new,
family = glmmTMB::nbinom2(link = "log"),
se = TRUE,
REML = FALSE)
REML = FALSE,
control = glmmTMB::glmmTMBControl(profile = TRUE, eigval_check = FALSE))
} else {
glmm_mod <- glmmTMB::glmmTMB(mod_formula,
data = glmm_basis_df_new,
offset = log(1 / Y.offset),
family = glmmTMB::nbinom2(link = "log"),
se = TRUE,
REML = FALSE)
REML = FALSE,
control = glmmTMB::glmmTMBControl(profile = TRUE, eigval_check = FALSE))
}
} else {
glmm_basis_df_new <- data.frame(X1 = tp1(X_pred[, 1], t = round(as.numeric(stats::quantile(X_pred[, 1], 1/3)), 4)),
Expand All @@ -175,14 +182,16 @@ fitGLMM <- function(X_pred = NULL,
data = glmm_basis_df_new,
family = glmmTMB::nbinom2(link = "log"),
se = TRUE,
REML = FALSE)
REML = FALSE,
control = glmmTMB::glmmTMBControl(profile = TRUE, eigval_check = FALSE))
} else {
glmm_mod <- glmmTMB::glmmTMB(Y ~ X1 + X2 + X3 + X4 + (1 + X1 + X2 + X3 + X4 | subject),
data = glmm_basis_df_new,
offset = log(1 / Y.offset),
family = glmmTMB::nbinom2(link = "log"),
se = TRUE,
REML = FALSE)
REML = FALSE,
control = glmmTMB::glmmTMBControl(profile = TRUE, eigval_check = FALSE))
}
}
# set up results
Expand Down
2 changes: 1 addition & 1 deletion R/marge2.R
Original file line number Diff line number Diff line change
Expand Up @@ -878,7 +878,7 @@ marge2 <- function(X_pred = NULL,
WIC_mtx = NULL,
GCV = NULL,
model_type = "GLMM",
coef_names = NULL,
coef_names = colnames(model_df),
marge_coef_names = colnames(B_final))
}

Expand Down
5 changes: 5 additions & 0 deletions man/fitGLMM.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 30e5d92

Please sign in to comment.