import warnings
import numpyro.distributions as dist
from jax.scipy.special import erfc
import jax.numpy as jnp
import numpy as np
[docs]
def likelihood_warnings(method, infer_intrinsic, nx, errors, covmat):
"""
Raise warnings if using asymptotically biased likelihood
for the situation of interest.
Args:
:method (str, default='mnr'): The name of the likelihood method to use
('mnr', 'gmm', 'unif' or 'prof').
:infer_intrinsic (bool, default=True): Whether to infer the intrinsic
scatter in the y direction
:nx (int): The number of observed x values
:errors (jnp.ndarray): If covmat=False, then this is [xerr, yerr], giving
the error on the observed x and y values. Otherwise, this is the
covariance matrix in the order (x, y)
:covmat (bool, default=False): This determines whether the errors argument
is [xerr, yerr] (False) or a covariance matrix (True).
"""
# Modify the warnings filter to always show all warnings
warnings.filterwarnings('always')
# Check if xerrs are all zero
if covmat:
xerr = errors[:nx, :nx]
else:
xerr = errors[0]
if isinstance(xerr, (float, int)):
no_xerr = xerr == 0
else:
no_xerr = jnp.all(jnp.array(xerr) == 0)
warning_message = None
if no_xerr and method not in ['unif', 'prof']:
warning_message = (
f'Not recommended method "{method}" for this setup. '
'Use "unif" or "prof" instead.')
elif (not no_xerr) and infer_intrinsic and method not in ['mnr', 'gmm']:
warning_message = (
f'Not recommended method "{method}" for this setup. '
'Use "mnr" or "gmm" instead.')
elif (not no_xerr) and (not infer_intrinsic) and method != 'prof':
warning_message = (
f'Not recommended method "{method}" for this setup. '
'Use "prof" instead.')
if warning_message is not None:
warnings.warn(warning_message, UserWarning)
[docs]
def negloglike_mnr_uplims(xobs, yobs, y_is_detected, xerr, yerr, f, fprime, sig, mu_gauss, w_gauss):
"""
Computes the negative log-likelihood under the assumption of an uncorrelated
Gaussian likelihood with a Gaussian prior on the true x positions,
where some of the y values are upper limits.
Args:
:xobs (jnp.ndarray): The observed x values
:yobs (jnp.ndarray): The observed y values
:y_is_detected (jnp.ndarray): A boolean array of the same length as xobs and
yobs, giving whether each point is a detection (True) or an upper limit (False)
:xerr (jnp.ndarray): The error on the observed x values
:yerr (jnp.ndarray): The error on the observed y values
:f (jnp.ndarray): If we are fitting the function f(x), this is f(x) evaluated
at xobs
:fprime (jnp.ndarray): If we are fitting the function f(x), this is df/dx
evaluated at xobs
:sig (float): The intrinsic scatter, which is added in quadrature with yerr
:mu_gauss (float): The mean of the Gaussian prior on the true x positions
:w_gauss (float): The standard deviation of the Gaussian prior on the true x
positions
"""
N = len(xobs)
Ai = fprime
Ai = jnp.broadcast_to(jnp.atleast_1d(Ai), (N,))
Bi = f - Ai * xobs
if not hasattr(xerr, '__len__') or len(xerr) == 1:
xerr = jnp.full(N, jnp.squeeze(jnp.array(xerr)))
if not hasattr(yerr, '__len__') or len(yerr) == 1:
yerr = jnp.full(N, jnp.squeeze(jnp.array(yerr)))
mask = np.asarray(y_is_detected, dtype=bool)
mask_uplim = ~mask
xdet = xobs[mask]
ydet = yobs[mask]
xerr_det = xerr[mask]
yerr_det = yerr[mask]
Ai_det = Ai[mask]
Bi_det = Bi[mask]
xuplim = xobs[mask_uplim]
yuplim = yobs[mask_uplim]
xerr_uplim = xerr[mask_uplim]
yerr_uplim = yerr[mask_uplim]
Ai_uplim = Ai[mask_uplim]
Bi_uplim = Bi[mask_uplim]
neglog_p = 0.0
# DETECTIONS
if len(xdet) > 0:
s2 = yerr_det ** 2 + sig**2
den = Ai_det**2 * xerr_det**2 * \
w_gauss**2 + s2*(xerr_det**2 + w_gauss**2)
numerator_t1 = w_gauss**2*(Ai_det*xdet + Bi_det - ydet)**2 + xerr_det**2*(
Ai_det*mu_gauss + Bi_det - ydet)**2 + s2*(xdet - mu_gauss)**2
denominator_t1 = 2*den
t2 = jnp.log(2 * jnp.pi * jnp.sqrt(den))
neglog_p = jnp.sum(numerator_t1/denominator_t1 + t2)
# UPPER LIMITS
if len(xuplim) > 0:
sigma_c_squared = (1/w_gauss**2 + 1/xerr_uplim**2)**(-1)
mu_c = sigma_c_squared * (mu_gauss/w_gauss**2 + xuplim/xerr_uplim**2)
sigma_squared = yerr_uplim**2 + sig**2
t1 = dist.Normal(xuplim, jnp.sqrt(
xerr_uplim**2 + w_gauss**2)).log_prob(mu_gauss).sum()
t2 = jnp.sum(jnp.log(0.5*erfc((Ai_uplim*mu_c + Bi_uplim - yuplim) /
(jnp.sqrt(2*sigma_squared + 2 * Ai_uplim**2 * sigma_c_squared)))))
neglog_p_uplim = - t1 - t2
neglog_p += neglog_p_uplim
return neglog_p
[docs]
def negloglike_mnr(xobs, yobs, xerr, yerr, f, fprime, sig, mu_gauss, w_gauss):
"""
Computes the negative log-likelihood under the assumption of an uncorrelated
Gaussian likelihood with a Gaussian prior on the true x positions.
Args:
:xobs (jnp.ndarray): The observed x values
:yobs (jnp.ndarray): The observed y values
:xerr (jnp.ndarray): The error on the observed x values
:yerr (jnp.ndarray): The error on the observed y values
:f (jnp.ndarray): If we are fitting the function f(x), this is f(x) evaluated
at xobs
:fprime (jnp.ndarray): If we are fitting the function f(x), this is df/dx
evaluated at xobs
:sig (float): The intrinsic scatter, which is added in quadrature with yerr
:mu_gauss (float): The mean of the Gaussian prior on the true x positions
:w_gauss (float): The standard deviation of the Gaussian prior on the true x
positions
Returns:
:neglog_p (float): The negative log-likelihood
"""
N = len(xobs)
Ai = fprime
if (not hasattr(Ai, "__len__")) or len(Ai) == 1:
Ai = jnp.full(N, jnp.squeeze(jnp.array(Ai)))
Bi = f - Ai * xobs
s2 = yerr ** 2 + sig ** 2
den = Ai ** 2 * w_gauss ** 2 * xerr ** 2 + s2 * (w_gauss ** 2 + xerr ** 2)
neglog_p = (
N * jnp.log(2 * jnp.pi)
+ 1/2 * jnp.sum(jnp.log(den))
+ 1/2 * jnp.sum(w_gauss ** 2 * (Ai * xobs + Bi - yobs) ** 2 / den)
+ 1/2 * jnp.sum(xerr ** 2 * (Ai * mu_gauss + Bi - yobs) ** 2 / den)
+ 1/2 * jnp.sum(s2 * (xobs - mu_gauss) ** 2 / den)
)
return neglog_p
[docs]
def negloglike_gmm(xobs, yobs, xerr, yerr, f, fprime, sig, all_mu_gauss, all_w_gauss,
all_weights):
"""
Computes the negative log-likelihood under the assumption of an uncorrelated
Gaussian likelihood with a GMM prior on the true x positions.
Args:
:xobs (jnp.ndarray): The observed x values
:yobs (jnp.ndarray): The observed y values
:xerr (jnp.ndarray): The error on the observed x values
:yerr (jnp.ndarray): The error on the observed y values
:f (jnp.ndarray): If we are fitting the function f(x), this is f(x) evaluated
at xobs
:fprime (jnp.ndarray): If we are fitting the function f(x), this is df/dx
evaluated at xobs
:sig (float): The intrinsic scatter, which is added in quadrature with yerr
:all_mu_gauss (jnp.ndarray): The means of the Gaussians in the GMM prior on
the true x positions
:all_w_gauss (jnp.ndarray): The standard deviations of the Gaussians in the GMM
prior on the true x positions
:all_weights (jnp.ndarray): The weights of the Gaussians in the GMM prior on the
true x positions
Returns:
:neglog_p (float): The negative log-likelihood
"""
ngauss = len(all_weights)
N = len(xobs)
all_log_p = jnp.empty((ngauss, N))
for i in range(ngauss):
mu_gauss = all_mu_gauss[i]
w_gauss = all_w_gauss[i]
weight = all_weights[i]
Ai = fprime
if (not hasattr(Ai, "__len__")) or len(Ai) == 1:
Ai = jnp.full(N, jnp.squeeze(jnp.array(Ai)))
Bi = f - Ai * xobs
s2 = yerr ** 2 + sig ** 2
den = Ai ** 2 * w_gauss ** 2 * xerr ** 2 + \
s2 * (w_gauss ** 2 + xerr ** 2)
all_log_p = all_log_p.at[i, :].set(
- jnp.log(weight)
+ 1/2 * jnp.log(2 * jnp.pi)
+ 1/2 * jnp.log(den)
+ 1/2 * (w_gauss ** 2 * (Ai * xobs + Bi - yobs) ** 2 / den)
+ 1/2 * (xerr ** 2 * (Ai * mu_gauss + Bi - yobs) ** 2 / den)
+ 1/2 * (s2 * (xobs - mu_gauss) ** 2 / den)
)
all_log_p = - all_log_p
# Combine the Gaussians
max_log_p = jnp.amax(all_log_p, axis=0)
neglog_p = - \
(max_log_p + jnp.log(jnp.sum(jnp.exp(all_log_p - max_log_p), axis=0)))
neglog_p = jnp.sum(neglog_p)
return neglog_p
[docs]
def negloglike_prof(xobs, yobs, xerr, yerr, f, fprime, sig, include_logdet=True):
"""
Computes the negative log-likelihood under the assumption of an uncorrelated
Gaussian likelihood, evaluated at the maximum likelihood values of xtrue
(the profile likelihood)
Args:
:xobs (jnp.ndarray): The observed x values
:yobs (jnp.ndarray): The observed y values
:xerr (jnp.ndarray): The error on the observed x values
:yerr (jnp.ndarray): The error on the observed y values
:f (jnp.ndarray): If we are fitting the function f(x), this is f(x) evaluated at
xobs
:fprime (jnp.ndarray): If we are fitting the function f(x), this is df/dx
evaluated at xobs
:sig (float): The intrinsic scatter, which is added in quadrature with yerr
:include_logdet (bool, default=True): Whether to include the normalisation term
in the likelihood proportional to log(det(S))
Returns:
:neglog_p (float): The negative log-likelihood
"""
N = len(xobs)
Ai = fprime
Bi = f - Ai * xobs
sigy = jnp.atleast_1d(jnp.sqrt(yerr ** 2 + sig ** 2))
if len(sigy) == 1:
sigy = jnp.full(N, sigy[0])
if include_logdet:
neglog_p = (
N / 2 * jnp.log(2 * jnp.pi)
+ jnp.sum(jnp.log(sigy))
+ 1/2 * jnp.sum((Ai * xobs + Bi - yobs) ** 2 /
(Ai ** 2 * xerr ** 2 + sigy ** 2))
)
else:
neglog_p = (
1/2 * jnp.sum((Ai * xobs + Bi - yobs) ** 2 /
(Ai ** 2 * xerr ** 2 + sigy ** 2))
)
return neglog_p
[docs]
def negloglike_unif(xobs, yobs, xerr, yerr, f, fprime, sig):
"""
Computes the negative log-likelihood under the assumption of an uncorrelated
Gaussian likelihood, where we have marginalised over the true x values,
assuming an infinite uniform prior on these.
Args:
:xobs (jnp.ndarray): The observed x values
:yobs (jnp.ndarray): The observed y values
:xerr (jnp.ndarray): The error on the observed x values
:yerr (jnp.ndarray): The error on the observed y values
:f (jnp.ndarray): If we are fitting the function f(x), this is f(x) evaluated at
xobs
:fprime (jnp.ndarray): If we are fitting the function f(x), this is df/dx
evaluated at xobs
:sig (float): The intrinsic scatter, which is added in quadrature with yerr
Returns:
:neglog_p (float): The negative log-likelihood
"""
N = len(xobs)
Ai = jnp.atleast_1d(fprime)
if (not hasattr(Ai, "__len__")) or len(Ai) == 1:
Ai = jnp.full(N, jnp.squeeze(jnp.array(Ai)))
Bi = f - Ai * xobs
neglog_p = (
N / 2 * jnp.log(2 * jnp.pi)
+ 1/2 * jnp.sum(jnp.log(Ai ** 2 * xerr ** 2 + yerr ** 2 + sig ** 2))
+ 1/2 * jnp.sum((Ai * xobs + Bi - yobs) ** 2 /
(Ai ** 2 * xerr ** 2 + yerr ** 2 + sig ** 2))
)
return neglog_p
[docs]
def check_valid_covmat(D, tol=1e-8):
"""
Check if a covariance matrix is valid (symmetric and positive semi-definite)
- Symmetry (within tolerance)
- Positive semi-definiteness (eigvals >= -tol)
Args:
:D (jnp.ndarray): The covariance matrix to check
:tol (float, default=1e-8): The tolerance for numerical checks
Returns:
:is_valid (bool): Whether the covariance matrix is valid
"""
# Symmetry check
symmetric = jnp.allclose(D, D.T, atol=tol)
# PSD check
eigvals = jnp.linalg.eigvalsh(D)
psd = jnp.all(eigvals >= -tol)
return jnp.logical_and(symmetric, psd)
[docs]
def negloglike_mnr_mv(xobs, yobs, Sigma, f, G, sig, mu_gauss, w_gauss):
"""
Computes the negative log-likelihood under the assumption of a correlated
Gaussian likelihood (i.e. arbitrary covariance matrix) with a Gaussian prior
on the true x positions.
Args:
:xobs (jnp.ndarray): The observed x values
:yobs (jnp.ndarray): The observed y values
:Sigma (jnp.ndarray): The covariance matrix giving the errors on the observed
(x, y) values
:f (jnp.ndarray): If we are fitting the function f(x), this is f(x) evaluated at
xobs
:G (jnp.ndarray): If we are fitting the function f(x), this is
G_{ij} = df_i/dx_j evaluated at xobs
:sig (float): The intrinsic scatter, which is added in quadrature with yerr
:mu_gauss (float): The mean of the Gaussian prior on the true x positions
:w_gauss (float): The standard deviation of the Gaussian prior on the true x
positions
Returns:
:neglog_p (float): The negative log-likelihood
"""
nx = len(xobs)
ny = len(yobs)
W = jnp.identity(nx) * w_gauss ** 2
GW = jnp.matmul(G, W)
# Covariance
M = Sigma + jnp.concatenate([
jnp.concatenate([W, GW.T], axis=-1),
jnp.concatenate([GW, jnp.matmul(GW, G.T)
+ jnp.identity(ny) * sig ** 2], axis=-1)
])
_, logdet2piM = jnp.linalg.slogdet(2 * jnp.pi * M)
Minv = jnp.linalg.inv(M)
# Vector
z = jnp.concatenate(
[mu_gauss - xobs, f + jnp.matmul(G, mu_gauss - xobs) - yobs])
neglog_p = 1/2 * logdet2piM + 1/2 * jnp.sum(z * jnp.matmul(Minv, z))
# Penalise invalid covariance matrices
is_valid = check_valid_covmat(M)
penalty = 1e20
return jnp.where(is_valid, neglog_p, penalty)
[docs]
def negloglike_prof_mv(xobs, yobs, Sigma, f, G, sig, include_logdet=True):
"""
Computes the negative log-likelihood under the assumption of a correlated
Gaussian likelihood (i.e. arbitrary covariance matrix), evaluated at the
maximum likelihood values of xtrue (the profile likelihood)
Args:
:xobs (jnp.ndarray): The observed x values
:yobs (jnp.ndarray): The observed y values
:Sigma (jnp.ndarray): The covariance matrix giving the errors on the observed
(x, y) values
:f (jnp.ndarray): If we are fitting the function f(x), this is f(x) evaluated at
xobs
:G (jnp.ndarray): If we are fitting the function f(x), this is
G_{ij} = df_i/dx_j evaluated at xobs
:sig (float): The intrinsic scatter, which is added in quadrature with yerr
:include_logdet (bool, default=True): Whether to include the normalisation term
in the likelihood proportional to log(det(S))
Returns:
:neglog_p (float): The negative log-likelihood
"""
nx = len(xobs)
ny = len(yobs)
D = (
Sigma[nx:, nx:] + + jnp.identity(ny) * sig ** 2
+ jnp.matmul(G, jnp.matmul(Sigma[:nx, :nx], G.T))
- jnp.matmul(Sigma[nx:, :nx], G.T) - jnp.matmul(G, Sigma[:nx, nx:])
)
S = jnp.array(Sigma)
S = S.at[nx:, nx:].set(S[nx:, nx:] + jnp.identity(ny) * sig ** 2)
_, logdet2piS = jnp.linalg.slogdet(2 * jnp.pi * S)
Dinv = jnp.linalg.inv(D)
z = f - yobs
if include_logdet:
neglog_p = 1/2 * logdet2piS + 1/2 * jnp.sum(z * jnp.matmul(Dinv, z))
else:
neglog_p = 1/2 * jnp.sum(z * jnp.matmul(Dinv, z))
# Penalise invalid covariance matrices
is_valid = check_valid_covmat(D)
penalty = 1e20
return jnp.where(is_valid, neglog_p, penalty)
[docs]
def negloglike_unif_mv(xobs, yobs, Sigma, f, G, sig):
"""
Computes the negative log-likelihood under the assumption of a correlated
Gaussian likelihood (i.e. arbitrary covariance matrix), where we have
marginalised over the true x values, assuming an infinite uniform prior on these.
Args:
:xobs (jnp.ndarray): The observed x values
:yobs (jnp.ndarray): The observed y values
:Sigma (jnp.ndarray): The covariance matrix giving the errors on the observed
(x, y) values
:f (jnp.ndarray): If we are fitting the function f(x), this is f(x) evaluated
at xobs
:G (jnp.ndarray): If we are fitting the function f(x), this is
G_{ij} = df_i/dx_j evaluated at xobs
:sig (float): The intrinsic scatter, which is added in quadrature with yerr
Returns:
:neglog_p (float): The negative log-likelihood
"""
nx = len(xobs)
ny = len(yobs)
D = (
Sigma[nx:, nx:] + + jnp.identity(ny) * sig ** 2
+ jnp.matmul(G, jnp.matmul(Sigma[:nx, :nx], G.T))
- jnp.matmul(Sigma[nx:, :nx], G.T) - jnp.matmul(G, Sigma[:nx, nx:])
)
_, logdet2piD = jnp.linalg.slogdet(2 * jnp.pi * D)
Dinv = jnp.linalg.inv(D)
z = f - yobs
neglog_p = 1/2 * logdet2piD + 1/2 * jnp.sum(z * jnp.matmul(Dinv, z))
# Penalise invalid covariance matrices
is_valid = check_valid_covmat(D)
penalty = 1e20
return jnp.where(is_valid, neglog_p, penalty)