Skip to content

Commit f66e54b

Browse files
committed
MAINT: rework stopping criteria
1 parent a9c18a4 commit f66e54b

File tree

1 file changed

+152
-73
lines changed

1 file changed

+152
-73
lines changed

nessai/samplers/importancesampler.py

+152-73
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,93 @@ def __getstate__(self):
276276
return state
277277

278278

279+
class StoppingCriterion:
280+
"""Object for storing information about a stopping criterion.
281+
282+
Includes the tolerance, current value and whether the tolerance
283+
has been reached based on how it should be checked.
284+
285+
Parameters
286+
----------
287+
name : str
288+
Name of the stopping criterion
289+
check : {"leq", "geq"}
290+
Indicates whether to check if the value is less than or equal (leq) or
291+
or greater than or equal (geq) than the tolerance.
292+
aliases : List[str]
293+
List of aliases (alternative names) for the criterion.
294+
tolerance : float
295+
Tolerance for the criterion.
296+
value : Optional[None]
297+
Current value. Does not have to be specified.
298+
"""
299+
300+
def __init__(
301+
self,
302+
name: str,
303+
check: str,
304+
aliases: list,
305+
tolerance: float = None,
306+
value: float = None,
307+
):
308+
if check.lower() not in ["leq", "geq"]:
309+
raise ValueError(
310+
f"Invalid value for `check`: {check}. "
311+
f"Choose from ['leq', 'geq']."
312+
)
313+
self._name = name
314+
self._check = check
315+
self._aliases = aliases
316+
self._tolerance = tolerance
317+
self._value = value
318+
319+
@property
320+
def name(self) -> str:
321+
return self._name
322+
323+
@property
324+
def check(self) -> str:
325+
return self._check
326+
327+
@property
328+
def aliases(self) -> List[str]:
329+
return self._aliases
330+
331+
@property
332+
def tolerance(self) -> float:
333+
return self._tolerance
334+
335+
@property
336+
def value(self) -> float:
337+
return self._value
338+
339+
def update_tolerance(self, tolerance) -> None:
340+
"""Update the tolerance"""
341+
self._tolerance = tolerance
342+
343+
def update_value(self, value) -> None:
344+
"""Update the current value."""
345+
self._value = value
346+
347+
def update_value_from_sampler(self, sampler) -> None:
348+
"""Udate"""
349+
value = getattr(sampler, self.name, None)
350+
if value is None:
351+
raise RuntimeError(f"{self.name} has not been computed!")
352+
self.update_value(value)
353+
354+
@property
355+
def reached_tolerance(self) -> bool:
356+
"""Indicates if the stopping criterion has been reached"""
357+
if self.check == "leq":
358+
return self._value <= self._tolerance
359+
else:
360+
return self._value >= self._tolerance
361+
362+
def summary(self) -> str:
363+
return f"{self.name}: {self.value:.4g} ({self.tolerance:.2g})"
364+
365+
279366
class ImportanceNestedSampler(BaseNestedSampler):
280367
"""
281368
@@ -326,33 +413,46 @@ class ImportanceNestedSampler(BaseNestedSampler):
326413
If False, this can help reduce the disk usage.
327414
"""
328415

329-
stopping_criterion_config = dict(
330-
ratio=dict(
331-
type="min",
416+
stopping_criteria = {
417+
"ratio": StoppingCriterion(
418+
name="ratio",
419+
check="leq",
332420
aliases=["ratio", "ratio_all"],
421+
tolerance=0.0,
422+
value=np.inf,
333423
),
334-
ratio_ns=dict(
335-
type="min",
424+
"ratio_ns": StoppingCriterion(
425+
name="ratio_ns",
426+
check="leq",
336427
aliases=["ratio_ns"],
428+
tolerance=0.0,
429+
value=np.inf,
337430
),
338-
Z_err=dict(
339-
type="min",
431+
"Z_err": StoppingCriterion(
432+
name="Z_err",
433+
check="leq",
340434
aliases=["Z_err", "evidence_error"],
435+
value=np.inf,
341436
),
342-
log_dZ=dict(
343-
type="min",
437+
"log_dZ": StoppingCriterion(
438+
name="log_dZ",
439+
check="leq",
344440
aliases=["log_dZ", "log_evidence"],
441+
value=np.inf,
345442
),
346-
ess=dict(
347-
type="max",
443+
"ess": StoppingCriterion(
444+
name="ess",
445+
check="geq",
348446
aliases=["ess"],
447+
value=0.0,
349448
),
350-
fractional_error=dict(
351-
type="min",
449+
"fractional_error": StoppingCriterion(
450+
name="fractional_error",
451+
check="leq",
352452
aliases=["fractional_error"],
453+
value=0.0,
353454
),
354-
)
355-
"""Dictionary of available stopping criteria and their aliases."""
455+
}
356456

357457
def __init__(
358458
self,
@@ -378,7 +478,7 @@ def __init__(
378478
min_remove: int = 1,
379479
max_samples: Optional[int] = None,
380480
stopping_criterion: str = "ratio",
381-
tolerance: float = 0.0,
481+
tolerance: Optional[float] = None,
382482
n_update: Optional[int] = None,
383483
plot_pool: bool = False,
384484
plot_level_cdf: bool = False,
@@ -667,16 +767,14 @@ def reached_tolerance(self) -> bool:
667767
Checks if any or all of the criteria have been met, this depends on the
668768
value of :code:`check_criteria`.
669769
"""
670-
flags = {}
671-
for sc, value in self.criterion.items():
672-
if self.stopping_criterion_config[sc]["type"] == "min":
673-
flags[sc] = value <= self.tolerance[sc]
674-
else:
675-
flags[sc] = value > self.tolerance[sc]
770+
flags = [
771+
self.stopping_criteria[name].reached_tolerance
772+
for name in self.criteria_to_check
773+
]
676774
if self._stop_any:
677-
return any(flags.values())
775+
return any(flags)
678776
else:
679-
return all(flags.values())
777+
return all(flags)
680778

681779
@staticmethod
682780
def add_fields():
@@ -698,41 +796,17 @@ def configure_stopping_criterion(
698796
else:
699797
tolerance = [float(tolerance)]
700798

701-
self.stopping_criterion = []
702-
for c in stopping_criterion:
703-
for criterion, cfg in self.stopping_criterion_config.items():
704-
if c in cfg["aliases"]:
705-
self.stopping_criterion.append(criterion)
706-
if len(self.stopping_criterion) != len(stopping_criterion):
707-
raise ValueError(
708-
f"Unknown stopping criterion: {stopping_criterion}"
709-
)
710-
for c, c_use in zip(stopping_criterion, self.stopping_criterion):
711-
if c != c_use:
712-
logger.info(
713-
f"Stopping criterion specified ({c}) is "
714-
f"an alias for {c_use}. Using {c_use}."
715-
)
716-
717-
self.tolerance = {
718-
sc: t for sc, t in zip(stopping_criterion, tolerance)
719-
}
720-
if len(self.stopping_criterion) != len(self.tolerance):
721-
raise ValueError(
722-
"Number of stopping criteria must match tolerances"
723-
)
724-
725-
types = {
726-
sc: self.stopping_criterion_config[sc]["type"]
727-
for sc in stopping_criterion
728-
}
729-
self.criterion = {
730-
sc: np.inf if t == "min" else -np.inf for sc, t in types.items()
731-
}
732-
733-
logger.info(f"Stopping criteria: {self.stopping_criterion}")
734-
logger.info(f"Tolerance: {self.tolerance}")
799+
self.criteria_to_check = []
800+
for name, tol in zip(stopping_criterion, tolerance):
801+
for sc in self.stopping_criteria.values():
802+
if name in sc.aliases:
803+
sc.update_tolerance(tol)
804+
self.criteria_to_check.append(name)
805+
break
806+
else:
807+
raise ValueError(f"Unknown stopping criterion: {name}")
735808

809+
logger.info(f"Stopping criteria to check: {self.criteria_to_check}")
736810
if check_criteria not in {"any", "all"}:
737811
raise ValueError("check_criteria must be any or all")
738812
if check_criteria == "any":
@@ -870,7 +944,7 @@ def initialise_history(self) -> None:
870944
samples_entropy=[],
871945
proposal_entropy=[],
872946
stopping_criteria={
873-
k: [] for k in self.stopping_criterion_config.keys()
947+
name: [] for name in self.stopping_criteria.keys()
874948
},
875949
)
876950
)
@@ -900,10 +974,8 @@ def update_history(self) -> None:
900974
self.model.likelihood_evaluations
901975
)
902976

903-
for k in self.stopping_criterion_config.keys():
904-
self.history["stopping_criteria"][k].append(
905-
getattr(self, k, np.nan)
906-
)
977+
for name, sc in self.stopping_criteria.items():
978+
self.history["stopping_criteria"][name].append(sc.value)
907979

908980
def determine_threshold_quantile(
909981
self,
@@ -1456,11 +1528,18 @@ def compute_stopping_criterion(self) -> List[float]:
14561528
self.ess = self.state.effective_n_posterior_samples
14571529
self.Z_err = np.exp(self.log_evidence_error)
14581530
self.fractional_error = self.state.evidence_error / self.state.evidence
1459-
cond = {sc: getattr(self, sc) for sc in self.stopping_criterion}
14601531

1461-
logger.info(
1462-
f"Stopping criteria: {cond} " f"- Tolerance: {self.tolerance}"
1463-
)
1532+
cond = {}
1533+
for name, sc in self.stopping_criteria.items():
1534+
sc.update_value_from_sampler(self)
1535+
if name in self.criteria_to_check:
1536+
cond[name] = sc.value
1537+
1538+
status = [
1539+
self.stopping_criteria[sc].summary()
1540+
for sc in self.criteria_to_check
1541+
]
1542+
logger.info(f"Stopping criteria: {status}")
14641543
return cond
14651544

14661545
def checkpoint(self, periodic: bool = False, force: bool = False):
@@ -2049,17 +2128,17 @@ def plot_state(
20492128
ax[m].legend()
20502129
m += 1
20512130

2052-
for (i, sc), tol in zip(
2053-
enumerate(self.stopping_criterion), self.tolerance
2054-
):
2131+
for i, sc_name in enumerate(self.criteria_to_check):
20552132
ax[m].plot(
20562133
its,
2057-
self.history["stopping_criteria"][sc],
2058-
label=sc,
2134+
self.history["stopping_criteria"][sc_name],
2135+
label=sc_name,
20592136
c=f"C{i}",
20602137
ls=config.plotting.line_styles[i],
20612138
)
2062-
ax[m].axhline(tol, ls=":", c=f"C{i}")
2139+
ax[m].axhline(
2140+
self.stopping_criteria[sc_name].tolerance, ls=":", c=f"C{i}"
2141+
)
20632142
ax[m].legend()
20642143
ax[m].set_ylabel("Stopping criterion")
20652144

0 commit comments

Comments
 (0)