Skip to content

Commit 6fa5c02

Browse files
committed
reduce number of calls to stats::terms and update.formula
1 parent 40acbff commit 6fa5c02

File tree

3 files changed

+74
-33
lines changed

3 files changed

+74
-33
lines changed

R/brmsterms.R

Lines changed: 66 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,11 @@ brmsterms.brmsformula <- function(formula, check_response = TRUE,
204204
unused_vars
205205
)
206206
if (check_response) {
207-
y$allvars <- update(y$respform, y$allvars)
207+
# add y$respform to the left-hand side of y$allvars
208+
# avoid using update.formula as it is inefficient for longer formulas
209+
formula_allvars <- y$respform
210+
formula_allvars[[3]] <- y$allvars[[2]]
211+
y$allvars <- formula_allvars
208212
}
209213
environment(y$allvars) <- environment(formula)
210214
y
@@ -241,8 +245,9 @@ brmsterms.mvbrmsformula <- function(formula, ...) {
241245
# @return a 'btl' object
242246
terms_lf <- function(formula) {
243247
formula <- rhs(as.formula(formula))
244-
check_accidental_helper_functions(formula)
245248
y <- nlist(formula)
249+
formula <- terms(formula)
250+
check_accidental_helper_functions(formula)
246251
types <- setdiff(all_term_types(), excluded_term_types(formula))
247252
for (t in types) {
248253
tmp <- do_call(paste0("terms_", t), list(formula))
@@ -338,16 +343,19 @@ terms_ad <- function(formula, family = NULL, check_response = TRUE) {
338343

339344
# extract fixed effects terms
340345
terms_fe <- function(formula) {
346+
if (!is.terms(formula)) {
347+
formula <- terms(formula)
348+
}
341349
all_terms <- all_terms(formula)
342350
sp_terms <- find_terms(all_terms, "all", complete = FALSE)
343351
re_terms <- all_terms[grepl("\\|", all_terms)]
344-
int_term <- attr(terms(formula), "intercept")
352+
int_term <- attr(formula, "intercept")
345353
fe_terms <- setdiff(all_terms, c(sp_terms, re_terms))
346354
out <- paste(c(int_term, fe_terms), collapse = "+")
347355
out <- str2formula(out)
348356
attr(out, "allvars") <- allvars_formula(out)
349357
attr(out, "decomp") <- get_decomp(formula)
350-
if (has_rsv_intercept(out)) {
358+
if (has_rsv_intercept(out, has_intercept(formula))) {
351359
attr(out, "int") <- FALSE
352360
}
353361
if (no_cmc(formula)) {
@@ -494,12 +502,14 @@ terms_ac <- function(formula) {
494502

495503
# extract offset terms
496504
terms_offset <- function(formula) {
497-
terms <- terms(as.formula(formula))
498-
pos <- attr(terms, "offset")
505+
if (!is.terms(formula)) {
506+
formula <- terms(as.formula(formula))
507+
}
508+
pos <- attr(formula, "offset")
499509
if (is.null(pos)) {
500510
return(NULL)
501511
}
502-
vars <- attr(terms, "variables")
512+
vars <- attr(formula, "variables")
503513
out <- ulapply(pos, function(i) deparse(vars[[i + 1]]))
504514
out <- str2formula(out)
505515
attr(out, "allvars") <- str2formula(all_vars(out))
@@ -703,8 +713,7 @@ allvars_formula <- function(...) {
703713
stop2("The following variable names are invalid: ",
704714
collapse_comma(invalid_vars))
705715
}
706-
out <- str2formula(c(out, all_vars))
707-
update(out, ~ .)
716+
str2formula(c(out, all_vars))
708717
}
709718

710719
# conveniently extract a formula of all relevant variables
@@ -740,6 +749,20 @@ plus_rhs <- function(x) {
740749
out
741750
}
742751

752+
# like stats::terms but keeps attributes if possible
753+
terms <- function(formula, ...) {
754+
old_attributes <- attributes(formula)
755+
formula <- stats::terms(formula, ...)
756+
new_attributes <- attributes(formula)
757+
sel_names <- setdiff(names(old_attributes), names(new_attributes))
758+
attributes(formula)[sel_names] <- old_attributes[sel_names]
759+
formula
760+
}
761+
762+
is.terms <- function(x) {
763+
inherits(x, "terms")
764+
}
765+
743766
# combine formulas for distributional parameters
744767
# @param formula1 primary formula from which to take the RHS
745768
# @param formula2 secondary formula used to update the RHS of formula1
@@ -887,7 +910,7 @@ all_terms <- function(x) {
887910
if (!length(x)) {
888911
return(character(0))
889912
}
890-
if (!inherits(x, "terms")) {
913+
if (!is.terms(x)) {
891914
x <- terms(as.formula(x))
892915
}
893916
trim_wsp(attr(x, "term.labels"))
@@ -963,10 +986,10 @@ find_terms <- function(x, type, complete = TRUE, ranef = FALSE) {
963986
validate_terms <- function(x) {
964987
no_int <- no_int(x)
965988
no_cmc <- no_cmc(x)
966-
if (is.formula(x) && !inherits(x, "terms")) {
989+
if (is.formula(x) && !is.terms(x)) {
967990
x <- terms(x)
968991
}
969-
if (!inherits(x, "terms")) {
992+
if (!is.terms(x)) {
970993
return(NULL)
971994
}
972995
if (no_int || !has_intercept(x) && no_cmc) {
@@ -979,32 +1002,48 @@ validate_terms <- function(x) {
9791002

9801003
# checks if the formula contains an intercept
9811004
has_intercept <- function(formula) {
982-
formula <- as.formula(formula)
983-
try_terms <- try(terms(formula), silent = TRUE)
984-
if (is(try_terms, "try-error")) {
985-
out <- FALSE
1005+
if (is.terms(formula)) {
1006+
out <- as.logical(attr(formula, "intercept"))
9861007
} else {
987-
out <- as.logical(attr(try_terms, "intercept"))
1008+
formula <- as.formula(formula)
1009+
try_terms <- try(terms(formula), silent = TRUE)
1010+
if (is(try_terms, "try-error")) {
1011+
out <- FALSE
1012+
} else {
1013+
out <- as.logical(attr(try_terms, "intercept"))
1014+
}
9881015
}
9891016
out
9901017
}
9911018

9921019
# check if model makes use of the reserved intercept variables
993-
has_rsv_intercept <- function(formula) {
1020+
# @param has_intercept does the model have an intercept?
1021+
# if NULL this will be inferred from formula itself
1022+
has_rsv_intercept <- function(formula, has_intercept = NULL) {
1023+
.has_rsv_intercept <- function(terms, has_intercept) {
1024+
has_intercept <- as_one_logical(has_intercept)
1025+
intercepts <- c("intercept", "Intercept")
1026+
out <- !has_intercept && any(intercepts %in% all_vars(rhs(terms)))
1027+
return(out)
1028+
}
1029+
if (is.terms(formula)) {
1030+
if (is.null(has_intercept)) {
1031+
has_intercept <- has_intercept(formula)
1032+
}
1033+
return(.has_rsv_intercept(formula, has_intercept))
1034+
}
9941035
formula <- try(as.formula(formula), silent = TRUE)
9951036
if (is(formula, "try-error")) {
996-
out <- FALSE
997-
} else {
1037+
return(FALSE)
1038+
}
1039+
if (is.null(has_intercept)) {
9981040
try_terms <- try(terms(formula), silent = TRUE)
9991041
if (is(try_terms, "try-error")) {
1000-
out <- FALSE
1001-
} else {
1002-
has_intercept <- attr(try_terms, "intercept")
1003-
intercepts <- c("intercept", "Intercept")
1004-
out <- !has_intercept && any(intercepts %in% all_vars(rhs(formula)))
1005-
}
1042+
return(FALSE)
1043+
}
1044+
has_intercept <- has_intercept(try_terms)
10061045
}
1007-
out
1046+
.has_rsv_intercept(formula, has_intercept)
10081047
}
10091048

10101049
# names of reserved variables

R/formula-re.R

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -306,11 +306,11 @@ split_re_terms <- function(re_terms) {
306306
}
307307
}
308308
# prepare effects of basic terms
309-
fe_form <- terms_fe(lhs_form)
309+
lhs_terms <- terms(lhs_form)
310+
fe_form <- terms_fe(lhs_terms)
310311
fe_terms <- all_terms(fe_form)
311-
has_intercept <- attr(terms(fe_form), "intercept")
312312
# the intercept lives within not outside of 'cs' terms
313-
has_intercept <- has_intercept && !"cs" %in% type[[i]]
313+
has_intercept <- has_intercept(lhs_terms) && !"cs" %in% type[[i]]
314314
if (length(fe_terms) || has_intercept) {
315315
new_lhs <- c(new_lhs, formula2str(fe_form, rm = 1))
316316
type[[i]] <- c(type[[i]], "")

tests/testthat/tests.brmsterms.R

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
context("Tests for formula parsing functions")
22

33
test_that("brmsterms finds all variables in very long formulas", {
4-
expect_equal(brmsterms(t2_brand_recall ~ psi_expsi + psi_api_probsolv +
5-
psi_api_ident + psi_api_intere + psi_api_groupint)$all,
6-
t2_brand_recall ~ t2_brand_recall + psi_expsi + psi_api_probsolv + psi_api_ident +
4+
expect_equal(
5+
all.vars(brmsterms(t2_brand_recall ~ psi_expsi + psi_api_probsolv +
6+
psi_api_ident + psi_api_intere + psi_api_groupint)$all),
7+
all.vars(t2_brand_recall ~ t2_brand_recall + psi_expsi + psi_api_probsolv + psi_api_ident +
78
psi_api_intere + psi_api_groupint)
9+
)
810
})
911

1012
test_that("brmsterms handles very long RE terms", {

0 commit comments

Comments
 (0)