Source code for askcarl.utils

"""Utility functions for dealing with Gaussians and their covariances."""
import jax
import jax.numpy as jnp
import numpy as np
from scipy.linalg import cholesky, solve_triangular


[docs] def mvn_logpdf(X, mean, prec_chol): """Compute log-prob of a Gaussian. Parameters ---------- X: array data, of shape (N, D) mean: array Mean of Gaussian, of shape (D) prec_chol: array precision matrix, of shape (D, D) Returns ------- logprob: array log-probability, one entry for each entry in X, of shape (N) """ if X.ndim == 1: X = X[None, :] # Convert (D,) -> (1, D) D = X.shape[1] x_centered = X - mean y = jnp.dot(x_centered, prec_chol.T) log_det = jnp.sum(jnp.log(jnp.diag(prec_chol))) quad_form = jnp.sum(y**2, axis=1) return log_det - 0.5 * (D * jnp.log(2 * jnp.pi)) - 0.5 * quad_form
[docs] def mvn_pdf(X, mean, prec_chol): """Compute log-prob of a Gaussian. Parameters ---------- X: array data, of shape (N, D) mean: array Mean of Gaussian, of shape (D) prec_chol: array precision matrix, of shape (D, D) Returns ------- logprob: array log-probability, one entry for each entry in X, of shape (N) """ return jnp.exp(mvn_logpdf(X, mean, prec_chol))
[docs] def is_positive_definite(cov, tol=1e-10, condthresh=1e6): """Check that the covariance matrix is well behaved. Parameters ---------- cov: array covariance matrix. shape (D, D) tol: float smallest eigvalsh value allowed condthresh: float minimum on matrix condition number Returns ------- bool True if the matrix is invertable and positive definite """ cond = np.linalg.cond(cov) is_invertible = cond < condthresh return is_invertible and np.all(np.linalg.eigvalsh(cov) > tol)
[docs] def cov_to_prec_cholesky(cov): """Convert covariance matrix to Cholesky factors of the precision matrix. Parameters ---------- cov: array covariance matrix. shape (D, D) Returns ------- prec_cholesky: array Cholesky factors of the precision matrix. shape (D, D) """ return solve_triangular(cholesky(cov, lower=True), np.eye(cov.shape[0]), lower=True)