diff --git a/evals/unlearning/eval_config.py b/evals/unlearning/eval_config.py index 2e672e7..5e36f45 100644 --- a/evals/unlearning/eval_config.py +++ b/evals/unlearning/eval_config.py @@ -1,11 +1,12 @@ -from dataclasses import dataclass, field - +from pydantic.dataclasses import dataclass +from pydantic import Field +from evals.base_eval_output import BaseEvalConfig @dataclass -class EvalConfig: - random_seed: int = 42 +class UnlearningEvalConfig(BaseEvalConfig): + random_seed: int = Field(default=42, title="Random Seed", description="Random seed") - all_dataset_names: list[str] = field( + dataset_names: list[str] = Field( default_factory=lambda: [ "wmdp-bio", "high_school_us_history", @@ -13,25 +14,78 @@ class EvalConfig: "high_school_geography", "human_aging", "college_biology", - ] + ], + title="Dataset Names", + description="List of dataset names", ) - intervention_method: str = "clamp_feature_activation" + intervention_method: str = Field( + default="clamp_feature_activation", + title="Intervention Method", + description="Intervention method", + ) - retain_thresholds: list[str] = field(default_factory=lambda: [0.001, 0.01]) - n_features_list: list[str] = field(default_factory=lambda: [10, 20]) - multipliers: list[str] = field(default_factory=lambda: [25, 50, 100, 200]) + retain_thresholds: list[float] = Field( + default_factory=lambda: [0.001, 0.01], + title="Retain Thresholds", + description="Retain thresholds", + ) + n_features_list: list[int] = Field( + default_factory=lambda: [10, 20], + title="N Features List", + description="N features list", + ) + multipliers: list[int] = Field( + default_factory=lambda: [25, 50, 100, 200], + title="Multipliers", + description="Multipliers", + ) - llm_batch_size: int = 4 - # multiple choice questions are shorter, so we can afford a larger batch size - mcq_batch_size: int = llm_batch_size * 2 + llm_batch_size: int = Field( + default=4, + title="LLM Batch Size", + description="LLM batch size", + ) + mcq_batch_size: int = Field( + default=8, + title="MCQ Batch Size", + description="MCQ batch size. Multiple choice questions are shorter, so we can afford a larger batch size", + ) - dataset_size: int = 1024 - seq_len: int = 1024 + dataset_size: int = Field( + default=1024, + title="Dataset Size", + description="Dataset size", + ) + seq_len: int = Field( + default=1024, + title="Sequence Length", + description="Sequence length", + ) - n_batch_loss_added: int = 50 - target_metric: str = "correct" - save_metrics: bool = True + n_batch_loss_added: int = Field( + default=50, + title="N Batch Loss Added", + description="N batch loss added", + ) + target_metric: str = Field( + default="correct", + title="Target Metric", + description="Target metric", + ) + save_metrics: bool = Field( + default=True, + title="Save Metrics", + description="Save metrics", + ) - model_name: str = "gemma-2-2b-it" - llm_dtype: str = "bfloat16" + model_name: str = Field( + default="gemma-2-2b-it", + title="Model Name", + description="Model name", + ) + llm_dtype: str = Field( + default="bfloat16", + title="LLM Data Type", + description="LLM data type", + ) diff --git a/evals/unlearning/eval_output.py b/evals/unlearning/eval_output.py new file mode 100644 index 0000000..411cf47 --- /dev/null +++ b/evals/unlearning/eval_output.py @@ -0,0 +1,49 @@ +from pydantic.dataclasses import dataclass +from pydantic import ConfigDict, Field +from evals.unlearning.eval_config import UnlearningEvalConfig +from evals.base_eval_output import ( + BaseEvalOutput, + BaseMetricCategories, + BaseMetrics, + DEFAULT_DISPLAY, + BaseResultDetail, +) + +EVAL_TYPE_ID_UNLEARNING = "unlearning" + + +@dataclass +class UnlearningMetrics(BaseMetrics): + unlearning_score: float = Field( + title="Unlearning Score", + description="Unlearning score", + json_schema_extra=DEFAULT_DISPLAY, + ) + +# Define the categories themselves +@dataclass +class UnlearningMetricCategories(BaseMetricCategories): + unlearning: UnlearningMetrics = Field( + title="Unlearning", + description="Metrics related to unlearning", + ) + +# Define the eval output +@dataclass(config=ConfigDict(title="Unlearning")) +class UnlearningEvalOutput( + BaseEvalOutput[UnlearningEvalConfig, UnlearningMetricCategories, BaseResultDetail] +): + """ + The output of core SAE evaluations measuring reconstruction quality, sparsity, and model preservation. + """ + + eval_config: UnlearningEvalConfig + eval_id: str + datetime_epoch_millis: int + eval_result_metrics: UnlearningMetricCategories + + eval_type_id: str = Field( + default=EVAL_TYPE_ID_UNLEARNING, + title="Eval Type ID", + description="The type of the evaluation", + ) diff --git a/evals/unlearning/eval_output_schema_unlearning.json b/evals/unlearning/eval_output_schema_unlearning.json new file mode 100644 index 0000000..e601e5f --- /dev/null +++ b/evals/unlearning/eval_output_schema_unlearning.json @@ -0,0 +1,244 @@ +{ + "$defs": { + "BaseResultDetail": { + "properties": {}, + "title": "BaseResultDetail", + "type": "object" + }, + "UnlearningEvalConfig": { + "properties": { + "random_seed": { + "default": 42, + "description": "Random seed", + "title": "Random Seed", + "type": "integer" + }, + "dataset_names": { + "description": "List of dataset names", + "items": { + "type": "string" + }, + "title": "Dataset Names", + "type": "array" + }, + "intervention_method": { + "default": "clamp_feature_activation", + "description": "Intervention method", + "title": "Intervention Method", + "type": "string" + }, + "retain_thresholds": { + "description": "Retain thresholds", + "items": { + "type": "number" + }, + "title": "Retain Thresholds", + "type": "array" + }, + "n_features_list": { + "description": "N features list", + "items": { + "type": "integer" + }, + "title": "N Features List", + "type": "array" + }, + "multipliers": { + "description": "Multipliers", + "items": { + "type": "integer" + }, + "title": "Multipliers", + "type": "array" + }, + "llm_batch_size": { + "default": 4, + "description": "LLM batch size", + "title": "LLM Batch Size", + "type": "integer" + }, + "mcq_batch_size": { + "default": 8, + "description": "MCQ batch size. Multiple choice questions are shorter, so we can afford a larger batch size", + "title": "MCQ Batch Size", + "type": "integer" + }, + "dataset_size": { + "default": 1024, + "description": "Dataset size", + "title": "Dataset Size", + "type": "integer" + }, + "seq_len": { + "default": 1024, + "description": "Sequence length", + "title": "Sequence Length", + "type": "integer" + }, + "n_batch_loss_added": { + "default": 50, + "description": "N batch loss added", + "title": "N Batch Loss Added", + "type": "integer" + }, + "target_metric": { + "default": "correct", + "description": "Target metric", + "title": "Target Metric", + "type": "string" + }, + "save_metrics": { + "default": true, + "description": "Save metrics", + "title": "Save Metrics", + "type": "boolean" + }, + "model_name": { + "default": "gemma-2-2b-it", + "description": "Model name", + "title": "Model Name", + "type": "string" + }, + "llm_dtype": { + "default": "bfloat16", + "description": "LLM data type", + "title": "LLM Data Type", + "type": "string" + } + }, + "title": "UnlearningEvalConfig", + "type": "object" + }, + "UnlearningMetricCategories": { + "properties": { + "unlearning": { + "$ref": "#/$defs/UnlearningMetrics", + "description": "Metrics related to unlearning", + "title": "Unlearning" + } + }, + "required": [ + "unlearning" + ], + "title": "UnlearningMetricCategories", + "type": "object" + }, + "UnlearningMetrics": { + "properties": { + "unlearning_score": { + "description": "Unlearning score", + "title": "Unlearning Score", + "type": "number", + "ui_default_display": true + } + }, + "required": [ + "unlearning_score" + ], + "title": "UnlearningMetrics", + "type": "object" + } + }, + "description": "The output of core SAE evaluations measuring reconstruction quality, sparsity, and model preservation.", + "properties": { + "eval_type_id": { + "default": "unlearning", + "description": "The type of the evaluation", + "title": "Eval Type ID", + "type": "string" + }, + "eval_config": { + "$ref": "#/$defs/UnlearningEvalConfig", + "description": "The configuration of the evaluation.", + "title": "Eval Config Type" + }, + "eval_id": { + "description": "A unique UUID identifying this specific eval run", + "title": "ID", + "type": "string" + }, + "datetime_epoch_millis": { + "description": "The datetime of the evaluation in epoch milliseconds", + "title": "DateTime (epoch ms)", + "type": "integer" + }, + "eval_result_metrics": { + "$ref": "#/$defs/UnlearningMetricCategories", + "description": "The metrics of the evaluation, organized by category. Define your own categories and the metrics that go inside them.", + "title": "Result Metrics Categorized" + }, + "eval_result_details": { + "default": null, + "description": "Optional. The details of the evaluation. A list of objects that stores nested or more detailed data, such as details about the absorption of each letter.", + "items": { + "$ref": "#/$defs/BaseResultDetail" + }, + "title": "Result Details", + "type": "array" + }, + "sae_bench_commit_hash": { + "description": "The commit hash of the SAE Bench that ran the evaluation.", + "title": "SAE Bench Commit Hash", + "type": "string" + }, + "sae_lens_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "description": "The ID of the SAE in SAE Lens.", + "title": "SAE Lens ID" + }, + "sae_lens_release_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "description": "The release ID of the SAE in SAE Lens.", + "title": "SAE Lens Release ID" + }, + "sae_lens_version": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "description": "The version of SAE Lens that ran the evaluation.", + "title": "SAE Lens Version" + }, + "eval_result_unstructured": { + "anyOf": [ + {}, + { + "type": "null" + } + ], + "default": null, + "description": "Optional. Any additional outputs that don't fit into the structured eval_result_metrics or eval_result_details fields. Since these are unstructured, don't expect this to be easily renderable in UIs, or contain any titles or descriptions.", + "title": "Unstructured Results" + } + }, + "required": [ + "eval_config", + "eval_id", + "datetime_epoch_millis", + "eval_result_metrics", + "sae_bench_commit_hash", + "sae_lens_id", + "sae_lens_release_id", + "sae_lens_version" + ], + "title": "Unlearning", + "type": "object" +} \ No newline at end of file diff --git a/evals/unlearning/main.py b/evals/unlearning/main.py index 0d9c12f..d0bc3b6 100644 --- a/evals/unlearning/main.py +++ b/evals/unlearning/main.py @@ -15,12 +15,9 @@ from datetime import datetime from transformer_lens import HookedTransformer from sae_lens import SAE - from evals.unlearning.utils.eval import run_eval_single_sae -import evals.unlearning.eval_config as eval_config import sae_bench_utils.activation_collection as activation_collection -import sae_bench_utils.formatting_utils as formatting_utils -import evals.unlearning.eval_config as eval_config +from evals.unlearning.eval_config import UnlearningEvalConfig from sae_bench_utils import ( get_eval_uuid, get_sae_lens_version, @@ -104,7 +101,7 @@ def convert_ndarrays_to_lists(obj): def run_eval( - config: eval_config.EvalConfig, + config: UnlearningEvalConfig, selected_saes_dict: dict[str, list[str]], device: str, output_path: str, @@ -223,8 +220,8 @@ def setup_environment(): def create_config_and_selected_saes( args, -) -> tuple[eval_config.EvalConfig, dict[str, list[str]]]: - config = eval_config.EvalConfig( +) -> tuple[UnlearningEvalConfig, dict[str, list[str]]]: + config = UnlearningEvalConfig( random_seed=args.random_seed, model_name=args.model_name, ) diff --git a/evals/unlearning/utils/eval.py b/evals/unlearning/utils/eval.py index 285c45f..fc88da6 100644 --- a/evals/unlearning/utils/eval.py +++ b/evals/unlearning/utils/eval.py @@ -1,16 +1,14 @@ -import argparse import os import numpy as np from transformer_lens import HookedTransformer from sae_lens import SAE - from evals.unlearning.utils.feature_activation import ( get_top_features, load_sparsity_data, save_feature_sparsity, ) from evals.unlearning.utils.metrics import calculate_metrics_list -import evals.unlearning.eval_config as eval_config +from evals.unlearning.eval_config import UnlearningEvalConfig def run_metrics_calculation( @@ -21,10 +19,10 @@ def run_metrics_calculation( retain_sparsity: np.ndarray, artifacts_folder: str, sae_name: str, - config: eval_config.EvalConfig, + config: UnlearningEvalConfig, force_rerun: bool, ): - all_dataset_names = config.all_dataset_names + dataset_names = config.dataset_names for retain_threshold in config.retain_thresholds: top_features_custom = get_top_features( @@ -53,7 +51,7 @@ def run_metrics_calculation( sweep, artifacts_folder, force_rerun, - all_dataset_names, + dataset_names, n_batch_loss_added=config.n_batch_loss_added, activation_store=activation_store, target_metric=config.target_metric, @@ -68,7 +66,7 @@ def run_metrics_calculation( def run_eval_single_sae( model: HookedTransformer, sae: SAE, - config: eval_config.EvalConfig, + config: UnlearningEvalConfig, artifacts_folder: str, sae_release_and_id: str, force_rerun: bool, diff --git a/tests/test_data/unlearning/sae_bench_gemma-2-2b_sweep_topk_ctx128_ef8_0824_blocks.3.hook_resid_post__trainer_2_eval_results.json b/tests/test_data/unlearning/sae_bench_gemma-2-2b_sweep_topk_ctx128_ef8_0824_blocks.3.hook_resid_post__trainer_2_eval_results.json index 8685dfc..16ddded 100644 --- a/tests/test_data/unlearning/sae_bench_gemma-2-2b_sweep_topk_ctx128_ef8_0824_blocks.3.hook_resid_post__trainer_2_eval_results.json +++ b/tests/test_data/unlearning/sae_bench_gemma-2-2b_sweep_topk_ctx128_ef8_0824_blocks.3.hook_resid_post__trainer_2_eval_results.json @@ -1,77 +1,65 @@ { - "sae_bench_gemma-2-2b_sweep_topk_ctx128_ef8_0824_blocks.3.hook_resid_post__trainer_2": { - "eval_instance_id": "d5b4749a-d95f-4efd-b20f-e61bc7fb4ca7", - "sae_lens_release": "sae_bench_gemma-2-2b_sweep_topk_ctx128_ef8_0824", - "sae_lens_id": "blocks.3.hook_resid_post__trainer_2", - "eval_type_id": "unlearning", - "sae_lens_version": "4.1.0", - "sae_bench_version": "5d545d199309334ac9985c036cbfb3b2f1a4638d", - "date_time": "2024-11-05T21:28:04.387840", - "eval_config": { - "random_seed": 44, - "all_dataset_names": [ - "wmdp-bio", - "high_school_us_history", - "college_computer_science", - "high_school_geography", - "human_aging", - "college_biology" - ], - "intervention_method": "clamp_feature_activation", - "retain_thresholds": [ - 0.01 - ], - "n_features_list": [ - 10 - ], - "multipliers": [ - 25 - ], - "llm_batch_size": 4, - "mcq_batch_size": 8, - "dataset_size": 256, - "seq_len": 1024, - "n_batch_loss_added": 50, - "target_metric": "correct", - "save_metrics": true, - "model_name": "gemma-2-2b-it", - "llm_dtype": "bfloat16" - }, - "eval_results": { - "unlearning_score": 0.17782026529312134 - }, - "eval_artifacts": { - "artifacts": "artifacts/unlearning/gemma-2-2b-it" - } + "sae_bench_gemma-2-2b_sweep_topk_ctx128_ef8_0824_blocks.3.hook_resid_post__trainer_2": { + "eval_instance_id": "d5b4749a-d95f-4efd-b20f-e61bc7fb4ca7", + "sae_lens_release": "sae_bench_gemma-2-2b_sweep_topk_ctx128_ef8_0824", + "sae_lens_id": "blocks.3.hook_resid_post__trainer_2", + "eval_type_id": "unlearning", + "sae_lens_version": "4.1.0", + "sae_bench_version": "5d545d199309334ac9985c036cbfb3b2f1a4638d", + "date_time": "2024-11-05T21:28:04.387840", + "eval_config": { + "random_seed": 44, + "dataset_names": [ + "wmdp-bio", + "high_school_us_history", + "college_computer_science", + "high_school_geography", + "human_aging", + "college_biology" + ], + "intervention_method": "clamp_feature_activation", + "retain_thresholds": [0.01], + "n_features_list": [10], + "multipliers": [25], + "llm_batch_size": 4, + "mcq_batch_size": 8, + "dataset_size": 256, + "seq_len": 1024, + "n_batch_loss_added": 50, + "target_metric": "correct", + "save_metrics": true, + "model_name": "gemma-2-2b-it", + "llm_dtype": "bfloat16" }, - "custom_eval_config": { - "random_seed": 44, - "all_dataset_names": [ - "wmdp-bio", - "high_school_us_history", - "college_computer_science", - "high_school_geography", - "human_aging", - "college_biology" - ], - "intervention_method": "clamp_feature_activation", - "retain_thresholds": [ - 0.01 - ], - "n_features_list": [ - 10 - ], - "multipliers": [ - 25 - ], - "llm_batch_size": 4, - "mcq_batch_size": 8, - "dataset_size": 256, - "seq_len": 1024, - "n_batch_loss_added": 50, - "target_metric": "correct", - "save_metrics": true, - "model_name": "gemma-2-2b-it", - "llm_dtype": "bfloat16" + "eval_results": { + "unlearning_score": 0.17782026529312134 + }, + "eval_artifacts": { + "artifacts": "artifacts/unlearning/gemma-2-2b-it" } -} \ No newline at end of file + }, + "custom_eval_config": { + "random_seed": 44, + "dataset_names": [ + "wmdp-bio", + "high_school_us_history", + "college_computer_science", + "high_school_geography", + "human_aging", + "college_biology" + ], + "intervention_method": "clamp_feature_activation", + "retain_thresholds": [0.01], + "n_features_list": [10], + "multipliers": [25], + "llm_batch_size": 4, + "mcq_batch_size": 8, + "dataset_size": 256, + "seq_len": 1024, + "n_batch_loss_added": 50, + "target_metric": "correct", + "save_metrics": true, + "model_name": "gemma-2-2b-it", + "llm_dtype": "bfloat16" + } +} diff --git a/tests/test_unlearning.py b/tests/test_unlearning.py index 4088628..0b1245d 100644 --- a/tests/test_unlearning.py +++ b/tests/test_unlearning.py @@ -1,6 +1,6 @@ import json import torch -from evals.unlearning.eval_config import EvalConfig +from evals.unlearning.eval_config import UnlearningEvalConfig import evals.unlearning.main as unlearning import sae_bench_utils.testing_utils as testing_utils from sae_bench_utils.sae_selection_utils import select_saes_multiple_patterns @@ -18,7 +18,7 @@ def test_end_to_end_different_seed(): print(f"Using device: {device}") - test_config = EvalConfig() + test_config = UnlearningEvalConfig() test_config.retain_thresholds = [0.01] test_config.n_features_list = [10] @@ -45,7 +45,7 @@ def test_end_to_end_different_seed(): device, output_path="evals/unlearning/test_results/", force_rerun=True, - clean_up_artifacts=True, + clean_up_artifacts=False, ) with open(results_filename, "r") as f: