Source code for ultranest.pathsampler

"""MCMC-like step sampling on a trajectory

These features are experimental.
"""

import matplotlib.pyplot as plt
import numpy as np

from ultranest.flatnuts import (ClockedBisectSampler, ClockedNUTSSampler,
                                ClockedStepSampler, DirectJumper,
                                IntervalJumper, SingleJumper)
from ultranest.samplingpath import (ContourSamplingPath, SamplingPath,
                                    extrapolate_ahead)
from ultranest.stepsampler import (StepSampler, generate_random_direction,
                                   generate_region_oriented_direction,
                                   generate_region_random_direction)


[docs] class SamplingPathSliceSampler(StepSampler): """Slice sampler, respecting the region, on the sampling path. This first builds up a complete trajectory, respecting reflections. Then, from the trajectory a new point is drawn with slice sampling. The trajectory is built by doubling the length to each side and checking if the point is still inside. If not, reflection is attempted with the gradient (either provided or region-based estimate). """ def __init__(self, nsteps): """Initialise sampler. Parameters ----------- nsteps: int number of accepted steps until the sample is considered independent. """ StepSampler.__init__(self, nsteps=nsteps) self.interval = None self.path = None
[docs] def generate_direction(self, ui, region, scale=1): """Choose new initial direction according to region.transformLayer axes.""" return generate_region_oriented_direction(ui, region, tscale=1, scale=scale)
[docs] def adjust_accept(self, accepted, unew, pnew, Lnew, nc): """Adjust proposal given that we have been *accepted* at a new point after *nc* calls.""" if accepted: # start with a new interval next time self.interval = None self.last = unew, Lnew self.history.append((unew, Lnew)) else: self.nrejects += 1 # continue on current interval pass self.logstat.append([accepted, self.scale])
[docs] def adjust_outside_region(self): """Adjust proposal given that we have stepped out of region.""" self.logstat.append([False, self.scale])
[docs] def move(self, ui, region, ndraw=1, plot=False): """Advance by slice sampling on the path.""" if self.interval is None: v = self.generate_direction(ui, region, scale=self.scale) self.path = ContourSamplingPath( SamplingPath(ui, v, 0.0), region) if not (ui > 0).all() or not (ui < 1).all() or not region.inside(ui.reshape((1, -1))): assert False, ui # unit hypercube diagonal gives a reasonable maximum path length maxlength = len(ui)**0.5 # expand direction until it is surely outside left = -1 right = +1 while abs(left * self.scale) < maxlength: xj, vj = self.path.extrapolate(left) if not (xj > 0).all() or not (xj < 1).all() or not region.inside(xj.reshape((1, -1))): break # self.path.add(left, xj, vj, 0.0) left *= 2 while abs(right * self.scale) < maxlength: xj, _ = self.path.extrapolate(right) if not (xj > 0).all() or not (xj < 1).all() or not region.inside(xj.reshape((1, -1))): break # self.path.add(right, xj, vj, 0.0) right *= 2 scale = max(-left, right) # print("scale %f gave %d %d " % (self.scale, left, right)) if scale < 5: self.scale /= 1.1 # if scale > 100: # self.scale *= 1.1 assert self.scale > 1e-10, self.scale self.interval = (left, right, None) else: left, right, mid = self.interval # we rejected mid, and shrink corresponding side if mid < 0: left = mid elif mid > 0: right = mid # shrink direction if outside while True: mid = np.random.randint(left, right + 1) # print("interpolating %d - %d - %d" % (left, mid, right), # self.path.points) if mid == 0: _, xj, _, _ = self.path.points[0] else: xj, _ = self.path.extrapolate(mid) if region.inside(xj.reshape((1, -1))): self.interval = (left, right, mid) return xj.reshape((1, -1)) else: if mid < 0: left = mid else: right = mid self.interval = (left, right, mid)
[docs] class SamplingPathStepSampler(StepSampler): """Step sampler on a sampling path.""" def __init__(self, nresets, nsteps, scale=1.0, balance=0.01, nudge=1.1, log=False): """Initialise sampler. Parameters ------------ nresets: int after this many iterations, select a new direction nsteps: int how many steps to make in total scale: float initial step size balance: float acceptance rate to target if below, scale is increased, if above, scale is decreased nudge: float factor for increasing scale (must be >=1) nudge=1 implies no step size adaptation. """ StepSampler.__init__(self, nsteps=nsteps) # self.lasti = None self.path = None self.nresets = nresets # initial step scale in transformed space self.scale = scale # fraction of times a reject is expected self.balance = balance # relative increase in step scale self.nudge = nudge assert nudge >= 1 self.log = log self.grad_function = None self.istep = 0 self.iresets = 0 self.start() self.terminate_path() self.logstat_labels = ['acceptance rate', 'reflection rate', 'scale', 'nstuck'] def __str__(self): """Get string representation.""" return '%s(nsteps=%d, nresets=%d, AR=%d%%)' % ( type(self).__name__, self.nsteps, self.nresets, (1 - self.balance) * 100)
[docs] def start(self): """Start sampler, reset all counters.""" if hasattr(self, 'naccepts') and self.nrejects + self.naccepts > 0: self.logstat.append([ self.naccepts / (self.nrejects + self.naccepts), self.nreflects / (self.nreflects + self.nrejects + self.naccepts), self.scale, self.nstuck]) self.nrejects = 0 self.naccepts = 0 self.nreflects = 0 self.nstuck = 0 self.istep = 0 self.iresets = 0 self.noutside_regions = 0 self.last = None, None self.history = [] self.direction = +1 self.deadends = set() self.path = None
[docs] def start_path(self, ui, region): """Start new trajectory path.""" # print("new direction:", self.scale, self.noutside_regions, self.nrejects, self.naccepts) v = self.generate_direction(ui, region, scale=self.scale) assert (v**2).sum() > 0, (v, self.scale) assert region.inside(ui.reshape((1, -1))).all(), ui self.path = ContourSamplingPath(SamplingPath(ui, v, 0.0), region) if self.grad_function is not None: self.path.gradient = self.grad_function if not (ui > 0).all() or not (ui < 1).all() or not region.inside(ui.reshape((1, -1))): assert False, ui self.direction = +1 self.lasti = 0 self.cache = {0: (True, ui, self.last[1])} self.deadends = set() # self.iresets += 1 if self.log: print() print("starting new direction", v, 'from', ui)
[docs] def terminate_path(self): """Terminate current path, and reset path counting variable.""" # check if we went anywhere: if -1 in self.deadends and +1 in self.deadends: # self.scale /= self.nudge self.nstuck += 1 # self.nrejects = 0 # self.naccepts = 0 # self.istep = 0 # self.noutside_regions = 0 self.direction = +1 self.deadends = set() self.path = None self.iresets += 1 if self.log: print("reset %d" % self.iresets)
[docs] def set_gradient(self, grad_function): """Set gradient function.""" print("set gradient function to %s" % grad_function.__name__) def plot_gradient_wrapper(x, plot=False): """wrapper that makes plots (when desired)""" v = grad_function(x) if plot: plt.plot(x[0], x[1], '+ ', color='k', ms=10) plt.plot([x[0], v[0] * 1e-2 + x[0]], [x[1], v[1] * 1e-2 + x[1]], color='gray') return v self.grad_function = plot_gradient_wrapper
[docs] def generate_direction(self, ui, region, scale): """Choose a random axis from region.transformLayer.""" return generate_region_random_direction(ui, region, scale=scale)
# return generate_random_direction(ui, region, scale=scale)
[docs] def adjust_accept(self, accepted, unew, pnew, Lnew, nc): """Adjust proposal given that we have been *accepted* at a new point after *nc* calls.""" self.cache[self.nexti] = (accepted, unew, Lnew) if accepted: # start at new point next time self.lasti = self.nexti self.last = unew, Lnew self.history.append((unew, Lnew)) self.naccepts += 1 else: # continue on current point, do not update self.last self.nrejects += 1 self.history.append((unew, Lnew)) assert self.scale > 1e-10, (self.scale, self.istep, self.nrejects)
[docs] def adjust_outside_region(self): """Adjust proposal given that we landed outside region.""" self.noutside_regions += 1 self.nrejects += 1
[docs] def adjust_scale(self, maxlength): """Adjust scale, but not above maxlength.""" # print("%2d | %2d | %2d | %2d %2d %2d %2d | %f" % (self.iresets, self.istep, # len(self.history), self.naccepts, self.nrejects, # self.noutside_regions, self.nstuck, self.scale)) assert len(self.history) > 1 if self.naccepts < (self.nrejects + self.naccepts) * self.balance: if self.log: print("adjusting scale %f down: istep=%d inside=%d outside=%d region=%d nstuck=%d" % ( self.scale, len(self.history), self.naccepts, self.nrejects, self.noutside_regions, self.nstuck)) self.scale /= self.nudge else: if self.scale < maxlength or True: if self.log: print("adjusting scale %f up: istep=%d inside=%d outside=%d region=%d nstuck=%d" % ( self.scale, len(self.history), self.naccepts, self.nrejects, self.noutside_regions, self.nstuck)) self.scale *= self.nudge assert self.scale > 1e-10, self.scale
[docs] def movei(self, ui, region, ndraw=1, plot=False): """Make a move and return the proposed index.""" if self.path is not None: if self.lasti - 1 in self.deadends and self.lasti + 1 in self.deadends: # stuck, cannot go anywhere. Stay. self.nexti = self.lasti return self.nexti if self.path is None: self.start_path(ui, region) assert not (self.lasti - 1 in self.deadends and self.lasti + 1 in self.deadends), \ (self.deadends, self.lasti) if self.lasti + self.direction in self.deadends: self.direction *= -1 self.nexti = self.lasti + self.direction # print("movei", self.nexti) # self.nexti = self.lasti + np.random.randint(0, 2) * 2 - 1 return self.nexti
[docs] def move(self, ui, region, ndraw=1, plot=False): """Advance move.""" u, v = self.get_point(self.movei(ui, region=region, ndraw=ndraw, plot=plot)) return u.reshape((1, -1))
[docs] def reflect(self, reflpoint, v, region, plot=False): """Reflect at *reflpoint* going in direction *v*. Return new direction.""" normal = self.path.gradient(reflpoint, plot=plot) if normal is None: return -v return v - 2 * (normal * v).sum() * normal
[docs] def get_point(self, inew): """Get point corresponding to index *inew*.""" ipoints = [(u, v) for i, u, p, v in self.path.points if i == inew] if len(ipoints) == 0: # print("getting point %d" % inew, self.path.points) #, "->", self.path.extrapolate(self.nexti)) return self.path.extrapolate(inew) else: return ipoints[0]
def __next__(self, region, Lmin, us, Ls, transform, loglike, ndraw=40, plot=False): """Get next point. Parameters ---------- region: MLFriends region. Lmin: float loglikelihood threshold us: array of vectors current live points Ls: array of floats current live point likelihoods transform: function transform function loglike: function loglikelihood function ndraw: int number of draws to attempt simultaneously. plot: bool whether to produce debug plots. """ # find most recent point in history conforming to current Lmin ui, Li = self.last if Li is not None and not Li >= Lmin: if self.log: print("wandered out of L constraint; resetting", ui[0]) ui, Li = None, None if Li is not None and not region.inside(ui.reshape((1,-1))): # region was updated and we are not inside anymore # so reset if self.log: print("region change; resetting") ui, Li = None, None if Li is None and self.history: # try to resume from a previous point above the current contour for uj, Lj in self.history[::-1]: if Lj >= Lmin and region.inside(uj.reshape((1,-1))): ui, Li = uj, Lj if self.log: print("recovered using history", ui) break # select starting point if Li is None: # choose a new random starting point mask = region.inside(us) assert mask.any(), ( "None of the live points satisfies the current region!", region.maxradiussq, region.u, region.unormed, us) i = np.random.randint(mask.sum()) self.starti = i ui = us[mask,:][i] if self.log: print("starting at", ui) assert np.logical_and(ui > 0, ui < 1).all(), ui Li = Ls[mask][i] self.start() self.history.append((ui, Li)) self.last = (ui, Li) inew = self.movei(ui, region, ndraw=ndraw) if self.log: print("i: %d->%d (step %d)" % (self.lasti, inew, self.istep)) # uold, _ = self.get_point(self.lasti) _, uold, Lold = self.cache[self.lasti] if plot: plt.plot(uold[0], uold[1], 'd', color='brown', ms=4) uret, pret, Lret = uold, transform(uold), Lold nc = 0 if inew != self.lasti: accept = False if inew not in self.cache: unew, _ = self.get_point(inew) if plot: plt.plot(unew[0], unew[1], 'x', color='k', ms=4) accept = np.logical_and(unew > 0, unew < 1).all() and region.inside(unew.reshape((1, -1))) if accept: if plot: plt.plot(unew[0], unew[1], '+', color='orange', ms=4) pnew = transform(unew) Lnew = loglike(pnew.reshape((1, -1))) nc = 1 else: Lnew = -np.inf if self.log: print("outside region: ", unew, "from", ui) self.deadends.add(inew) self.adjust_outside_region() else: _, unew, Lnew = self.cache[self.nexti] # if plot: # plt.plot(unew[0], unew[1], 's', color='r', ms=2) if self.log: print(" suggested point:", unew) pnew = transform(unew) if Lnew >= Lmin: if self.log: print(" -> inside.") if plot: plt.plot(unew[0], unew[1], 'o', color='g', ms=4) self.adjust_accept(True, unew, pnew, Lnew, nc) uret, pret, Lret = unew, pnew, Lnew else: if plot: plt.plot(unew[0], unew[1], '+', color='k', ms=2, alpha=0.3) if self.log: print(" -> outside.") jump_successful = False if inew not in self.cache and inew not in self.deadends: # first time we try to go beyond # try to reflect: reflpoint, v = self.get_point(inew) if self.log: print(" trying to reflect at", reflpoint) self.nreflects += 1 sign = -1 if inew < 0 else +1 vnew = self.reflect(reflpoint, v * sign, region=region) * sign xk, vk = extrapolate_ahead(sign, reflpoint, vnew, contourpath=self.path) if plot: plt.plot([reflpoint[0], (-v + reflpoint)[0]], [reflpoint[1], (-v + reflpoint)[1]], '-', color='k', lw=0.5, alpha=0.5) plt.plot([reflpoint[0], (vnew + reflpoint)[0]], [reflpoint[1], (vnew + reflpoint)[1]], '-', color='k', lw=1) if self.log: print(" trying", xk) accept = np.logical_and(xk > 0, xk < 1).all() and region.inside(xk.reshape((1, -1))) if accept: pk = transform(xk) Lk = loglike(pk.reshape((1, -1)))[0] nc += 1 if Lk >= Lmin: jump_successful = True uret, pret, Lret = xk, pk, Lk if self.log: print("successful reflect!") self.path.add(inew, xk, vk, Lk) self.adjust_accept(True, xk, pk, Lk, nc) else: if self.log: print("unsuccessful reflect") self.adjust_accept(False, xk, pk, Lk, nc) else: if self.log: print("unsuccessful reflect out of region") self.adjust_outside_region() if plot: plt.plot(xk[0], xk[1], 'x', color='g' if jump_successful else 'r', ms=8) if not jump_successful: # unsuccessful. mark as deadend self.deadends.add(inew) # print("deadends:", self.deadends) else: self.adjust_accept(False, uret, pret, Lret, nc) # self.adjust_accept(False, unew, pnew, Lnew, nc) assert inew in self.cache or inew in self.deadends, (inew in self.cache, inew in self.deadends) else: # stuck, proposal did not move us self.nstuck += 1 self.adjust_accept(False, uret, pret, Lret, nc) # increase step count self.istep += 1 if self.istep == self.nsteps: if self.log: print("triggering re-orientation") # reset path so we go in a new direction self.terminate_path() self.istep = 0 # if had enough resets, return final point if self.iresets >= self.nresets: if self.log: print("walked %d paths; returning sample" % self.iresets) self.adjust_scale(maxlength=len(uret)**0.5) self.start() self.last = None, None return uret, pret, Lret, nc # do not have a independent sample yet return None, None, None, nc
[docs] class OtherSamplerProxy(object): """Proxy for ClockedSamplers.""" def __init__(self, nnewdirections, sampler='steps', nsteps=0, balance=0.9, scale=0.1, nudge=1.1, log=False): """Initialise sampler. Parameters ----------- nnewdirections: int number of accepted steps until the sample is considered independent. sampler: str which sampler to use nsteps: number of steps in sampler balance: acceptance rate to target scale: initial proposal scale nudge: adjustment factor for scale when acceptance rate is too low or high. must be >=1. """ self.nsteps = nsteps self.samplername = sampler self.sampler = None self.scale = scale self.nudge = nudge self.balance = balance self.log = log self.last = None, None self.ncalls = 0 self.nnewdirections = nnewdirections self.nreflections = 0 self.nreverses = 0 self.nsteps_done = 0 self.naccepts = 0 self.nrejects = 0 self.logstat = [] self.logstat_labels = ['accepted', 'scale'] def __str__(self): """Get string representation.""" return 'Proxy[%s](%dx%d steps, AR=%d%%)' % ( self.samplername, self.nnewdirections, self.nsteps, self.balance * 100)
[docs] def accumulate_statistics(self): """Accumulate statistics at end of step sequence.""" self.nreflections += self.sampler.nreflections self.nreverses += self.sampler.nreverses points = self.sampler.points # range ilo, _, _, _ = min(points) ihi, _, _, _ = max(points) self.nsteps_done += ihi - ilo self.naccepts += self.stepper.naccepts self.nrejects += self.stepper.nrejects if self.log: print("%2d direction encountered %2d accepts, %2d rejects" % ( self.nrestarts, self.stepper.naccepts, self.stepper.nrejects))
[docs] def adjust_scale(self, maxlength): """Adjust proposal scale, but not above maxlength.""" log = self.log if log: print("%2d | %2d %2d %2d | %f" % (self.nrestarts, self.naccepts, self.nrejects, self.nreflections, self.scale)) self.logstat.append([self.naccepts / (self.naccepts + self.nrejects), self.scale]) if self.naccepts < (self.nrejects + self.naccepts) * self.balance: if log: print("adjusting scale %f down" % self.scale) self.scale /= self.nudge else: if self.scale < maxlength or True: if log: print("adjusting scale %f up" % self.scale) self.scale *= self.nudge assert self.scale > 1e-10, self.scale
[docs] def startup(self, region, us, Ls): """Choose a new random starting point.""" if self.log: print("starting from scratch...") mask = region.inside(us) assert mask.any(), ( "Not all of the live points satisfy the current region!", region.maxradiussq, region.u[~mask,:], region.unormed[~mask,:], us[~mask,:]) i = np.random.randint(mask.sum()) self.starti = i ui = us[mask,:][i] assert np.logical_and(ui > 0, ui < 1).all(), ui Li = Ls[mask][i] self.last = ui, Li self.ncalls = 0 self.nrestarts = 0 self.nreflections = 0 self.nreverses = 0 self.nsteps_done = 0 self.naccepts = 0 self.nrejects = 0 self.sampler = None self.stepper = None
[docs] def start_direction(self, region): """Choose a new random direction.""" if self.log: print("choosing random direction") ui, Li = self.last v = generate_random_direction(ui, region, scale=self.scale) # v = generate_region_random_direction(ui, region, scale=self.scale) self.nrestarts += 1 if self.sampler is None or True: samplingpath = SamplingPath(ui, v, Li) contourpath = ContourSamplingPath(samplingpath, region) if self.samplername == 'steps': self.sampler = ClockedStepSampler(contourpath, log=self.log) self.stepper = DirectJumper(self.sampler, self.nsteps, log=self.log) elif self.samplername == 'bisect': self.sampler = ClockedBisectSampler(contourpath, log=self.log) self.stepper = DirectJumper(self.sampler, self.nsteps, log=self.log) elif self.samplername == 'nuts': self.sampler = ClockedNUTSSampler(contourpath, log=self.log) self.stepper = IntervalJumper(self.sampler, self.nsteps, log=self.log) else: assert False
def __next__(self, region, Lmin, us, Ls, transform, loglike, ndraw=40, plot=False): """Get next point. Parameters ---------- region: MLFriends region. Lmin: float loglikelihood threshold us: array of vectors current live points Ls: array of floats current live point likelihoods transform: function transform function loglike: function loglikelihood function ndraw: int number of draws to attempt simultaneously. plot: bool whether to produce debug plots. """ # find most recent point in history conforming to current Lmin ui, Li = self.last if Li is not None and not Li >= Lmin: # print("wandered out of L constraint; resetting", ui[0]) ui, Li = None, None if Li is not None and not region.inside(ui.reshape((1,-1))): # region was updated and we are not inside anymore # so reset ui, Li = None, None if Li is None: self.startup(region, us, Ls) if self.sampler is None: self.start_direction(region) self.stepper.prepare_jump() Llast = None gaps = {} while True: if not self.sampler.is_done(): u, is_independent = self.sampler.next(Llast=Llast) if not is_independent and u is not None: # should evaluate point Llast = None if region.inside(u.reshape((1,-1))): p = transform(u.reshape((1, -1))) L = loglike(p)[0] self.ncalls += 1 if L > Lmin: Llast = L else: Llast = None else: u, i = self.stepper.check_gaps(gaps) if u is None: unew, Lnew = self.stepper.make_jump(gaps) break # done! # check that u is allowed: assert i not in gaps gaps[i] = True if region.inside(u.reshape((1,-1))): p = transform(u.reshape((1, -1))) L = loglike(p)[0] self.ncalls += 1 if L > Lmin: # point is OK gaps[i] = False unew, Lnew = u, L break # if self.log: print("after %d calls, jumped to" % self.ncalls, unew) assert np.isfinite(unew).all(), unew assert np.isfinite(Lnew).all(), Lnew self.accumulate_statistics() # forget sampler self.last = unew, Lnew self.sampler = None self.stepper = None # done, reset: # print("got a sample:", unew) if self.nrestarts >= self.nnewdirections: xnew = transform(unew) self.adjust_scale(maxlength=len(unew)**0.5) # forget as starting point self.last = None, None self.nrestarts = 0 return unew, xnew, Lnew, self.ncalls else: return None, None, None, 0
[docs] def plot(self, filename): """Plot sampler statistics.""" if len(self.logstat) == 0: return parts = np.transpose(self.logstat) plt.figure(figsize=(10, 1 + 3 * len(parts))) for i, (label, part) in enumerate(zip(self.logstat_labels, parts)): plt.subplot(len(parts), 1, 1 + i) plt.ylabel(label) plt.plot(part) if np.min(part) > 0: plt.yscale('log') plt.savefig(filename, bbox_inches='tight') plt.close()