Skip to content

Commit

Permalink
Clean up duplicated functions
Browse files Browse the repository at this point in the history
  • Loading branch information
adamkarvonen committed Nov 16, 2024
1 parent d3e8e87 commit 28d2f2f
Show file tree
Hide file tree
Showing 6 changed files with 14 additions and 79 deletions.
18 changes: 3 additions & 15 deletions evals/absorption/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,18 +226,6 @@ def arg_parser():
return parser


def setup_environment():
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

if torch.backends.mps.is_available():
device = "mps"
else:
device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Using device: {device}")
return device


def create_config_and_selected_saes(args) -> tuple[AbsorptionEvalConfig, list[tuple[str, str]]]:
config = AbsorptionEvalConfig(
random_seed=args.random_seed,
Expand Down Expand Up @@ -269,7 +257,7 @@ def create_config_and_selected_saes(args) -> tuple[AbsorptionEvalConfig, list[tu
--model_name pythia-70m-deduped
"""
args = arg_parser().parse_args()
device = setup_environment()
device = formatting_utils.setup_environment()

start_time = time.time()

Expand All @@ -291,15 +279,15 @@ def create_config_and_selected_saes(args) -> tuple[AbsorptionEvalConfig, list[tu
print(f"Finished evaluation in {end_time - start_time:.2f} seconds")


# # Use this code snippet to use custom SAE objects
# Use this code snippet to use custom SAE objects
# if __name__ == "__main__":
# import baselines.identity_sae as identity_sae
# import baselines.jumprelu_sae as jumprelu_sae

# """
# python evals/absorption/main.py
# """
# device = setup_environment()
# device = formatting_utils.setup_environment()

# start_time = time.time()

Expand Down
16 changes: 3 additions & 13 deletions evals/autointerp/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
)
import sae_bench_utils.dataset_utils as dataset_utils
import sae_bench_utils.activation_collection as activation_collection
import sae_bench_utils.formatting_utils as formatting_utils


from sae_bench_utils import (
Expand Down Expand Up @@ -625,17 +626,6 @@ def run_eval(
return results_dict


def setup_environment():
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
if torch.backends.mps.is_available():
device = "mps"
else:
device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Using device: {device}")
return device


def create_config_and_selected_saes(
args,
) -> tuple[AutoInterpEvalConfig, list[tuple[str, str]]]:
Expand Down Expand Up @@ -707,7 +697,7 @@ def arg_parser():
"""
args = arg_parser().parse_args()
device = setup_environment()
device = formatting_utils.setup_environment()

start_time = time.time()

Expand Down Expand Up @@ -748,7 +738,7 @@ def arg_parser():
# import baselines.identity_sae as identity_sae
# import baselines.jumprelu_sae as jumprelu_sae

# device = setup_environment()
# device = formatting_utils.setup_environment()

# start_time = time.time()

Expand Down
13 changes: 1 addition & 12 deletions evals/mdl/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,17 +497,6 @@ def run_eval(
return results_dict


def setup_environment():
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
if torch.backends.mps.is_available():
device = "mps"
else:
device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Using device: {device}")
return device


def create_config_and_selected_saes(
args,
) -> tuple[MDLEvalConfig, list[tuple[str, str]]]:
Expand Down Expand Up @@ -570,7 +559,7 @@ def arg_parser():
logger.add(sys.stdout, level="INFO")

args = arg_parser().parse_args()
device = setup_environment()
device = formatting_utils.setup_environment()

start_time = time.time()

Expand Down
15 changes: 2 additions & 13 deletions evals/shift_and_tpp/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,17 +831,6 @@ def run_eval(
return results_dict


def setup_environment():
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
if torch.backends.mps.is_available():
device = "mps"
else:
device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Using device: {device}")
return device


def create_config_and_selected_saes(
args,
) -> tuple[ShiftAndTppEvalConfig, list[tuple[str, str]]]:
Expand Down Expand Up @@ -932,7 +921,7 @@ def str_to_bool(value):
--perform_scr true
"""
args = arg_parser().parse_args()
device = setup_environment()
device = formatting_utils.setup_environment()

start_time = time.time()

Expand Down Expand Up @@ -971,7 +960,7 @@ def str_to_bool(value):
# """
# python evals/shift_and_tpp/main.py
# """
# device = setup_environment()
# device = formatting_utils.setup_environment()

# start_time = time.time()

Expand Down
15 changes: 2 additions & 13 deletions evals/sparse_probing/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,17 +353,6 @@ def run_eval(
return results_dict


def setup_environment():
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
if torch.backends.mps.is_available():
device = "mps"
else:
device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Using device: {device}")
return device


def create_config_and_selected_saes(
args,
) -> tuple[SparseProbingEvalConfig, list[tuple[str, str]]]:
Expand Down Expand Up @@ -427,7 +416,7 @@ def arg_parser():
"""
args = arg_parser().parse_args()
device = setup_environment()
device = formatting_utils.setup_environment()

start_time = time.time()

Expand Down Expand Up @@ -465,7 +454,7 @@ def arg_parser():
# """
# python evals/sparse_probing/main.py
# """
# device = setup_environment()
# device = formatting_utils.setup_environment()

# start_time = time.time()

Expand Down
16 changes: 3 additions & 13 deletions evals/unlearning/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
get_saes_from_regex,
select_saes_multiple_patterns,
)
import sae_bench_utils.formatting_utils as formatting_utils

EVAL_TYPE = "unlearning"

Expand Down Expand Up @@ -208,17 +209,6 @@ def run_eval(
return results_dict


def setup_environment():
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
if torch.backends.mps.is_available():
device = "mps"
else:
device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Using device: {device}")
return device


def create_config_and_selected_saes(
args,
) -> tuple[UnlearningEvalConfig, list[tuple[str, str]]]:
Expand Down Expand Up @@ -287,7 +277,7 @@ def arg_parser():
--model_name gemma-2-2b-it
"""
args = arg_parser().parse_args()
device = setup_environment()
device = formatting_utils.setup_environment()

start_time = time.time()

Expand Down Expand Up @@ -323,7 +313,7 @@ def arg_parser():
# """
# python evals/unlearning/main.py
# """
# device = setup_environment()
# device = formatting_utils.setup_environment()

# start_time = time.time()

Expand Down

0 comments on commit 28d2f2f

Please sign in to comment.