MCMC ensemble convergence test¶
[1]:
import sys
import numpy as np
import arviz
import emcee
import warnings
import matplotlib.pyplot as plt
Running an ensemble¶
Lets take the tutorial example of running emcee:
[2]:
def log_prob(x, ivar):
return -0.5 * np.sum(ivar * x ** 2)
ndim, nwalkers = 5, 32
ivar = 1. / np.random.rand(ndim)
p0 = np.random.randn(nwalkers, ndim)
nsteps = 10000
sampler = emcee.EnsembleSampler(nwalkers, ndim, log_prob, args=[ivar])
sampler.run_mcmc(p0, nsteps)
print("Chain shape:", sampler.get_chain().shape)
Chain shape: (10000, 32, 5)
There are two things we need to improve here:
we want to remove some warm-up points from the beginning of each chain. Lets take away the first quarter of the chain.
To reliably test the stationarity of a chain, we need several independent chains that should appear indistinguishable. Because the ensemble proposals entangles the walkers among each other, one ensemble is not enough. We need a few independently run ensembles. Four is usually enough.
[3]:
samplers = []
for i in range(4):
sampler = emcee.EnsembleSampler(nwalkers, ndim, log_prob, args=[ivar])
print("ensemble %d: warm-up ..." % (i+1))
state = sampler.run_mcmc(p0, nsteps // 4)
sampler.reset()
print("ensemble %d: sampling ..." % (i+1))
sampler.run_mcmc(state, nsteps)
samplers.append(sampler)
ensemble 1: warm-up ...
ensemble 1: sampling ...
ensemble 2: warm-up ...
ensemble 2: sampling ...
ensemble 3: warm-up ...
ensemble 3: sampling ...
ensemble 4: warm-up ...
ensemble 4: sampling ...
Convergence testing¶
First we test within each ensemble, that each walker has a short auto-correlation time. Secondly, we check the Geweke drift from the first to last quarter of the chain. These checks are done for each parameter.
[4]:
converged = True
# require chain to be at least 5 auto-correlation lengths long
min_autocorr_times = 5
# Geweke convergence test threshold
geweke_max = 1.0
# whether you already want some plots showing the issue
plot = False
for c, sampler in enumerate(samplers):
print("looking for issues within chain %d ..." % (c+1))
chain = sampler.get_chain()
flat_chain = sampler.get_chain(flat=True)
num_steps, num_walkers, ndim = chain.shape
# 0. analyse each variable
max_autocorrlength = 1
for i in range(ndim):
chain_variable = chain[:, :, i]
# 1. treat each walker as a independent chain
try:
for w in range(num_walkers):
chain_walker = chain_variable[:, w]
with warnings.catch_warnings():
warnings.simplefilter("ignore")
tau = emcee.autocorr.integrated_time(chain_walker, c=5, tol=50, quiet=False)
tau = max(tau, 1)
max_autocorrlength = max(max_autocorrlength, tau)
if len(chain_walker) / tau < min_autocorr_times:
print("autocorrelation is long for parameter %d: tau=%.1f -> %dx lengths" % (i+1, tau, num_steps / tau))
converged = False
# you could plot chain_walker to visualise
break
except emcee.autocorr.AutocorrError:
max_autocorrlength = np.inf
if min_autocorr_times > 0:
print("autocorrelation is too long for parameter %d to be estimated" % (i+1))
converged = False
# you could plot chain_walker to visualise
break
if not converged:
break
# secondly, detect drift with geweke
a = flat_chain[:len(flat_chain) // 4, i]
b = flat_chain[-len(flat_chain) // 4:, i]
geweke_z = (a.mean() - b.mean()) / (np.var(a) + np.var(b))**0.5
if geweke_z > geweke_max:
print("geweke drift (z=%.1f) detected for parameter %d" % (geweke_z, i+1))
converged = False
# you can plot histograms of a and b to visualise
looking for issues within chain 1 ...
looking for issues within chain 2 ...
looking for issues within chain 3 ...
looking for issues within chain 4 ...
The above is just a first smoke-test, if these tests do not succeed, you really are in trouble!
[5]:
chains = np.asarray([sampler.get_chain(flat=True) for sampler in samplers])
rhat = arviz.rhat(arviz.convert_to_dataset(chains)).x.data
[6]:
print("Rhat: %.2f" % rhat.max())
Rhat: 1.00
Interpreting the result¶
You can find out more about the Rhat rank test. As a rule of thumb:
If Rhat is below 1.01, then no convergence problem was detected.
If Rhat is higher, you need to run your ensembles longer.
And yes, I have seen very badly incorrect posteriors with Rhat=1.1.
Combining all posteriors¶
Now we can put all the ensembles together to get very nice posterior sampling.
[7]:
full_chain = np.concatenate([sampler.get_chain(flat=True) for sampler in samplers])
[8]:
full_chain.shape
[8]:
(1280000, 5)
Automating this¶
The autoemcee package implements a wrapper for emcee (Affine-invariante Ensemble Sampling) and zeus (Ensemble Slice Sampling).
It keeps increasing the number of MCMC steps until no convergence issues are found. The ensembles are run and initialised separately to be conservative.