@@ -326,15 +326,31 @@ class ImportanceNestedSampler(BaseNestedSampler):
326
326
If False, this can help reduce the disk usage.
327
327
"""
328
328
329
- stopping_criterion_aliases = dict (
330
- ratio = ["ratio" , "ratio_all" ],
331
- ratio_ns = ["ratio_ns" ],
332
- Z_err = ["Z_err" , "evidence_error" ],
333
- log_dZ = ["log_dZ" , "log_evidence" ],
334
- ess = [
335
- "ess" ,
336
- ],
337
- fractional_error = ["fractional_error" ],
329
+ stopping_criterion_config = dict (
330
+ ratio = dict (
331
+ type = "min" ,
332
+ aliases = ["ratio" , "ratio_all" ],
333
+ ),
334
+ ratio_ns = dict (
335
+ type = "min" ,
336
+ aliases = ["ratio_ns" ],
337
+ ),
338
+ Z_err = dict (
339
+ type = "min" ,
340
+ aliases = ["Z_err" , "evidence_error" ],
341
+ ),
342
+ log_dZ = dict (
343
+ type = "min" ,
344
+ aliases = ["log_dZ" , "log_evidence" ],
345
+ ),
346
+ ess = dict (
347
+ type = "max" ,
348
+ aliases = ["ess" ],
349
+ ),
350
+ fractional_error = dict (
351
+ type = "min" ,
352
+ aliases = ["fractional_error" ],
353
+ ),
338
354
)
339
355
"""Dictionary of available stopping criteria and their aliases."""
340
356
@@ -455,7 +471,7 @@ def __init__(
455
471
self .log_dZ = np .inf
456
472
self .ratio = np .inf
457
473
self .ratio_ns = np .inf
458
- self .ess = 0.0
474
+ self .ess = 0
459
475
self .Z_err = np .inf
460
476
461
477
self ._final_samples = None
@@ -651,14 +667,16 @@ def reached_tolerance(self) -> bool:
651
667
Checks if any or all of the criteria have been met, this depends on the
652
668
value of :code:`check_criteria`.
653
669
"""
670
+ flags = {}
671
+ for sc , value in self .criterion .items ():
672
+ if self .stopping_criterion_config [sc ]["type" ] == "min" :
673
+ flags [sc ] = value <= self .tolerance [sc ]
674
+ else :
675
+ flags [sc ] = value > self .tolerance [sc ]
654
676
if self ._stop_any :
655
- return any (
656
- [c <= t for c , t in zip (self .criterion , self .tolerance )]
657
- )
677
+ return any (flags .values ())
658
678
else :
659
- return all (
660
- [c <= t for c , t in zip (self .criterion , self .tolerance )]
661
- )
679
+ return all (flags .values ())
662
680
663
681
@staticmethod
664
682
def add_fields ():
@@ -676,16 +694,16 @@ def configure_stopping_criterion(
676
694
stopping_criterion = [stopping_criterion ]
677
695
678
696
if isinstance (tolerance , list ):
679
- self . tolerance = [float (t ) for t in tolerance ]
697
+ tolerance = [float (t ) for t in tolerance ]
680
698
else :
681
- self . tolerance = [float (tolerance )]
699
+ tolerance = [float (tolerance )]
682
700
683
701
self .stopping_criterion = []
684
702
for c in stopping_criterion :
685
- for criterion , aliases in self .stopping_criterion_aliases .items ():
686
- if c in aliases :
703
+ for criterion , cfg in self .stopping_criterion_config .items ():
704
+ if c in cfg [ " aliases" ] :
687
705
self .stopping_criterion .append (criterion )
688
- if not self .stopping_criterion :
706
+ if len ( self .stopping_criterion ) != len ( stopping_criterion ) :
689
707
raise ValueError (
690
708
f"Unknown stopping criterion: { stopping_criterion } "
691
709
)
@@ -695,11 +713,22 @@ def configure_stopping_criterion(
695
713
f"Stopping criterion specified ({ c } ) is "
696
714
f"an alias for { c_use } . Using { c_use } ."
697
715
)
716
+
717
+ self .tolerance = {
718
+ sc : t for sc , t in zip (stopping_criterion , tolerance )
719
+ }
698
720
if len (self .stopping_criterion ) != len (self .tolerance ):
699
721
raise ValueError (
700
722
"Number of stopping criteria must match tolerances"
701
723
)
702
- self .criterion = len (self .tolerance ) * [np .inf ]
724
+
725
+ types = {
726
+ sc : self .stopping_criterion_config [sc ]["type" ]
727
+ for sc in stopping_criterion
728
+ }
729
+ self .criterion = {
730
+ sc : np .inf if t == "min" else - np .inf for sc , t in types .items ()
731
+ }
703
732
704
733
logger .info (f"Stopping criteria: { self .stopping_criterion } " )
705
734
logger .info (f"Tolerance: { self .tolerance } " )
@@ -841,7 +870,7 @@ def initialise_history(self) -> None:
841
870
samples_entropy = [],
842
871
proposal_entropy = [],
843
872
stopping_criteria = {
844
- k : [] for k in self .stopping_criterion_aliases .keys ()
873
+ k : [] for k in self .stopping_criterion_config .keys ()
845
874
},
846
875
)
847
876
)
@@ -871,7 +900,7 @@ def update_history(self) -> None:
871
900
self .model .likelihood_evaluations
872
901
)
873
902
874
- for k in self .stopping_criterion_aliases .keys ():
903
+ for k in self .stopping_criterion_config .keys ():
875
904
self .history ["stopping_criteria" ][k ].append (
876
905
getattr (self , k , np .nan )
877
906
)
@@ -1427,11 +1456,10 @@ def compute_stopping_criterion(self) -> List[float]:
1427
1456
self .ess = self .state .effective_n_posterior_samples
1428
1457
self .Z_err = np .exp (self .log_evidence_error )
1429
1458
self .fractional_error = self .state .evidence_error / self .state .evidence
1430
- cond = [ getattr (self , sc ) for sc in self .stopping_criterion ]
1459
+ cond = { sc : getattr (self , sc ) for sc in self .stopping_criterion }
1431
1460
1432
1461
logger .info (
1433
- f"Stopping criteria ({ self .stopping_criterion } ): { cond } "
1434
- f"- Tolerance: { self .tolerance } "
1462
+ f"Stopping criteria: { cond } " f"- Tolerance: { self .tolerance } "
1435
1463
)
1436
1464
return cond
1437
1465
@@ -1583,7 +1611,7 @@ def nested_sampling_loop(self):
1583
1611
1584
1612
logger .info (
1585
1613
f"Finished nested sampling loop after { self .iteration } iterations "
1586
- f"with { self .stopping_criterion } = { self . criterion } "
1614
+ f"with { self .criterion } "
1587
1615
)
1588
1616
self .finalise ()
1589
1617
logger .info (f"Training time: { self .training_time } " )
0 commit comments