Skip to content

Commit 6edf0c8

Browse files
committed
remove chains that fail to initialize while sampling in parallel
leaving the other chains untouched
1 parent 496df8d commit 6edf0c8

File tree

3 files changed

+10
-3
lines changed

3 files changed

+10
-3
lines changed

News.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ output: html_document
66
# under development
77
----------------------------------------------------------------
88

9+
* remove chains that fail to initialize while sampling in parallel leaving the other chains untouched
910

1011
# brms 0.4.1
1112
----------------------------------------------------------------

R/main.R

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -301,9 +301,15 @@ brm <- function(formula, data = NULL, family = c("gaussian", "identity"), prior
301301
cl <- makeCluster(n.cluster)
302302
clusterEvalQ(cl, require(rstan))
303303
clusterExport(cl = cl, c("x", "inits", "n.iter", "n.warmup", "n.thin"), envir = environment())
304-
x$fit <- rstan::sflist2stanfit(parLapply(cl, 1:n.chains, fun = function(i)
304+
sflist <- parLapply(cl, 1:n.chains, fun = function(i)
305305
rstan::sampling(x$fit, data = x$data, iter = n.iter, pars = x$exclude, init = inits[i],
306-
warmup = n.warmup, thin = n.thin, chains = 1, chain_id = i, include = FALSE)))
306+
warmup = n.warmup, thin = n.thin, chains = 1, chain_id = i, include = FALSE))
307+
x$fit <- rstan::sflist2stanfit(rmNULL(lapply(1:length(sflist), function(i) {
308+
if (!is(sflist[[i]], "stanfit") || length(sflist[[i]]@sim$samples) == 0) {
309+
warning(paste("chain", i, "did not contain samples and was removed from the fitted model"))
310+
return(NULL)
311+
} else return(sflist[[i]])
312+
})))
307313
stopCluster(cl)
308314
}
309315
else x$fit <- rstan::sampling(x$fit, data = x$data, iter = n.iter, pars = x$exclude, init = inits,

R/misc.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ array2list <- function(x) {
1010
l
1111
}
1212

13-
isNULL <- function(x) is.null(x) || all(sapply(x, is.null))
13+
isNULL <- function(x) is.null(x) || ifelse(is.vector(x), all(sapply(x, is.null)), FALSE)
1414

1515
rmNULL <- function(x) {
1616
x <- Filter(Negate(isNULL), x)

0 commit comments

Comments
 (0)