Source code for ultranest.viz

# noqa: D400 D205
"""
Live point visualisations
-------------------------

Gives a live impression of current exploration.
This is powerful because the user can abort partial runs if the fit
converges to unreasonable values.

"""

from __future__ import division, print_function

import shutil
import string
import sys
from xml.sax.saxutils import escape as html_escape

import numpy as np
from numpy import log10

clusteridstrings = ['%d' % i for i in range(10)] + list(string.ascii_uppercase) + list(string.ascii_lowercase)

spearman = None
try:
    import scipy.stats
    spearman = scipy.stats.spearmanr
except ImportError:
    pass


[docs] def round_parameterlimits(plo, phi, paramlimitguess=None): """Guess the current parameter range. Parameters ---------- plo: array of floats for each parameter, current minimum value phi: array of floats for each parameter, current maximum value paramlimitguess: array of float tuples for each parameter, guess of parameter range if available Returns ------- plo_rounded: array of floats for each parameter, rounded minimum value phi_rounded: array of floats for each parameter, rounded maximum value formats: array of float tuples for each parameter, string format for representing it. """ with np.errstate(divide='ignore'): expos = log10(np.abs([plo, phi])) expolo = np.floor(np.min(expos, axis=0)) expohi = np.ceil(np.max(expos, axis=0)) is_negative = plo < 0 plo_rounded = np.where(is_negative, -10**expohi, 0) phi_rounded = np.where(is_negative, 10**expohi, 10**expohi) if paramlimitguess is not None: for i, (plo_guess, phi_guess) in enumerate(paramlimitguess): # if plo_guess is higher than what we thought, we can increase to match if plo_guess <= plo[i] and plo_guess >= plo_rounded[i]: plo_rounded[i] = plo_guess if phi_guess >= phi[i] and phi_guess <= phi_rounded[i]: phi_rounded[i] = phi_guess formats = [] for i in range(len(plo)): fmt = '%+.1e' if -1 <= expolo[i] <= 2 and -1 <= expohi[i] <= 2: fmt = '%+.1f' if -4 <= expolo[i] <= 0 and -4 <= expohi[i] <= 0: fmt = '%%+.%df' % (max(0, -min(expolo[i], expohi[i]))) if phi[i] == plo[i]: fmt = '%+.1f' elif fmt % plo[i] == fmt % phi[i]: fmt = '%%+.%df' % (max(0, -int(np.floor(log10(phi[i] - plo[i]))))) formats.append(fmt) return plo_rounded, phi_rounded, formats
[docs] def nicelogger(points, info, region, transformLayer, region_fresh=False): """Log current live points and integration progress to stdout. Parameters ----------- points: dict with keys "u", "p", "logl" live points (u: cube coordinates, p: transformed coordinates, logl: loglikelihood values) info: dict integration information. Keys are: - paramlims (optional): parameter ranges - logvol: expected volume at this iteration region: MLFriends Current region. transformLayer: ScaleLayer or AffineLayer or MaxPrincipleGapAffineLayer Current transformLayer (for clustering information). region_fresh: bool Whether the region was just updated. """ p = points['p'] paramnames = info['paramnames'] # print() # print('lnZ = %.1f, remainder = %.1f, lnLike = %.1f | Efficiency: %d/%d = %.4f%%\r' % ( # logz, logz_remain, np.max(logl), ncall, it, it * 100 / ncall)) plo = p.min(axis=0) phi = p.max(axis=0) plo_rounded, phi_rounded, paramformats = round_parameterlimits(plo, phi, paramlimitguess=info.get('paramlims')) if sys.stderr.isatty() and hasattr(shutil, 'get_terminal_size'): columns, _ = shutil.get_terminal_size(fallback=(80, 25)) else: columns, _ = 80, 25 paramwidth = max([len(pname) for pname in paramnames]) width = columns - 23 - paramwidth width = max(width, 10) indices = ((p - plo_rounded) * width / (phi_rounded - plo_rounded).reshape((1, -1))).astype(int) indices[indices >= width] = width - 1 indices[indices < 0] = 0 ndim = len(plo) print() print() clusterids = transformLayer.clusterids % len(clusteridstrings) nmodes = transformLayer.nclusters print( "Mono-modal" if nmodes == 1 else "Have %d modes" % nmodes, "Volume: ~exp(%.2f)" % region.estimate_volume(), '*' if region_fresh else ' ', "Expected Volume: exp(%.2f)" % info['logvol'], '' if 'order_test_correlation' not in info else ("Quality: correlation length: %d (%s)" % (info['order_test_correlation'], '+' if info['order_test_direction'] >= 0 else '-')) if np.isfinite(info['order_test_correlation']) else "Quality: ok", ) if info.get('stepsampler_info', {}).get('num_logs', 0) > 0: stepsampler_info = dict(info['stepsampler_info']) stepsampler_info['frac_far_enough'] *= 100 if 'mean_distance' in stepsampler_info: print(( 'Step sampler performance: %(rejection_rate).1f rej/step, %(mean_nsteps)d steps/it, ' 'rel jump distance: %(mean_distance).2f (should be >1), %(frac_far_enough).2f%% (should be >50%%)') % stepsampler_info ) else: print() print() if ndim == 1: pass elif ndim == 2 and spearman is not None: rho, pval = spearman(p) if pval < 0.01 and abs(rho) > 0.75: print(" %s between %s and %s: rho=%.2f" % ( 'positive degeneracy' if rho > 0 else 'negative degeneracy', paramnames[0], paramnames[1], rho)) elif spearman is not None: rho, pval = spearman(p) if np.isfinite(pval).all() and pval.ndim == 2: for i, param in enumerate(paramnames): for j, param2 in enumerate(paramnames[:i]): if pval[i,j] < 0.01 and abs(rho[i,j]) > 0.99: s = 'positive relation' if rho[i,j] > 0 else 'negative relation' print(" perfect %s between %s and %s" % (s, param, param2)) elif pval[i,j] < 0.01 and abs(rho[i,j]) > 0.75: s = 'positive degeneracy' if rho[i,j] > 0 else 'negative degeneracy' print(" %s between %s and %s: rho=%.2f" % (s, param, param2, rho[i,j])) for i, (param, fmt) in enumerate(zip(paramnames, paramformats)): if nmodes == 1: line = [' ' for _ in range(width)] for j in np.unique(indices[:,i]): line[j] = '*' linestr = ''.join(line) else: line = [' ' for _ in range(width)] for clusterid, j in zip(clusterids, indices[:,i]): if clusterid > 0 and line[j] in (' ', '0'): # set it to correct cluster id line[j] = clusteridstrings[clusterid] elif clusterid == 0 and line[j] == ' ': # empty, so set it although we don't know the cluster id line[j] = '0' # else: # line[j] = '*' linestr = ''.join(line) line = linestr ilo, ihi = indices[:,i].min(), indices[:,i].max() if ilo > 10: assert line[:10] == ' ' * 10 leftstr = fmt % plo[i] j = ilo - 2 - len(leftstr) # left-bound if j < width and j > 0: line = line[:j] + leftstr + line[j + len(leftstr):] if ihi < width - 10: rightstr = fmt % phi[i] j = ihi + 3 # right-bound if j < width and j > 0: line = line[:j] + rightstr + line[j + len(rightstr):] parampadded = ('%%-%ds' % paramwidth) % param print('%s: %09s|%s|%9s' % (parampadded, fmt % plo_rounded[i], line, fmt % phi_rounded[i])) print()
[docs] def isnotebook(): """Check if running in a Jupyter notebook.""" try: shell = get_ipython().__class__.__name__ if shell == 'ZMQInteractiveShell': # noqa: SIM103 return True # Jupyter notebook or qtconsole elif shell == 'TerminalInteractiveShell': # noqa: SIM103 return False # Terminal running IPython else: return False # Other type (?) except NameError: return False # Probably standard Python interpreter
[docs] class LivePointsWidget: """ Widget for ipython and jupyter notebooks. Shows where the live points are currently in parameter space. """ def __init__(self): """Initialise. To draw, call .initialize().""" self.grid = None self.label = None self.laststatus = None
[docs] def initialize(self, paramnames, width): """Set up and display widget. Parameters ---------- paramnames: list of str Parameter names width: int number of html table columns. """ from IPython.display import display from ipywidgets import HTML, GridspecLayout, Layout, VBox grid = GridspecLayout(len(paramnames), width + 3) self.laststatus = [] for a, paramname in enumerate(paramnames): self.laststatus.append('*' * width) htmlcode = "<div style='background-color:#6E6BF4;'>&nbsp;</div>" for b in range(width): grid[a, b + 2] = HTML(htmlcode, layout=Layout(margin="0")) htmlcode = "<div style='background-color:#FFB858; font-weight:bold; padding-right: 2em;'>%s</div>" grid[a, 0] = HTML(htmlcode % html_escape(paramname), layout=Layout(margin="0")) grid[a, 1] = HTML("...", layout=Layout(margin="0")) grid[a,-1] = HTML("...", layout=Layout(margin="0")) self.grid = grid self.label = HTML() box = VBox(children=[self.label, grid]) display(box)
def __call__(self, points, info, region, transformLayer, region_fresh=False): """Update widget to show current live points and integration progress to stdout. Parameters ----------- points: dict with keys u, p, logl live points (u: cube coordinates, p: transformed coordinates, logl: loglikelihood values) info: dict integration information. Keys are: - paramlims (optional): parameter ranges - logvol: expected volume at this iteration region: MLFriends Current region. transformLayer: ScaleLayer or AffineLayer or MaxPrincipleGapAffineLayer Current transformLayer (for clustering information). region_fresh: bool Whether the region was just updated. """ # t = time.time() # if self.lastupdate is not None and self.lastupdate < t - 5: # return # self.lastupdate = t # u, p, logl = points['u'], points['p'], points['logl'] p = points['p'] paramnames = info['paramnames'] # print() # print('lnZ = %.1f, remainder = %.1f, lnLike = %.1f | Efficiency: %d/%d = %.4f%%\r' % ( # logz, logz_remain, np.max(logl), ncall, it, it * 100 / ncall)) plo = p.min(axis=0) phi = p.max(axis=0) plo_rounded, phi_rounded, paramformats = round_parameterlimits(plo, phi, paramlimitguess=info.get('paramlims')) width = 50 if self.grid is None: self.initialize(paramnames, width) with np.errstate(invalid="ignore"): indices = ((p - plo_rounded) * width / (phi_rounded - plo_rounded).reshape((1, -1))).astype(int) indices[indices >= width] = width - 1 indices[indices < 0] = 0 ndim = len(plo) clusterids = transformLayer.clusterids % len(clusteridstrings) nmodes = transformLayer.nclusters labeltext = ("Mono-modal" if nmodes == 1 else "Have %d modes" % nmodes) + \ (" | Volume: ~exp(%.2f) " % region.estimate_volume()) + ('*' if region_fresh else ' ') + \ " | Expected Volume: exp(%.2f)" % info['logvol'] + \ ('' if 'order_test_correlation' not in info else (" | Quality: correlation length: %d (%s)" % (info['order_test_correlation'], '+' if info['order_test_direction'] >= 0 else '-')) if np.isfinite(info['order_test_correlation']) else " | Quality: ok") if info.get('stepsampler_info', {}).get('num_logs', 0) > 0: stepsampler_info = dict(info['stepsampler_info']) stepsampler_info['frac_far_enough'] *= 100 if 'mean_distance' in stepsampler_info: labeltext += ( "<br/>" 'Step sampler performance: %(rejection_rate).1f%% rej/step, %(mean_nsteps)d steps/it' 'mean rel jump distance: %(mean_distance).2f (should be >1), %(frac_far_enough).2f%% (should be >50%%)' ) % stepsampler_info if ndim == 1: pass elif ndim == 2 and spearman is not None: rho, pval = spearman(p) if pval < 0.01 and abs(rho) > 0.75: labeltext += ("<br/> %s between %s and %s: rho=%.2f" % ( 'positive degeneracy' if rho > 0 else 'negative degeneracy', paramnames[0], paramnames[1], rho)) elif spearman is not None: rho, pval = spearman(p) for i, param in enumerate(paramnames): for j, param2 in enumerate(paramnames[:i]): if pval[i,j] < 0.01 and abs(rho[i,j]) > 0.99: labeltext += ("<br/> perfect %s between %s and %s" % ( 'positive relation' if rho[i,j] > 0 else 'negative relation', param2, param)) elif pval[i,j] < 0.01 and abs(rho[i,j]) > 0.75: labeltext += ("<br/> %s between %s and %s: rho=%.2f" % ( 'positive degeneracy' if rho[i,j] > 0 else 'negative degeneracy', param2, param, rho[i,j])) for i, (_param, fmt) in enumerate(zip(paramnames, paramformats)): if nmodes == 1: line = [' ' for _ in range(width)] for j in np.unique(indices[:,i]): line[j] = '*' linestr = ''.join(line) else: line = [' ' for _ in range(width)] for clusterid, j in zip(clusterids, indices[:,i]): if clusterid > 0 and line[j] in (' ', '0'): # set it to correct cluster id line[j] = clusteridstrings[clusterid] elif clusterid == 0 and line[j] == ' ': # empty, so set it although we don't know the cluster id line[j] = '0' # else: # line[j] = '*' linestr = ''.join(line) oldlinestr = self.laststatus[i] for j, (c, d) in enumerate(zip(linestr, oldlinestr)): if c != d: if c == ' ': self.grid[i, j + 2].value = "<div style='background-color:white;'>&nbsp;</div>" else: self.grid[i, j + 2].value = "<div style='background-color:#6E6BF4; font-family:monospace'>%s</div>" % c.replace('*', '&nbsp;') self.laststatus[i] = linestr # self.grid[i,0].value = param self.grid[i, 1].value = fmt % plo_rounded[i] self.grid[i,-1].value = fmt % phi_rounded[i] self.label.value = labeltext
[docs] def get_default_viz_callback(): """Get default callback. LivePointsWidget for Jupyter notebook, nicelogger otherwise. """ if isnotebook(): return LivePointsWidget() else: return nicelogger