diff --git a/evals/shift_and_tpp/eval_config.py b/evals/shift_and_tpp/eval_config.py index 3e3570e..75381d4 100644 --- a/evals/shift_and_tpp/eval_config.py +++ b/evals/shift_and_tpp/eval_config.py @@ -14,6 +14,10 @@ class EvalConfig: ) spurious_corr: bool = True + + # This reduces randomness in the SCR results + early_stopping_patience: int = 40 + # Load datset and probes train_set_size: int = 4000 test_set_size: int = 1000 # This is limited as the test set is smaller than the train set diff --git a/evals/shift_and_tpp/main.py b/evals/shift_and_tpp/main.py index a8985e6..ee24e73 100644 --- a/evals/shift_and_tpp/main.py +++ b/evals/shift_and_tpp/main.py @@ -521,6 +521,7 @@ def run_eval_single_dataset( epochs=config.probe_epochs, lr=config.probe_lr, spurious_corr=config.spurious_corr, + early_stopping_patience=config.early_stopping_patience, ) torch.set_grad_enabled(False)