-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
b6ed053
commit a7be6df
Showing
7 changed files
with
441 additions
and
111 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,37 +1,91 @@ | ||
from dataclasses import dataclass, field | ||
|
||
from pydantic.dataclasses import dataclass | ||
from pydantic import Field | ||
from evals.base_eval_output import BaseEvalConfig | ||
|
||
@dataclass | ||
class EvalConfig: | ||
random_seed: int = 42 | ||
class UnlearningEvalConfig(BaseEvalConfig): | ||
random_seed: int = Field(default=42, title="Random Seed", description="Random seed") | ||
|
||
all_dataset_names: list[str] = field( | ||
dataset_names: list[str] = Field( | ||
default_factory=lambda: [ | ||
"wmdp-bio", | ||
"high_school_us_history", | ||
"college_computer_science", | ||
"high_school_geography", | ||
"human_aging", | ||
"college_biology", | ||
] | ||
], | ||
title="Dataset Names", | ||
description="List of dataset names", | ||
) | ||
|
||
intervention_method: str = "clamp_feature_activation" | ||
intervention_method: str = Field( | ||
default="clamp_feature_activation", | ||
title="Intervention Method", | ||
description="Intervention method", | ||
) | ||
|
||
retain_thresholds: list[str] = field(default_factory=lambda: [0.001, 0.01]) | ||
n_features_list: list[str] = field(default_factory=lambda: [10, 20]) | ||
multipliers: list[str] = field(default_factory=lambda: [25, 50, 100, 200]) | ||
retain_thresholds: list[float] = Field( | ||
default_factory=lambda: [0.001, 0.01], | ||
title="Retain Thresholds", | ||
description="Retain thresholds", | ||
) | ||
n_features_list: list[int] = Field( | ||
default_factory=lambda: [10, 20], | ||
title="N Features List", | ||
description="N features list", | ||
) | ||
multipliers: list[int] = Field( | ||
default_factory=lambda: [25, 50, 100, 200], | ||
title="Multipliers", | ||
description="Multipliers", | ||
) | ||
|
||
llm_batch_size: int = 4 | ||
# multiple choice questions are shorter, so we can afford a larger batch size | ||
mcq_batch_size: int = llm_batch_size * 2 | ||
llm_batch_size: int = Field( | ||
default=4, | ||
title="LLM Batch Size", | ||
description="LLM batch size", | ||
) | ||
mcq_batch_size: int = Field( | ||
default=8, | ||
title="MCQ Batch Size", | ||
description="MCQ batch size. Multiple choice questions are shorter, so we can afford a larger batch size", | ||
) | ||
|
||
dataset_size: int = 1024 | ||
seq_len: int = 1024 | ||
dataset_size: int = Field( | ||
default=1024, | ||
title="Dataset Size", | ||
description="Dataset size", | ||
) | ||
seq_len: int = Field( | ||
default=1024, | ||
title="Sequence Length", | ||
description="Sequence length", | ||
) | ||
|
||
n_batch_loss_added: int = 50 | ||
target_metric: str = "correct" | ||
save_metrics: bool = True | ||
n_batch_loss_added: int = Field( | ||
default=50, | ||
title="N Batch Loss Added", | ||
description="N batch loss added", | ||
) | ||
target_metric: str = Field( | ||
default="correct", | ||
title="Target Metric", | ||
description="Target metric", | ||
) | ||
save_metrics: bool = Field( | ||
default=True, | ||
title="Save Metrics", | ||
description="Save metrics", | ||
) | ||
|
||
model_name: str = "gemma-2-2b-it" | ||
llm_dtype: str = "bfloat16" | ||
model_name: str = Field( | ||
default="gemma-2-2b-it", | ||
title="Model Name", | ||
description="Model name", | ||
) | ||
llm_dtype: str = Field( | ||
default="bfloat16", | ||
title="LLM Data Type", | ||
description="LLM data type", | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
from pydantic.dataclasses import dataclass | ||
from pydantic import ConfigDict, Field | ||
from evals.unlearning.eval_config import UnlearningEvalConfig | ||
from evals.base_eval_output import ( | ||
BaseEvalOutput, | ||
BaseMetricCategories, | ||
BaseMetrics, | ||
DEFAULT_DISPLAY, | ||
BaseResultDetail, | ||
) | ||
|
||
EVAL_TYPE_ID_UNLEARNING = "unlearning" | ||
|
||
|
||
@dataclass | ||
class UnlearningMetrics(BaseMetrics): | ||
unlearning_score: float = Field( | ||
title="Unlearning Score", | ||
description="Unlearning score", | ||
json_schema_extra=DEFAULT_DISPLAY, | ||
) | ||
|
||
# Define the categories themselves | ||
@dataclass | ||
class UnlearningMetricCategories(BaseMetricCategories): | ||
unlearning: UnlearningMetrics = Field( | ||
title="Unlearning", | ||
description="Metrics related to unlearning", | ||
) | ||
|
||
# Define the eval output | ||
@dataclass(config=ConfigDict(title="Unlearning")) | ||
class UnlearningEvalOutput( | ||
BaseEvalOutput[UnlearningEvalConfig, UnlearningMetricCategories, BaseResultDetail] | ||
): | ||
""" | ||
The output of core SAE evaluations measuring reconstruction quality, sparsity, and model preservation. | ||
""" | ||
|
||
eval_config: UnlearningEvalConfig | ||
eval_id: str | ||
datetime_epoch_millis: int | ||
eval_result_metrics: UnlearningMetricCategories | ||
|
||
eval_type_id: str = Field( | ||
default=EVAL_TYPE_ID_UNLEARNING, | ||
title="Eval Type ID", | ||
description="The type of the evaluation", | ||
) |
Oops, something went wrong.