Skip to content

experimental: support for power spectrum data #165

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/elisa/infer/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,15 @@ def get_stat(d: FixedData) -> Statistic:
def check_stat(d: FixedData, s: Statistic):
"""Check if data type and likelihood are matched."""
name = d.name

if s == 'whittle':
if d.spec_poisson:
raise ValueError(
f'{name} data has Poisson uncertainties, '
'and using Whittle likelihood (whittle) is invalid'
)
return

if not d.spec_poisson and s != 'chi2':
raise ValueError(
f'{name} data has Gaussian uncertainties, '
Expand Down
20 changes: 16 additions & 4 deletions src/elisa/infer/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
cstat,
pgstat,
pstat,
whittle,
wstat,
)
from elisa.util.config import get_parallel_number
Expand Down Expand Up @@ -212,7 +213,7 @@ def get_counts_data(counts: dict[str, JAXArray]) -> dict[str, JAXArray]:
obs_counts = {
f'{k}_Non': (
v.net_counts
if stat[k] in _STATISTIC_SPEC_NORMAL
if stat[k] in _STATISTIC_SPEC_NORMAL or stat[k] == 'whittle'
else v.spec_counts
)
for k, v in data.items()
Expand All @@ -225,7 +226,9 @@ def get_counts_data(counts: dict[str, JAXArray]) -> dict[str, JAXArray]:
obs_data = get_counts_data(obs_counts)

# ======================== count data simulator ===========================
def simulator_factory(data_dist: Literal['norm', 'poisson'], *dist_args):
def simulator_factory(
data_dist: Literal['norm', 'poisson', 'exp'], *dist_args
):
"""Factory to create data simulator."""

def simulator(
Expand All @@ -244,18 +247,26 @@ def simulator(
return rng.normal(model_values, *dist_args, shape)
elif data_dist == 'poisson':
return rng.poisson(model_values, shape)
elif data_dist == 'exp':
return rng.exponential(model_values, shape)
else:
raise NotImplementedError(f'{data_dist = }')

return simulator

simulators = {}
sampling_dist: dict[str, tuple[Literal['norm', 'poisson'], tuple]] = {}
sampling_dist: dict[
str,
tuple[Literal['norm', 'poisson', 'exp'], tuple],
] = {}
for k, s in stat.items():
d = data[k]

name = f'{k}_Non'
if s in _STATISTIC_SPEC_NORMAL:
if s == 'whittle':
simulators[name] = simulator_factory('exp')
sampling_dist[name] = ('exp', ())
elif s in _STATISTIC_SPEC_NORMAL:
simulators[name] = simulator_factory('norm', d.spec_errors)
sampling_dist[name] = ('norm', (d.spec_errors,))
else:
Expand Down Expand Up @@ -323,6 +334,7 @@ def simulate(
'pstat': pstat,
'wstat': wstat,
'pgstat': pgstat,
'whittle': whittle,
}
likelihood: dict[str, Callable[[JAXArray], None]] = {
k: likelihood_wrapper[stat[k]](v, model[k].eval)
Expand Down
49 changes: 47 additions & 2 deletions src/elisa/infer/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from jax import lax
from jax.experimental.sparse import BCSR
from jax.scipy.special import xlogy
from numpyro.distributions import Normal, Poisson
from numpyro.distributions import Exponential, Normal, Poisson
from numpyro.distributions.util import validate_sample

if TYPE_CHECKING:
Expand All @@ -30,7 +30,7 @@
# for source estimation, which is probably due to the choice of conjugate
# prior of Poisson background data.
# 'lstat' will be included here with a proper prior at some point.
Statistic = Literal['chi2', 'cstat', 'pstat', 'pgstat', 'wstat']
Statistic = Literal['chi2', 'cstat', 'pstat', 'pgstat', 'wstat', 'whittle']

_STATISTIC_OPTIONS: frozenset[str] = frozenset(get_args(Statistic))
_STATISTIC_SPEC_NORMAL: frozenset[str] = frozenset({'chi2'})
Expand Down Expand Up @@ -174,6 +174,13 @@ def log_prob(self, value):
return jnp.clip(logp - gof, a_max=0.0)


class BetterExponential(Exponential):
@validate_sample
def log_prob(self, value):
gof = -jnp.log(value) - 1.0
return jnp.log(self.rate) - self.rate * value - gof


def _get_resp_matrix(data: FixedData) -> JAXArray | BCSR:
if data.response_sparse:
return BCSR.from_scipy_sparse(data.sparse_matrix.T)
Expand Down Expand Up @@ -448,3 +455,41 @@ def likelihood(params: ParamNameValMapping, predictive: bool = False):
)

return likelihood


def whittle(
data: FixedData,
model: ModelCompiledFn,
) -> Callable[[ParamNameValMapping, bool], None]:
"""Whittle likelihood for power spectrum (periodogram)."""
name = str(data.name)
power = jnp.array(data.net_counts, float)
freq_bins = jnp.array(data.photon_egrid, float)
df = jnp.diff(freq_bins)

def likelihood(
params: ParamNameValMapping,
predictive: bool = False,
) -> None:
"""Whittle likelihood defined via numpyro primitives."""
pmodel = model(freq_bins, params)
numpyro.deterministic(name, pmodel / df)
numpyro.deterministic(f'{name}_Non_model', pmodel)
pdata = numpyro.primitives.mutable(f'{name}_Non_data', power)

with numpyro.plate(f'{name}_plate', len(power)):
dist_on = BetterExponential(1.0 / pmodel)
numpyro.sample(
name=f'{name}_Non',
fn=dist_on,
obs=None if predictive else pdata,
)

# record log likelihood into chains to avoid re-computation
if not predictive:
loglike_on = numpyro.deterministic(
name=f'{name}_Non_loglike', value=dist_on.log_prob(pdata)
)
numpyro.deterministic(name=f'{name}_loglike', value=loglike_on)

return likelihood
11 changes: 10 additions & 1 deletion src/elisa/plot/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,10 @@ def pit(self) -> tuple[Array, Array]:
on_data = self.spec_counts
on_model = self.model('on', 'mle')

if stat == 'whittle':
pit = stats.expon.cdf(on_data, scale=on_model)
return pit, pit

if stat in _STATISTIC_SPEC_NORMAL: # chi2
pit = stats.norm.cdf((on_data - on_model) / self.net_errors)
return pit, pit
Expand Down Expand Up @@ -575,7 +579,7 @@ def _pearson_residuals(
stat = self.statistic

if rtype == 'mle':
if stat in _STATISTIC_SPEC_NORMAL:
if stat in _STATISTIC_SPEC_NORMAL or stat == 'whittle':
on_data = self.net_counts
else:
on_data = self.spec_counts
Expand All @@ -584,6 +588,8 @@ def _pearson_residuals(

if stat in _STATISTIC_SPEC_NORMAL:
std = self.net_errors
elif stat == 'whittle':
std = np.sqrt(self.model('on', rtype))
else:
std = None

Expand Down Expand Up @@ -654,6 +660,7 @@ def quantile_residuals_mle(
lower = np.full(r.shape, False)
lower[lower_mask] = True

assert np.isfinite(r).all()
return r, lower, upper


Expand Down Expand Up @@ -967,6 +974,8 @@ def _pearson_residuals(

if stat in _STATISTIC_SPEC_NORMAL:
std = self.net_errors
elif stat == 'whittle':
std = np.sqrt(on_model)
else:
std = None

Expand Down