Skip to content

Commit 955cb88

Browse files
author
Yang Zonglin Phd CIL
committed
add claude3 experiments
1 parent 4b32cee commit 955cb88

File tree

34 files changed

+146
-53
lines changed

34 files changed

+146
-53
lines changed
-40 Bytes
Binary file not shown.

__pycache__/evaluator.cpython-38.pyc

157 Bytes
Binary file not shown.
-40 Bytes
Binary file not shown.

__pycache__/tomato.cpython-38.pyc

840 Bytes
Binary file not shown.

__pycache__/utils.cpython-38.pyc

50 Bytes
Binary file not shown.

compare_score.py

+40-4
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,9 @@ def read_file_find_score_concat_score(model_name, start_end_id_1, num_CoLM_feedb
103103

104104

105105
def find_hyperparameter_for_display_results(model_name, method_name):
106-
assert method_name == "MOOSE_base" or method_name == "MOOSE" or method_name == "rand_background_baseline" or method_name == "rand_background_rand_inspiration_baseline" or method_name == "rand_background_BM25_inspiration_baseline" or method_name == "gpt35_background_gpt35_inspiration" or method_name == "groundtruth_background_groundtruth_inspiration" or method_name == "MOOSE_wo_ff1" or method_name == "MOOSE_wo_ff2" or method_name == "MOOSE_wo_survey" or method_name == "MOOSE_w_random_corpus"
106+
assert method_name == "MOOSE_base" or method_name == "MOOSE" or method_name == "rand_background_baseline" or method_name == "rand_background_rand_inspiration_baseline" or method_name == "rand_background_BM25_inspiration_baseline" or method_name == "gpt35_background_gpt35_inspiration" or method_name == "groundtruth_background_groundtruth_inspiration" or method_name == "MOOSE_wo_ff1" or method_name == "MOOSE_wo_ff2" or method_name == "MOOSE_wo_survey" or method_name == "MOOSE_w_random_corpus" or method_name == "MOOSE_base_claude" or method_name == "MOOSE_based_with_ff1_and_ff2_claude" or method_name == "MOOSE_claude_onlyindirect2" or method_name == "MOOSE_claude_onlyindirect0" or method_name == "MOOSE_baseline2_claude"
107107

108+
### chatgpt ckpts
108109
## baseline ckpts
109110
ckpt_baseline1_0_50 = "chatgpt_50bkg_0itr_bkgnoter0_indirect0_onlyindirect0_close0_ban1_baseline1_hypEqlInsp_manualTitleSuggester_clearSplit_pastfdbkmodified_hypSuggestor"
110111
ckpt_baseline2_0_50 = "chatgpt_50bkg_0itr_bkgnoter0_indirect0_onlyindirect0_close0_ban1_baseline2_hypEqlInsp_manualTitleSuggester_clearSplit_pastfdbkmodified_hypSuggestor"
@@ -123,6 +124,20 @@ def find_hyperparameter_for_display_results(model_name, method_name):
123124
ckpt_tomato_pf_0_50_without_selfeval_with_hypSuggestor = "chatgpt_50bkg_4itr_bkgnoter0_indirect1_onlyindirect2_close0_ban1_baseline0_survey1_bkgInspPasgSwap0_hypSuggestor1_hypEqlInsp_manualTitleSuggester_clearSplit_pastfdbkmodified_hypSuggestor"
124125
ckpt_tomato_pf_0_50_noSurvey = "chatgpt_50bkg_4itr_bkgnoter0_indirect1_onlyindirect2_close0_ban0_baseline0_survey0_hypEqlInsp_manualTitleSuggester_clearSplit_pastfdbkmodified_hypSuggestor"
125126
ckpt_tomato_pf_0_50_bkg_insp_pasg_swap = "chatgpt_50bkg_4itr_bkgnoter0_indirect1_onlyindirect2_close0_ban0_baseline0_survey1_bkgInspPasgSwap1_hypEqlInsp_manualTitleSuggester_clearSplit_pastfdbkmodified_hypSuggestor"
127+
### claude ckpts
128+
## MOOSE-based ckpts
129+
ckpt_tomato_base_0_5_claude = "claude_5bkg_4itr_bkgnoter0_indirect0_onlyindirect0_close0_ban1_baseline0_survey1_bkgInspPasgSwap0_hypSuggestor0"
130+
ckpt_tomato_base_5_50_claude = "claude_45bkg_4itr_bkgnoter5_indirect0_onlyindirect0_close0_ban1_baseline0_survey1_bkgInspPasgSwap0_hypSuggestor0"
131+
## MOOSE-future ckpts
132+
ckpt_tomato_base_ff1_ff2_0_5_claude = "claude_5bkg_4itr_bkgnoter0_indirect0_onlyindirect0_close0_ban0_baseline0_survey1_bkgInspPasgSwap0_hypSuggestor1"
133+
## MOOSE-future-past-indirect2 ckpts
134+
ckpt_tomato_base_ff1_ff2_past_onlyindirect2_0_5_claude = "claude_5bkg_4itr_bkgnoter0_indirect1_onlyindirect2_close0_ban0_baseline0_survey1_bkgInspPasgSwap0_hypSuggestor1"
135+
## MOOSE-future-past-indirect0 ckpts
136+
ckpt_tomato_base_ff1_ff2_past_onlyindirect0_0_5_claude = "claude_5bkg_4itr_bkgnoter0_indirect1_onlyindirect0_close0_ban0_baseline0_survey1_bkgInspPasgSwap0_hypSuggestor1"
137+
ckpt_tomato_base_ff1_ff2_past_onlyindirect0_5_50_claude = "claude_45bkg_4itr_bkgnoter5_indirect1_onlyindirect0_close0_ban0_baseline0_survey1_bkgInspPasgSwap0_hypSuggestor1"
138+
## MOOSE_baseline2_claude
139+
ckpt_baseline2_0_50_claude = "claude_50bkg_0itr_bkgnoter0_indirect0_onlyindirect0_close0_ban1_baseline2_survey1_bkgInspPasgSwap0_hypSuggestor0"
140+
126141

127142
if method_name == "MOOSE_base":
128143
start_end_id = [[0,5], [5,25], [25,50]]
@@ -168,6 +183,26 @@ def find_hyperparameter_for_display_results(model_name, method_name):
168183
start_end_id = [[0, 50]]
169184
num_CoLM_feedback_times = 4
170185
ckpt_addr_full = [ckpt_tomato_pf_0_50_bkg_insp_pasg_swap]
186+
elif method_name == "MOOSE_base_claude":
187+
start_end_id = [[0, 5], [5, 50]]
188+
num_CoLM_feedback_times = 4
189+
ckpt_addr_full = [ckpt_tomato_base_0_5_claude, ckpt_tomato_base_5_50_claude]
190+
elif method_name == "MOOSE_based_with_ff1_and_ff2_claude":
191+
start_end_id = [[0, 5]]
192+
num_CoLM_feedback_times = 4
193+
ckpt_addr_full = [ckpt_tomato_base_ff1_ff2_0_5_claude]
194+
elif method_name == "MOOSE_claude_onlyindirect2":
195+
start_end_id = [[0, 5]]
196+
num_CoLM_feedback_times = 4
197+
ckpt_addr_full = [ckpt_tomato_base_ff1_ff2_past_onlyindirect2_0_5_claude]
198+
elif method_name == "MOOSE_claude_onlyindirect0":
199+
start_end_id = [[0, 5], [5, 50]]
200+
num_CoLM_feedback_times = 4
201+
ckpt_addr_full = [ckpt_tomato_base_ff1_ff2_past_onlyindirect0_0_5_claude, ckpt_tomato_base_ff1_ff2_past_onlyindirect0_5_50_claude]
202+
elif method_name == "MOOSE_baseline2_claude":
203+
start_end_id = [[0, 50]]
204+
num_CoLM_feedback_times = 0
205+
ckpt_addr_full = [ckpt_baseline2_0_50_claude]
171206
else:
172207
raise NotImplementedError
173208

@@ -179,9 +214,10 @@ def main():
179214
# 'chatgpt' or 'gpt4'
180215
model_name = 'gpt4'
181216
# "MOOSE_base", "rand_background_baseline", "rand_background_rand_inspiration_baseline", "rand_background_BM25_inspiration_baseline", "gpt35_background_gpt35_inspiration", "MOOSE_wo_ff1", "MOOSE_wo_ff2", "MOOSE_wo_survey", "MOOSE_w_random_corpus"
182-
method_name1 = "MOOSE_base"
217+
# "MOOSE_baseline2_claude", "MOOSE_base_claude", "MOOSE_claude_onlyindirect0"
218+
method_name1 = "MOOSE_base_claude"
183219
# "MOOSE"
184-
method_name2 = "MOOSE"
220+
method_name2 = "MOOSE_claude_onlyindirect0"
185221
## load data and find score
186222
start_end_id_1, num_CoLM_feedback_times_1, ckpt_addr1_full = find_hyperparameter_for_display_results(model_name, method_name1)
187223
start_end_id_2, num_CoLM_feedback_times_2, ckpt_addr2_full = find_hyperparameter_for_display_results(model_name, method_name2)
@@ -204,7 +240,7 @@ def main():
204240
print("ave_score2_w_ind: ", ave_score2_w_ind)
205241

206242
# score_all_itrs
207-
if method_name1 == "MOOSE_base" and method_name2 == "MOOSE":
243+
if (method_name1 == "MOOSE_base" and method_name2 == "MOOSE") or (method_name1 == "MOOSE_base_claude" and method_name2 == "MOOSE_claude_onlyindirect0"):
208244
score_all_itrs = np.concatenate((score1_wo_ind_itrs, score2_wo_ind_itrs, score2_w_ind_itrs), axis=1)
209245
print("\nscore_all_itrs: ", score_all_itrs.shape)
210246
ave_score_all_itrs = np.nanmean(score_all_itrs, axis=1)

evaluate_main.py

+4
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ def main():
1212
parser.add_argument("--num_CoLM_feedback_times", type=int, default=1, help="number of re-generation times given new feedbacks for CoLM")
1313
parser.add_argument("--start_id", type=int, default=0, help="To evaluate [start_id : end_id] of the Checkpoint file; -1 when not using it")
1414
parser.add_argument("--end_id", type=int, default=10, help="To evaluate [start_id : end_id] of the Checkpoint file; -1 when not using it")
15+
parser.add_argument("--if_indirect_feedback", type=int, default=1, help="whether conduct indirect feedback modules such as inspiration_changer and background_changer; also can be called --if_past_feedback")
16+
parser.add_argument("--if_only_indirect_feedback", type=int, default=0, help="0: tomato-base will perform; 1: Do NOT perform tomato-base because tomato-base has been performed in this checkpoint (prev data will be load up); 2: Do NOT perform tomato-base, but at least tomato-base + past feedback")
1517
# used for prev_eval_output_dir: ~/Outs/Tomato/gpt4_eval_chatgpt_25bkg_4itr_bkgnoter0_indirect0_onlyindirect0_close0_ban1_hypEqlInsp_manualTitleSuggester_clearSplit_pastfdbkmodified_hypSuggestor_5_25.out
1618
parser.add_argument("--prev_eval_output_dir", type=str, default="", help="In case previous evaluation code has exception, but we don't want to waste money on openai API to re-evaluate the already evaluated hypotheses -- we pick up the previous score from the 'x.out' file")
1719
parser.add_argument("--if_azure_api", type=int, default=0, help="0: Use openai api from openai website; 1: use openai api from azure")
@@ -21,6 +23,8 @@ def main():
2123

2224
assert args.model_name == 'gpt4' or args.model_name == 'chatgpt'
2325
assert args.start_id >= -1 and args.end_id >= -1
26+
assert args.if_indirect_feedback == 1 or args.if_indirect_feedback == 0
27+
assert args.if_only_indirect_feedback == 0 or args.if_only_indirect_feedback == 1 or args.if_only_indirect_feedback == 2
2428
assert args.if_azure_api == 0 or args.if_azure_api == 1
2529
assert args.if_groundtruth_hypotheses == 0 or args.if_groundtruth_hypotheses == 1
2630
if args.start_id == -1 or args.end_id == -1:

evaluate_main.sh

+5-4
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@
33
#SBATCH --partition=DGXq
44
#SBATCH -w node19
55
#SBATCH --gres=gpu:1
6-
#SBATCH --output /export/home/zonglin001/Outs/Tomato/gpt4_eval_chatgpt_50bkg_4itr_bkgnoter0_indirect1_onlyindirect2_close0_ban0_baseline0_survey1_bkgInspPasgSwap0_hypSuggestor0_hypEqlInsp_manualTitleSuggester_clearSplit_pastfdbkmodified_hypSuggestor.out
6+
#SBATCH --output /export/home/zonglin001/Outs/Tomato/gpt4_eval_claude_45bkg_4itr_bkgnoter5_indirect1_onlyindirect0_close0_ban0_baseline0_survey1_bkgInspPasgSwap0_hypSuggestor1.out
77

88
# chatgpt / gpt4
99
python -u evaluate_main.py --if_groundtruth_hypotheses 0 \
1010
--model_name gpt4 --num_CoLM_feedback_times 4 \
11-
--start_id 0 --end_id 50 \
12-
--if_azure_api 0 \
13-
--output_dir ~/Checkpoints/Tomato/chatgpt_50bkg_4itr_bkgnoter0_indirect1_onlyindirect2_close0_ban0_baseline0_survey1_bkgInspPasgSwap0_hypSuggestor0_hypEqlInsp_manualTitleSuggester_clearSplit_pastfdbkmodified_hypSuggestor \
11+
--if_indirect_feedback 1 --if_only_indirect_feedback 0 \
12+
--start_id 5 --end_id 50 \
13+
--if_azure_api 1 \
14+
--output_dir ~/Checkpoints/Tomato/claude_45bkg_4itr_bkgnoter5_indirect1_onlyindirect0_close0_ban0_baseline0_survey1_bkgInspPasgSwap0_hypSuggestor1 \
1415
--api_key sk-

evaluator.py

+39-11
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os, time, re
22
import torch
33
import openai
4+
from openai import AzureOpenAI
45
import transformers
56
from transformers import GPT2LMHeadModel, GPT2Tokenizer
67
import numpy as np
@@ -17,10 +18,15 @@ def __init__(self, args):
1718
if args.if_azure_api == 0:
1819
openai.api_key = self.args.api_key
1920
else:
20-
openai.api_type = ""
21-
openai.api_base = ""
22-
openai.api_version = ""
23-
openai.api_key = self.args.api_key
21+
# openai.api_type = ""
22+
# openai.api_base = ""
23+
# openai.api_version = "2024-02-15-preview"
24+
# openai.api_key = self.args.api_key
25+
self.client = AzureOpenAI(
26+
azure_endpoint = "https://declaregpt4.openai.azure.com/",
27+
api_key=self.args.api_key,
28+
api_version="2024-02-15-preview"
29+
)
2430
assert openai.api_key != ""
2531
# self.hypotheses is a sub-element of self.result
2632
self.result = None
@@ -95,6 +101,19 @@ def evaluate(self):
95101
score_reasons = {}
96102
cnt_finished = 0
97103
if self.args.if_groundtruth_hypotheses == 0:
104+
# num_chunks_with_and_without_past_feedback_per_bkg
105+
if self.args.if_indirect_feedback == 0:
106+
num_chunks_with_and_without_past_feedback_per_bkg = 1
107+
elif self.args.if_indirect_feedback == 1:
108+
if self.args.if_only_indirect_feedback == 0 or self.args.if_only_indirect_feedback == 1:
109+
num_chunks_with_and_without_past_feedback_per_bkg = 2
110+
elif self.args.if_only_indirect_feedback == 2:
111+
num_chunks_with_and_without_past_feedback_per_bkg = 1
112+
else:
113+
raise NotImplementedError
114+
else:
115+
raise NotImplementedError
116+
# start looping
98117
for cur_id_bkg, cur_bkg_ori in enumerate(self.background):
99118
if cur_bkg_ori not in scores:
100119
cur_bkg = cur_bkg_ori
@@ -103,16 +122,17 @@ def evaluate(self):
103122
score_reasons[cur_bkg] = []
104123
cur_bkg = cur_bkg_ori
105124
cur_hyp_for_cur_bkg = self.hypotheses[cur_bkg_ori]
106-
# in case a bkg has more than one data item in our dataset
107-
if len(cur_hyp_for_cur_bkg) > 1:
108-
cur_hyp_for_cur_bkg = cur_hyp_for_cur_bkg[:1]
125+
# in case a bkg has more than one data item (annotated publication) in our dataset
126+
if len(cur_hyp_for_cur_bkg) > 1*num_chunks_with_and_without_past_feedback_per_bkg:
127+
cur_hyp_for_cur_bkg = cur_hyp_for_cur_bkg[:1*num_chunks_with_and_without_past_feedback_per_bkg]
109128
else:
110129
# raise Exception("repeated key in scores: {}; cur_bkg: {}".format(scores, cur_bkg))
130+
assert len(self.hypotheses[cur_bkg_ori]) == 2*num_chunks_with_and_without_past_feedback_per_bkg
111131
cur_bkg = cur_bkg_ori + " "
112132
assert cur_bkg not in score_reasons
113133
scores[cur_bkg] = []
114134
score_reasons[cur_bkg] = []
115-
cur_hyp_for_cur_bkg = self.hypotheses[cur_bkg_ori][1:2]
135+
cur_hyp_for_cur_bkg = self.hypotheses[cur_bkg_ori][1*num_chunks_with_and_without_past_feedback_per_bkg:2*num_chunks_with_and_without_past_feedback_per_bkg]
116136
if cur_id_bkg == 0:
117137
print("len(cur_hyp_for_cur_bkg): ", len(cur_hyp_for_cur_bkg))
118138
for cur_id_hyp_direct_or_indirect , cur_hyp_direct_or_indirect in enumerate(cur_hyp_for_cur_bkg):
@@ -225,13 +245,21 @@ def llm_generation(self, input_txt):
225245
reply = response["choices"][0]['message']['content']
226246
if_api_completed = True
227247
else:
228-
response = openai.ChatCompletion.create(
229-
engine=api_model_name,
248+
# response = openai.ChatCompletion.create(
249+
# engine=api_model_name,
250+
# messages=[{"role": "user", "content": input_txt}],
251+
# top_p=0.90,
252+
# temperature=temperature,
253+
# max_tokens=max_tokens)
254+
# reply = response["choices"][0]['message']['content']
255+
# if_api_completed = True
256+
response = self.client.chat.completions.create(
257+
model=api_model_name,
230258
messages=[{"role": "user", "content": input_txt}],
231259
top_p=0.90,
232260
temperature=temperature,
233261
max_tokens=max_tokens)
234-
reply = response["choices"][0]['message']['content']
262+
reply = response.choices[0].message.content
235263
if_api_completed = True
236264
except:
237265
print("OpenAI reach its rate limit")

main.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
def main():
99
parser = argparse.ArgumentParser()
1010
parser.add_argument("--model_name", type=str, default="vicuna",
11-
help="model name: gpt2/llama/vicuna/vicuna13/chatgpt/falcon")
11+
help="model name: gpt2/llama/vicuna/vicuna13/chatgpt/falcon/claude")
1212
parser.add_argument("--root_data_dir", type=str, default="./Data/")
1313
parser.add_argument("--survey_data_dir", type=str, default="./Data/Surveys/")
1414
parser.add_argument("--output_dir", type=str, default="~/Checkpoints/Tomato/try")
@@ -28,7 +28,7 @@ def main():
2828
args = parser.parse_args()
2929

3030
# check hyper-parameters
31-
assert args.model_name == 'llama' or args.model_name == 'vicuna' or args.model_name == 'vicuna13' or args.model_name == 'gpt2' or args.model_name == 'chatgpt' or args.model_name == 'falcon'
31+
assert args.model_name == 'llama' or args.model_name == 'vicuna' or args.model_name == 'vicuna13' or args.model_name == 'gpt2' or args.model_name == 'chatgpt' or args.model_name == 'falcon' or args.model_name == "claude"
3232
assert args.if_indirect_feedback == 1 or args.if_indirect_feedback == 0
3333
assert args.if_only_indirect_feedback == 0 or args.if_only_indirect_feedback == 1 or args.if_only_indirect_feedback == 2
3434
assert args.if_close_domain == 1 or args.if_close_domain == 0
@@ -59,7 +59,7 @@ def main():
5959
# check gpu
6060
n_gpu = torch.cuda.device_count()
6161
print("n_gpu: ", n_gpu)
62-
if not args.model_name == 'chatgpt':
62+
if not (args.model_name == 'chatgpt' or args.model_name == "claude"):
6363
print_nvidia_smi()
6464
assert n_gpu >= 1
6565

main.sh

+7-7
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
11
#!/bin/bash
2-
#SBATCH -J Tomato
2+
#SBATCH -J baseline
33
#SBATCH --partition=DGXq
44
#SBATCH -w node18
55
#SBATCH --gres=gpu:1
6-
#SBATCH --output /export/home/zonglin001/Outs/Tomato/chatgpt_50bkg_4itr_bkgnoter0_indirect0_onlyindirect0_close0_ban1_baseline0_survey1_bkgInspPasgSwap0_hypSuggestor0_hypEqlInsp_manualTitleSuggester_clearSplit_pastfdbkmodified_hypSuggestor.out
6+
#SBATCH --output /export/home/zonglin001/Outs/Tomato/claude_50bkg_0itr_bkgnoter0_indirect0_onlyindirect0_close0_ban1_baseline2_survey1_bkgInspPasgSwap0_hypSuggestor0.out
77

88

9-
# vicuna / gpt2 / chatgpt / vicuna13 / falcon
10-
python -u main.py --model_name chatgpt \
11-
--output_dir ~/Checkpoints/Tomato/chatgpt_50bkg_4itr_bkgnoter0_indirect0_onlyindirect0_close0_ban1_baseline0_survey1_bkgInspPasgSwap0_hypSuggestor0_hypEqlInsp_manualTitleSuggester_clearSplit_pastfdbkmodified_hypSuggestor \
12-
--num_background_for_hypotheses 50 --num_CoLM_feedback_times 4 --bkg_corpus_chunk_noter 0 \
9+
# vicuna / gpt2 / chatgpt / vicuna13 / falcon / claude
10+
python -u main.py --model_name claude \
11+
--output_dir ~/Checkpoints/Tomato/claude_50bkg_0itr_bkgnoter0_indirect0_onlyindirect0_close0_ban1_baseline2_survey1_bkgInspPasgSwap0_hypSuggestor0 \
12+
--num_background_for_hypotheses 50 --num_CoLM_feedback_times 0 --bkg_corpus_chunk_noter 0 \
1313
--if_indirect_feedback 0 --if_only_indirect_feedback 0 \
1414
--if_close_domain 0 --if_ban_selfeval 1 \
15-
--if_baseline 0 \
15+
--if_baseline 2 \
1616
--if_novelty_module_have_access_to_surveys 1 --if_insp_pasg_for_bkg_and_bkg_pasg_included_in_insp 0 \
1717
--if_hypothesis_suggstor 0 \
1818
--api_key sk-

0 commit comments

Comments
 (0)