Skip to content

Commit cb00506

Browse files
Update run.py
1 parent 964f9f6 commit cb00506

File tree

1 file changed

+11
-8
lines changed

1 file changed

+11
-8
lines changed

run.py

+11-8
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,15 @@
99
if __name__ == "__main__":
1010

1111
parser = argparse.ArgumentParser(description="") #parser is an object of the class Argument Parser.
12-
parser.add_argument("--dataset_name", type=str, default="sentence-transformers/wikipedia-en-sentences", required=False)
13-
#parser.add_argument("--dataset_name", type=str, default="bios", required=False)
12+
#parser.add_argument("--dataset_name", type=str, default="sentence-transformers/wikipedia-en-sentences", required=False)
13+
parser.add_argument("--dataset_name", type=str, default="bios", required=False)
1414
parser.add_argument("--bios_zs_to_keep", type=list, default=[1], required=False)
1515
parser.add_argument("--bios_ys_to_keep", type=list, default=["professor"], required=False)
1616
parser.add_argument("--num_sents", type=int, default=500,required=False)
1717
parser.add_argument("--prompt", type=str, default="first_k",required=False)
18-
parser.add_argument("--prompt_first_k", type=int, default=5,required=False)
19-
parser.add_argument("--max_new_tokens", type=int, default=25,required=False)
18+
parser.add_argument("--prompt_first_k", type=int, default=7,required=False)
19+
parser.add_argument("--max_new_tokens", type=int, default=40,required=False)
20+
parser.add_argument("--num_counterfactuals", type=int, default=1,required=False)
2021
parser.add_argument("--models", type=list, default=[
2122
("openai-community/gpt2-xl", "mimic_gender_gpt2_instruct"),
2223
("meta-llama/Meta-Llama-3-8B-Instruct", "mimic_gender_llama3_instruct"),
@@ -55,17 +56,19 @@
5556
prompt = tokenzier.bos_token
5657

5758
# get counterfactual
58-
59+
counterfactuals = []
5960
original_continuation_tokens, original_continuation = utils.get_continuation(original_model, tokenizer, prompt, max_new_tokens=args.max_new_tokens, return_only_continuation=True,num_beams=1, do_sample=True, token_healing=True)
60-
count_tokens, count_text = utils.get_counterfactual_output(counterfactual_model, original_model, tokenizer, prompt, original_continuation, args.max_new_tokens)
61-
all_outputs.append({"tokens": count_tokens, "text": count_text})
61+
for l in range(args.num_counterfactuals):
62+
count_tokens, count_text = utils.get_counterfactual_output(counterfactual_model, original_model, tokenizer, prompt, original_continuation, args.max_new_tokens)
63+
counterfactuals.append({"tokens": count_tokens, "text": count_text})
64+
all_outputs.append(counterfactuals)
6265
orig_str = prompt.replace(tokenizer.bos_token,"")+original_continuation
6366
orig_str_tokens = tokenizer.encode(orig_str, return_tensors="pt", add_special_tokens=False).detach().cpu().numpy()[0]
6467
all_sents.append({"tokens": orig_str_tokens, "text": orig_str})
6568

6669
print("Original: {}\n--------------------\nCounterfactual: {}".format(orig_str, count_text))
6770
print("==================================")
6871
dataset_name = "wiki" if "wiki" in args.dataset_name else "bios"
69-
fname = f"counterfactuals2/{dataset_name}_{orig.split("/")[1]}->{intervention_type}_prompt:{args.prompt}_sents:{args.num_sents}_prompt_first_k:{args.prompt_first_k}_max_new_tokens:{args.max_new_tokens}.pkl"
72+
fname = f"counterfactuals/{dataset_name}_{orig.split("/")[1]}->{intervention_type}_prompt:{args.prompt}_sents:{args.num_sents}_prompt_first_k:{args.prompt_first_k}_max_new_tokens:{args.max_new_tokens}.pkl"
7073
with open(fname, "wb") as f:
7174
pickle.dump({"original": all_sents, "counter": all_outputs}, f)

0 commit comments

Comments
 (0)