Source code for autoemcee

"""Calculates the Bayesian evidence and posterior samples of arbitrary monomodal models."""

from __future__ import print_function
from __future__ import division

import os
import sys
import logging
import warnings
import corner

import numpy as np

import emcee
import arviz

__all__ = ['ReactiveAffineInvariantSampler']
__author__ = """Johannes Buchner"""
__email__ = 'johannes.buchner.acad@gmx.com'
__version__ = '0.4.0'


# Some parts are from the nnest library by Adam Moss (https://github.com/adammoss/nnest)
def create_logger(module_name, log_dir=None, level=logging.INFO):
    """
    Set up the logging channel `module_name`.

    Append to ``debug.log`` in `log_dir` (if not ``None``).
    Write to stdout with output level `level`.

    If logging handlers are already registered, no new handlers are
    registered.
    """
    logger = logging.getLogger(str(module_name))
    logger.setLevel(logging.DEBUG)
    first_logger = logger.handlers == []
    if log_dir is not None and first_logger:
        # create file handler which logs even debug messages
        handler = logging.FileHandler(os.path.join(log_dir, 'debug.log'))
        formatter = logging.Formatter(
            '%(asctime)s [{}] [%(levelname)s] %(message)s'.format(module_name),
            datefmt='%H:%M:%S')
        handler.setFormatter(formatter)
        handler.setLevel(logging.DEBUG)
        logger.addHandler(handler)
    if first_logger:
        # if it is new, register to write to stdout
        handler = logging.StreamHandler(sys.stdout)
        handler.setLevel(level)
        formatter = logging.Formatter('[{}] %(message)s'.format(module_name))
        handler.setFormatter(formatter)
        #logger.setLevel(level)
        logger.addHandler(handler)

    return logger


def vectorize(function):
    """Vectorize likelihood or prior_transform function."""
    def vectorized(args):
        """ vectorized version of function"""
        return np.asarray([function(arg) for arg in args])

    vectorized.__name__ = function.__name__
    return vectorized


[docs]class ReactiveAffineInvariantSampler(object): """Emcee sampler with reactive exploration strategy.""" def __init__(self, param_names, loglike, transform=None, num_test_samples=2, vectorized=False, sampler='goodman-weare', ): """Initialise sampler. Parameters ----------- param_names: list of str, names of the parameters. Length gives dimensionality of the sampling problem. loglike: function log-likelihood function. Receives multiple parameter vectors, returns vector of likelihood. transform: function parameter transform from unit cube to physical parameters. Receives multiple cube vectors, returns multiple parameter vectors. sampler: str if 'goodman-weare': use Goodman & Weare's affine invariant MCMC ensemble sampler if 'slice': use Karamanis & Beutler (2020)'s ensemble slice sampler vectorized: bool if true, likelihood and transform receive arrays of points, and return arrays num_test_samples: int test transform and likelihood with this number of random points for errors first. Useful to catch bugs. """ self.paramnames = param_names x_dim = len(self.paramnames) self.sampler = 'reactive-importance' self.x_dim = x_dim self.ncall = 0 self.use_mpi = False self.sampler = sampler if sampler not in ('goodman-weare', 'slice'): raise ValueError("sampler needs to be one of ('goodman-weare', 'slice')") try: from mpi4py import MPI self.comm = MPI.COMM_WORLD self.mpi_size = self.comm.Get_size() self.mpi_rank = self.comm.Get_rank() if self.mpi_size > 1: self.use_mpi = True self._setup_distributed_seeds() except Exception: self.mpi_size = 1 self.mpi_rank = 0 self.log = self.mpi_rank == 0 self.logger = create_logger('autoemcee') if not vectorized: loglike = vectorize(loglike) if transform is not None: transform = vectorize(transform) self.ncall = 0 self._set_likelihood_function(transform, loglike, num_test_samples) def _set_likelihood_function(self, transform, loglike, num_test_samples, make_safe=False): """Store the transform and log-likelihood functions. Tests with `num_test_samples` whether they work and give the correct output. if make_safe is set, make functions safer by accepting misformed return shapes and non-finite likelihood values. """ # do some checks on the likelihood function # this makes debugging easier by failing early with meaningful errors # test with num_test_samples random points u = np.random.uniform(size=(num_test_samples, self.x_dim)) p = transform(u) if transform is not None else u assert p.shape == (num_test_samples, self.x_dim,), ("Error in transform function: returned shape is %s, expected %s" % (p.shape, self.x_dim)) logl = loglike(p) assert np.logical_and(u > 0, u < 1).all(), ("Error in transform function: u was modified!") assert logl.shape == (num_test_samples,), ("Error in loglikelihood function: returned shape is %s, expected %s" % (logl.shape, num_test_samples)) assert np.isfinite(logl).all(), ("Error in loglikelihood function: returned non-finite number: %s for input u=%s p=%s" % (logl, u, p)) self.loglike = loglike if transform is None: self.transform = lambda x: x else: self.transform = transform def _setup_distributed_seeds(self): if not self.use_mpi: return seed = 0 if self.mpi_rank == 0: seed = np.random.randint(0, 1000000) seed = self.comm.bcast(seed, root=0) if self.mpi_rank > 0: # from http://arxiv.org/abs/1005.4117 seed = int(abs(((seed * 181) * ((self.mpi_rank - 83) * 359)) % 104729)) # print('setting seed:', self.mpi_rank, seed) np.random.seed(seed) def _emcee_logprob(self, u): mask = np.logical_and((u > 0).all(axis=1), (u < 1).all(axis=1)) L = -np.inf * np.ones(len(u)) p = self.transform(u[mask, :]) L[mask] = self.loglike(p) return L
[docs] def find_starting_walkers(self, num_global_samples, num_walkers): assert num_global_samples > num_walkers, (num_global_samples, num_walkers) ndim, loglike, transform = self.x_dim, self.loglike, self.transform if self.log: self.logger.debug("global sampling for starting point ...") u = np.random.uniform(size=(num_global_samples, ndim)) p = transform(u) L = loglike(p) # find indices of the highest likelihood ones i = np.argsort(L)[::-1][:num_walkers] return u[i, :], p[i, :], L[i]
[docs] def run(self, num_global_samples=10000, num_chains=4, num_walkers=None, max_ncalls=1000000, max_improvement_loops=4, num_initial_steps=100, min_autocorr_times=0, rhat_max=1.01, geweke_max=2., progress=True): """Sample until MCMC chains have converged. The steps are: 1. Draw *num_global_samples* from prior. The highest *num_walkers* points are selected. 2. Set *num_steps* to *num_initial_steps* 3. Run *num_chains* MCMC ensembles for *num_steps* steps 4. For each walker chain, compute auto-correlation length (Convergence requires *num_steps*/autocorrelation length > *min_autocorr_times*) 5. For each parameter, compute geweke convergence diagnostic (Convergence requires |z| < geweke_max) 6. For each ensemble, compute gelman-rubin rank convergence diagnostic (Convergence requires rhat<rhat_max) 7. If converged, stop and return results. 8. Increase *num_steps* by 10, and repeat from (3) up to *max_improvement_loops* times. Parameters ---------- num_global_samples: int Number of samples to draw from the prior to num_chains: int Number of independent ensembles to run. If running with MPI, this is set to the number of MPI processes. num_walkers: int Ensemble size. If None, max(100, 4 * dim) is used max_ncalls: int Maximum number of likelihood function evaluations num_initial_steps: int Number of sampler steps to take in first iteration max_improvement_loops: int Number of times MCMC should be re-attempted (see above) min_autocorr_times: float if positive, additionally require for convergence that the number of samples is larger than the *min_autocorr_times* times the autocorrelation length. geweke_max: float Maximum absolute z-score of the geweke test allowed for convergence. rhat_max: float Maximum r-hat allowed for convergence. progress: bool if True, show progress bars """ if num_walkers is None: num_walkers = max(100, 4 * self.x_dim) num_steps = num_initial_steps if self.use_mpi: num_chains = self.mpi_size num_chains_here = 1 else: num_chains_here = num_chains if self.log: self.logger.info("finding starting points and running initial %d MCMC steps" % (num_steps)) self.ncall = 0 ncall_here = 0 self.samplers = [] for chain in range(num_chains_here): u, p, L = self.find_starting_walkers(num_global_samples, num_walkers) ncall_here += num_global_samples if self.sampler == 'goodman-weare': sampler = emcee.EnsembleSampler(num_walkers, self.x_dim, self._emcee_logprob, vectorize=True) elif self.sampler == 'slice': import zeus sampler = zeus.EnsembleSampler(nwalkers=num_walkers, ndim=self.x_dim, logprob_fn=self._emcee_logprob, vectorize=True, maxiter=1e10, maxsteps=1e10) self.samplers.append(sampler) sampler.run_mcmc(u, num_steps, progress=self.log and progress) ncall_here += num_walkers ncall_here += getattr(sampler, 'ncall', num_steps * num_walkers) if self.use_mpi: recv_ncall = self.comm.gather(self.ncall, root=0) ncall_here = sum(self.comm.bcast(recv_ncall, root=0)) assert ncall_here > 0, ncall_here self.ncall += ncall_here for it in range(max_improvement_loops): if self.log: self.logger.debug("checking convergence (iteration %d) ..." % (it+1)) converged = True # check state of chains: for sampler in self.samplers: chain = sampler.get_chain() assert chain.shape == (num_steps, num_walkers, self.x_dim), (chain.shape, (num_steps, num_walkers, self.x_dim)) accepts = (chain[1:, :, :] != chain[:-1, :, :]).any(axis=2).sum(axis=0) assert accepts.shape == (num_walkers,) if self.log: i = np.argsort(accepts) self.logger.debug( "acceptance rates: %s%% (worst few)", (accepts[i[:8]] * 100. / (num_steps - 1)).astype(int)) flat_chain = sampler.get_chain(flat=True) # diagnose this chain # 0. analyse each variable max_autocorrlength = 1 for i in range(self.x_dim): chain_variable = chain[:, :, i] # 1. treat each walker as a independent chain try: for w in range(num_walkers): chain_walker = chain_variable[:, w] with warnings.catch_warnings(): warnings.simplefilter("ignore") tau = emcee.autocorr.integrated_time(chain_walker, c=5, tol=50, quiet=False) tau = max(tau, 1) max_autocorrlength = max(max_autocorrlength, tau) if len(chain_walker) / tau < min_autocorr_times: self.logger.debug("autocorrelation is long for parameter '%s': tau=%.1f -> %dx lengths" % (self.paramnames[i], tau, num_steps / tau)) converged = False break except emcee.autocorr.AutocorrError: max_autocorrlength = np.inf if min_autocorr_times > 0: self.logger.debug("autocorrelation is too long for parameter '%s' to be estimated" % (self.paramnames[i])) converged = False break if not converged: break # secondly, detect drift with geweke a = flat_chain[:len(flat_chain) // 4, i] b = flat_chain[-len(flat_chain) // 4:, i] geweke_z = (a.mean() - b.mean()) / (np.var(a) + np.var(b))**0.5 if geweke_z > geweke_max: self.logger.debug("geweke drift (z=%.1f) detected for parameter '%s'" % (geweke_z, self.paramnames[i])) converged = False self.logger.debug("autocorrelation length: tau=%.1f -> %dx lengths" % (max_autocorrlength, num_steps / max_autocorrlength)) if not converged: break # merge converged across MPI chains if self.use_mpi: recv_converged = self.comm.gather(converged, root=0) converged = all(self.comm.bcast(recv_converged, root=0)) if converged: # finally, gelman-rubin diagnostic on chains chains = np.asarray([sampler.get_chain(flat=True) for sampler in self.samplers]) if self.use_mpi: recv_chains = self.comm.gather(chains, root=0) chains = np.concatenate(self.comm.bcast(recv_chains, root=0)) assert chains.shape == (num_chains, num_steps * num_walkers, self.x_dim), (chains.shape, (num_chains, num_steps * num_walkers, self.x_dim)) rhat = arviz.rhat(arviz.convert_to_dataset(chains)).x.data if self.log: self.logger.info("rhat chain diagnostic: %s (<%.3f is good)", rhat, rhat_max) converged = np.all(rhat < rhat_max) if self.use_mpi: converged = self.comm.bcast(converged, root=0) if converged: if self.log: self.logger.info("converged!!!") break if self.ncall > max_ncalls: if self.log: self.logger.warning("maximum number of likelihood calls reached") break if self.log: self.logger.info("not converged yet at iteration %d after %d evals" % (it + 1, self.ncall)) #self.logger.error("error at iteration %d" % (it+1)) last_num_steps = num_steps num_steps = int(last_num_steps * 10) next_ncalls = ncall_here * 10 if next_ncalls > max_ncalls: if self.log: self.logger.warning("would need more likelihood calls (%d) than maximum (%d) for next step" % (next_ncalls, max_ncalls)) break self.logger.debug("expected memory usage: %.2f GiB" % (num_chains * num_steps * num_walkers * self.x_dim * 4 / 1024**3)) if num_chains * num_steps * num_walkers * self.x_dim * 4 >= 5 * 1024**3: if self.log: self.logger.warning("would need too much memory for next step") break if self.log: self.logger.info("Running %d MCMC steps ..." % (num_steps)) ncall_here = 0 for sampler in self.samplers: #chain = sampler.get_chain(flat=True) last_samples = sampler.get_chain()[-1, :, :] # get a scale small compared to the width of the current posterior std = np.clip(last_samples.std(axis=0) / (num_walkers * self.x_dim), a_min=1e-30, a_max=1e-1) assert std.shape == (self.x_dim,), std.shape # sample a point from last chain point i = np.ones(num_walkers, dtype=int) * np.random.randint(0, len(last_samples)) self.logger.info("Starting points chosen: %s, L=%.1f", set(i), L.max()) # select points u = np.clip(last_samples[i, :], 1e-10, 1 - 1e-10) # add a bit of noise noise = np.random.normal(0, std, size=(num_walkers, self.x_dim)) u = u + noise # avoid border u = np.clip(u, 1e-10, 1 - 1e-10) assert u.shape == (num_walkers, self.x_dim), (u.shape, (num_walkers, self.x_dim)) if self.log: self.logger.info("Starting at %s +- %s", u.mean(axis=0), u.std(axis=0)) sampler.reset() #self.logger.info("not converged yet at iteration %d" % (it+1)) sampler.run_mcmc(u, last_num_steps, progress=self.log) ncall_here += num_walkers ncall_here += getattr(sampler, 'ncall', last_num_steps * num_walkers) assert ncall_here > 0, ncall_here last_samples = sampler.get_chain()[-1, :, :] assert last_samples.shape == (num_walkers, self.x_dim), (last_samples.shape, (num_walkers, self.x_dim)) sampler.reset() sampler.run_mcmc(last_samples, num_steps, progress=self.log and progress) ncall_here += num_walkers ncall_here += getattr(sampler, 'ncall', num_steps * num_walkers) assert ncall_here > 0, (ncall_here, getattr(sampler, 'ncall', num_steps * num_walkers)) if self.use_mpi: recv_ncall = self.comm.gather(ncall_here, root=0) ncall_here = sum(self.comm.bcast(recv_ncall, root=0)) if self.log: self.logger.info("Used %d calls in last MCMC run", ncall_here) self.ncall += ncall_here if self.transform is None: eqsamples = np.concatenate([sampler.get_chain(flat=True) for sampler in self.samplers]) else: eqsamples = np.concatenate([self.transform(sampler.get_chain(flat=True)) for sampler in self.samplers]) if self.use_mpi: recv_eqsamples = self.comm.gather(eqsamples, root=0) eqsamples = np.concatenate(self.comm.bcast(recv_eqsamples, root=0)) self.results = dict( paramnames=self.paramnames, posterior=dict( mean=eqsamples.mean(axis=0).tolist(), stdev=eqsamples.std(axis=0).tolist(), median=np.percentile(eqsamples, 50, axis=0).tolist(), errlo=np.percentile(eqsamples, 15.8655, axis=0).tolist(), errup=np.percentile(eqsamples, 84.1345, axis=0).tolist(), ), samples=eqsamples, ncall = int(self.ncall), converged = int(converged), ) return self.results
[docs] def print_results(self): "" "Give summary of marginal likelihood and parameters."" " if self.log: print() for i, p in enumerate(self.paramnames): v = self.results['samples'][:, i] sigma = v.std() med = v.mean() if sigma == 0: i = 3 else: i = max(0, int(-np.floor(np.log10(sigma))) + 1) fmt = '%%.%df' % i fmts = '\t'.join([' %-20s' + fmt + " +- " + fmt]) print(fmts % (p, med, sigma))
[docs] def plot(self, **kwargs): if self.log: corner.corner( self.results['samples'], labels=self.results['paramnames'], show_titles=True)