|
9 | 9 | if __name__ == "__main__":
|
10 | 10 |
|
11 | 11 | parser = argparse.ArgumentParser(description="") #parser is an object of the class Argument Parser.
|
12 |
| - parser.add_argument("--dataset_name", type=str, default="sentence-transformers/wikipedia-en-sentences", required=False) |
13 |
| - #parser.add_argument("--dataset_name", type=str, default="bios", required=False) |
| 12 | + #parser.add_argument("--dataset_name", type=str, default="sentence-transformers/wikipedia-en-sentences", required=False) |
| 13 | + parser.add_argument("--dataset_name", type=str, default="bios", required=False) |
14 | 14 | parser.add_argument("--bios_zs_to_keep", type=list, default=[1], required=False)
|
15 | 15 | parser.add_argument("--bios_ys_to_keep", type=list, default=["professor"], required=False)
|
16 | 16 | parser.add_argument("--num_sents", type=int, default=500,required=False)
|
17 | 17 | parser.add_argument("--prompt", type=str, default="first_k",required=False)
|
18 |
| - parser.add_argument("--prompt_first_k", type=int, default=5,required=False) |
19 |
| - parser.add_argument("--max_new_tokens", type=int, default=25,required=False) |
| 18 | + parser.add_argument("--prompt_first_k", type=int, default=7,required=False) |
| 19 | + parser.add_argument("--max_new_tokens", type=int, default=40,required=False) |
| 20 | + parser.add_argument("--num_counterfactuals", type=int, default=1,required=False) |
20 | 21 | parser.add_argument("--models", type=list, default=[
|
21 | 22 | ("openai-community/gpt2-xl", "mimic_gender_gpt2_instruct"),
|
22 | 23 | ("meta-llama/Meta-Llama-3-8B-Instruct", "mimic_gender_llama3_instruct"),
|
|
55 | 56 | prompt = tokenzier.bos_token
|
56 | 57 |
|
57 | 58 | # get counterfactual
|
58 |
| - |
| 59 | + counterfactuals = [] |
59 | 60 | original_continuation_tokens, original_continuation = utils.get_continuation(original_model, tokenizer, prompt, max_new_tokens=args.max_new_tokens, return_only_continuation=True,num_beams=1, do_sample=True, token_healing=True)
|
60 |
| - count_tokens, count_text = utils.get_counterfactual_output(counterfactual_model, original_model, tokenizer, prompt, original_continuation, args.max_new_tokens) |
61 |
| - all_outputs.append({"tokens": count_tokens, "text": count_text}) |
| 61 | + for l in range(args.num_counterfactuals): |
| 62 | + count_tokens, count_text = utils.get_counterfactual_output(counterfactual_model, original_model, tokenizer, prompt, original_continuation, args.max_new_tokens) |
| 63 | + counterfactuals.append({"tokens": count_tokens, "text": count_text}) |
| 64 | + all_outputs.append(counterfactuals) |
62 | 65 | orig_str = prompt.replace(tokenizer.bos_token,"")+original_continuation
|
63 | 66 | orig_str_tokens = tokenizer.encode(orig_str, return_tensors="pt", add_special_tokens=False).detach().cpu().numpy()[0]
|
64 | 67 | all_sents.append({"tokens": orig_str_tokens, "text": orig_str})
|
65 | 68 |
|
66 | 69 | print("Original: {}\n--------------------\nCounterfactual: {}".format(orig_str, count_text))
|
67 | 70 | print("==================================")
|
68 | 71 | dataset_name = "wiki" if "wiki" in args.dataset_name else "bios"
|
69 |
| - fname = f"counterfactuals2/{dataset_name}_{orig.split("/")[1]}->{intervention_type}_prompt:{args.prompt}_sents:{args.num_sents}_prompt_first_k:{args.prompt_first_k}_max_new_tokens:{args.max_new_tokens}.pkl" |
| 72 | + fname = f"counterfactuals/{dataset_name}_{orig.split("/")[1]}->{intervention_type}_prompt:{args.prompt}_sents:{args.num_sents}_prompt_first_k:{args.prompt_first_k}_max_new_tokens:{args.max_new_tokens}.pkl" |
70 | 73 | with open(fname, "wb") as f:
|
71 | 74 | pickle.dump({"original": all_sents, "counter": all_outputs}, f)
|
0 commit comments