Source code for likelihoods

import warnings
import numpyro.distributions as dist
from jax.scipy.special import erfc
import jax.numpy as jnp
import numpy as np


[docs] def likelihood_warnings(method, infer_intrinsic, nx, errors, covmat): """ Raise warnings if using asymptotically biased likelihood for the situation of interest. Args: :method (str, default='mnr'): The name of the likelihood method to use ('mnr', 'gmm', 'unif' or 'prof'). :infer_intrinsic (bool, default=True): Whether to infer the intrinsic scatter in the y direction :nx (int): The number of observed x values :errors (jnp.ndarray): If covmat=False, then this is [xerr, yerr], giving the error on the observed x and y values. Otherwise, this is the covariance matrix in the order (x, y) :covmat (bool, default=False): This determines whether the errors argument is [xerr, yerr] (False) or a covariance matrix (True). """ # Modify the warnings filter to always show all warnings warnings.filterwarnings('always') # Check if xerrs are all zero if covmat: xerr = errors[:nx, :nx] else: xerr = errors[0] if isinstance(xerr, (float, int)): no_xerr = xerr == 0 else: no_xerr = jnp.all(jnp.array(xerr) == 0) warning_message = None if no_xerr and method not in ['unif', 'prof']: warning_message = ( f'Not recommended method "{method}" for this setup. ' 'Use "unif" or "prof" instead.') elif (not no_xerr) and infer_intrinsic and method not in ['mnr', 'gmm']: warning_message = ( f'Not recommended method "{method}" for this setup. ' 'Use "mnr" or "gmm" instead.') elif (not no_xerr) and (not infer_intrinsic) and method != 'prof': warning_message = ( f'Not recommended method "{method}" for this setup. ' 'Use "prof" instead.') if warning_message is not None: warnings.warn(warning_message, UserWarning)
[docs] def negloglike_mnr_uplims(xobs, yobs, y_is_detected, xerr, yerr, f, fprime, sig, mu_gauss, w_gauss): """ Computes the negative log-likelihood under the assumption of an uncorrelated Gaussian likelihood with a Gaussian prior on the true x positions, where some of the y values are upper limits. Args: :xobs (jnp.ndarray): The observed x values :yobs (jnp.ndarray): The observed y values :y_is_detected (jnp.ndarray): A boolean array of the same length as xobs and yobs, giving whether each point is a detection (True) or an upper limit (False) :xerr (jnp.ndarray): The error on the observed x values :yerr (jnp.ndarray): The error on the observed y values :f (jnp.ndarray): If we are fitting the function f(x), this is f(x) evaluated at xobs :fprime (jnp.ndarray): If we are fitting the function f(x), this is df/dx evaluated at xobs :sig (float): The intrinsic scatter, which is added in quadrature with yerr :mu_gauss (float): The mean of the Gaussian prior on the true x positions :w_gauss (float): The standard deviation of the Gaussian prior on the true x positions """ N = len(xobs) Ai = fprime Ai = jnp.broadcast_to(jnp.atleast_1d(Ai), (N,)) Bi = f - Ai * xobs if not hasattr(xerr, '__len__') or len(xerr) == 1: xerr = jnp.full(N, jnp.squeeze(jnp.array(xerr))) if not hasattr(yerr, '__len__') or len(yerr) == 1: yerr = jnp.full(N, jnp.squeeze(jnp.array(yerr))) mask = np.asarray(y_is_detected, dtype=bool) mask_uplim = ~mask xdet = xobs[mask] ydet = yobs[mask] xerr_det = xerr[mask] yerr_det = yerr[mask] Ai_det = Ai[mask] Bi_det = Bi[mask] xuplim = xobs[mask_uplim] yuplim = yobs[mask_uplim] xerr_uplim = xerr[mask_uplim] yerr_uplim = yerr[mask_uplim] Ai_uplim = Ai[mask_uplim] Bi_uplim = Bi[mask_uplim] neglog_p = 0.0 # DETECTIONS if len(xdet) > 0: s2 = yerr_det ** 2 + sig**2 den = Ai_det**2 * xerr_det**2 * \ w_gauss**2 + s2*(xerr_det**2 + w_gauss**2) numerator_t1 = w_gauss**2*(Ai_det*xdet + Bi_det - ydet)**2 + xerr_det**2*( Ai_det*mu_gauss + Bi_det - ydet)**2 + s2*(xdet - mu_gauss)**2 denominator_t1 = 2*den t2 = jnp.log(2 * jnp.pi * jnp.sqrt(den)) neglog_p = jnp.sum(numerator_t1/denominator_t1 + t2) # UPPER LIMITS if len(xuplim) > 0: sigma_c_squared = (1/w_gauss**2 + 1/xerr_uplim**2)**(-1) mu_c = sigma_c_squared * (mu_gauss/w_gauss**2 + xuplim/xerr_uplim**2) sigma_squared = yerr_uplim**2 + sig**2 t1 = dist.Normal(xuplim, jnp.sqrt( xerr_uplim**2 + w_gauss**2)).log_prob(mu_gauss).sum() t2 = jnp.sum(jnp.log(0.5*erfc((Ai_uplim*mu_c + Bi_uplim - yuplim) / (jnp.sqrt(2*sigma_squared + 2 * Ai_uplim**2 * sigma_c_squared))))) neglog_p_uplim = - t1 - t2 neglog_p += neglog_p_uplim return neglog_p
[docs] def negloglike_mnr(xobs, yobs, xerr, yerr, f, fprime, sig, mu_gauss, w_gauss): """ Computes the negative log-likelihood under the assumption of an uncorrelated Gaussian likelihood with a Gaussian prior on the true x positions. Args: :xobs (jnp.ndarray): The observed x values :yobs (jnp.ndarray): The observed y values :xerr (jnp.ndarray): The error on the observed x values :yerr (jnp.ndarray): The error on the observed y values :f (jnp.ndarray): If we are fitting the function f(x), this is f(x) evaluated at xobs :fprime (jnp.ndarray): If we are fitting the function f(x), this is df/dx evaluated at xobs :sig (float): The intrinsic scatter, which is added in quadrature with yerr :mu_gauss (float): The mean of the Gaussian prior on the true x positions :w_gauss (float): The standard deviation of the Gaussian prior on the true x positions Returns: :neglog_p (float): The negative log-likelihood """ N = len(xobs) Ai = fprime if (not hasattr(Ai, "__len__")) or len(Ai) == 1: Ai = jnp.full(N, jnp.squeeze(jnp.array(Ai))) Bi = f - Ai * xobs s2 = yerr ** 2 + sig ** 2 den = Ai ** 2 * w_gauss ** 2 * xerr ** 2 + s2 * (w_gauss ** 2 + xerr ** 2) neglog_p = ( N * jnp.log(2 * jnp.pi) + 1/2 * jnp.sum(jnp.log(den)) + 1/2 * jnp.sum(w_gauss ** 2 * (Ai * xobs + Bi - yobs) ** 2 / den) + 1/2 * jnp.sum(xerr ** 2 * (Ai * mu_gauss + Bi - yobs) ** 2 / den) + 1/2 * jnp.sum(s2 * (xobs - mu_gauss) ** 2 / den) ) return neglog_p
[docs] def negloglike_gmm(xobs, yobs, xerr, yerr, f, fprime, sig, all_mu_gauss, all_w_gauss, all_weights): """ Computes the negative log-likelihood under the assumption of an uncorrelated Gaussian likelihood with a GMM prior on the true x positions. Args: :xobs (jnp.ndarray): The observed x values :yobs (jnp.ndarray): The observed y values :xerr (jnp.ndarray): The error on the observed x values :yerr (jnp.ndarray): The error on the observed y values :f (jnp.ndarray): If we are fitting the function f(x), this is f(x) evaluated at xobs :fprime (jnp.ndarray): If we are fitting the function f(x), this is df/dx evaluated at xobs :sig (float): The intrinsic scatter, which is added in quadrature with yerr :all_mu_gauss (jnp.ndarray): The means of the Gaussians in the GMM prior on the true x positions :all_w_gauss (jnp.ndarray): The standard deviations of the Gaussians in the GMM prior on the true x positions :all_weights (jnp.ndarray): The weights of the Gaussians in the GMM prior on the true x positions Returns: :neglog_p (float): The negative log-likelihood """ ngauss = len(all_weights) N = len(xobs) all_log_p = jnp.empty((ngauss, N)) for i in range(ngauss): mu_gauss = all_mu_gauss[i] w_gauss = all_w_gauss[i] weight = all_weights[i] Ai = fprime if (not hasattr(Ai, "__len__")) or len(Ai) == 1: Ai = jnp.full(N, jnp.squeeze(jnp.array(Ai))) Bi = f - Ai * xobs s2 = yerr ** 2 + sig ** 2 den = Ai ** 2 * w_gauss ** 2 * xerr ** 2 + \ s2 * (w_gauss ** 2 + xerr ** 2) all_log_p = all_log_p.at[i, :].set( - jnp.log(weight) + 1/2 * jnp.log(2 * jnp.pi) + 1/2 * jnp.log(den) + 1/2 * (w_gauss ** 2 * (Ai * xobs + Bi - yobs) ** 2 / den) + 1/2 * (xerr ** 2 * (Ai * mu_gauss + Bi - yobs) ** 2 / den) + 1/2 * (s2 * (xobs - mu_gauss) ** 2 / den) ) all_log_p = - all_log_p # Combine the Gaussians max_log_p = jnp.amax(all_log_p, axis=0) neglog_p = - \ (max_log_p + jnp.log(jnp.sum(jnp.exp(all_log_p - max_log_p), axis=0))) neglog_p = jnp.sum(neglog_p) return neglog_p
[docs] def negloglike_prof(xobs, yobs, xerr, yerr, f, fprime, sig, include_logdet=True): """ Computes the negative log-likelihood under the assumption of an uncorrelated Gaussian likelihood, evaluated at the maximum likelihood values of xtrue (the profile likelihood) Args: :xobs (jnp.ndarray): The observed x values :yobs (jnp.ndarray): The observed y values :xerr (jnp.ndarray): The error on the observed x values :yerr (jnp.ndarray): The error on the observed y values :f (jnp.ndarray): If we are fitting the function f(x), this is f(x) evaluated at xobs :fprime (jnp.ndarray): If we are fitting the function f(x), this is df/dx evaluated at xobs :sig (float): The intrinsic scatter, which is added in quadrature with yerr :include_logdet (bool, default=True): Whether to include the normalisation term in the likelihood proportional to log(det(S)) Returns: :neglog_p (float): The negative log-likelihood """ N = len(xobs) Ai = fprime Bi = f - Ai * xobs sigy = jnp.atleast_1d(jnp.sqrt(yerr ** 2 + sig ** 2)) if len(sigy) == 1: sigy = jnp.full(N, sigy[0]) if include_logdet: neglog_p = ( N / 2 * jnp.log(2 * jnp.pi) + jnp.sum(jnp.log(sigy)) + 1/2 * jnp.sum((Ai * xobs + Bi - yobs) ** 2 / (Ai ** 2 * xerr ** 2 + sigy ** 2)) ) else: neglog_p = ( 1/2 * jnp.sum((Ai * xobs + Bi - yobs) ** 2 / (Ai ** 2 * xerr ** 2 + sigy ** 2)) ) return neglog_p
[docs] def negloglike_unif(xobs, yobs, xerr, yerr, f, fprime, sig): """ Computes the negative log-likelihood under the assumption of an uncorrelated Gaussian likelihood, where we have marginalised over the true x values, assuming an infinite uniform prior on these. Args: :xobs (jnp.ndarray): The observed x values :yobs (jnp.ndarray): The observed y values :xerr (jnp.ndarray): The error on the observed x values :yerr (jnp.ndarray): The error on the observed y values :f (jnp.ndarray): If we are fitting the function f(x), this is f(x) evaluated at xobs :fprime (jnp.ndarray): If we are fitting the function f(x), this is df/dx evaluated at xobs :sig (float): The intrinsic scatter, which is added in quadrature with yerr Returns: :neglog_p (float): The negative log-likelihood """ N = len(xobs) Ai = jnp.atleast_1d(fprime) if (not hasattr(Ai, "__len__")) or len(Ai) == 1: Ai = jnp.full(N, jnp.squeeze(jnp.array(Ai))) Bi = f - Ai * xobs neglog_p = ( N / 2 * jnp.log(2 * jnp.pi) + 1/2 * jnp.sum(jnp.log(Ai ** 2 * xerr ** 2 + yerr ** 2 + sig ** 2)) + 1/2 * jnp.sum((Ai * xobs + Bi - yobs) ** 2 / (Ai ** 2 * xerr ** 2 + yerr ** 2 + sig ** 2)) ) return neglog_p
[docs] def check_valid_covmat(D, tol=1e-8): """ Check if a covariance matrix is valid (symmetric and positive semi-definite) - Symmetry (within tolerance) - Positive semi-definiteness (eigvals >= -tol) Args: :D (jnp.ndarray): The covariance matrix to check :tol (float, default=1e-8): The tolerance for numerical checks Returns: :is_valid (bool): Whether the covariance matrix is valid """ # Symmetry check symmetric = jnp.allclose(D, D.T, atol=tol) # PSD check eigvals = jnp.linalg.eigvalsh(D) psd = jnp.all(eigvals >= -tol) return jnp.logical_and(symmetric, psd)
[docs] def negloglike_mnr_mv(xobs, yobs, Sigma, f, G, sig, mu_gauss, w_gauss): """ Computes the negative log-likelihood under the assumption of a correlated Gaussian likelihood (i.e. arbitrary covariance matrix) with a Gaussian prior on the true x positions. Args: :xobs (jnp.ndarray): The observed x values :yobs (jnp.ndarray): The observed y values :Sigma (jnp.ndarray): The covariance matrix giving the errors on the observed (x, y) values :f (jnp.ndarray): If we are fitting the function f(x), this is f(x) evaluated at xobs :G (jnp.ndarray): If we are fitting the function f(x), this is G_{ij} = df_i/dx_j evaluated at xobs :sig (float): The intrinsic scatter, which is added in quadrature with yerr :mu_gauss (float): The mean of the Gaussian prior on the true x positions :w_gauss (float): The standard deviation of the Gaussian prior on the true x positions Returns: :neglog_p (float): The negative log-likelihood """ nx = len(xobs) ny = len(yobs) W = jnp.identity(nx) * w_gauss ** 2 GW = jnp.matmul(G, W) # Covariance M = Sigma + jnp.concatenate([ jnp.concatenate([W, GW.T], axis=-1), jnp.concatenate([GW, jnp.matmul(GW, G.T) + jnp.identity(ny) * sig ** 2], axis=-1) ]) _, logdet2piM = jnp.linalg.slogdet(2 * jnp.pi * M) Minv = jnp.linalg.inv(M) # Vector z = jnp.concatenate( [mu_gauss - xobs, f + jnp.matmul(G, mu_gauss - xobs) - yobs]) neglog_p = 1/2 * logdet2piM + 1/2 * jnp.sum(z * jnp.matmul(Minv, z)) # Penalise invalid covariance matrices is_valid = check_valid_covmat(M) penalty = 1e20 return jnp.where(is_valid, neglog_p, penalty)
[docs] def negloglike_prof_mv(xobs, yobs, Sigma, f, G, sig, include_logdet=True): """ Computes the negative log-likelihood under the assumption of a correlated Gaussian likelihood (i.e. arbitrary covariance matrix), evaluated at the maximum likelihood values of xtrue (the profile likelihood) Args: :xobs (jnp.ndarray): The observed x values :yobs (jnp.ndarray): The observed y values :Sigma (jnp.ndarray): The covariance matrix giving the errors on the observed (x, y) values :f (jnp.ndarray): If we are fitting the function f(x), this is f(x) evaluated at xobs :G (jnp.ndarray): If we are fitting the function f(x), this is G_{ij} = df_i/dx_j evaluated at xobs :sig (float): The intrinsic scatter, which is added in quadrature with yerr :include_logdet (bool, default=True): Whether to include the normalisation term in the likelihood proportional to log(det(S)) Returns: :neglog_p (float): The negative log-likelihood """ nx = len(xobs) ny = len(yobs) D = ( Sigma[nx:, nx:] + + jnp.identity(ny) * sig ** 2 + jnp.matmul(G, jnp.matmul(Sigma[:nx, :nx], G.T)) - jnp.matmul(Sigma[nx:, :nx], G.T) - jnp.matmul(G, Sigma[:nx, nx:]) ) S = jnp.array(Sigma) S = S.at[nx:, nx:].set(S[nx:, nx:] + jnp.identity(ny) * sig ** 2) _, logdet2piS = jnp.linalg.slogdet(2 * jnp.pi * S) Dinv = jnp.linalg.inv(D) z = f - yobs if include_logdet: neglog_p = 1/2 * logdet2piS + 1/2 * jnp.sum(z * jnp.matmul(Dinv, z)) else: neglog_p = 1/2 * jnp.sum(z * jnp.matmul(Dinv, z)) # Penalise invalid covariance matrices is_valid = check_valid_covmat(D) penalty = 1e20 return jnp.where(is_valid, neglog_p, penalty)
[docs] def negloglike_unif_mv(xobs, yobs, Sigma, f, G, sig): """ Computes the negative log-likelihood under the assumption of a correlated Gaussian likelihood (i.e. arbitrary covariance matrix), where we have marginalised over the true x values, assuming an infinite uniform prior on these. Args: :xobs (jnp.ndarray): The observed x values :yobs (jnp.ndarray): The observed y values :Sigma (jnp.ndarray): The covariance matrix giving the errors on the observed (x, y) values :f (jnp.ndarray): If we are fitting the function f(x), this is f(x) evaluated at xobs :G (jnp.ndarray): If we are fitting the function f(x), this is G_{ij} = df_i/dx_j evaluated at xobs :sig (float): The intrinsic scatter, which is added in quadrature with yerr Returns: :neglog_p (float): The negative log-likelihood """ nx = len(xobs) ny = len(yobs) D = ( Sigma[nx:, nx:] + + jnp.identity(ny) * sig ** 2 + jnp.matmul(G, jnp.matmul(Sigma[:nx, :nx], G.T)) - jnp.matmul(Sigma[nx:, :nx], G.T) - jnp.matmul(G, Sigma[:nx, nx:]) ) _, logdet2piD = jnp.linalg.slogdet(2 * jnp.pi * D) Dinv = jnp.linalg.inv(D) z = f - yobs neglog_p = 1/2 * logdet2piD + 1/2 * jnp.sum(z * jnp.matmul(Dinv, z)) # Penalise invalid covariance matrices is_valid = check_valid_covmat(D) penalty = 1e20 return jnp.where(is_valid, neglog_p, penalty)