import corner
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import rcParams
from getdist import plots, MCSamples
import arviz as az
from fgivenx import plot_contours
import roxy.mcmc
rcParams['text.usetex'] = False
rcParams.update({'font.size': 14})
[docs]
def triangle_plot(samples, labels=None, to_plot='all', module='corner',
truths=None, param_prior=None, savename=None, show=True):
"""
Plot the 1D and 2D posterior distributions of the parameters in a triangle plot.
Args:
:samples (dict): The MCMC samples, where the keys are the parameter names and
values are ndarrays of the samples
:labels (dict, default=None): Dictionary of parameter labels ot use in the plot.
If None, then use the names given as keys in samples.
:to_plot (list, default='all'): If 'all', then use all parameters. If a list,
then only use the parameters given in that list
:module (str, default='corner'): Which module to use to make the triangle plot
('corner' or 'getdist' currently available)
:truths (dict, default=None): If not None, use this to specify the true
values of the parameters to plot.
:param_prior (dict, default=None): If not None and using 'getdist', use this to
specify the range of the varibales to prevent undesirable smoothing effects.
:savename (str, default=None): If not None, save the figure to the file given
by this argument.
:show (bool, default=True): If True, display the figure with plt.show()
"""
names, all_samples = roxy.mcmc.samples_to_array(samples)
if to_plot != 'all':
idx = [np.squeeze(np.where(names == p)) for p in to_plot]
names = names[idx]
all_samples = all_samples[:, idx]
if labels is None:
labs = list(names)
for p, label in zip(['mu_gauss', 'w_gauss', 'sig'],
[r'\mu_{\rm gauss}', r'w_{\rm gauss}', r'\sigma_{\rm int}']):
if (p in names) and ((p in to_plot) or (to_plot == 'all')):
i = np.squeeze(np.where(names == p))
labs[i] = label
# GMM parameters
if 'weights_0' in names:
# Extract number of Gaussians
ngauss = len([n for n in names if n.startswith('weights')])
for i in range(ngauss):
for p, label in zip([f'mu_gauss_{i}', f'w_gauss_{i}', f'weights_{i}'],
[r'\mu_{%i}' % i, r'w_{%i}' % i, r'\nu_{%i}' % i]):
j = np.squeeze(np.where(names == p))
labs[j] = label
# Kelly prior parameters
if 'hyper_mu' in names:
for p, label in zip(['hyper_mu', 'hyper_w2', 'hyper_u2'],
[r'\mu_\star', r'w_\star^2', r'u_\star^2']):
j = np.squeeze(np.where(names == p))
labs[j] = label
else:
labs = [labels[n] for n in names]
if module == 'corner':
labs = ['$' + label + '$' for label in labs]
fig, _ = plt.subplots(len(labs), len(labs), figsize=(8, 8))
markers = None
if truths is not None:
markers = [truths[n] if n in truths else None for n in names]
corner.corner(all_samples, labels=labs, fig=fig, truths=markers)
elif module == 'getdist':
if param_prior is None:
ranges = {}
else:
ranges = param_prior
ranges['w_gauss'] = [0, None]
if ('sig' in ranges and
((param_prior['sig'][0] is None) or (param_prior['sig'][1] is None))):
ranges['sig'] = [0, param_prior['sig'][1]]
if 'weights_0' in names:
for i in range(ngauss):
ranges[f'w_gauss_{i}'] = [0, None]
ranges[f'weights_{i}'] = [0, 1]
if 'hyper_mu' in names:
ranges['hyper_w2'] = [0, None]
ranges['hyper_u2'] = [0, None]
samps = MCSamples(
samples=all_samples,
names=names,
labels=labs,
ranges=ranges
)
g = plots.get_subplot_plotter(width_inch=8)
g.triangle_plot(samps, filled=True, markers=truths)
else:
raise NotImplementedError
plt.gcf().align_labels()
if savename is not None:
plt.savefig(savename, transparent=False)
if show:
plt.show()
plt.clf()
plt.close(plt.gcf())
[docs]
def trace_plot(samples, to_plot='all', truths=None, savename=None, show=True):
"""
Plot the trace of the parameter values as a function of MCMC step
Args:
:samples (dict): The MCMC samples, where the keys are the parameter names and
values are ndarrays of the samples
:to_plot (list, default='all'): If 'all', then use all parameters. If a list,
then only use the parameters given in that list
:truths (dict, default=None): If not None, use this to specify the true
values of the parameters to plot.
:savename (str, default=None): If not None, save the figure to the file given by
this argument.
:show (bool, default=True): If True, display the figure with plt.show()
"""
# Check for GMM
if 'weights' in samples.keys():
new_samples = samples.copy()
for k in ['mu_gauss', 'w_gauss', 'weights']:
new_samples.pop(k)
v = samples[k]
for i in range(v.shape[1]):
new_samples[f'{k}_{i}'] = v[:, i]
npar = len(new_samples.keys())
res = az.from_dict(new_samples)
else:
res = az.from_dict(samples)
npar = len(samples.keys())
if to_plot != 'all':
npar = len(to_plot)
figsize = (12, min(2 * npar, 10))
lines = {}
if truths is not None:
lines = [ (k, {}, [truths[k]]) for k in list(res['posterior'].data_vars) if k in truths]
if to_plot == 'all':
az.plot_trace(res, compact=True, figsize=figsize, lines=lines)
else:
az.plot_trace(res, compact=True, var_names=to_plot, figsize=figsize, lines=lines)
plt.tight_layout()
if savename is not None:
plt.savefig(savename, transparent=False)
if show:
plt.show()
plt.clf()
plt.close(plt.gcf())
[docs]
def posterior_predictive_plot(reg, samples, xobs, yobs, xerr, yerr, y_is_detected=[],
savename=None, show=True, xlabel=r'$x$', ylabel=r'$y$',
errorbar_kwargs={'fmt': '.', 'markersize': 1,
'zorder': 10, 'capsize': 1,
'elinewidth': 0.5, 'color': 'k',
'alpha': 1},
fgivenx_kwargs={}, xscale='linear',
yscale='linear', xlim=None, ylim=None):
"""
Make the posterior predictive plot showing the 1, 2 and 3 sigma predictions
of the function given the inferred parameters and plot the observed points on
the same plot.
Args:
:reg (roxy.regressor.RoxyRegressor): The regressor object used for the inference
:samples (dict): The MCMC samples, where the keys are the parameter names and
values are ndarrays of the samples
:xobs (jnp.ndarray): The observed x values
:yobs (jnp.ndarray): The observed y values
:xerr (jnp.ndarray): The error on the observed x values
:yerr (jnp.ndarray): The error on the observed y values
:y_is_detected (array-like, default=[]): Boolean array of the same length as
yobs, where True indicates a detected point and False indicates an upper limit.
:savename (str, default=None): If not None, save the figure to the file given
by this argument.
:show (bool, default=True): If True, display the figure with plt.show()
:xlabel (str, default='$x$'): The label to use for the x axis
:ylabel (str, default='$x$'): The label to use for the y axis
:errorbar_kwargs (dict): Dictionary of kwargs to pass to plt.errorbar
:fgivenx_kwargs (dict): Dictionary of kwargs to pass to fgivenx.plot_contours
:xscale (str, default='linear'): Scale to use for x axis ('linear' or 'log')
:yscale (str, default='linear'): Scale to use for y axis ('linear' or 'log')
:xlim (tuple, default=None): If not None, set the x limits to this value
:ylim (tuple, default=None): If not None, set the y limits to this value
Returns:
:fig (matplotlib.figure.Figure): The figure containing the posterior predictive
plot
"""
names, all_samples = roxy.mcmc.samples_to_array(samples)
pidx = reg.get_param_index(names, verbose=False)
def f(x, theta):
t = reg.param_default
t = t.at[pidx].set(theta[:len(pidx)])
return reg.value(x, t)
print('\nMaking posterior predictive plot')
fig, ax = plt.subplots(1, 1)
if len(y_is_detected) > 0:
ax.errorbar(xobs[y_is_detected], yobs[y_is_detected],
xerr=xerr, yerr=yerr, **errorbar_kwargs)
ax.errorbar(xobs[~y_is_detected], yobs[~y_is_detected],
xerr=xerr, yerr=yerr, uplims=True, **errorbar_kwargs)
else:
ax.errorbar(xobs, yobs, xerr=xerr, yerr=yerr, **errorbar_kwargs)
ax.set_xscale(xscale)
ax.set_yscale(yscale)
if xlim is not None:
ax.set_xlim(xlim)
if ylim is not None:
ax.set_ylim(ylim)
xmin, xmax = ax.get_xlim()
if xscale == 'log':
x = np.logspace(np.log10(xmin), np.log10(xmax), 200)
else:
x = np.linspace(xmin, xmax, 200)
cbar = plot_contours(f, x, all_samples, ax, **fgivenx_kwargs)
cbar = plt.colorbar(cbar, ticks=[0, 1, 2, 3])
cbar.set_ticklabels(['', r'$1\sigma$', r'$2\sigma$', r'$3\sigma$'])
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
ax.set_xlim(xmin, xmax)
fig.tight_layout()
if savename is not None:
plt.savefig(savename, transparent=False)
if show:
plt.show()
return fig