1
- """Support for the family of Gaussian likelihood ."""
1
+ """Support for the family of Gaussian likelihoods ."""
2
2
3
3
from __future__ import annotations
4
4
30
30
31
31
32
32
class State (Enum ):
33
- """The states used in GaussFamily."""
33
+ """The states used in GaussFamily.
34
+
35
+ GaussFamily and all subclasses enforce a statemachine behavior based on
36
+ these states to ensure that the necessary initialization and setup is done
37
+ in the correct order.
38
+ """
34
39
35
40
INITIALIZED = 1
36
41
READY = 2
@@ -62,6 +67,12 @@ def enforce_states(
62
67
If terminal is None the state of the object is not modified.
63
68
If terminal is not None and the call to the wrapped method returns
64
69
normally the state of the object is set to terminal.
70
+
71
+ :param initial: The initial states allowable for the wrapped method
72
+ :param terminal: The terminal state ensured for the wrapped method. None
73
+ indicates no state change happens.
74
+ :param failure_message: The failure message for the AssertionError raised
75
+ :return: The wrapped method
65
76
"""
66
77
initials : list [State ]
67
78
if isinstance (initial , list ):
@@ -74,6 +85,9 @@ def decorator_enforce_states(func: Callable[P, T]) -> Callable[P, T]:
74
85
75
86
This closure is what actually contains the values of initials, terminal, and
76
87
failure_message.
88
+
89
+ :param func: The method to be wrapped
90
+ :return: The wrapped method
77
91
"""
78
92
79
93
@wraps (func )
@@ -132,8 +146,11 @@ class GaussFamily(Likelihood):
132
146
def __init__ (
133
147
self ,
134
148
statistics : Sequence [Statistic ],
135
- ):
136
- """Initialize the base class parts of a GaussFamily object."""
149
+ ) -> None :
150
+ """Initialize the base class parts of a GaussFamily object.
151
+
152
+ :param statistics: A list of statistics to be include in chisquared calculations
153
+ """
137
154
super ().__init__ ()
138
155
self .state : State = State .INITIALIZED
139
156
if len (statistics ) == 0 :
@@ -160,7 +177,12 @@ def __init__(
160
177
def create_ready (
161
178
cls , statistics : Sequence [Statistic ], covariance : npt .NDArray [np .float64 ]
162
179
) -> GaussFamily :
163
- """Create a GaussFamily object in the READY state."""
180
+ """Create a GaussFamily object in the READY state.
181
+
182
+ :param statistics: A list of statistics to be include in chisquared calculations
183
+ :param covariance: The covariance matrix of the statistics
184
+ :return: A ready GaussFamily object
185
+ """
164
186
obj = cls (statistics )
165
187
obj ._set_covariance (covariance )
166
188
obj .state = State .READY
@@ -178,6 +200,8 @@ def _update(self, _: ParamsMap) -> None:
178
200
for its own reasons must be sure to do what this does: check the state
179
201
at the start of the method, and change the state at the end of the
180
202
method.
203
+
204
+ :param _: a ParamsMap object, not used
181
205
"""
182
206
183
207
@enforce_states (
@@ -201,7 +225,10 @@ def _reset(self) -> None:
201
225
failure_message = "read() must only be called once" ,
202
226
)
203
227
def read (self , sacc_data : sacc .Sacc ) -> None :
204
- """Read the covariance matrix for this likelihood from the SACC file."""
228
+ """Read the covariance matrix for this likelihood from the SACC file.
229
+
230
+ :param sacc_data: The SACC data object to be read
231
+ """
205
232
if sacc_data .covariance is None :
206
233
msg = (
207
234
f"The { type (self ).__name__ } likelihood requires a covariance, "
@@ -216,11 +243,13 @@ def read(self, sacc_data: sacc.Sacc) -> None:
216
243
217
244
self ._set_covariance (covariance )
218
245
219
- def _set_covariance (self , covariance ) :
246
+ def _set_covariance (self , covariance : npt . NDArray [ np . float64 ]) -> None :
220
247
"""Set the covariance matrix.
221
248
222
249
This method is used to set the covariance matrix and perform the
223
250
necessary calculations to prepare the likelihood for computation.
251
+
252
+ :param covariance: The covariance matrix for this likelihood
224
253
"""
225
254
indices_list = []
226
255
data_vector_list = []
@@ -276,6 +305,7 @@ def get_cov(
276
305
:param statistic: The statistic for which the sub-covariance matrix
277
306
should be returned. If not specified, return the covariance of all
278
307
statistics.
308
+ :return: The covariance matrix (or portion thereof)
279
309
"""
280
310
assert self .cov is not None
281
311
if statistic is None :
@@ -301,7 +331,10 @@ def get_cov(
301
331
failure_message = "read() must be called before get_data_vector()" ,
302
332
)
303
333
def get_data_vector (self ) -> npt .NDArray [np .float64 ]:
304
- """Get the data vector from all statistics in the right order."""
334
+ """Get the data vector from all statistics in the right order.
335
+
336
+ :return: The data vector
337
+ """
305
338
assert self .data_vector is not None
306
339
return self .data_vector
307
340
@@ -315,6 +348,7 @@ def compute_theory_vector(self, tools: ModelingTools) -> npt.NDArray[np.float64]
315
348
"""Computes the theory vector using the current instance of pyccl.Cosmology.
316
349
317
350
:param tools: Current ModelingTools object
351
+ :return: The computed theory vector
318
352
"""
319
353
theory_vector_list : list [npt .NDArray [np .float64 ]] = [
320
354
stat .compute_theory_vector (tools ) for stat in self .statistics
@@ -329,7 +363,10 @@ def compute_theory_vector(self, tools: ModelingTools) -> npt.NDArray[np.float64]
329
363
"get_theory_vector()" ,
330
364
)
331
365
def get_theory_vector (self ) -> npt .NDArray [np .float64 ]:
332
- """Get the theory vector from all statistics in the right order."""
366
+ """Get the already-computed theory vector from all statistics.
367
+
368
+ :return: The theory vector, with all statistics in the right order
369
+ """
333
370
assert (
334
371
self .theory_vector is not None
335
372
), "theory_vector is None after compute_theory_vector() has been called"
@@ -343,7 +380,14 @@ def get_theory_vector(self) -> npt.NDArray[np.float64]:
343
380
def compute (
344
381
self , tools : ModelingTools
345
382
) -> tuple [npt .NDArray [np .float64 ], npt .NDArray [np .float64 ]]:
346
- """Calculate and return both the data and theory vectors."""
383
+ """Calculate and return both the data and theory vectors.
384
+
385
+ This method is dprecated and will be removed in a future version of Firecrown.
386
+
387
+ :param tools: the ModelingTools to be used in the calculation of the
388
+ theory vector
389
+ :return: a tuple containing the data vector and the theory vector
390
+ """
347
391
warnings .warn (
348
392
"The use of the `compute` method on Statistic is deprecated."
349
393
"The Statistic objects should implement `get_data` and "
@@ -359,7 +403,12 @@ def compute(
359
403
failure_message = "update() must be called before compute_chisq()" ,
360
404
)
361
405
def compute_chisq (self , tools : ModelingTools ) -> float :
362
- """Calculate and return the chi-squared for the given cosmology."""
406
+ """Calculate and return the chi-squared for the given cosmology.
407
+
408
+ :param tools: the ModelingTools to be used in the calculation of the
409
+ theory vector
410
+ :return: the chi-squared
411
+ """
363
412
theory_vector : npt .NDArray [np .float64 ]
364
413
data_vector : npt .NDArray [np .float64 ]
365
414
residuals : npt .NDArray [np .float64 ]
@@ -386,6 +435,10 @@ def get_sacc_indices(
386
435
"""Get the SACC indices of the statistic or list of statistics.
387
436
388
437
If no statistic is given, get the indices of all statistics of the likelihood.
438
+
439
+ :param statistics: The statistic or list of statistics for which the
440
+ SACC indices are desired
441
+ :return: The SACC indices
389
442
"""
390
443
if statistic is None :
391
444
statistic = [stat .statistic for stat in self .statistics ]
@@ -409,7 +462,15 @@ def get_sacc_indices(
409
462
def make_realization (
410
463
self , sacc_data : sacc .Sacc , add_noise : bool = True , strict : bool = True
411
464
) -> sacc .Sacc :
412
- """Create a new realization of the model."""
465
+ """Create a new realization of the model.
466
+
467
+ :param sacc_data: The SACC data object containing the covariance matrix
468
+ to be read
469
+ :param add_noise: If True, add noise to the realization.
470
+ :param strict: If True, check that the indices of the realization cover
471
+ all the indices of the SACC data object.
472
+ :return: The SACC data object containing the new realization
473
+ """
413
474
sacc_indices = self .get_sacc_indices ()
414
475
415
476
if add_noise :
0 commit comments