--- title: "checkpointing: brms" output: rmarkdown::html_vignette vignette: > %\VignetteIndexEntry{checkpointing: brms} %\VignetteEngine{knitr::rmarkdown} %\VignetteEncoding{UTF-8} --- ```{r, include = FALSE} knitr::opts_chunk$set( collapse = TRUE, comment = "#>" ) ``` The following examples walk through using **chkptstanr** with the popular `R` package [**brms**](https://paul-buerkner.github.io/brms/). The basic idea is to (1) generate the [**Stan**](https://mc-stan.org/users/interfaces/rstan) code with **brms**, (2) fit the model with [**cmdstanr**](https://mc-stan.org/cmdstanr/) (with the desired number of checkpoints), and then (3) return a `brmsfit` object. This is all done internally, so the workflow is very similar to using **brms**. ## Packages ```r library(chkptstanr) library(posterior) library(bayesplot) library(ggplot2) library(brms) ``` # Example 1: No Stopping ## Model Fitting You can use chkpt_brms in nearly the same way as brm. The only difference is that you need to specify the path to the folder where the checkpoints will be stored, and the number of iterations between each checkpoint. ```r fit1 <- chkpt_brms(count ~ zAge + zBase * Trt + (1|patient), data = epilepsy, family = poisson(), iter_per_chkpt = 200, path = 'checkpoints/epilepsy') ``` When running the above, a custom progress bar is printed that includes information about the checkpoints. ```r #> Initial Warmup (Typical Set) #> Chkpt: 1 / 10; Iteration: 200 / 2000 (warmup) #> Chkpt: 2 / 10; Iteration: 400 / 2000 (warmup) #> Chkpt: 3 / 10; Iteration: 600 / 2000 (warmup) #> Chkpt: 4 / 10; Iteration: 800 / 2000 (warmup) #> Chkpt: 5 / 10; Iteration: 1000 / 2000 (warmup) #> Chkpt: 6 / 10; Iteration: 1200 / 2000 (sample) #> Chkpt: 7 / 10; Iteration: 1400 / 2000 (sample) #> Chkpt: 8 / 10; Iteration: 1600 / 2000 (sample) #> Chkpt: 9 / 10; Iteration: 1800 / 2000 (sample) #> Chkpt: 10 / 10; Iteration: 2000 / 2000 (sample) Checkpointing complete ``` In this case, checkpointing is complete. ## Summary `fit1` is a `brmsfit` object which means that all of the functionality of **brms** can still be used. Here is the summary output: ``` Family: poisson Links: mu = log Formula: count ~ zAge + zBase * Trt + (1 | patient) Data: data (Number of observations: 236) Draws: 4 chains, each with iter = 2000; warmup = 1000; thin = 1; total post-warmup draws = 4000 Multilevel Hyperparameters: ~patient (Number of levels: 59) Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS sd(Intercept) 0.58 0.07 0.45 0.73 1.00 956 1696 Regression Coefficients: Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS Intercept 1.77 0.12 1.53 2.00 1.00 887 1575 zAge 0.10 0.09 -0.07 0.26 1.00 871 1256 zBase 0.70 0.12 0.47 0.94 1.00 986 1675 Trt1 -0.26 0.16 -0.60 0.05 1.01 987 1170 zBase:Trt1 0.05 0.16 -0.26 0.37 1.00 1075 1824 Draws were sampled using sample(hmc). For each parameter, Bulk_ESS and Tail_ESS are effective sample size measures, and Rhat is the potential scale reduction factor on split chains (at convergence, Rhat = 1). ``` ## Posterior Predictive Check Of course, due to being a `brmsfit` object, it is seamless perform a posterior predictive check. ```r pp_check(fit1) ``` ![](../man/figures/pp_check_f1.png) # Example 2: Stoping, restarting, and resetting We can interupt the sampling manually: ```r fit1 <- chkpt_brms(count ~ zAge + zBase * Trt + (1|patient), data = epilepsy, family = poisson(), iter_per_chkpt = 200, path = 'checkpoints/epilepsy') #> Compiling Stan program... #> Initial Warmup (Typical Set) #> Chkpt: 1 / 10; Iteration: 200 / 2000 (warmup) #> Chkpt: 2 / 10; Iteration: 400 / 2000 (warmup) #> Chkpt: 3 / 10; Iteration: 600 / 2000 (warmup) #> Chkpt: 4 / 10; Iteration: 800 / 2000 (warmup) #> Chkpt: 5 / 10; Iteration: 1000 / 2000 (warmup) #> Chkpt: 6 / 10; Iteration: 1200 / 2000 (sample) #> Sampling aborted. You can examine the results or continue sampling by rerunning the same code. ``` If the sampler is passed the warmup stage, it returns a `brmsfit` object, so you can examine the results: ```r summary(fit1) ``` ``` Family: poisson Links: mu = log Formula: count ~ zAge + zBase * Trt + (1 | patient) Data: data (Number of observations: 236) Draws: 4 chains, each with iter = 1200; warmup = 1000; thin = 1; total post-warmup draws = 800 Multilevel Hyperparameters: ~patient (Number of levels: 59) Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS sd(Intercept) 0.57 0.07 0.45 0.72 1.02 209 236 Regression Coefficients: Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS Intercept 1.75 0.11 1.52 1.98 1.02 206 371 zAge 0.10 0.08 -0.07 0.25 1.02 240 443 zBase 0.70 0.11 0.47 0.95 1.00 220 370 Trt1 -0.24 0.16 -0.56 0.07 1.02 199 239 zBase:Trt1 0.05 0.15 -0.26 0.35 1.00 243 350 Draws were sampled using sample(hmc). For each parameter, Bulk_ESS and Tail_ESS are effective sample size measures, and Rhat is the potential scale reduction factor on split chains (at convergence, Rhat = 1). ``` We see that the model has not converged, and we can continue sampling by rerunning the same code. ```r fit1 <- chkpt_brms(count ~ zAge + zBase * Trt + (1|patient), data = epilepsy, family = poisson(), iter_per_chkpt = 200, path = 'checkpoints/epilepsy') #> Model executable is up to date! #> Chkpt: 7 / 10; Iteration: 1400 / 2000 (sample) #> Chkpt: 8 / 10; Iteration: 1600 / 2000 (sample) #> Chkpt: 9 / 10; Iteration: 1800 / 2000 (sample) #> Chkpt: 10 / 10; Iteration: 2000 / 2000 (sample) #> Checkpointing complete ``` And examine the final results: ```r summary(fit1) ``` ``` Family: poisson Links: mu = log Formula: count ~ zAge + zBase * Trt + (1 | patient) Data: data (Number of observations: 236) Draws: 4 chains, each with iter = 2000; warmup = 1000; thin = 1; total post-warmup draws = 4000 Multilevel Hyperparameters: ~patient (Number of levels: 59) Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS sd(Intercept) 0.58 0.07 0.45 0.73 1.00 956 1696 Regression Coefficients: Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS Intercept 1.77 0.12 1.53 2.00 1.00 887 1575 zAge 0.10 0.09 -0.07 0.26 1.00 871 1256 zBase 0.70 0.12 0.47 0.94 1.00 986 1675 Trt1 -0.26 0.16 -0.60 0.05 1.01 987 1170 zBase:Trt1 0.05 0.16 -0.26 0.37 1.00 1075 1824 ``` ### Predetermine stopping point In addition to manually aborting the run, we can predetermine the stopping point by specifying the number of iterations after which to stop the sampler via the `stop_after` argument. ```r fit1 <- chkpt_brms(count ~ zAge + zBase * Trt + (1|patient), data = epilepsy, family = poisson(), iter_per_chkpt = 200, stop_after = 1400, path = 'checkpoints/epilepsy') #> Compiling Stan program... #> Initial Warmup (Typical Set) #> Chkpt: 1 / 10; Iteration: 200 / 2000 (warmup) #> Chkpt: 2 / 10; Iteration: 400 / 2000 (warmup) #> Chkpt: 3 / 10; Iteration: 600 / 2000 (warmup) #> Chkpt: 4 / 10; Iteration: 800 / 2000 (warmup) #> Chkpt: 5 / 10; Iteration: 1000 / 2000 (warmup) #> Chkpt: 6 / 10; Iteration: 1200 / 2000 (sample) #> Chkpt: 7 / 10; Iteration: 1400 / 2000 (sample) #> Sampling aborted. You can examine the results or continue sampling by rerunning the same code. ``` ### Reset sampling If we want to reset the sampling, we can use the `reset` argument, as long as we have not changed any of the key arguments. For example, we can reset the sampling and start from scratch, but we cannot change the formula, data, or family (but we can change "stop_after") ```r fit1 <- chkpt_brms(count ~ zAge + zBase * Trt + (1|patient), data = epilepsy, family = poisson(), iter_per_chkpt = 200, path = 'checkpoints/epilepsy', stop_after = 1600, reset = TRUE) ``` If we try to change the formula, data, or family, we will get an error: ```r fit1 <- chkpt_brms(count ~ 1 + (1|patient), data = epilepsy, family = poisson(), iter_per_chkpt = 200, path = 'checkpoints/epilepsy', stop_after = 1600, reset = TRUE) ``` ``` Error: Important arguments have been changed. Please completely reset the checkpointing via reset_checkpoints(path, recompile = TRUE). Interupted before or during warmup. No samples available. ``` This is because we cannot use the existing compiled model. We need to reset the checkpoints and recompile the model: ```r reset_checkpoints('checkpoints/epilepsy', recompile = TRUE) ```