MCMC ensemble convergence test

[1]:
import sys
import numpy as np
import arviz
import emcee
import warnings
import matplotlib.pyplot as plt

Running an ensemble

Lets take the tutorial example of running emcee:

[2]:
def log_prob(x, ivar):
    return -0.5 * np.sum(ivar * x ** 2)

ndim, nwalkers = 5, 32
ivar = 1. / np.random.rand(ndim)
p0 = np.random.randn(nwalkers, ndim)

nsteps = 10000
sampler = emcee.EnsembleSampler(nwalkers, ndim, log_prob, args=[ivar])
sampler.run_mcmc(p0, nsteps)

print("Chain shape:", sampler.get_chain().shape)
Chain shape: (10000, 32, 5)

There are two things we need to improve here:

  1. we want to remove some warm-up points from the beginning of each chain. Lets take away the first quarter of the chain.

  2. To reliably test the stationarity of a chain, we need several independent chains that should appear indistinguishable. Because the ensemble proposals entangles the walkers among each other, one ensemble is not enough. We need a few independently run ensembles. Four is usually enough.

[3]:
samplers = []

for i in range(4):
    sampler = emcee.EnsembleSampler(nwalkers, ndim, log_prob, args=[ivar])
    print("ensemble %d: warm-up ..." % (i+1))
    state = sampler.run_mcmc(p0, nsteps // 4)
    sampler.reset()
    print("ensemble %d: sampling ..." % (i+1))
    sampler.run_mcmc(state, nsteps)
    samplers.append(sampler)
ensemble 1: warm-up ...
ensemble 1: sampling ...
ensemble 2: warm-up ...
ensemble 2: sampling ...
ensemble 3: warm-up ...
ensemble 3: sampling ...
ensemble 4: warm-up ...
ensemble 4: sampling ...

Convergence testing

First we test within each ensemble, that each walker has a short auto-correlation time. Secondly, we check the Geweke drift from the first to last quarter of the chain. These checks are done for each parameter.

[4]:
converged = True

# require chain to be at least 5 auto-correlation lengths long
min_autocorr_times = 5

# Geweke convergence test threshold
geweke_max = 1.0

# whether you already want some plots showing the issue
plot = False

for c, sampler in enumerate(samplers):
    print("looking for issues within chain %d ..." % (c+1))
    chain = sampler.get_chain()
    flat_chain = sampler.get_chain(flat=True)
    num_steps, num_walkers, ndim = chain.shape
    # 0. analyse each variable
    max_autocorrlength = 1
    for i in range(ndim):
        chain_variable = chain[:, :, i]
        # 1. treat each walker as a independent chain
        try:
            for w in range(num_walkers):
                chain_walker = chain_variable[:, w]
                with warnings.catch_warnings():
                    warnings.simplefilter("ignore")
                    tau = emcee.autocorr.integrated_time(chain_walker, c=5, tol=50, quiet=False)
                    tau = max(tau, 1)
                max_autocorrlength = max(max_autocorrlength, tau)
                if len(chain_walker) / tau < min_autocorr_times:
                    print("autocorrelation is long for parameter %d: tau=%.1f -> %dx lengths" % (i+1, tau, num_steps / tau))
                    converged = False
                    # you could plot chain_walker to visualise
                    break
        except emcee.autocorr.AutocorrError:
            max_autocorrlength = np.inf
            if min_autocorr_times > 0:
                print("autocorrelation is too long for parameter %d to be estimated" % (i+1))
                converged = False
                # you could plot chain_walker to visualise
                break

        if not converged:
            break

        # secondly, detect drift with geweke
        a = flat_chain[:len(flat_chain) // 4, i]
        b = flat_chain[-len(flat_chain) // 4:, i]
        geweke_z = (a.mean() - b.mean()) / (np.var(a) + np.var(b))**0.5
        if geweke_z > geweke_max:
            print("geweke drift (z=%.1f) detected for parameter %d" % (geweke_z, i+1))
            converged = False
            # you can plot histograms of a and b to visualise

looking for issues within chain 1 ...
looking for issues within chain 2 ...
looking for issues within chain 3 ...
looking for issues within chain 4 ...

The above is just a first smoke-test, if these tests do not succeed, you really are in trouble!

[5]:
chains = np.asarray([sampler.get_chain(flat=True) for sampler in samplers])

rhat = arviz.rhat(arviz.convert_to_dataset(chains)).x.data
[6]:
print("Rhat: %.2f" % rhat.max())
Rhat: 1.00

Interpreting the result

You can find out more about the Rhat rank test. As a rule of thumb:

If Rhat is below 1.01, then no convergence problem was detected.

If Rhat is higher, you need to run your ensembles longer.

And yes, I have seen very badly incorrect posteriors with Rhat=1.1.

Combining all posteriors

Now we can put all the ensembles together to get very nice posterior sampling.

[7]:
full_chain = np.concatenate([sampler.get_chain(flat=True) for sampler in samplers])
[8]:
full_chain.shape
[8]:
(1280000, 5)

Automating this

The autoemcee package implements a wrapper for emcee (Affine-invariante Ensemble Sampling) and zeus (Ensemble Slice Sampling).

It keeps increasing the number of MCMC steps until no convergence issues are found. The ensembles are run and initialised separately to be conservative.