Specifying priors

This tutorial demonstrates how to specify parameter priors, including:

  • uniform and log-uniform distributions

  • gaussian and more complicated distributions

  • multi-dimensional priors (not factorized)

  • non-analytic priors

  • priors on fractions

[1]:
import numpy as np
import scipy.stats
import matplotlib.pyplot as plt
%matplotlib inline

Cumulative prior distributions

Any 1-dimensional probability distribution is normalised so that its integral is 1. That is, the cumulative distribution goes from 0 to 1. For example:

[2]:
gaussdistribution = scipy.stats.norm(2, 0.3)
uniformdistribution = scipy.stats.uniform(3.5, 1.2)
x = np.linspace(0, 5, 400)
plt.figure()
plt.plot(x, gaussdistribution.pdf(x), '--', label='density (Gauss)')
plt.plot(x, gaussdistribution.cdf(x), label='cumulative (Gauss)')
plt.plot(x, uniformdistribution.pdf(x), '--', label='density (uniform)')
plt.plot(x, uniformdistribution.cdf(x), label='cumulative (uniform)')
plt.ylabel('Probability')
plt.xlabel('Model parameter x')
plt.legend();
[2]:
<matplotlib.legend.Legend at 0x7f114004bc40>
_images/priors_3_1.svg

Transforming from the unit interval

We invert the cumulative probability distribution mapping quantiles (0…1) to the corresponding model parameter value.

Lets start with the uniform distribution.

[3]:
def transform_1d_uniform(quantile):
    lower_bound = 3.5
    width = 1.2
    return quantile * width + lower_bound

Scipy provides the inverse cumulative probability distributions:

[4]:
def transform_1d(quantile):
    return gaussdistribution.ppf(quantile)

UltraNest samples from the unit interval to obtain prior samples. Lets try drawing a few examples from our function:

[5]:
uniform_samples = transform_1d_uniform(np.random.uniform(0, 1, size=100000))
gauss_samples = transform_1d(np.random.uniform(0, 1, size=100000))
plt.figure()
plt.hist(uniform_samples, bins=20, histtype='step', density=True, label='Uniform')
plt.hist(gauss_samples, bins=100, histtype='step', density=True, label='Gauss')
plt.xlabel('Model parameter x')
plt.ylabel('Density')
plt.legend();
[5]:
<matplotlib.legend.Legend at 0x7f113df18cd0>
_images/priors_9_1.svg

Beautiful! We obtained nice samples that follow the prior distribution.

The unit hypercube

Lets specify a prior for UltraNest with multiple parameters:

  • a uniform distribution from 3.5 to 4.7

  • a log-uniform distribution from 0.01 to 100

  • a gaussian distribution around 2.0 +- 0.3

[6]:
# out transform function will receive one quantile corresponding to each of the three parameter
def transform(quantile_cube):
    # prepare the output array, which has the same shape
    transformed_parameters = np.empty_like(quantile_cube)
    # first parameter: a uniform distribution
    transformed_parameters[0] = 3.5 + 1.2 * quantile_cube[0]
    # second parameter: a log-uniform distribution
    transformed_parameters[1] = 10**(-2 + 4 * quantile_cube[1])
    # third parameter: Gaussian
    transformed_parameters[2] = mydistribution.ppf(quantile_cube[2])

    return transformed_parameters

Some recommendations:

  • scipy.stats provides many 1-d distributions that can be used like this.

  • avoid building scipy.stats objects in the transform, because this is slow – build them outside first, then only invoke the .ppf method in the transform.

Dependent priors

In some cases, a previous experiment gives informative priors which we want to incorporate, and they may be inter-dependent. For example, consider a two-dimensional gaussian prior distribution:

[7]:
means = np.array([2, 3])
cov = np.array([[1, 0.6], [0.6, 0.4]])
[8]:
rv = scipy.stats.multivariate_normal(means, cov)
x, y = np.linspace(-1, 5, 400), np.linspace(1.5, 5, 400)
X, Y = np.meshgrid(x, y)
Z = rv.pdf(np.transpose([X.flatten(), Y.flatten()])).reshape(X.shape)
plt.figure()
plt.title('Correlated prior')
plt.contourf(X, Y, Z, cmap='magma_r')
plt.xlabel('Parameter 1')
plt.ylabel('Parameter 2');
[8]:
Text(0, 0.5, 'Parameter 2')
_images/priors_16_1.svg

We recall:

  • Parameter 1 has a cumulative distribution

  • At each value of Parameter 1, Parameter 2 has a cumulative distribution.

  • We can thus specify a dependent distribution using the unit hypercube

[9]:
a = np.linalg.inv(cov)
l, v = np.linalg.eigh(a)
rotation_matrix = np.dot(v, np.diag(1. / np.sqrt(l)))

def transform_correlated(quantiles):
    # sample a independent multivariate gaussian
    independent_gaussian = scipy.stats.norm.ppf(quantiles)
    # rotate and shift
    return means + np.einsum('ij,kj->ki', rotation_matrix, independent_gaussian)

Lets try sampling!

[10]:
samples = transform_correlated(np.random.uniform(0, 1, size=(100, 2)))

plt.figure()
plt.title('Correlated prior')
plt.contourf(X, Y, Z, cmap='magma_r')
plt.plot(samples[:,0], samples[:,1], 'o', mew=1, mfc='w', mec='k')
plt.xlabel('Parameter 1')
plt.ylabel('Parameter 2');

[10]:
Text(0, 0.5, 'Parameter 2')
_images/priors_20_1.svg

A similar effect can be achieved by defining transforms in sequence (this is a different prior though):

[11]:
gauss1 = scipy.stats.norm(2, 1)
gauss2 = scipy.stats.norm(0, 0.1)


def transform_correlated(quantiles):
    parameters = np.empty_like(quantiles)
    # first parameter is independent
    parameters[0] = gauss1.ppf(quantiles[0])
    # second parameter depends on first parameter, here with a shift
    parameters[1] = parameters[0] + gauss2.ppf(quantiles[0])
    return parameters

Non-analytic priors

Sometimes, the prior may not be easily invertable. For example, when it is given as posterior samples from a previous analysis. I

[12]:
posterior_samples = np.hstack((np.random.uniform(0, 3, 2000), np.random.normal(3, 0.2, 2000)))

plt.figure(figsize=(4,2))
plt.hist(posterior_samples, histtype='step', bins=100);

[12]:
(array([ 27.,  26.,  25.,  18.,  25.,  25.,  21.,  27.,  25.,  23.,  15.,
         21.,  22.,  20.,  17.,  31.,  20.,  33.,  29.,  20.,  22.,  20.,
         23.,  27.,  25.,  27.,  25.,  24.,  24.,  18.,  16.,  36.,  20.,
         22.,  23.,  36.,  30.,  22.,  23.,  19.,  25.,  32.,  17.,  30.,
         15.,  21.,  19.,  23.,  27.,  26.,  32.,  21.,  15.,  25.,  31.,
         20.,  16.,  29.,  21.,  25.,  21.,  37.,  17.,  24.,  32.,  30.,
         19.,  25.,  32.,  32.,  33.,  39.,  42.,  57.,  59.,  79.,  97.,
        108., 120., 130., 143., 162., 177., 144., 155., 133., 128., 100.,
         89.,  86.,  57.,  50.,  42.,  26.,  22.,  13.,   5.,   5.,   5.,
          3.]),
 array([9.52203625e-04, 3.69641993e-02, 7.29761951e-02, 1.08988191e-01,
        1.45000187e-01, 1.81012182e-01, 2.17024178e-01, 2.53036174e-01,
        2.89048169e-01, 3.25060165e-01, 3.61072161e-01, 3.97084157e-01,
        4.33096152e-01, 4.69108148e-01, 5.05120144e-01, 5.41132139e-01,
        5.77144135e-01, 6.13156131e-01, 6.49168127e-01, 6.85180122e-01,
        7.21192118e-01, 7.57204114e-01, 7.93216109e-01, 8.29228105e-01,
        8.65240101e-01, 9.01252097e-01, 9.37264092e-01, 9.73276088e-01,
        1.00928808e+00, 1.04530008e+00, 1.08131208e+00, 1.11732407e+00,
        1.15333607e+00, 1.18934806e+00, 1.22536006e+00, 1.26137205e+00,
        1.29738405e+00, 1.33339605e+00, 1.36940804e+00, 1.40542004e+00,
        1.44143203e+00, 1.47744403e+00, 1.51345602e+00, 1.54946802e+00,
        1.58548002e+00, 1.62149201e+00, 1.65750401e+00, 1.69351600e+00,
        1.72952800e+00, 1.76553999e+00, 1.80155199e+00, 1.83756399e+00,
        1.87357598e+00, 1.90958798e+00, 1.94559997e+00, 1.98161197e+00,
        2.01762396e+00, 2.05363596e+00, 2.08964796e+00, 2.12565995e+00,
        2.16167195e+00, 2.19768394e+00, 2.23369594e+00, 2.26970793e+00,
        2.30571993e+00, 2.34173193e+00, 2.37774392e+00, 2.41375592e+00,
        2.44976791e+00, 2.48577991e+00, 2.52179190e+00, 2.55780390e+00,
        2.59381590e+00, 2.62982789e+00, 2.66583989e+00, 2.70185188e+00,
        2.73786388e+00, 2.77387587e+00, 2.80988787e+00, 2.84589987e+00,
        2.88191186e+00, 2.91792386e+00, 2.95393585e+00, 2.98994785e+00,
        3.02595984e+00, 3.06197184e+00, 3.09798384e+00, 3.13399583e+00,
        3.17000783e+00, 3.20601982e+00, 3.24203182e+00, 3.27804381e+00,
        3.31405581e+00, 3.35006781e+00, 3.38607980e+00, 3.42209180e+00,
        3.45810379e+00, 3.49411579e+00, 3.53012778e+00, 3.56613978e+00,
        3.60215178e+00]),
 [<matplotlib.patches.Polygon at 0x7f113c32d280>])
_images/priors_24_1.svg

In this case, you can compute the cumulative distribution numerically and invert it:

[13]:
hist, bin_edges = np.histogram(posterior_samples, bins=100)
hist_cumulative = np.cumsum(hist / hist.sum())
bin_middle = (bin_edges[:-1] + bin_edges[1:]) / 2

def transform_histogram(quantile):
    return np.interp(quantile, hist_cumulative, bin_middle)

samples = transform_histogram(np.random.uniform(size=1000))
plt.figure(figsize=(4,2))
plt.hist(posterior_samples, histtype='step', bins=100, density=True);
plt.hist(samples, histtype='step', bins=100, density=True);

[13]:
(array([0.51072442, 0.08512074, 0.11349432, 0.14186789, 0.02837358,
        0.14186789, 0.11349432, 0.25536221, 0.25536221, 0.17024147,
        0.02837358, 0.14186789, 0.14186789, 0.08512074, 0.31210937,
        0.17024147, 0.08512074, 0.11349432, 0.19861505, 0.25536221,
        0.11349432, 0.08512074, 0.11349432, 0.25536221, 0.17024147,
        0.31210937, 0.22698863, 0.08512074, 0.05674716, 0.22698863,
        0.11349432, 0.08512074, 0.08512074, 0.14186789, 0.19861505,
        0.31210937, 0.14186789, 0.25536221, 0.19861505, 0.14186789,
        0.17024147, 0.17024147, 0.17024147, 0.17024147, 0.11349432,
        0.14186789, 0.19861505, 0.25536221, 0.19861505, 0.11349432,
        0.14186789, 0.22698863, 0.08512074, 0.17024147, 0.36885653,
        0.14186789, 0.11349432, 0.19861505, 0.25536221, 0.28373579,
        0.19861505, 0.19861505, 0.22698863, 0.11349432, 0.19861505,
        0.39723011, 0.11349432, 0.17024147, 0.17024147, 0.11349432,
        0.36885653, 0.22698863, 0.28373579, 0.28373579, 0.28373579,
        0.42560368, 0.76608663, 0.39723011, 0.82283379, 0.539098  ,
        0.90795453, 0.70933947, 1.13494316, 0.93632811, 1.078196  ,
        0.96470169, 0.85120737, 0.87958095, 0.99307526, 0.85120737,
        0.539098  , 0.36885653, 0.39723011, 0.45397726, 0.19861505,
        0.17024147, 0.11349432, 0.11349432, 0.        , 0.05674716]),
 array([0.0189582 , 0.05420226, 0.08944631, 0.12469037, 0.15993442,
        0.19517848, 0.23042254, 0.26566659, 0.30091065, 0.3361547 ,
        0.37139876, 0.40664281, 0.44188687, 0.47713093, 0.51237498,
        0.54761904, 0.58286309, 0.61810715, 0.65335121, 0.68859526,
        0.72383932, 0.75908337, 0.79432743, 0.82957148, 0.86481554,
        0.9000596 , 0.93530365, 0.97054771, 1.00579176, 1.04103582,
        1.07627987, 1.11152393, 1.14676799, 1.18201204, 1.2172561 ,
        1.25250015, 1.28774421, 1.32298826, 1.35823232, 1.39347638,
        1.42872043, 1.46396449, 1.49920854, 1.5344526 , 1.56969666,
        1.60494071, 1.64018477, 1.67542882, 1.71067288, 1.74591693,
        1.78116099, 1.81640505, 1.8516491 , 1.88689316, 1.92213721,
        1.95738127, 1.99262532, 2.02786938, 2.06311344, 2.09835749,
        2.13360155, 2.1688456 , 2.20408966, 2.23933371, 2.27457777,
        2.30982183, 2.34506588, 2.38030994, 2.41555399, 2.45079805,
        2.48604211, 2.52128616, 2.55653022, 2.59177427, 2.62701833,
        2.66226238, 2.69750644, 2.7327505 , 2.76799455, 2.80323861,
        2.83848266, 2.87372672, 2.90897077, 2.94421483, 2.97945889,
        3.01470294, 3.049947  , 3.08519105, 3.12043511, 3.15567916,
        3.19092322, 3.22616728, 3.26141133, 3.29665539, 3.33189944,
        3.3671435 , 3.40238756, 3.43763161, 3.47287567, 3.50811972,
        3.54336378]),
 [<matplotlib.patches.Polygon at 0x7f113debd640>])
_images/priors_26_1.svg

Fraction parameters

Some parameters such as fractions (or abundances) may be required to sum to 1. How can we specify such parameters?

One option is to use absolute numbers. For example, instead of specifying the total mass and mass fractions for each chemical element, parameterise the mass of each element. This avoids the <=1 constraint, and may be easier to infer. A drawback is that the prior ranges for each element mass may be wide.

The other option is to use the right distribution exactly made for this, which samples unbiased under the constraint (sum<=1): The Dirichlet distribution. Here we assume that the prior on the individual fraction is flat (flat Dirichlet distribution, \(\alpha=1\)).

[14]:
def transform_dirichlet(quantiles):
    # https://en.wikipedia.org/wiki/Dirichlet_distribution#Random_number_generation
    # first inverse transform sample from Gamma(alpha=1,beta=1), which is Exponential(1)
    gamma_quantiles = -np.log(quantiles)
    # dirichlet variables
    return gamma_quantiles / gamma_quantiles.sum(axis=1).reshape((-1, 1))

Lets have a look at the samples:

[15]:
samples = transform_dirichlet(np.random.uniform(0, 1, size=(400, 3)))

from mpl_toolkits.mplot3d import Axes3D
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.view_init(elev=10., azim=-30)
ax.plot(samples[:,0], samples[:,1], samples[:,2], 'x ');
[15]:
[<mpl_toolkits.mplot3d.art3d.Line3D at 0x7f113dff0220>]
_images/priors_31_1.svg

The samples nicely lie on the plane where the sum is 1.

Further topics

Check out the rest of the documentation and the tutorials.

They illustrate how to: