Setting up

The examples are run using the SBC R package. - consult the Getting Started with SBC vignette for basics of the package.

knitr::opts_chunk$set(cache = TRUE)
library(SBC)
library(ggplot2)
library(patchwork)
library(tidyverse)
library(cmdstanr)
theme_set(cowplot::theme_cowplot())

options(mc.cores = parallel::detectCores(), SBC.min_chunk_size = 5)

library(future)
plan(multisession)

cache_dir <- "./_SBC_cache_ordered_simplex"

fig_dir <- "./_figs" 

if(!dir.exists(cache_dir)) {
  dir.create(cache_dir)
}
if(!dir.exists(fig_dir)) {
  dir.create(fig_dir)
}

devtools::load_all()

hist_plot_width <- 8
hist_plot_height <- 4

We recall that the model works over the ordered simplex:

\[ \text{OrdSimplex}_K = \{\mathbf{x} \in \mathbb{R}^K | 0 < x_1 < \ldots < x_K < 1, \sum_{i=1}^K x_i = 1 \} \]

and the full model is

\[ \begin{align} \mathbf{x} &\in \text{OrdSimplex}_4, \pi(\mathbf{x}) \propto \text{Dirichlet(2, 2, 2, 2)} \\ \mathbf{y} &\sim \text{Multinomial(10, x)} \end{align} \]

To generate data we note that due to the prior being symmetric over the unrestricted simplex, we can sample from the prior by taking a draw from the Dirichlet distribution and ordering it (if the prior was not symmetrical, some form of rejection sampling would be necessary).

The code to generate datasets is below:

generate_one_dataset <- function(N, K, prior_alpha = 1) {
  x_raw <- MCMCpack::rdirichlet(1, alpha = rep(prior_alpha, K))
  x <- sort(x_raw)
  observed <- as.integer(rmultinom(1, size = N, prob = x))
  
  list(
    variables = list(x = x),
    generated = list(K = K, observed = observed, prior_alpha = rep(prior_alpha, K))
  )
}

set.seed(56823974)
ds_long <- generate_datasets(
    SBC_generator_function(generate_one_dataset, N = 10, K = 4, prior_alpha = 3),
    n_sims = 6000)

ds <- ds_long[1:1000]

We will use 1000 datasets (the ds variable) for most checks, but for detailed investigations, we’ll also use the ds_long version with 6000 datasets.

Additionally, we define derived test quantities for the log prior and the log likelihood:

log_ddirichlet <- function(x, alpha) {
  -sum(lgamma(alpha)) + lgamma(sum(alpha)) + sum((alpha - 1) * log(x))
}

dq <- derived_quantities(log_lik = dmultinom(observed, prob = x, log = TRUE),
                         log_prior = log_ddirichlet(x, prior_alpha), 
                         .globals = "log_ddirichlet")

We will not repeat the mathematical description of the individual variants, please refer to the paper.

Min

The Stan code for the min variant of the model is:

cat(readLines("stan/ordered_simplex_min.stan"), sep = "\n")


functions {
 //Input: vector of numbers constrained to [0,1]
 vector ordered_simplex_constrain_min_lp(vector u) {
    int Km1 = rows(u);
    vector[Km1 + 1] x;
    real remaining = 1; // Remaining amount to be distributed
    real base = 0; // The minimum for the next element
    for(i in 1:Km1) {
      if(u[i] <= 0 || u[i] >= 1) {
        reject("All elements of u have to be in [0,1]");
      }
      int K_prime = Km1 + 2 - i; // Number of remaining elements
      //First constrain to [0; remaining / K_prime]
      real x_cons = remaining * inv(K_prime) * u[i];
      // Jacobian for the constraint
      target += log(remaining) - log(K_prime);

      x[i] = base + x_cons;
      base = x[i];
      //We added  x_cons to each of the K_prime elements yet to be processed
      //remaining -= x_cons * K_prime;
      remaining *= 1 - u[i];
    }
    x[Km1 + 1] = base + remaining;

    return x;
 }
}
data {
  int K;
  array[K] int<lower=0> observed;
  vector<lower=0>[K] prior_alpha;
}



parameters {
  vector<lower=0, upper=1>[K - 1] u;
}

transformed parameters {
  simplex[K] x = ordered_simplex_constrain_min_lp(u);
}

model {
  x ~ dirichlet(prior_alpha);
  observed ~ multinomial(x);
}

Compile the model, build the backend

m_min <- cmdstan_model("stan/ordered_simplex_min.stan")
## Model executable is up to date!
backend_min <- SBC_backend_cmdstan_sample(m_min, chains = 2)

Run SBC

res_min <- compute_SBC(ds, backend_min, keep_fits = FALSE, dquants = dq,
                              cache_location = file.path(cache_dir, "ordered_simplex_min.rds"),
                              cache_mode = "results")
## Results loaded from cache file 'ordered_simplex_min.rds'
##  - 1 (0%) fits had at least one Rhat > 1.01. Largest Rhat was 1.012.
##  - 1000 (100%) fits had some steps rejected. Maximum number of rejections was 15.
## Not all diagnostics are OK.
## You can learn more by inspecting $default_diagnostics, $backend_diagnostics 
## and/or investigating $outputs/$messages/$warnings for detailed output from the backend.
plot_rank_hist(res_min)

plot_ecdf_diff(res_min)

We see that there are no problems apparent after 1000 simulations.

We however also see in the plot below, that the data are not very informative about any of the model parameters.

plot_sim_estimated(res_min, alpha = 0.2)

Softmax Bad

Now, we’ll test the incorrect version of the softmax approach. The Stan code is:

cat(readLines("stan/ordered_simplex_softmax_bad.stan"), sep = "\n")
functions {
  vector ordered_simplex_constrain_softmax_lp(vector v) {
     int K = size(v) + 1;
     vector[K] v0 = append_row(0, v);
     // Jacobian
     target += sum(v) - (K - 1) * log_sum_exp(v0);
     return softmax(v0);
  }
}

data {
  int K;
  array[K] int<lower=0> observed;
  vector<lower=0>[K] prior_alpha;
}


parameters {
  positive_ordered[K - 1] v;
}

transformed parameters {
  simplex[K] x =  ordered_simplex_constrain_softmax_lp(v);
}

model {
  x ~ dirichlet(prior_alpha);
  observed ~ multinomial(x);
}

Compile the model, build the backend

m_softmax_bad <- cmdstan_model("stan/ordered_simplex_softmax_bad.stan")
## Model executable is up to date!
backend_softmax_bad <- SBC_backend_cmdstan_sample(m_softmax_bad, chains = 2)

Run SBC (we’re using ds_long to show some long-run behaviour)

res_softmax_bad <- compute_SBC(ds_long, backend_softmax_bad, keep_fits = FALSE, dquants = dq,
                              cache_location = file.path(cache_dir, "ordered_simplex_softmax_bad.rds"),
                              cache_mode = "results")
## Results loaded from cache file 'ordered_simplex_softmax_bad.rds'
##  - 45 (1%) fits had at least one Rhat > 1.01. Largest Rhat was 1.018.
##  - 9 (0%) fits had divergent transitions. Maximum number of divergences was 1.
##  - 30 (0%) fits had some steps rejected. Maximum number of rejections was 1.
## Not all diagnostics are OK.
## You can learn more by inspecting $default_diagnostics, $backend_diagnostics 
## and/or investigating $outputs/$messages/$warnings for detailed output from the backend.
plot_rank_hist(res_softmax_bad[1:1000])

plot_ecdf_diff(res_softmax_bad[1:1000])

We see that the true value vs. fitted posterior is very quite to the correct case - any single simulation is likely to get an OK-ish recovery of the true parameters and so would be unlikely to discover the problem.

plot_sim_estimated(res_softmax_bad, alpha = 0.2)

Here we show the history of the gamma statistic for various quantities. Eventually all quantities detect the problem, but note the different horizontal axis between top row (quantities that detect the problem quickly) and bottom row (quantities that detect the problem slowly). The vertical red dashed line marks 400 simulations.

shared_mark <- geom_vline(color = "red", linetype = "dashed", xintercept = 400)
ylim <-  c(-28, 5)
plot_softmax_bad_quick <- plot_log_gamma_history(res_softmax_bad[1:400], variables_regex = "log_prior|x\\[1|3|4", ylim = ylim) + theme(axis.title = element_blank()) + shared_mark

plot_softmax_bad_slow <- plot_log_gamma_history(res_softmax_bad[1:3000], variables_regex = "log_lik|x\\[(2)\\]", ylim = ylim) + theme(axis.title = element_blank()) + shared_mark

#axis title: https://stackoverflow.com/questions/65291723/merging-two-y-axes-titles-in-patchwork
p_label <- ggplot(data.frame(l = "Log Gamma - Threshold", x = 1, y = 1)) +
      geom_text(aes(x, y, label = l), angle = 90, size = 5) + 
      theme_void() +
      coord_cartesian(clip = "off")

p_hist_softmax_bad <- p_label + (plot_softmax_bad_quick / plot_softmax_bad_slow) + plot_layout(widths = c(0.4, 25))
p_hist_softmax_bad

ggsave(file.path(fig_dir, "hist_softmax_bad.pdf"), p_hist_softmax_bad, width = 8, height = 3)

Softmax - corrected

Now, the correct version of the softmax approach. The Stan code is:

cat(readLines("stan/ordered_simplex_softmax.stan"), sep = "\n")
functions {
  vector ordered_simplex_constrain_softmax_lp(vector v) {
     int K = size(v) + 1;
     vector[K] v0 = append_row(0, v);
     // Jacobian
     target += sum(v) - K * log_sum_exp(v0);
     return softmax(v0);
  }
}

data {
  int K;
  array[K] int<lower=0> observed;
  vector<lower=0>[K] prior_alpha;
}


parameters {
  positive_ordered[K - 1] v;
}

transformed parameters {
  simplex[K] x =  ordered_simplex_constrain_softmax_lp(v);
}

model {
  x ~ dirichlet(prior_alpha);
  observed ~ multinomial(x);
}

(the only change is using K instead of K - 1 on line 6.

Compile the model, build the backend

m_softmax <- cmdstan_model("stan/ordered_simplex_softmax.stan")
## Model executable is up to date!
backend_softmax <- SBC_backend_cmdstan_sample(m_softmax, chains = 2)

Run SBC

res_softmax <- compute_SBC(ds, backend_softmax, keep_fits = FALSE, dquants = dq,
                              cache_location = file.path(cache_dir, "ordered_simplex_softmax.rds"),
                              cache_mode = "results")
## Results loaded from cache file 'ordered_simplex_softmax.rds'
##  - 4 (0%) fits had at least one Rhat > 1.01. Largest Rhat was 1.012.
##  - 2 (0%) fits had divergent transitions. Maximum number of divergences was 1.
##  - 9 (1%) fits had some steps rejected. Maximum number of rejections was 1.
## Not all diagnostics are OK.
## You can learn more by inspecting $default_diagnostics, $backend_diagnostics 
## and/or investigating $outputs/$messages/$warnings for detailed output from the backend.
plot_rank_hist(res_softmax)

plot_ecdf_diff(res_softmax)

plot_sim_estimated(res_softmax, alpha = 0.2)

Gamma

Finally the gamma variant. The Stan code is:

cat(readLines("stan/ordered_simplex_gamma.stan"), sep = "\n")
data {
  int K;
  array[K] int<lower=0> observed;
  vector<lower=0>[K] prior_alpha;
}

parameters {
  positive_ordered[K] w;
}

transformed parameters {
  simplex[K] x = w / sum(w);
}

model {
  w ~ gamma(prior_alpha, 1);
  observed ~ multinomial(x);
}

Compile the model, build the backend

m_gamma <- cmdstan_model("stan/ordered_simplex_gamma.stan")
## Model executable is up to date!
backend_gamma <- SBC_backend_cmdstan_sample(m_gamma, chains = 2)

Run SBC

res_gamma <- compute_SBC(ds, backend_gamma, keep_fits = FALSE, dquants = dq,
                              cache_location = file.path(cache_dir, "ordered_simplex_gamma.rds"),
                              cache_mode = "results")
## Results loaded from cache file 'ordered_simplex_gamma.rds'
##  - 3 (0%) fits had at least one Rhat > 1.01. Largest Rhat was 1.013.
##  - 34 (3%) fits had some steps rejected. Maximum number of rejections was 1.
## Not all diagnostics are OK.
## You can learn more by inspecting $default_diagnostics, $backend_diagnostics 
## and/or investigating $outputs/$messages/$warnings for detailed output from the backend.
plot_rank_hist(res_gamma)

plot_ecdf_diff(res_gamma)

plot_sim_estimated(res_gamma, alpha = 0.2)

Performance

SBC gave us a simulation study for free, so let us examine some performance characteristics (for the correct implementations only):

all_results <- list("softmax" = res_softmax, 
                    "min" = res_min,
                    "gamma" = res_gamma)

perf_from_result <- function(res, variant) {
  ess_res <- res$stats %>% 
    filter(grepl("^x", variable)) %>%
    group_by(sim_id) %>%
    summarise(min_x_ess = min(ess_bulk))
  stats <- res$backend_diagnostics %>% 
    inner_join(res$default_diagnostics, by = "sim_id") %>%
    inner_join(ess_res, by = "sim_id")
  stopifnot(identical(stats$sim_id, res$backend_diagnostics$sim_id))
  stats$variant <- variant
  stats
}

performance_data <- all_results %>% imap_dfr(perf_from_result) %>%
  mutate(ess_per_time = min_x_ess / max_chain_time)
performance_data %>% ggplot(aes(x = ess_per_time)) + geom_histogram(bins = 50) + facet_wrap(~variant, ncol = 1)

performance_data %>% group_by(variant) %>%
  mutate(high_rhat = max_rhat > 1.01, divergences = n_divergent > 0, 
         non_converged = high_rhat | divergences) %>%
  summarise(`Mean ESS per s` = mean(ess_per_time), `High Rhat` = scales::percent(mean(high_rhat), accuracy = 0.1),
           `Divergent transitions` = scales::percent(mean(divergences), accuracy = 0.1),
           `Any convergence problem` = scales::percent(mean(non_converged), accuracy = 0.1) 
                                                     )
## # A tibble: 3 × 5
##   variant `Mean ESS per s` `High Rhat` `Divergent transitions` `Any convergence problem`
##   <chr>              <dbl> <chr>       <chr>                   <chr>                    
## 1 gamma              6295. 0.3%        0.0%                    0.3%                     
## 2 min               10491. 0.1%        0.0%                    0.1%                     
## 3 softmax            4304. 0.4%        0.2%                    0.6%