Skip to content

Latest commit

 

History

History
226 lines (190 loc) · 10.6 KB

README.md

File metadata and controls

226 lines (190 loc) · 10.6 KB

Note:

Development for this project is on pause, partly due to personal time constraints, partly because I'm aiming to transition the core out of RStudio Jobs to a python-based background-process framework that should be more environment agnostic (ex. could be run in RStudio or VSCode). I'm also finding the NETCDF file format somewhat buggy and am considering alternatives like zarr.

aria

An R package implementing an idiosyncratic Stan workflow.

Install via:

remotes::install_github('mike-lawrence/aria/aria')

(Yup, that's an extra /aria relative to what you might be used to in terms of install_github() argument strings)

Why "aria"?

Where Stan is eponymous for Stanislaw Ulam, this package was inspired by Gelman, 2021 to pay homage to Arianna Rosenbluth, with liberties taken to employ a shorter sobriquet that happens also to connote the fact that this is a solo project of mine.

Why use aria?

You probably shouldn't, not yet at least. That is, this package is still in it's development infancy, focused first on feature implementation and only secondarily aiming for things like cross-platform reliability and ability to handle all possible workflow scenarios. I like to think of this package as a kind of workflow proof-of-concept, permitting the agile implementation of my ideas; once I have things I think are worthwhile to others, I'll work on advocating their inclusion in the more popular stability-focused packages.

Implemented features

  • Use of stanc3 for syntax checking in RStudio (automatically enabled on package load; see ?aria::enable_rstudio_syntax_compile)
  • Option to trigger model compilation on save with a aria: compile = 1 string at the top of your Stan file
  • Smart compilation whereby the saved model is compared to a stored (in a folder called aria) representation and compilation only proceeds if functional changes to the code have been made (n.b. including to any includes!). So go ahead and add comments and modify whitespace without fear that they will cause unnecessary model recompilation.
  • Both compilation and sampling occur in background processes with outputs/progress monitored by an RStudio Job.
  • Automatic check for runtime errors at compilation using a special debugging exe and dummy data.
  • Automatic check for runtime errors at the outset of sampling using a special debugging exe and the real data
  • If no runtime errors are encountered, use of a performance-tuned exe for sampling
  • Data are cached for faster start of sampling when the same data are sampled repeatedly
  • Nicer progress indicators including estimated time remaining and parsimonious display of any important error messages.
  • Progress indication including diagnostics
  • During-sampling redirection of output to an ArviZ-compliant file format, enabling faster post-sampling access to the output data as well as during-sampling monitoring of diagnostics & posterior samples.

Features under development

  • Diagnostics-driven sampling, whereby the model performance is monitored on a variety of criteria (divergences encountered, rhats, ESS; also standard sample-count as well as wall-time) and terminates only when those criteria are met.
  • Resuming sampling of unexpectedly-terminated chains.
  • When compiling performance exe: Moving transformed parameters to model block and removing generated quantities entirely. This yields slightly less compute/write time at the cost of requiring a subsequent aria::generate_quantities() run.
  • aria::generate_quantities(), which extracts the quantities that would have been computed/saved by the code as written but moved/removed from the performance exe, puts them all in the generated quantities, compiles and runs with just the post-warmup (and possibly thinned) samples.
  • Automated SBC by extracting user-supplied parameters (and their down-stream dependencies; requires models be written in generative order) and placing them in GQ with _rng functions replacing their priors.

Glaring omissions

  • cross-platform support; aria currently should run on unix-like systems (Linux, MacOS, WSL) but certainly won't work on Windows yet.
  • handle within-chain parallelizing compile arguments; this should be easy, I just never use these, being on a meagre 4-core myself
  • tests; 😬

How to use aria

aria uses features that were introduced in cmstan 2.27.0, so if you haven't grabbed that yet, you need to run:

cmdstanr::install_cmdstan(version='2.27.0')

When first opening a project, run:

aria::enable_rstudio_syntax_compile()

This enables the enhanced "Check on Save" in RStudio. You can then work on your stan file in rstudio, saving as you go to obtain syntax checks from the stanc3 syntax checker.

When you're ready to compile, add

/*
aria: compile = 1
*/

at the top of the file and save, at which point an RStudio Job will launch to show you progress of a multi-step compilation. By default, compilation will involve:

  1. Syntax check
  2. Compiling a "debugging" executable
  3. Running a "debug check" using this executable and automatically generated synthetic data
  4. Compiling a "performance" executable for use in sampling

At this point you might hear auditory feedback on successful completion or failure. To suppress these sounds, run:

options('aria_sotto_vocce'=TRUE)

There are other comments you can put inside the aria comment block above, enabling modification of the syntax-checking and compilation process away from the defaults enumerated above:

  • aria: syntax_ignore += The parameter iZc has no priors.

Any line starting with aria: syntax_ignore += tells the syntax checker to ignore any warning matching the rest of the line's content. (n.b. fuzzy-matching from prior aria versions has been removed in favor of printing messages when warnings are encountered.) This option only really makes sense to need when you're trying to compile (syntax errors block compilation) and have already verified that the warning is not taking into account your more advanced Stan code.

  • aria: compile_debug = 0

A value of 0 prevents the default compilation of a "debug" exe. Recommended that you not use this unless you're really pressed for time.

  • aria: run_debug = 0

A value of 0 prevents the default running of the debug exe (if compiled) using auto-generated synthetic data. For use with models that have complicated structure that foils the synthetic data generator and yield nan's in the target.

  • aria: make_local += STAN_NO_RANGE_CHECKS=true

Lines prefaced with aria: make_local += have the remainder of their line content appended as newlines to a make/local makefile used for compilation of the performance exe. If there are no such lines, the existing make/local is used. If there are such lines and a make/local exists, it is temporarily moved and the lines specified in the .stan file are added to a temporary empty make/local. A maximally-optimized make/local would be used via:

aria: make_local += CXXFLAGS+=-O3
aria: make_local += CXXFLAGS+=-g0
aria: make_local += STAN_NO_RANGE_CHECKS=true
aria: make_local += STAN_CPP_OPTIMS=true
aria: make_local += STANCFLAGS+=--Oexperimental

Now that your model is compiled, you can use aria::compose() to sample:

# compose a posterior given data and model
#   Note:
#     - pipe-compatible data-first arguments
#     - we pass the path to the Stan code; aria will go find the exe
aria::compose( 
	data = my_data
	, code_path = 'stan/my_mod.stan' 
	, out_path = 'sampled/my_data_my_mod_out.nc'
)

This will return NULL invisibly but launch sampling in the background with an RStudio Job to monitor the progress, including during-sampling diagnostics. During sampling, the CSVs generated by cmdstan are parsed and output is stored in a NetCDF4 file (hence the .nc extension in the above example). There is also some information (timing, messages from cmdstan) stored in a file at aria/marginalia.qs that can be viewed via:

aria::marginalia()

The output of sampling can be accessed via:

post = aria::coda( out_path = 'sampled/my_data_my_mod_out.nc')

Which initializes an R6 object with pointers to the pertinent internals of the NetCDF4 file. Ultimately I'll be working toward making said internals compliant with the InferenceData spec, but for now you can view what's there via:

post$nc_info()

And get rvar representations via:

post$draws() #all variables
post$draws('mu') #just mu
post$draws(variables=c('mu','sigma')) #mu & sigma
post$draws(groups='parameters') #just the parameters (i.e. no TP or GQ)

#plays-well with posterior:
post$draws('mu') %>% posterior::summarise_draws()

Finally, here's some code to run some diagnostics and summarize the posterior of each variable:

library(tidyverse)


# Check treedepth, divergences, & rebfmi
(
	post$draws(group='sample_stats')
	%>% posterior::as_draws_df()
	%>% group_by(.chain)
	%>% summarise(
		max_treedepth = max(treedepth)
		, num_divergent = sum(divergent)
		, rebfmi = var(energy)/(sum(diff(energy)^2)/n()) #n.b. reciprocal of typical EBFMI, so bigger=bad, like rhat
	)
)

# gather summary for core parameters (inc. rhat & ess)
(
	post$draws(groups='parameters')
	%>% posterior::summarise_draws()
) ->
	fit_summary

# check the range of rhat & ess
(
	fit_summary
	%>% select(rhat,contains('ess'))
	%>% summary()
)

And here's a nice diagnostics-and-quantiles viz that is unrelated to this package but I like and wanted to share (code below): image

(
	fit_summary
	%>% filter(str_starts(variable,fixed('z_m.')))
	%>% ggplot()
	+ geom_hline(yintercept = 0)
	+ geom_linerange(
		mapping = aes(
			x = variable
			, ymin = q2.5
			, ymax = q97.5
			, colour = ess_tail
		)
	)
	+ geom_linerange(
		mapping = aes(
			x = variable
			, ymin = q25
			, ymax = q75
			, colour = ess_bulk
		)
		, size = 3
	)
	+ geom_point(
		mapping = aes(
			x = variable
			, y = q50
			, fill = rhat
		)
		, shape = 21
		, size = 2
	)
	+ coord_flip()
	+ scale_color_gradient(
		high = 'white'
		, low = scales::muted('red')
	)
	+ scale_fill_gradient(
		low = 'white'
		, high = scales::muted('red')
	)
	+ labs(
		y = 'Mean'
		, colour = 'ESS'
		, fill = 'Rhat'
	)
	+ theme(
		panel.background = element_rect(fill='grey50')
		, panel.grid = element_blank()
	)
)