from jax import lax
import jax.random
import jax.numpy as jnp
from jax.scipy.stats import norm as jax_norm
from jax.scipy.special import ndtri, ndtr
import numpy as np
import scipy.optimize
import numpyro.distributions as dist
from numpyro.distributions.util import promote_shapes
from numpyro.distributions.util import validate_sample
from numpyro.distributions.util import is_prng_key
import roxy.likelihoods
[docs]
class Likelihood_MNR_uplims(dist.Distribution):
"""
Class to be used by ``numpyro`` to evaluate the log-likelihood under
the assumption of an uncorrelated Gaussian likelihood with a Gaussian
prior on the true x positions, where some of the points are upper limits
(i.e. non-detections in y).
Args:
:xobs (jnp.ndarray): The observed x values
:yobs (jnp.ndarray): The observed y values
:y_is_detected (jnp.ndarray): A boolean array of the same length as xobs and
yobs, giving whether each point is a detection (True) or an upper limit
(False)
:xerr (jnp.ndarray): The error on the observed x values
:yerr (jnp.ndarray): The error on the observed y values
:f (jnp.ndarray): If we are fitting the function f(x), this is f(x) evaluated
at xobs
:fprime (jnp.ndarray): If we are fitting the function f(x), this is df/dx
evaluated at xobs
:sig (float): The intrinsic scatter, which is added in quadrature with yerr
:mu_gauss (float): The mean of the Gaussian prior on the true x positions
:w_gauss (float): The standard deviation of the Gaussian prior on the true x
positions
"""
def __init__(self, xobs, yobs, y_is_detected, xerr, yerr, f, fprime, sig,
mu_gauss, w_gauss):
self.xobs, self.yobs, self.y_is_detected, self.xerr, self.yerr, self.f, self.fprime, self.sig, self.mu_gauss, self.w_gauss = \
promote_shapes(xobs, yobs, y_is_detected, xerr,
yerr, f, fprime, sig, mu_gauss, w_gauss)
batch_shape = lax.broadcast_shapes(
jnp.shape(xobs),
jnp.shape(yobs),
jnp.shape(y_is_detected),
jnp.shape(xerr),
jnp.shape(yerr),
jnp.shape(f),
jnp.shape(fprime),
jnp.shape(sig),
jnp.shape(mu_gauss),
jnp.shape(w_gauss),
)
super(Likelihood_MNR_uplims, self).__init__(batch_shape=batch_shape)
[docs]
def sample(self, key, sample_shape=()):
raise NotImplementedError
[docs]
def log_prob(self, value):
return - roxy.likelihoods.negloglike_mnr_uplims(self.xobs, self.yobs, self.y_is_detected,
self.xerr, self.yerr, self.f, self.fprime,
self.sig, self.mu_gauss, self.w_gauss)
[docs]
class Likelihood_MNR(dist.Distribution):
"""
Class to be used by ``numpyro`` to evaluate the log-likelihood under
the assumption of an uncorrelated Gaussian likelihood with a Gaussian
prior on the true x positions.
Args:
:xobs (jnp.ndarray): The observed x values
:yobs (jnp.ndarray): The observed y values
:xerr (jnp.ndarray): The error on the observed x values
:yerr (jnp.ndarray): The error on the observed y values
:f (jnp.ndarray): If we are fitting the function f(x), this is f(x) evaluated
at xobs
:fprime (jnp.ndarray): If we are fitting the function f(x), this is df/dx
evaluated at xobs
:sig (float): The intrinsic scatter, which is added in quadrature with yerr
:mu_gauss (float): The mean of the Gaussian prior on the true x positions
:w_gauss (float): The standard deviation of the Gaussian prior on the true x
positions
"""
def __init__(self, xobs, yobs, xerr, yerr, f, fprime, sig, mu_gauss, w_gauss):
self.xobs, self.yobs, self.xerr, self.yerr, self.f, self.fprime, self.sig, \
self.mu_gauss, self.w_gauss = promote_shapes(xobs, yobs, xerr, yerr, f,
fprime, sig, mu_gauss, w_gauss)
batch_shape = lax.broadcast_shapes(
jnp.shape(xobs),
jnp.shape(yobs),
jnp.shape(xerr),
jnp.shape(yerr),
jnp.shape(f),
jnp.shape(fprime),
jnp.shape(sig),
jnp.shape(mu_gauss),
jnp.shape(w_gauss),
)
super(Likelihood_MNR, self).__init__(batch_shape=batch_shape)
[docs]
def sample(self, key, sample_shape=()):
raise NotImplementedError
[docs]
def log_prob(self, value):
return - roxy.likelihoods.negloglike_mnr(self.xobs, self.yobs, self.xerr,
self.yerr, self.f, self.fprime, self.sig,
self.mu_gauss, self.w_gauss)
[docs]
class Likelihood_MNR_MV(dist.Distribution):
"""
Class to be used by ``numpyro`` to evaluate the log-likelihood under
the assumption of a correlated Gaussian likelihood (i.e. arbitrary
covariance matrix), with a Gaussian prior on the true x positions.
Args:
:xobs (jnp.ndarray): The observed x values
:yobs (jnp.ndarray): The observed y values
:Sxx (jnp.ndarray): The xx component of the covariance matrix giving the errors
on the observed (x, y) values
:Syy (jnp.ndarray): The yy component of the covariance matrix giving the errors
on the observed (x, y) values
:Sxy (jnp.ndarray): The xy component of the covariance matrix giving the errors
on the observed (x, y) values
:f (jnp.ndarray): If we are fitting the function f(x), this is f(x) evaluated at
xobs
:G (jnp.ndarray): If we are fitting the function f(x), this is
G_{ij} = df_i/dx_j evaluated at xobs
:sig (float): The intrinsic scatter, which is added in quadrature with yerr
:mu_gauss (float): The mean of the Gaussian prior on the true x positions
:w_gauss (float): The standard deviation of the Gaussian prior on the true x
positions
"""
def __init__(self, xobs, yobs, Sxx, Syy, Sxy, f, G, sig, mu_gauss, w_gauss):
xobs_p = xobs[..., jnp.newaxis]
yobs_p = yobs[..., jnp.newaxis]
f_p = f[..., jnp.newaxis]
xobs_p, yobs_p, Sxx, Syy, Sxy, f_p, self.G, \
sig, mu_gauss, w_gauss = promote_shapes(xobs_p, yobs_p, Sxx, Syy, Sxy,
f_p, G, sig, mu_gauss, w_gauss)
batch_shape = lax.broadcast_shapes(
jnp.shape(xobs_p)[:-2],
jnp.shape(yobs_p)[:-2],
jnp.shape(Sxx)[:-2],
jnp.shape(Syy)[:-2],
jnp.shape(Sxy)[:-2],
jnp.shape(f_p)[:-2],
jnp.shape(self.G)[:-2],
jnp.shape(sig)[:-2],
jnp.shape(mu_gauss)[:-2],
jnp.shape(w_gauss)[:-2],
)
event_shape = jnp.shape(xobs_p)[-1:]
self.xobs = xobs_p[..., 0]
self.yobs = yobs_p[..., 0]
self.Sigma = jnp.concatenate(
[jnp.concatenate([Sxx, Sxy], axis=-1),
jnp.concatenate([Sxy.T, Syy], axis=-1)]
)
self.f = f_p[..., 0]
self.sig = sig[..., 0]
self.mu_gauss = mu_gauss[..., 0]
self.w_gauss = w_gauss[..., 0]
super(Likelihood_MNR_MV, self).__init__(
batch_shape=batch_shape,
event_shape=event_shape
)
[docs]
def sample(self, key, sample_shape=()):
raise NotImplementedError
[docs]
def log_prob(self, value):
return - roxy.likelihoods.negloglike_mnr_mv(self.xobs, self.yobs, self.Sigma,
self.f, self.G, self.sig,
self.mu_gauss, self.w_gauss)
[docs]
class Likelihood_prof(dist.Distribution):
"""
Class to be used by ``numpyro`` to evaluate the log-likelihood under
the assumption of an uncorrelated Gaussian likelihood, evaluated at
the maximum likelihood values of xtrue (the profile likelihood)
Args:
:xobs (jnp.ndarray): The observed x values
:yobs (jnp.ndarray): The observed y values
:xerr (jnp.ndarray): The error on the observed x values
:yerr (jnp.ndarray): The error on the observed y values
:f (jnp.ndarray): If we are fitting the function f(x), this is f(x) evaluated
at xobs
:fprime (jnp.ndarray): If we are fitting the function f(x), this is df/dx
evaluated at xobs
:sig (float): The intrinsic scatter, which is added in quadrature with yerr
:include_logdet (bool, default=True): Whether to include the normalisation term
in the likelihood proportional to log(det(S))
"""
def __init__(self, xobs, yobs, xerr, yerr, f, fprime, sig, include_logdet=True):
self.xobs, self.yobs, self.xerr, self.yerr, \
self.f, self.fprime, self.sig = promote_shapes(xobs, yobs, xerr, yerr,
f, fprime, sig)
batch_shape = lax.broadcast_shapes(
jnp.shape(xobs),
jnp.shape(yobs),
jnp.shape(xerr),
jnp.shape(yerr),
jnp.shape(f),
jnp.shape(fprime),
jnp.shape(sig),
)
self.include_logdet = include_logdet
super(Likelihood_prof, self).__init__(batch_shape=batch_shape)
[docs]
def sample(self, key, sample_shape=()):
raise NotImplementedError
[docs]
def log_prob(self, value):
return - roxy.likelihoods.negloglike_prof(self.xobs, self.yobs, self.xerr,
self.yerr, self.f, self.fprime, self.sig,
include_logdet=self.include_logdet)
[docs]
class Likelihood_prof_MV(dist.Distribution):
"""
Class to be used by ``numpyro`` to evaluate the log-likelihood under
the assumption of a correlated Gaussian likelihood (i.e. arbitrary covariance
matrix), evaluated at the maximum likelihood values of xtrue (the profile
likelihood)
Args:
:xobs (jnp.ndarray): The observed x values
:yobs (jnp.ndarray): The observed y values
:Sxx (jnp.ndarray): The xx component of the covariance matrix giving the errors
on the observed (x, y) values
:Syy (jnp.ndarray): The yy component of the covariance matrix giving the errors
on the observed (x, y) values
:Sxy (jnp.ndarray): The xy component of the covariance matrix giving the errors
on the observed (x, y) values
:f (jnp.ndarray): If we are fitting the function f(x), this is f(x) evaluated at
xobs
:G (jnp.ndarray): If we are fitting the function f(x), this is
G_{ij} = df_i/dx_j evaluated at xobs
:sig (float): The intrinsic scatter, which is added in quadrature with yerr
:include_logdet (bool, default=True): Whether to include the normalisation term
in the likelihood proportional to log(det(S))
"""
def __init__(self, xobs, yobs, Sxx, Syy, Sxy, f, G, sig,
include_logdet=True):
xobs_p = xobs[..., jnp.newaxis]
yobs_p = yobs[..., jnp.newaxis]
f_p = f[..., jnp.newaxis]
xobs_p, yobs_p, Sxx, Syy, Sxy, f_p, self.G, sig = promote_shapes(xobs_p, yobs_p,
Sxx, Syy, Sxy, f_p, G, sig)
batch_shape = lax.broadcast_shapes(
jnp.shape(xobs_p)[:-2],
jnp.shape(yobs_p)[:-2],
jnp.shape(Sxx)[:-2],
jnp.shape(Syy)[:-2],
jnp.shape(Sxy)[:-2],
jnp.shape(f_p)[:-2],
jnp.shape(self.G)[:-2],
jnp.shape(sig)[:-2],
)
event_shape = jnp.shape(xobs_p)[-1:]
self.xobs = xobs_p[..., 0]
self.yobs = yobs_p[..., 0]
self.Sigma = jnp.concatenate(
[jnp.concatenate([Sxx, Sxy], axis=-1),
jnp.concatenate([Sxy.T, Syy], axis=-1)]
)
self.f = f_p[..., 0]
self.sig = sig[..., 0]
self.include_logdet = include_logdet
super(Likelihood_prof_MV, self).__init__(
batch_shape=batch_shape,
event_shape=event_shape
)
[docs]
def sample(self, key, sample_shape=()):
raise NotImplementedError
[docs]
def log_prob(self, value):
return - roxy.likelihoods.negloglike_prof_mv(self.xobs, self.yobs, self.Sigma,
self.f, self.G, self.sig,
include_logdet=self.include_logdet)
[docs]
class Likelihood_unif(dist.Distribution):
"""
Class to be used by ``numpyro`` to evaluate the log-likelihood under
the assumption of an uncorrelated Gaussian likelihood, where we have
marginalised over the true x values, assuming an infinite uniform
prior on these.
Args:
:xobs (jnp.ndarray): The observed x values
:yobs (jnp.ndarray): The observed y values
:xerr (jnp.ndarray): The error on the observed x values
:yerr (jnp.ndarray): The error on the observed y values
:f (jnp.ndarray): If we are fitting the funciton f(x), this is f(x) evaluated
at xobs
:fprime (jnp.ndarray): If we are fitting the funciton f(x), this is df/dx
evaluated at xobs
:sig (float): The intrinsic scatter, which is added in quadrature with yerr
"""
def __init__(self, xobs, yobs, xerr, yerr, f, fprime, sig):
self.xobs, self.yobs, self.xerr, self.yerr, \
self.f, self.fprime, self.sig = promote_shapes(xobs, yobs, xerr, yerr,
f, fprime, sig)
batch_shape = lax.broadcast_shapes(
jnp.shape(xobs),
jnp.shape(yobs),
jnp.shape(xerr),
jnp.shape(yerr),
jnp.shape(f),
jnp.shape(fprime),
jnp.shape(sig),
)
super(Likelihood_unif, self).__init__(batch_shape=batch_shape)
[docs]
def sample(self, key, sample_shape=()):
raise NotImplementedError
[docs]
def log_prob(self, value):
return - roxy.likelihoods.negloglike_unif(self.xobs, self.yobs, self.xerr,
self.yerr, self.f, self.fprime, self.sig)
[docs]
class Likelihood_unif_MV(dist.Distribution):
"""
Class to be used by ``numpyro`` to evaluate the log-likelihood under
the assumption of a correlated Gaussian likelihood (i.e. arbitrary covariance
matrix), where we have marginalised over the true x values, assuming an
infinite uniform prior on these.
Args:
:xobs (jnp.ndarray): The observed x values
:yobs (jnp.ndarray): The observed y values
:Sxx (jnp.ndarray): The xx component of the covariance matrix giving the errors
on the observed (x, y) values
:Syy (jnp.ndarray): The yy component of the covariance matrix giving the errors
on the observed (x, y) values
:Sxy (jnp.ndarray): The xy component of the covariance matrix giving the errors
on the observed (x, y) values
:f (jnp.ndarray): If we are fitting the function f(x), this is f(x) evaluated at
xobs
:G (jnp.ndarray): If we are fitting the function f(x), this is
G_{ij} = df_i/dx_j evaluated at xobs
:sig (float): The intrinsic scatter, which is added in quadrature with yerr
"""
def __init__(self, xobs, yobs, Sxx, Syy, Sxy, f, G, sig):
xobs_p = xobs[..., jnp.newaxis]
yobs_p = yobs[..., jnp.newaxis]
f_p = f[..., jnp.newaxis]
xobs_p, yobs_p, Sxx, Syy, Sxy, f_p, self.G, sig = promote_shapes(xobs_p, yobs_p,
Sxx, Syy, Sxy, f_p, G, sig)
batch_shape = lax.broadcast_shapes(
jnp.shape(xobs_p)[:-2],
jnp.shape(yobs_p)[:-2],
jnp.shape(Sxx)[:-2],
jnp.shape(Syy)[:-2],
jnp.shape(Sxy)[:-2],
jnp.shape(f_p)[:-2],
jnp.shape(self.G)[:-2],
jnp.shape(sig)[:-2],
)
event_shape = jnp.shape(xobs_p)[-1:]
self.xobs = xobs_p[..., 0]
self.yobs = yobs_p[..., 0]
self.Sigma = jnp.concatenate(
[jnp.concatenate([Sxx, Sxy], axis=-1),
jnp.concatenate([Sxy.T, Syy], axis=-1)]
)
self.f = f_p[..., 0]
self.sig = sig[..., 0]
super(Likelihood_unif_MV, self).__init__(
batch_shape=batch_shape,
event_shape=event_shape
)
[docs]
def sample(self, key, sample_shape=()):
raise NotImplementedError
[docs]
def log_prob(self, value):
return - roxy.likelihoods.negloglike_unif_mv(self.xobs, self.yobs, self.Sigma,
self.f, self.G, self.sig)
[docs]
class Likelihood_GMM(dist.Distribution):
"""
Class to be used by ``numpyro`` to evaluate the log-likelihood under
the assumption of an uncorrelated Gaussian likelihood with a Gaussian
Mixture Model prior on the true x positions.
Args:
:xobs (jnp.ndarray): The observed x values
:yobs (jnp.ndarray): The observed y values
:xerr (jnp.ndarray): The error on the observed x values
:yerr (jnp.ndarray): The error on the observed y values
:f (jnp.ndarray): If we are fitting the function f(x), this is f(x) evaluated
at xobs
:fprime (jnp.ndarray): If we are fitting the function f(x), this is df/dx
evaluated at xobs
:sig (float): The intrinsic scatter, which is added in quadrature with yerr
:all_mu_gauss (jnp.ndarray): The mean of the Gaussians in the GMM prior on the
true x positions
:all_w_gauss (jnp.ndarray): The standard deviation of the Gaussians in the GMM
prior on the true x positions
:all_weights (jnp.ndarray): The weights of the Gaussians in the GMM prior on the
true x positions
"""
def __init__(self, xobs, yobs, xerr, yerr, f, fprime, sig, all_mu_gauss,
all_w_gauss, all_weights):
self.xobs, self.yobs, self.xerr, self.yerr, self.f, self.fprime, self.sig = \
promote_shapes(xobs, yobs, xerr, yerr, f, fprime, sig)
self.all_mu_gauss, self.all_w_gauss, self.all_weights = promote_shapes(
all_mu_gauss, all_w_gauss, all_weights)
batch_shape = lax.broadcast_shapes(
jnp.shape(xobs),
jnp.shape(yobs),
jnp.shape(xerr),
jnp.shape(yerr),
jnp.shape(f),
jnp.shape(fprime),
jnp.shape(sig),
(),
(),
()
)
super(Likelihood_GMM, self).__init__(batch_shape=batch_shape,)
[docs]
def sample(self, key, sample_shape=()):
raise NotImplementedError
[docs]
def log_prob(self, value):
return - roxy.likelihoods.negloglike_gmm(self.xobs, self.yobs, self.xerr,
self.yerr, self.f, self.fprime, self.sig,
self.all_mu_gauss, self.all_w_gauss,
self.all_weights)
[docs]
def samples_to_array(samples):
"""
Converts a dictionary of samples returned by ``numpro`` to a list of names
and an array of samples.
Args:
:samples (dict): The MCMC samples, where the keys are the parameter names and
values are ndarrays of the samples
Returns:
:labels (np.ndarray): The names of the sampled variables
:all_samples (np.ndarray): The sampled values for these variables.
Shape = (number of samples, number of parameters).
"""
keys = list(samples.keys())
# Get labels and length of vector for each parameter
labels = []
nparam = np.zeros(len(keys), dtype=int)
for m in range(len(keys)):
if len(samples[keys[m]].shape) == 1:
labels += [keys[m]]
nparam[m] = 1
else:
nparam[m] = samples[keys[m]].shape[1]
labels += [keys[m] + '_%i' % n for n in range(nparam[m])]
nparam = [0] + list(np.cumsum(nparam))
all_samples = np.empty((samples[keys[0]].shape[0], len(labels)))
# Flatten the samples array so it is (# samples, # parameters)
for m in range(len(keys)):
if len(samples[keys[m]].shape) == 1:
all_samples[:, nparam[m]] = samples[keys[m]][:]
else:
for n in range(nparam[m+1]-nparam[m]):
all_samples[:, nparam[m]+n] = samples[keys[m]][:, n]
labels = np.array(labels)
all_samples = np.array(all_samples)
return labels, all_samples
[docs]
def compute_bias(samples, truths, verbose=True):
"""
Computes the bias between MCMC samples and the true values of the parameters.
For the intrinsic scatter, a truncated normal distribution is first fitted to
the parameters since this parameter can only take positive values.
Args:
:samples (dict): The MCMC samples, where the keys are the parameter names and
values are ndarrays of the samples
:truths (dict): The true values of the parameters. The keys should be a subset
of the keys of samples
:verbose (bool, default=True): Whether to print biases or not
Reurns:
: biases (dict): The biases in each of the parameters (units=number of sigmas)
"""
biases = {}
for k, v in truths.items():
if k == 'sig':
# Fit these samples to a truncated Gaussian
def negloglike(pars):
mu, sig = pars
nll = (
np.log(2) - 0.5 * np.log(2 * np.pi * sig ** 2)
- np.log(1 + scipy.special.erf(mu / np.sqrt(2) / sig))
- (samples[k] - mu) ** 2 / 2 / sig ** 2
)
return - np.sum(nll)
initial = [np.mean(samples[k]), np.std(samples[k])]
bounds = [(None, None), (0, None)] # sigma must be >= 0
res = scipy.optimize.minimize(negloglike, initial, bounds=bounds,
method='nelder-mead')
mu, sig = res.x
if verbose:
print('Truncated normal fit for sig:', mu, sig)
else:
mu = float(np.mean(samples[k]))
sig = float(np.std(samples[k]))
biases[k] = (mu - v) / sig
if verbose:
print('\nComputed biases (units=sigma):')
for k, b in biases.items():
print(f'{k}:\t{b}')
return biases
[docs]
class OrderedNormal(dist.Distribution):
arg_constraints = {"loc": dist.constraints.real,
"scale": dist.constraints.positive}
support = dist.constraints.ordered_vector
reparametrized_params = ["loc", "scale"]
def __init__(self, loc=0.0, scale=1.0, *, validate_args=None):
self.loc, self.scale = promote_shapes(loc, scale)
batch_shape = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale))
super(OrderedNormal, self).__init__(
batch_shape=batch_shape, validate_args=validate_args
)
[docs]
def sample(self, key, sample_shape=()):
assert is_prng_key(key)
eps = jax.random.normal(
key, shape=sample_shape + self.batch_shape + self.event_shape
)
res = self.loc + eps * self.scale
return jnp.sort(res)
[docs]
@validate_sample
def log_prob(self, value):
normalize_term = jnp.log(jnp.sqrt(2 * jnp.pi) * self.scale)
value_scaled = (value - self.loc) / self.scale
return -0.5 * value_scaled**2 - normalize_term
[docs]
def cdf(self, value):
scaled = (value - self.loc) / self.scale
return ndtr(scaled)
[docs]
def log_cdf(self, value):
return jax_norm.logcdf(value, loc=self.loc, scale=self.scale)
[docs]
def icdf(self, q):
return self.loc + self.scale * ndtri(q)
@property
def mean(self):
return jnp.broadcast_to(self.loc, self.batch_shape)
@property
def variance(self):
return jnp.broadcast_to(self.scale**2, self.batch_shape)