Skip to content

Commit 6dc57d6

Browse files
Add files via upload
1 parent 9ce0165 commit 6dc57d6

File tree

1 file changed

+121
-0
lines changed

1 file changed

+121
-0
lines changed

analyze_edit.py

+121
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
import pickle
2+
import matplotlib.pyplot as plt
3+
import numpy as np
4+
5+
models = ["counterfactuals2/wiki_Meta-Llama-3-8B-Instruct->mimic_gender_llama3_instruct_prompt:first_k_sents:500_prompt_first_k:5_max_new_tokens:25.pkl",
6+
"counterfactuals2/wiki_Meta-Llama-3-8B-Instruct->honest_steering_llama3_instruct_prompt:first_k_sents:500_prompt_first_k:5_max_new_tokens:25.pkl",
7+
"counterfactuals2/wiki_Meta-Llama-3-8B->chat_llama3_prompt:first_k_sents:500_prompt_first_k:5_max_new_tokens:25.pkl",
8+
"counterfactuals2/wiki_gpt2-xl->mimic_gender_gpt2_instruct_prompt:first_k_sents:500_prompt_first_k:5_max_new_tokens:25.pkl",
9+
"counterfactuals2/wiki_gpt2-xl->GPT2-memit-louvre-rome_prompt:first_k_sents:500_prompt_first_k:5_max_new_tokens:25.pkl",
10+
"counterfactuals2/wiki_gpt2-xl->GPT2-memit-koalas-new_zealand_prompt:first_k_sents:500_prompt_first_k:5_max_new_tokens:25.pkl"]
11+
names = ["LLaMA3-Steering-Gender", "LLaMA3-Steering-Honest", "LLaMA3-Instruct", "GPT2-XL-Steering-Gender", "GPT2-XL-MEMIT-Louvre", "GPT2-XL-MEMIT-Koalas"] #["Honest-LLama", "GPT-XL-ROME", "LLama2-Chat", "GPT2-XL-MEMIT"
12+
13+
models = ["counterfactuals/wiki_Meta-Llama-3-8B-Instruct->mimic_gender_llama3_instruct_prompt:first_k_sents:500_prompt_first_k:5_max_new_tokens:25.pkl",
14+
"counterfactuals/wiki_gpt2-xl->mimic_gender_gpt2_prompt:first_k_sents:500_prompt_first_k:5_max_new_tokens:25.pkl",
15+
"counterfactuals/wiki_Meta-Llama-3-8B-Instruct->honest_steering_llama3_instruct_prompt:first_k_sents:500_prompt_first_k:5_max_new_tokens:25.pkl",
16+
"counterfactuals/wiki_Meta-Llama-3-8B->chat_llama3_prompt:first_k_sents:500_prompt_first_k:5_max_new_tokens:25.pkl",
17+
"counterfactuals/wiki_gpt2-xl->GPT2-memit-koalas-new_zealand_prompt:first_k_sents:500_prompt_first_k:5_max_new_tokens:25.pkl",
18+
"counterfactuals/wiki_gpt2-xl->GPT2-memit-louvre-rome_prompt:first_k_sents:500_prompt_first_k:5_max_new_tokens:25.pkl"]
19+
names = ["LLaMA3-Steering-Gender", "GPT2-XL-Steering-Gender", "LLaMA3-Steering-Honest", "LLaMA3-Instruct", "GPT2-XL-MEMIT-Koalas", "GPT2-XL-MEMIT-Louvre"]
20+
21+
#models = ["counterfactuals3/wiki_Meta-Llama-3-8B-Instruct->mimic_gender_llama3_instruct_prompt:first_k_sents:50_prompt_first_k:5_max_new_tokens:25.pkl"]
22+
#names = ["LLaMA3-Steering-Gender"]
23+
24+
name2data = {}
25+
colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', "cyan", "purple", "red"] #['blue', 'orange', 'green', "red", "cyan"] # Define colors for consistency
26+
27+
plt.rcParams["font.family"] = "serif"
28+
plt.rcParams.update({'font.size': 15})
29+
plt.figure(figsize=(8, 6))
30+
31+
32+
def levenshteinDistance(s1, s2):
33+
if len(s1) > len(s2):
34+
s1, s2 = s2, s1
35+
36+
distances = range(len(s1) + 1)
37+
for i2, c2 in enumerate(s2):
38+
distances_ = [i2+1]
39+
for i1, c1 in enumerate(s1):
40+
if c1 == c2:
41+
distances_.append(distances[i1])
42+
else:
43+
distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1])))
44+
distances = distances_
45+
return distances[-1]
46+
47+
EDIT_DISTANCE=True
48+
49+
print(names, models)
50+
for idx, (name, model) in enumerate(zip(names, models)):
51+
print(name)
52+
with open(model, "rb") as f:
53+
data = pickle.load(f)
54+
orig, count = data["original"], data["counter"]
55+
#counter = [d["counter"] for d in counter]
56+
#orig = [o.split(" ") for o in original]
57+
#count = [c.split(" ") for c in counter]
58+
59+
orig = [d["text"] for d in orig]
60+
print(count[0])
61+
count = [d["text"] for d in count]
62+
name2data[name] = (orig, count)
63+
64+
65+
diffs=[]
66+
#print(len(orig), len(count))
67+
68+
for o,c in zip(orig, count):
69+
70+
#print(o,c)
71+
if EDIT_DISTANCE:
72+
diffs.append(levenshteinDistance(o,c)/len(c))
73+
else:
74+
i=0
75+
for oo,cc in zip(o,c):
76+
#print("try", cc,oo)
77+
if cc != oo:
78+
#print(i, len(oo))
79+
diffs.append(i/len(o))
80+
break
81+
i+=1
82+
#print(diffs)
83+
84+
85+
plt.hist(
86+
diffs,
87+
density=False,
88+
bins=15,
89+
alpha=0.5,
90+
label=name,
91+
color=colors[idx]
92+
)
93+
94+
# Calculate and plot median
95+
median_diff = np.median(diffs)
96+
mean_diff = np.mean(diffs)
97+
print(np.mean(diffs))
98+
plt.axvline(
99+
mean_diff,
100+
color=colors[idx],
101+
linestyle='dashed',
102+
linewidth=2
103+
)
104+
"""
105+
plt.text(
106+
median_diff,
107+
plt.ylim()[1]*0.9 - idx*plt.ylim()[1]*0.08, # Adjust y-position for each label
108+
f'Median {name}: {median_diff:.2f}',
109+
rotation=0,
110+
color=colors[idx],
111+
verticalalignment='top',
112+
horizontalalignment='center',
113+
fontsize=20, # Increase font size of median labels
114+
bbox=dict(facecolor='white', alpha=0.5, edgecolor='none')
115+
)
116+
"""
117+
plt.xlabel("Edit Distance (characters)", fontsize=14)
118+
plt.ylabel("Counts", fontsize=14)
119+
plt.grid()
120+
plt.legend(fontsize=13)
121+
plt.savefig("edit_distance_new.pdf", dpi=800)

0 commit comments

Comments
 (0)