Skip to content

Commit a9c18a4

Browse files
committed
BUG: fix ESS stopping criterion
1 parent 87108fa commit a9c18a4

File tree

1 file changed

+56
-28
lines changed

1 file changed

+56
-28
lines changed

nessai/samplers/importancesampler.py

+56-28
Original file line numberDiff line numberDiff line change
@@ -326,15 +326,31 @@ class ImportanceNestedSampler(BaseNestedSampler):
326326
If False, this can help reduce the disk usage.
327327
"""
328328

329-
stopping_criterion_aliases = dict(
330-
ratio=["ratio", "ratio_all"],
331-
ratio_ns=["ratio_ns"],
332-
Z_err=["Z_err", "evidence_error"],
333-
log_dZ=["log_dZ", "log_evidence"],
334-
ess=[
335-
"ess",
336-
],
337-
fractional_error=["fractional_error"],
329+
stopping_criterion_config = dict(
330+
ratio=dict(
331+
type="min",
332+
aliases=["ratio", "ratio_all"],
333+
),
334+
ratio_ns=dict(
335+
type="min",
336+
aliases=["ratio_ns"],
337+
),
338+
Z_err=dict(
339+
type="min",
340+
aliases=["Z_err", "evidence_error"],
341+
),
342+
log_dZ=dict(
343+
type="min",
344+
aliases=["log_dZ", "log_evidence"],
345+
),
346+
ess=dict(
347+
type="max",
348+
aliases=["ess"],
349+
),
350+
fractional_error=dict(
351+
type="min",
352+
aliases=["fractional_error"],
353+
),
338354
)
339355
"""Dictionary of available stopping criteria and their aliases."""
340356

@@ -455,7 +471,7 @@ def __init__(
455471
self.log_dZ = np.inf
456472
self.ratio = np.inf
457473
self.ratio_ns = np.inf
458-
self.ess = 0.0
474+
self.ess = 0
459475
self.Z_err = np.inf
460476

461477
self._final_samples = None
@@ -651,14 +667,16 @@ def reached_tolerance(self) -> bool:
651667
Checks if any or all of the criteria have been met, this depends on the
652668
value of :code:`check_criteria`.
653669
"""
670+
flags = {}
671+
for sc, value in self.criterion.items():
672+
if self.stopping_criterion_config[sc]["type"] == "min":
673+
flags[sc] = value <= self.tolerance[sc]
674+
else:
675+
flags[sc] = value > self.tolerance[sc]
654676
if self._stop_any:
655-
return any(
656-
[c <= t for c, t in zip(self.criterion, self.tolerance)]
657-
)
677+
return any(flags.values())
658678
else:
659-
return all(
660-
[c <= t for c, t in zip(self.criterion, self.tolerance)]
661-
)
679+
return all(flags.values())
662680

663681
@staticmethod
664682
def add_fields():
@@ -676,16 +694,16 @@ def configure_stopping_criterion(
676694
stopping_criterion = [stopping_criterion]
677695

678696
if isinstance(tolerance, list):
679-
self.tolerance = [float(t) for t in tolerance]
697+
tolerance = [float(t) for t in tolerance]
680698
else:
681-
self.tolerance = [float(tolerance)]
699+
tolerance = [float(tolerance)]
682700

683701
self.stopping_criterion = []
684702
for c in stopping_criterion:
685-
for criterion, aliases in self.stopping_criterion_aliases.items():
686-
if c in aliases:
703+
for criterion, cfg in self.stopping_criterion_config.items():
704+
if c in cfg["aliases"]:
687705
self.stopping_criterion.append(criterion)
688-
if not self.stopping_criterion:
706+
if len(self.stopping_criterion) != len(stopping_criterion):
689707
raise ValueError(
690708
f"Unknown stopping criterion: {stopping_criterion}"
691709
)
@@ -695,11 +713,22 @@ def configure_stopping_criterion(
695713
f"Stopping criterion specified ({c}) is "
696714
f"an alias for {c_use}. Using {c_use}."
697715
)
716+
717+
self.tolerance = {
718+
sc: t for sc, t in zip(stopping_criterion, tolerance)
719+
}
698720
if len(self.stopping_criterion) != len(self.tolerance):
699721
raise ValueError(
700722
"Number of stopping criteria must match tolerances"
701723
)
702-
self.criterion = len(self.tolerance) * [np.inf]
724+
725+
types = {
726+
sc: self.stopping_criterion_config[sc]["type"]
727+
for sc in stopping_criterion
728+
}
729+
self.criterion = {
730+
sc: np.inf if t == "min" else -np.inf for sc, t in types.items()
731+
}
703732

704733
logger.info(f"Stopping criteria: {self.stopping_criterion}")
705734
logger.info(f"Tolerance: {self.tolerance}")
@@ -841,7 +870,7 @@ def initialise_history(self) -> None:
841870
samples_entropy=[],
842871
proposal_entropy=[],
843872
stopping_criteria={
844-
k: [] for k in self.stopping_criterion_aliases.keys()
873+
k: [] for k in self.stopping_criterion_config.keys()
845874
},
846875
)
847876
)
@@ -871,7 +900,7 @@ def update_history(self) -> None:
871900
self.model.likelihood_evaluations
872901
)
873902

874-
for k in self.stopping_criterion_aliases.keys():
903+
for k in self.stopping_criterion_config.keys():
875904
self.history["stopping_criteria"][k].append(
876905
getattr(self, k, np.nan)
877906
)
@@ -1427,11 +1456,10 @@ def compute_stopping_criterion(self) -> List[float]:
14271456
self.ess = self.state.effective_n_posterior_samples
14281457
self.Z_err = np.exp(self.log_evidence_error)
14291458
self.fractional_error = self.state.evidence_error / self.state.evidence
1430-
cond = [getattr(self, sc) for sc in self.stopping_criterion]
1459+
cond = {sc: getattr(self, sc) for sc in self.stopping_criterion}
14311460

14321461
logger.info(
1433-
f"Stopping criteria ({self.stopping_criterion}): {cond} "
1434-
f"- Tolerance: {self.tolerance}"
1462+
f"Stopping criteria: {cond} " f"- Tolerance: {self.tolerance}"
14351463
)
14361464
return cond
14371465

@@ -1583,7 +1611,7 @@ def nested_sampling_loop(self):
15831611

15841612
logger.info(
15851613
f"Finished nested sampling loop after {self.iteration} iterations "
1586-
f"with {self.stopping_criterion} = {self.criterion}"
1614+
f"with {self.criterion}"
15871615
)
15881616
self.finalise()
15891617
logger.info(f"Training time: {self.training_time}")

0 commit comments

Comments
 (0)