Skip to content

Commit

Permalink
unlearning start
Browse files Browse the repository at this point in the history
  • Loading branch information
hijohnnylin committed Nov 8, 2024
1 parent b6ed053 commit a7be6df
Show file tree
Hide file tree
Showing 7 changed files with 441 additions and 111 deletions.
94 changes: 74 additions & 20 deletions evals/unlearning/eval_config.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,91 @@
from dataclasses import dataclass, field

from pydantic.dataclasses import dataclass
from pydantic import Field
from evals.base_eval_output import BaseEvalConfig

@dataclass
class EvalConfig:
random_seed: int = 42
class UnlearningEvalConfig(BaseEvalConfig):
random_seed: int = Field(default=42, title="Random Seed", description="Random seed")

all_dataset_names: list[str] = field(
dataset_names: list[str] = Field(
default_factory=lambda: [
"wmdp-bio",
"high_school_us_history",
"college_computer_science",
"high_school_geography",
"human_aging",
"college_biology",
]
],
title="Dataset Names",
description="List of dataset names",
)

intervention_method: str = "clamp_feature_activation"
intervention_method: str = Field(
default="clamp_feature_activation",
title="Intervention Method",
description="Intervention method",
)

retain_thresholds: list[str] = field(default_factory=lambda: [0.001, 0.01])
n_features_list: list[str] = field(default_factory=lambda: [10, 20])
multipliers: list[str] = field(default_factory=lambda: [25, 50, 100, 200])
retain_thresholds: list[float] = Field(
default_factory=lambda: [0.001, 0.01],
title="Retain Thresholds",
description="Retain thresholds",
)
n_features_list: list[int] = Field(
default_factory=lambda: [10, 20],
title="N Features List",
description="N features list",
)
multipliers: list[int] = Field(
default_factory=lambda: [25, 50, 100, 200],
title="Multipliers",
description="Multipliers",
)

llm_batch_size: int = 4
# multiple choice questions are shorter, so we can afford a larger batch size
mcq_batch_size: int = llm_batch_size * 2
llm_batch_size: int = Field(
default=4,
title="LLM Batch Size",
description="LLM batch size",
)
mcq_batch_size: int = Field(
default=8,
title="MCQ Batch Size",
description="MCQ batch size. Multiple choice questions are shorter, so we can afford a larger batch size",
)

dataset_size: int = 1024
seq_len: int = 1024
dataset_size: int = Field(
default=1024,
title="Dataset Size",
description="Dataset size",
)
seq_len: int = Field(
default=1024,
title="Sequence Length",
description="Sequence length",
)

n_batch_loss_added: int = 50
target_metric: str = "correct"
save_metrics: bool = True
n_batch_loss_added: int = Field(
default=50,
title="N Batch Loss Added",
description="N batch loss added",
)
target_metric: str = Field(
default="correct",
title="Target Metric",
description="Target metric",
)
save_metrics: bool = Field(
default=True,
title="Save Metrics",
description="Save metrics",
)

model_name: str = "gemma-2-2b-it"
llm_dtype: str = "bfloat16"
model_name: str = Field(
default="gemma-2-2b-it",
title="Model Name",
description="Model name",
)
llm_dtype: str = Field(
default="bfloat16",
title="LLM Data Type",
description="LLM data type",
)
49 changes: 49 additions & 0 deletions evals/unlearning/eval_output.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from pydantic.dataclasses import dataclass
from pydantic import ConfigDict, Field
from evals.unlearning.eval_config import UnlearningEvalConfig
from evals.base_eval_output import (
BaseEvalOutput,
BaseMetricCategories,
BaseMetrics,
DEFAULT_DISPLAY,
BaseResultDetail,
)

EVAL_TYPE_ID_UNLEARNING = "unlearning"


@dataclass
class UnlearningMetrics(BaseMetrics):
unlearning_score: float = Field(
title="Unlearning Score",
description="Unlearning score",
json_schema_extra=DEFAULT_DISPLAY,
)

# Define the categories themselves
@dataclass
class UnlearningMetricCategories(BaseMetricCategories):
unlearning: UnlearningMetrics = Field(
title="Unlearning",
description="Metrics related to unlearning",
)

# Define the eval output
@dataclass(config=ConfigDict(title="Unlearning"))
class UnlearningEvalOutput(
BaseEvalOutput[UnlearningEvalConfig, UnlearningMetricCategories, BaseResultDetail]
):
"""
The output of core SAE evaluations measuring reconstruction quality, sparsity, and model preservation.
"""

eval_config: UnlearningEvalConfig
eval_id: str
datetime_epoch_millis: int
eval_result_metrics: UnlearningMetricCategories

eval_type_id: str = Field(
default=EVAL_TYPE_ID_UNLEARNING,
title="Eval Type ID",
description="The type of the evaluation",
)
Loading

0 comments on commit a7be6df

Please sign in to comment.