@@ -276,6 +276,93 @@ def __getstate__(self):
276
276
return state
277
277
278
278
279
+ class StoppingCriterion :
280
+ """Object for storing information about a stopping criterion.
281
+
282
+ Includes the tolerance, current value and whether the tolerance
283
+ has been reached based on how it should be checked.
284
+
285
+ Parameters
286
+ ----------
287
+ name : str
288
+ Name of the stopping criterion
289
+ check : {"leq", "geq"}
290
+ Indicates whether to check if the value is less than or equal (leq) or
291
+ or greater than or equal (geq) than the tolerance.
292
+ aliases : List[str]
293
+ List of aliases (alternative names) for the criterion.
294
+ tolerance : float
295
+ Tolerance for the criterion.
296
+ value : Optional[None]
297
+ Current value. Does not have to be specified.
298
+ """
299
+
300
+ def __init__ (
301
+ self ,
302
+ name : str ,
303
+ check : str ,
304
+ aliases : list ,
305
+ tolerance : float = None ,
306
+ value : float = None ,
307
+ ):
308
+ if check .lower () not in ["leq" , "geq" ]:
309
+ raise ValueError (
310
+ f"Invalid value for `check`: { check } . "
311
+ f"Choose from ['leq', 'geq']."
312
+ )
313
+ self ._name = name
314
+ self ._check = check
315
+ self ._aliases = aliases
316
+ self ._tolerance = tolerance
317
+ self ._value = value
318
+
319
+ @property
320
+ def name (self ) -> str :
321
+ return self ._name
322
+
323
+ @property
324
+ def check (self ) -> str :
325
+ return self ._check
326
+
327
+ @property
328
+ def aliases (self ) -> List [str ]:
329
+ return self ._aliases
330
+
331
+ @property
332
+ def tolerance (self ) -> float :
333
+ return self ._tolerance
334
+
335
+ @property
336
+ def value (self ) -> float :
337
+ return self ._value
338
+
339
+ def update_tolerance (self , tolerance ) -> None :
340
+ """Update the tolerance"""
341
+ self ._tolerance = tolerance
342
+
343
+ def update_value (self , value ) -> None :
344
+ """Update the current value."""
345
+ self ._value = value
346
+
347
+ def update_value_from_sampler (self , sampler ) -> None :
348
+ """Udate"""
349
+ value = getattr (sampler , self .name , None )
350
+ if value is None :
351
+ raise RuntimeError (f"{ self .name } has not been computed!" )
352
+ self .update_value (value )
353
+
354
+ @property
355
+ def reached_tolerance (self ) -> bool :
356
+ """Indicates if the stopping criterion has been reached"""
357
+ if self .check == "leq" :
358
+ return self ._value <= self ._tolerance
359
+ else :
360
+ return self ._value >= self ._tolerance
361
+
362
+ def summary (self ) -> str :
363
+ return f"{ self .name } : { self .value :.4g} ({ self .tolerance :.2g} )"
364
+
365
+
279
366
class ImportanceNestedSampler (BaseNestedSampler ):
280
367
"""
281
368
@@ -326,33 +413,46 @@ class ImportanceNestedSampler(BaseNestedSampler):
326
413
If False, this can help reduce the disk usage.
327
414
"""
328
415
329
- stopping_criterion_config = dict (
330
- ratio = dict (
331
- type = "min" ,
416
+ stopping_criteria = {
417
+ "ratio" : StoppingCriterion (
418
+ name = "ratio" ,
419
+ check = "leq" ,
332
420
aliases = ["ratio" , "ratio_all" ],
421
+ tolerance = 0.0 ,
422
+ value = np .inf ,
333
423
),
334
- ratio_ns = dict (
335
- type = "min" ,
424
+ "ratio_ns" : StoppingCriterion (
425
+ name = "ratio_ns" ,
426
+ check = "leq" ,
336
427
aliases = ["ratio_ns" ],
428
+ tolerance = 0.0 ,
429
+ value = np .inf ,
337
430
),
338
- Z_err = dict (
339
- type = "min" ,
431
+ "Z_err" : StoppingCriterion (
432
+ name = "Z_err" ,
433
+ check = "leq" ,
340
434
aliases = ["Z_err" , "evidence_error" ],
435
+ value = np .inf ,
341
436
),
342
- log_dZ = dict (
343
- type = "min" ,
437
+ "log_dZ" : StoppingCriterion (
438
+ name = "log_dZ" ,
439
+ check = "leq" ,
344
440
aliases = ["log_dZ" , "log_evidence" ],
441
+ value = np .inf ,
345
442
),
346
- ess = dict (
347
- type = "max" ,
443
+ "ess" : StoppingCriterion (
444
+ name = "ess" ,
445
+ check = "geq" ,
348
446
aliases = ["ess" ],
447
+ value = 0.0 ,
349
448
),
350
- fractional_error = dict (
351
- type = "min" ,
449
+ "fractional_error" : StoppingCriterion (
450
+ name = "fractional_error" ,
451
+ check = "leq" ,
352
452
aliases = ["fractional_error" ],
453
+ value = 0.0 ,
353
454
),
354
- )
355
- """Dictionary of available stopping criteria and their aliases."""
455
+ }
356
456
357
457
def __init__ (
358
458
self ,
@@ -378,7 +478,7 @@ def __init__(
378
478
min_remove : int = 1 ,
379
479
max_samples : Optional [int ] = None ,
380
480
stopping_criterion : str = "ratio" ,
381
- tolerance : float = 0.0 ,
481
+ tolerance : Optional [ float ] = None ,
382
482
n_update : Optional [int ] = None ,
383
483
plot_pool : bool = False ,
384
484
plot_level_cdf : bool = False ,
@@ -667,16 +767,14 @@ def reached_tolerance(self) -> bool:
667
767
Checks if any or all of the criteria have been met, this depends on the
668
768
value of :code:`check_criteria`.
669
769
"""
670
- flags = {}
671
- for sc , value in self .criterion .items ():
672
- if self .stopping_criterion_config [sc ]["type" ] == "min" :
673
- flags [sc ] = value <= self .tolerance [sc ]
674
- else :
675
- flags [sc ] = value > self .tolerance [sc ]
770
+ flags = [
771
+ self .stopping_criteria [name ].reached_tolerance
772
+ for name in self .criteria_to_check
773
+ ]
676
774
if self ._stop_any :
677
- return any (flags . values () )
775
+ return any (flags )
678
776
else :
679
- return all (flags . values () )
777
+ return all (flags )
680
778
681
779
@staticmethod
682
780
def add_fields ():
@@ -698,41 +796,17 @@ def configure_stopping_criterion(
698
796
else :
699
797
tolerance = [float (tolerance )]
700
798
701
- self .stopping_criterion = []
702
- for c in stopping_criterion :
703
- for criterion , cfg in self .stopping_criterion_config .items ():
704
- if c in cfg ["aliases" ]:
705
- self .stopping_criterion .append (criterion )
706
- if len (self .stopping_criterion ) != len (stopping_criterion ):
707
- raise ValueError (
708
- f"Unknown stopping criterion: { stopping_criterion } "
709
- )
710
- for c , c_use in zip (stopping_criterion , self .stopping_criterion ):
711
- if c != c_use :
712
- logger .info (
713
- f"Stopping criterion specified ({ c } ) is "
714
- f"an alias for { c_use } . Using { c_use } ."
715
- )
716
-
717
- self .tolerance = {
718
- sc : t for sc , t in zip (stopping_criterion , tolerance )
719
- }
720
- if len (self .stopping_criterion ) != len (self .tolerance ):
721
- raise ValueError (
722
- "Number of stopping criteria must match tolerances"
723
- )
724
-
725
- types = {
726
- sc : self .stopping_criterion_config [sc ]["type" ]
727
- for sc in stopping_criterion
728
- }
729
- self .criterion = {
730
- sc : np .inf if t == "min" else - np .inf for sc , t in types .items ()
731
- }
732
-
733
- logger .info (f"Stopping criteria: { self .stopping_criterion } " )
734
- logger .info (f"Tolerance: { self .tolerance } " )
799
+ self .criteria_to_check = []
800
+ for name , tol in zip (stopping_criterion , tolerance ):
801
+ for sc in self .stopping_criteria .values ():
802
+ if name in sc .aliases :
803
+ sc .update_tolerance (tol )
804
+ self .criteria_to_check .append (name )
805
+ break
806
+ else :
807
+ raise ValueError (f"Unknown stopping criterion: { name } " )
735
808
809
+ logger .info (f"Stopping criteria to check: { self .criteria_to_check } " )
736
810
if check_criteria not in {"any" , "all" }:
737
811
raise ValueError ("check_criteria must be any or all" )
738
812
if check_criteria == "any" :
@@ -870,7 +944,7 @@ def initialise_history(self) -> None:
870
944
samples_entropy = [],
871
945
proposal_entropy = [],
872
946
stopping_criteria = {
873
- k : [] for k in self .stopping_criterion_config .keys ()
947
+ name : [] for name in self .stopping_criteria .keys ()
874
948
},
875
949
)
876
950
)
@@ -900,10 +974,8 @@ def update_history(self) -> None:
900
974
self .model .likelihood_evaluations
901
975
)
902
976
903
- for k in self .stopping_criterion_config .keys ():
904
- self .history ["stopping_criteria" ][k ].append (
905
- getattr (self , k , np .nan )
906
- )
977
+ for name , sc in self .stopping_criteria .items ():
978
+ self .history ["stopping_criteria" ][name ].append (sc .value )
907
979
908
980
def determine_threshold_quantile (
909
981
self ,
@@ -1456,11 +1528,18 @@ def compute_stopping_criterion(self) -> List[float]:
1456
1528
self .ess = self .state .effective_n_posterior_samples
1457
1529
self .Z_err = np .exp (self .log_evidence_error )
1458
1530
self .fractional_error = self .state .evidence_error / self .state .evidence
1459
- cond = {sc : getattr (self , sc ) for sc in self .stopping_criterion }
1460
1531
1461
- logger .info (
1462
- f"Stopping criteria: { cond } " f"- Tolerance: { self .tolerance } "
1463
- )
1532
+ cond = {}
1533
+ for name , sc in self .stopping_criteria .items ():
1534
+ sc .update_value_from_sampler (self )
1535
+ if name in self .criteria_to_check :
1536
+ cond [name ] = sc .value
1537
+
1538
+ status = [
1539
+ self .stopping_criteria [sc ].summary ()
1540
+ for sc in self .criteria_to_check
1541
+ ]
1542
+ logger .info (f"Stopping criteria: { status } " )
1464
1543
return cond
1465
1544
1466
1545
def checkpoint (self , periodic : bool = False , force : bool = False ):
@@ -2049,17 +2128,17 @@ def plot_state(
2049
2128
ax [m ].legend ()
2050
2129
m += 1
2051
2130
2052
- for (i , sc ), tol in zip (
2053
- enumerate (self .stopping_criterion ), self .tolerance
2054
- ):
2131
+ for i , sc_name in enumerate (self .criteria_to_check ):
2055
2132
ax [m ].plot (
2056
2133
its ,
2057
- self .history ["stopping_criteria" ][sc ],
2058
- label = sc ,
2134
+ self .history ["stopping_criteria" ][sc_name ],
2135
+ label = sc_name ,
2059
2136
c = f"C{ i } " ,
2060
2137
ls = config .plotting .line_styles [i ],
2061
2138
)
2062
- ax [m ].axhline (tol , ls = ":" , c = f"C{ i } " )
2139
+ ax [m ].axhline (
2140
+ self .stopping_criteria [sc_name ].tolerance , ls = ":" , c = f"C{ i } "
2141
+ )
2063
2142
ax [m ].legend ()
2064
2143
ax [m ].set_ylabel ("Stopping criterion" )
2065
2144
0 commit comments