diff --git a/custom_saes/run_all_evals_custom_saes.py b/custom_saes/run_all_evals_custom_saes.py index 750d124..58df6a1 100644 --- a/custom_saes/run_all_evals_custom_saes.py +++ b/custom_saes/run_all_evals_custom_saes.py @@ -146,7 +146,10 @@ def run_evals( "unlearning": ( lambda: unlearning.run_eval( unlearning.UnlearningEvalConfig( - model_name="gemma-2-2b-it", random_seed=RANDOM_SEED, llm_dtype=llm_dtype + model_name="gemma-2-2b-it", + random_seed=RANDOM_SEED, + llm_dtype=llm_dtype, + llm_batch_size=llm_batch_size, ), selected_saes, device,