Skip to content

Commit bd70476

Browse files
authored
Yet even more docstrings (#445)
* Improve docstrings and type annotations * Fix bug in NumberCounts property computation
1 parent a6c217a commit bd70476

File tree

11 files changed

+294
-89
lines changed

11 files changed

+294
-89
lines changed

firecrown/likelihood/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Classes used to represent likelihoods and functions to support them.
1+
"""Classes used to represent likelihoods, and functions to support them.
22
33
Subpackages contain specific likelihood implementations, e.g., Gaussian and Student-t.
44
The submodule :mod:`firecrown.likelihood.likelihood` contain the abstract base class for

firecrown/likelihood/binned_cluster_number_counts.py

+28-15
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,4 @@
1-
"""This module holds classes needed to predict the binned cluster number counts.
2-
3-
The binned cluster number counts statistic predicts the number of galaxy
4-
clusters within a single redshift and mass bin.
5-
"""
1+
"""Binned cluster number counts statistic support."""
62

73
from __future__ import annotations
84

@@ -26,11 +22,7 @@
2622

2723

2824
class BinnedClusterNumberCounts(Statistic):
29-
"""The Binned Cluster Number Counts statistic.
30-
31-
This class will make a prediction for the number of clusters in a z, mass bin
32-
and compare that prediction to the data provided in the sacc file.
33-
"""
25+
"""A statistic representing the number of clusters in a z, mass bin."""
3426

3527
def __init__(
3628
self,
@@ -39,6 +31,13 @@ def __init__(
3931
cluster_recipe: ClusterRecipe,
4032
systematics: None | list[SourceSystematic] = None,
4133
):
34+
"""Initialize this statistic.
35+
36+
:param cluster_properties: The cluster observables to use.
37+
:param survey_name: The name of the survey to use.
38+
:param cluster_recipe: The cluster recipe to use.
39+
:param systematics: The systematics to apply to this statistic.
40+
"""
4241
super().__init__()
4342
self.systematics = systematics or []
4443
self.theory_vector: None | TheoryVector = None
@@ -50,7 +49,10 @@ def __init__(
5049
self.bins: list[SaccBin] = []
5150

5251
def read(self, sacc_data: sacc.Sacc) -> None:
53-
"""Read the data for this statistic and mark it as ready for use."""
52+
"""Read the data for this statistic and mark it as ready for use.
53+
54+
:param sacc_data: The data in the sacc format.
55+
"""
5456
# Build the data vector and indices needed for the likelihood
5557
if self.cluster_properties == ClusterProperty.NONE:
5658
raise ValueError("You must specify at least one cluster property.")
@@ -77,12 +79,19 @@ def read(self, sacc_data: sacc.Sacc) -> None:
7779
super().read(sacc_data)
7880

7981
def get_data_vector(self) -> DataVector:
80-
"""Gets the statistic data vector."""
82+
"""Gets the statistic data vector.
83+
84+
:return: The statistic data vector.
85+
"""
8186
assert self.data_vector is not None
8287
return self.data_vector
8388

8489
def _compute_theory_vector(self, tools: ModelingTools) -> TheoryVector:
85-
"""Compute a statistic from sources, concrete implementation."""
90+
"""Compute a statistic from sources, concrete implementation.
91+
92+
:param tools: The modeling tools used to compute the statistic.
93+
:return: The computed statistic.
94+
"""
8695
assert tools.cluster_abundance is not None
8796

8897
theory_vector_list: list[float] = []
@@ -116,6 +125,9 @@ def get_binned_cluster_property(
116125
Using the data from the sacc file, this function evaluates the likelihood for
117126
a single point of the parameter space, and returns the predicted mean mass of
118127
the clusters in each bin.
128+
129+
:param cluster_counts: The number of clusters in each bin.
130+
:param cluster_properties: The cluster observables to use.
119131
"""
120132
assert tools.cluster_abundance is not None
121133

@@ -124,8 +136,6 @@ def get_binned_cluster_property(
124136
total_observable = self.cluster_recipe.evaluate_theory_prediction(
125137
tools.cluster_abundance, this_bin, self.sky_area, cluster_properties
126138
)
127-
cluster_counts.append(counts)
128-
129139
mean_observable = total_observable / counts
130140
mean_values.append(mean_observable)
131141

@@ -137,6 +147,9 @@ def get_binned_cluster_counts(self, tools: ModelingTools) -> list[float]:
137147
Using the data from the sacc file, this function evaluates the likelihood for
138148
a single point of the parameter space, and returns the predicted number of
139149
clusters in each bin.
150+
151+
:param tools: The modeling tools used to compute the statistic.
152+
:return: The number of clusters in each bin.
140153
"""
141154
assert tools.cluster_abundance is not None
142155

Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
1+
"""Backward compatibility support for deprecated directory structure."""
2+
13
# flake8: noqa
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
1+
"""Backward compatibility support for deprecated directory structure."""
2+
13
# flake8: noqa

firecrown/likelihood/gaussfamily.py

+73-12
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Support for the family of Gaussian likelihood."""
1+
"""Support for the family of Gaussian likelihoods."""
22

33
from __future__ import annotations
44

@@ -30,7 +30,12 @@
3030

3131

3232
class State(Enum):
33-
"""The states used in GaussFamily."""
33+
"""The states used in GaussFamily.
34+
35+
GaussFamily and all subclasses enforce a statemachine behavior based on
36+
these states to ensure that the necessary initialization and setup is done
37+
in the correct order.
38+
"""
3439

3540
INITIALIZED = 1
3641
READY = 2
@@ -62,6 +67,12 @@ def enforce_states(
6267
If terminal is None the state of the object is not modified.
6368
If terminal is not None and the call to the wrapped method returns
6469
normally the state of the object is set to terminal.
70+
71+
:param initial: The initial states allowable for the wrapped method
72+
:param terminal: The terminal state ensured for the wrapped method. None
73+
indicates no state change happens.
74+
:param failure_message: The failure message for the AssertionError raised
75+
:return: The wrapped method
6576
"""
6677
initials: list[State]
6778
if isinstance(initial, list):
@@ -74,6 +85,9 @@ def decorator_enforce_states(func: Callable[P, T]) -> Callable[P, T]:
7485
7586
This closure is what actually contains the values of initials, terminal, and
7687
failure_message.
88+
89+
:param func: The method to be wrapped
90+
:return: The wrapped method
7791
"""
7892

7993
@wraps(func)
@@ -132,8 +146,11 @@ class GaussFamily(Likelihood):
132146
def __init__(
133147
self,
134148
statistics: Sequence[Statistic],
135-
):
136-
"""Initialize the base class parts of a GaussFamily object."""
149+
) -> None:
150+
"""Initialize the base class parts of a GaussFamily object.
151+
152+
:param statistics: A list of statistics to be include in chisquared calculations
153+
"""
137154
super().__init__()
138155
self.state: State = State.INITIALIZED
139156
if len(statistics) == 0:
@@ -160,7 +177,12 @@ def __init__(
160177
def create_ready(
161178
cls, statistics: Sequence[Statistic], covariance: npt.NDArray[np.float64]
162179
) -> GaussFamily:
163-
"""Create a GaussFamily object in the READY state."""
180+
"""Create a GaussFamily object in the READY state.
181+
182+
:param statistics: A list of statistics to be include in chisquared calculations
183+
:param covariance: The covariance matrix of the statistics
184+
:return: A ready GaussFamily object
185+
"""
164186
obj = cls(statistics)
165187
obj._set_covariance(covariance)
166188
obj.state = State.READY
@@ -178,6 +200,8 @@ def _update(self, _: ParamsMap) -> None:
178200
for its own reasons must be sure to do what this does: check the state
179201
at the start of the method, and change the state at the end of the
180202
method.
203+
204+
:param _: a ParamsMap object, not used
181205
"""
182206

183207
@enforce_states(
@@ -201,7 +225,10 @@ def _reset(self) -> None:
201225
failure_message="read() must only be called once",
202226
)
203227
def read(self, sacc_data: sacc.Sacc) -> None:
204-
"""Read the covariance matrix for this likelihood from the SACC file."""
228+
"""Read the covariance matrix for this likelihood from the SACC file.
229+
230+
:param sacc_data: The SACC data object to be read
231+
"""
205232
if sacc_data.covariance is None:
206233
msg = (
207234
f"The {type(self).__name__} likelihood requires a covariance, "
@@ -216,11 +243,13 @@ def read(self, sacc_data: sacc.Sacc) -> None:
216243

217244
self._set_covariance(covariance)
218245

219-
def _set_covariance(self, covariance):
246+
def _set_covariance(self, covariance: npt.NDArray[np.float64]) -> None:
220247
"""Set the covariance matrix.
221248
222249
This method is used to set the covariance matrix and perform the
223250
necessary calculations to prepare the likelihood for computation.
251+
252+
:param covariance: The covariance matrix for this likelihood
224253
"""
225254
indices_list = []
226255
data_vector_list = []
@@ -276,6 +305,7 @@ def get_cov(
276305
:param statistic: The statistic for which the sub-covariance matrix
277306
should be returned. If not specified, return the covariance of all
278307
statistics.
308+
:return: The covariance matrix (or portion thereof)
279309
"""
280310
assert self.cov is not None
281311
if statistic is None:
@@ -301,7 +331,10 @@ def get_cov(
301331
failure_message="read() must be called before get_data_vector()",
302332
)
303333
def get_data_vector(self) -> npt.NDArray[np.float64]:
304-
"""Get the data vector from all statistics in the right order."""
334+
"""Get the data vector from all statistics in the right order.
335+
336+
:return: The data vector
337+
"""
305338
assert self.data_vector is not None
306339
return self.data_vector
307340

@@ -315,6 +348,7 @@ def compute_theory_vector(self, tools: ModelingTools) -> npt.NDArray[np.float64]
315348
"""Computes the theory vector using the current instance of pyccl.Cosmology.
316349
317350
:param tools: Current ModelingTools object
351+
:return: The computed theory vector
318352
"""
319353
theory_vector_list: list[npt.NDArray[np.float64]] = [
320354
stat.compute_theory_vector(tools) for stat in self.statistics
@@ -329,7 +363,10 @@ def compute_theory_vector(self, tools: ModelingTools) -> npt.NDArray[np.float64]
329363
"get_theory_vector()",
330364
)
331365
def get_theory_vector(self) -> npt.NDArray[np.float64]:
332-
"""Get the theory vector from all statistics in the right order."""
366+
"""Get the already-computed theory vector from all statistics.
367+
368+
:return: The theory vector, with all statistics in the right order
369+
"""
333370
assert (
334371
self.theory_vector is not None
335372
), "theory_vector is None after compute_theory_vector() has been called"
@@ -343,7 +380,14 @@ def get_theory_vector(self) -> npt.NDArray[np.float64]:
343380
def compute(
344381
self, tools: ModelingTools
345382
) -> tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]:
346-
"""Calculate and return both the data and theory vectors."""
383+
"""Calculate and return both the data and theory vectors.
384+
385+
This method is dprecated and will be removed in a future version of Firecrown.
386+
387+
:param tools: the ModelingTools to be used in the calculation of the
388+
theory vector
389+
:return: a tuple containing the data vector and the theory vector
390+
"""
347391
warnings.warn(
348392
"The use of the `compute` method on Statistic is deprecated."
349393
"The Statistic objects should implement `get_data` and "
@@ -359,7 +403,12 @@ def compute(
359403
failure_message="update() must be called before compute_chisq()",
360404
)
361405
def compute_chisq(self, tools: ModelingTools) -> float:
362-
"""Calculate and return the chi-squared for the given cosmology."""
406+
"""Calculate and return the chi-squared for the given cosmology.
407+
408+
:param tools: the ModelingTools to be used in the calculation of the
409+
theory vector
410+
:return: the chi-squared
411+
"""
363412
theory_vector: npt.NDArray[np.float64]
364413
data_vector: npt.NDArray[np.float64]
365414
residuals: npt.NDArray[np.float64]
@@ -386,6 +435,10 @@ def get_sacc_indices(
386435
"""Get the SACC indices of the statistic or list of statistics.
387436
388437
If no statistic is given, get the indices of all statistics of the likelihood.
438+
439+
:param statistics: The statistic or list of statistics for which the
440+
SACC indices are desired
441+
:return: The SACC indices
389442
"""
390443
if statistic is None:
391444
statistic = [stat.statistic for stat in self.statistics]
@@ -409,7 +462,15 @@ def get_sacc_indices(
409462
def make_realization(
410463
self, sacc_data: sacc.Sacc, add_noise: bool = True, strict: bool = True
411464
) -> sacc.Sacc:
412-
"""Create a new realization of the model."""
465+
"""Create a new realization of the model.
466+
467+
:param sacc_data: The SACC data object containing the covariance matrix
468+
to be read
469+
:param add_noise: If True, add noise to the realization.
470+
:param strict: If True, check that the indices of the realization cover
471+
all the indices of the SACC data object.
472+
:return: The SACC data object containing the new realization
473+
"""
413474
sacc_indices = self.get_sacc_indices()
414475

415476
if add_noise:

firecrown/likelihood/gaussian.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,19 @@
1313
class ConstGaussian(GaussFamily):
1414
"""A Gaussian log-likelihood with a constant covariance matrix."""
1515

16-
def compute_loglike(self, tools: ModelingTools):
17-
"""Compute the log-likelihood."""
16+
def compute_loglike(self, tools: ModelingTools) -> float:
17+
"""Compute the log-likelihood.
18+
19+
:params tools: The modeling tools used to compute the likelihood.
20+
:return: The log-likelihood.
21+
"""
1822
return -0.5 * self.compute_chisq(tools)
1923

2024
def make_realization_vector(self) -> np.ndarray:
21-
"""Create a new realization of the model."""
25+
"""Create a new (randomized) realization of the model.
26+
27+
:return: A new realization of the model
28+
"""
2229
theory_vector = self.get_theory_vector()
2330
assert self.cholesky is not None
2431
new_data_vector = theory_vector + np.dot(

0 commit comments

Comments
 (0)