Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hackathon 5th No.73] ToT #7660

Merged
merged 54 commits into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
522a7af
Hackathon TASK73 ToT
ErnestinaQiu Dec 14, 2023
a68dacb
update readme tutorial
ErnestinaQiu Dec 14, 2023
c65870b
modify according to Lint
ErnestinaQiu Dec 14, 2023
315cc3e
modify according Link
ErnestinaQiu Dec 14, 2023
b001719
Delete LICENSE
ErnestinaQiu Dec 14, 2023
8932aca
Update LICENSE
ErnestinaQiu Dec 14, 2023
da288af
black format
ErnestinaQiu Dec 14, 2023
916d70c
isort format
ErnestinaQiu Dec 14, 2023
e64c51e
Update search_crosswords-dfs.ipynb
ErnestinaQiu Dec 14, 2023
ef5cfa6
update files formats
ErnestinaQiu Dec 14, 2023
96a6d35
Update LICENSE
ErnestinaQiu Dec 14, 2023
6c95517
Update LICENSE
ErnestinaQiu Dec 14, 2023
1728255
Update LICENSE
ErnestinaQiu Dec 14, 2023
bd35dde
Update LICENSE
ErnestinaQiu Dec 14, 2023
f5ff4df
delete test data
ErnestinaQiu Dec 15, 2023
5621975
delete some unnecessary files
ErnestinaQiu Dec 21, 2023
e7d3ba6
add paddlenlp-llama2
ErnestinaQiu Dec 21, 2023
84ee4d0
fix one bug
ErnestinaQiu Dec 22, 2023
effc87b
fix outputs bug
ErnestinaQiu Dec 22, 2023
c514ca9
delete meta/llama2
ErnestinaQiu Dec 22, 2023
402fa97
modify according to comments
ErnestinaQiu Dec 22, 2023
ae8c242
change according to comments
ErnestinaQiu Dec 22, 2023
c7979ed
Delete .gitignore
ErnestinaQiu Dec 22, 2023
c8f79e2
Create .gitignore
ErnestinaQiu Dec 22, 2023
994286c
Move directory
Dec 22, 2023
1f9499a
Add tree of thoughts scripts
Dec 22, 2023
065a1e9
add first dir
ErnestinaQiu Dec 22, 2023
3fd243d
Merge branch 'develop' of https://github.com/ErnestinaQiu/tot into de…
ErnestinaQiu Dec 22, 2023
24982e4
add note
ErnestinaQiu Dec 22, 2023
cebe49e
Update README.md
ErnestinaQiu Jan 25, 2024
987117e
Update requirements.txt
ErnestinaQiu Jan 25, 2024
c74e478
Update demo.py
ErnestinaQiu Jan 25, 2024
26179dc
Update .gitignore
ErnestinaQiu Jan 25, 2024
e1fdd67
Update run.py
ErnestinaQiu Jan 25, 2024
e793463
Update __init__.py
ErnestinaQiu Jan 25, 2024
5e4dcf1
chat templates
ErnestinaQiu Jan 25, 2024
671c6b8
add Ernie
ErnestinaQiu Jan 25, 2024
1face8b
Update llama.py
ErnestinaQiu Jan 25, 2024
e9b2100
Update bfs.py
ErnestinaQiu Jan 25, 2024
94d7a82
Update models.py
ErnestinaQiu Jan 25, 2024
4472dd6
Update run.py
ErnestinaQiu Jan 25, 2024
780ed41
format style
ErnestinaQiu Jan 25, 2024
8e90744
format style
ErnestinaQiu Jan 25, 2024
8fc80a3
format style
ErnestinaQiu Jan 25, 2024
e53ad12
format style
ErnestinaQiu Jan 25, 2024
c29e61c
format style
ErnestinaQiu Jan 25, 2024
1e9c384
format style
ErnestinaQiu Jan 25, 2024
94153a9
format style
ErnestinaQiu Jan 25, 2024
0592931
format style
ErnestinaQiu Jan 25, 2024
b1a65a9
删掉重复的“测试结果”
ErnestinaQiu Jan 26, 2024
fac1c02
删除Ernie的token,设置环境变量解决
ErnestinaQiu Jan 26, 2024
fc122c3
format style
ErnestinaQiu Jan 26, 2024
df25595
format style
ErnestinaQiu Jan 26, 2024
34c0953
删除注释掉的代码
ErnestinaQiu Jan 26, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -123,4 +123,4 @@ FETCH_HEAD

# vscode
.vscode
./ppdiffusers/ppdiffusers/version.py
./ppdiffusers/ppdiffusers/version.py
195 changes: 195 additions & 0 deletions pipelines/examples/tree-of-thought/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
# Tree of Thoughts (ToT)

![teaser](https://github.com/PaddlePaddle/PaddleNLP/assets/48557439/30f9e365-398a-4822-b3c2-a0768f70e310)

论文[Tree of Thoughts: Deliberate Problem Solving with Large Language Models](https://arxiv.org/abs/2305.10601) 的代码 prompts 和 model outputs 实现。


## Setup
1. 安装
```bash
git clone git@github.com:PaddlePaddle/PaddleNLP.git
cd pipelines/examples/tree-of-thought/
pip install -r requirements.txt
```

2. 请从 https://github.com/ErnestinaQiu/tree-of-thought-llm/tree/master/src/tot/data 获取测试数据,并放置在 pipelines/examples/tree-of-thought/tree/master/src/tot/data

## Quick Start
以下是脚本,该脚本尝试使用4 5 6 10解决24点游戏(由于使用llama-7b-chat,可能会稍慢一些)


在目录 pipelines/examples/agents/tree-of-thought-llm 下运行

```
python demo.py
```

以下是文档的中文翻译:

```python
import argparse
from tot.methods.bfs import solve
from tot.tasks.game24 import Game24Task

args = argparse.Namespace(backend='llama-2-7b-chat', temperature=0.6, task='game24', naive_run=False, prompt_sample=None, method_generate='propose', method_evaluate='value', method_select='greedy', n_generate_sample=1, n_evaluate_sample=3, n_select_sample=5)

task = Game24Task()
ys, infos = solve(args, task, 900)
print(ys[0])
```

输出结果可能如下(注意它不是确定性的,有时输出可能是错误的):
```
10 - 4 = 6 (left: 5 6 6)
5 * 6 = 30 (left: 6 30)
30 - 6 = 24 (left: 24)
Answer: (5 * (10 - 4)) - 6 = 24
```

## 论文实验

通过 ``sh scripts/{game24, text, crosswords}/{standard_sampling, cot_sampling, bfs}.sh`` 运行实验。

非常简单的 ``run.py`` 实现了 ToT + BFS 算法,以及朴素的 IO/CoT 抽样。一些关键参数:

- ``--naive_run``: 如果为 True,则运行朴素的 IO/CoT 抽样,而不是 ToT + BFS。
- ``--prompt_sample`` (choices=[``standard``, ``cot``]): 抽样提示
- ``--method_generate`` (choices=[``sample``, ``propose``]): 思维生成器,是抽样独立思维(用于创意写作)还是提出连续思维(用于24点游戏)
- ``--method_evaluate`` (choices=[``value``, ``vote``]): 状态评估器,是独立使用价值状态(用于24点游戏)还是对状态进行投票(用于创意写作)
- ``--n_generate_sample``: 提示进行思维生成的次数
- ``--n_evaluate_sample``: 提示进行状态评估的次数
- ``--n_select_sample``: 每一步保留的状态数量(即论文中的 ``b`` 在 ToT + BFS 算法中)

## 论文轨迹

``logs/`` 包含论文实验的所有轨迹,除了 ``logs/game24/gpt-4_0.7_propose1_value3_greedy5_start900_end1000.json``,该文件是在论文之后重新生成的(因为原始实验是在笔记本中进行的),由于 GPT 解码中的随机性,得分从原来的 74\% 下降到了 69\%。我们希望将来汇总多次运行以考虑抽样随机性,并更新论文,但这不应影响论文的主要结论。

## 论文实验的任务脚本
### crosswords(填字游戏)
```
python run.py \
--task crosswords \ # 任务名:填字游戏
--task_start_index 0 \ # 填字游戏任务数据集中开始的序号
--task_end_index 20 \ # 填字游戏任务数据集中结束的序号
--naive_run \
--prompt_sample cot \ # 抽样提示的方式, cot
--n_generate_sample 10 # 提示进行思维生成的次数, 10次
```

```
python run.py \
--task crosswords \
--task_start_index 0 \
--task_end_index 20 \
--naive_run \ # 运行朴素的 IO/CoT 抽样
--prompt_sample standard \ # 抽样提示的方式, standard
--n_generate_sample 10
```

### game24(24点游戏)
```
python run.py \
--task game24 \ # 任务名:24点游戏
--task_start_index 900 \ # 24点游戏任务数据集中开始的序号
--task_end_index 1000 \ # 24点游戏任务数据集中结束的序号
--method_generate propose \ # 思维生成器,是抽样独立思维(用于创意写作)还是提出连续思维(用于24点游戏)
--method_evaluate value \ # 状态评估器,独立使用价值状态(用于24点游戏)
--method_select greedy \ # 策略选择,"greedy"(贪婪)
--n_evaluate_sample 3 \ # 提示进行状态评估的次数
--n_select_sample 5 \ # 每一步保留的状态数量(即论文中的 ``b`` 在 ToT + BFS 算法中)
```

```
python run.py \
--task game24 \
--task_start_index 900 \
--task_end_index 1000 \
--naive_run \ # 运行朴素的 IO/CoT 抽样
--prompt_sample cot \ # 抽样提示的方式, cot
--n_generate_sample 100 \
```

```
python run.py \
--task game24 \
--task_start_index 900 \
--task_end_index 1000 \
--naive_run \
--prompt_sample standard \
--n_generate_sample 100 \
```

### text(创意写作)
```
python run.py \
--task text \ # 任务名:创意写作
--task_start_index 0 \ # 创意写作任务数据集中开始的序号
--task_end_index 100 \ # 创意写作任务数据集中结束的序号
--method_generate sample \ # 思维生成器,是抽样独立思维(用于创意写作)还是提出连续思维(用于24点游戏)
--method_evaluate vote \ # 状态评估器,对状态进行投票(用于创意写作)
--method_select greedy \ # 策略选择,"sample"(举例)
--n_generate_sample 5 \ # 提示进行思维生成的次数
--n_evaluate_sample 5 \ # 提示进行状态评估的次数
--n_select_sample 1 \ # 每一步保留的状态数量(即论文中的 ``b`` 在 ToT + BFS 算法中)
--prompt_sample cot \
--temperature 1.0 \
```

```
python run.py \
--task text \
--task_start_index 0 \
--task_end_index 100 \
--naive_run \ # 运行朴素的 IO/CoT 抽样
--prompt_sample cot \ # 抽样提示的方式, cot
--n_generate_sample 10 \
--temperature 1.0 \
```

```
python run.py \
--task text \
--task_start_index 0 \
--task_end_index 100 \
--naive_run \ # 运行朴素的 IO/CoT 抽样
--prompt_sample standard \ # 抽样提示的方式, standard
--n_generate_sample 10 \
--temperature 1.0 \
```

## 测试结果
本测试采用的是paddlenlp中facebook/llama-2-7b-chat 和 facebook/llama-2-13b-chat.使用的参数为 temperature=0.6, decode_strategy为"greedy_search",max_new_tokens=512,结果如下
|model|method|acc|
|----|----|----|
|llama-2-7b-chat|cot|0|
|llama-2-7b-chat|standard sampling| 0|
|llama-2-7b-chat|ToT| 3%|
|llama-2-13b-chat|cot|0|
|llama-2-13b-chat|standard sampling|0|
|llama-2-13b-chat|ToT|2%|


## 如何添加新任务

设置一个新任务很容易,主要包括两个步骤。
* 在 ``tot/tasks/`` 中设置一个新的任务类和任务文件在 ``tot/data/`` 中。查看 ``tot/tasks/game24.py`` 以获取示例。将任务添加到 ``tot/tasks/__init__.py`` 中。
* 在 ``tot/prompts/`` 中设置任务特定的提示。查看 ``tot/prompts/game24.py`` 以获取示例。根据任务的性质,选择 ``--method_generate`` (choices=[``sample``, ``propose``]) 和 ``--method_evaluate`` (choices=[``value``, ``vote``]) 及其相应的提示。


## 致谢

我们借鉴了Shunyu Yao ect.出色的框架设计,在此对Tree of Thoughts作者及其开源社区表示感谢。

We learn form the excellent framework design of Shunyu Yao, and we would like to express our thanks to the authors of Tree of Thoughts and their open source community.

```bibtex
@misc{yao2023tree,
title={{Tree of Thoughts}: Deliberate Problem Solving with Large Language Models},
author={Shunyu Yao and Dian Yu and Jeffrey Zhao and Izhak Shafran and Thomas L. Griffiths and Yuan Cao and Karthik Narasimhan},
year={2023},
eprint={2305.10601},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
```
43 changes: 43 additions & 0 deletions pipelines/examples/tree-of-thought/demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse

from src.llm import Ernie, Ernie_llm_list, llamaChatCompletion, llm_config
from src.tot.methods.bfs import solve
from src.tot.tasks.game24 import Game24Task

args = argparse.Namespace(
backend="llama-2-7b-chat",
temperature=0.6,
task="game24",
naive_run=False,
prompt_sample=None,
method_generate="propose",
method_evaluate="value",
method_select="greedy",
n_generate_sample=1,
n_evaluate_sample=3,
n_select_sample=5,
log_fp="log.txt",
)

task = Game24Task()
if args.backend in llm_config.keys():
chatter = llamaChatCompletion(args.backend)
elif args.backend in Ernie_llm_list:
chatter = Ernie(model=args.backend)
ys, infos = solve(args, task, 900, chatter=chatter)
print(ys[0])
print(infos)
20 changes: 20 additions & 0 deletions pipelines/examples/tree-of-thought/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
aiohttp==3.8.4
aiosignal==1.3.1
async-timeout==4.0.2
attrs==23.1.0
certifi==2023.5.7
charset-normalizer==3.1.0
frozenlist==1.3.3
idna==3.4
mpmath==1.3.0
multidict==6.0.4
numpy==1.24.3
requests==2.31.0
sympy==1.12
tqdm==4.65.0
urllib3==2.0.2
yarl==1.9.2
pandas==2.0.3
erniebot==0.5.0
paddlenlp==2.7.1
paddlepaddle-gpu==2.6.0
126 changes: 126 additions & 0 deletions pipelines/examples/tree-of-thought/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# coding=utf8, ErnestinaQiu

# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import json
import os
import time

from src.llm.llama import Ernie, Ernie_llm_list, llamaChatCompletion, llm_config
from src.tot.methods.bfs import naive_solve, solve
from src.tot.models import gpt_usage
from src.tot.tasks import get_task


def run(args, chatter):
task = get_task(args.task)
logs, cnt_avg, cnt_any = [], 0, 0
if args.naive_run:
file = f"./logs/{args.task}/{args.backend}_{args.temperature}_naive_{args.prompt_sample}_sample_{args.n_generate_sample}_start{args.task_start_index}_end{args.task_end_index}.json"
metric_fp = f"./logs/{args.task}/{args.backend}_{args.temperature}_{args.method_select}_{args.n_generate_sample}_start{args.task_start_index}_end{args.task_end_index}_metric.txt"
else:
file = f"./logs/{args.task}/{args.backend}_{args.temperature}_{args.method_generate}{args.n_generate_sample}_{args.method_evaluate}{args.n_evaluate_sample}_{args.method_select}{args.n_select_sample}_start{args.task_start_index}_end{args.task_end_index}.json"
metric_fp = f"./logs/{args.task}/{args.backend}_{args.temperature}_{args.method_generate}{args.n_generate_sample}_{args.method_evaluate}{args.n_evaluate_sample}_{args.method_select}{args.n_select_sample}_start{args.task_start_index}_end{args.task_end_index}_metric.txt"
os.makedirs(os.path.dirname(file), exist_ok=True)

for i in range(args.task_start_index, args.task_end_index):
args.log_fp = f"./logs/{args.task}/{args.backend}_{args.temperature}_{args.prompt_sample}_sample_{args.n_generate_sample}_start{args.task_start_index}_end{args.task_end_index}.log"
args.query_fp = f"./logs/{args.task}/{args.backend}_{args.temperature}_{args.prompt_sample}_sample_{args.n_generate_sample}_start{args.task_start_index}_end{args.task_end_index}_query.log"
f = open(args.log_fp, "a", encoding="utf8")
f.write(f"------ index: {i}")
f.close()

f = open(args.query_fp, "a", encoding="utf8")
f.write(f"------ index: {i}")
f.close()

chatter.query = []
chatter.tokenizer.init_chat_template(
os.path.join(os.getcwd(), "pipelines", "examples", "tree-of-thought", "src", "llm", "chat_template.json")
)

# solve
if args.naive_run:
ys, info = naive_solve(args, task, i, chatter=chatter, args=args)
else:
ys, info = solve(args, task, i, chatter=chatter, args=args)

# log
infos = [task.test_output(i, y) for y in ys]
info.update({"idx": i, "ys": ys, "infos": infos, "usage_so_far": gpt_usage(args.backend)})
logs.append(info)
with open(file, "w") as f:
json.dump(logs, f, indent=4)

# log main metric
accs = [info["r"] for info in infos]
cnt_avg += sum(accs) / len(accs)
cnt_any += any(accs)
mes = f"{i}, 'sum(accs)', {sum(accs)}, 'cnt_avg', {cnt_avg}, 'cnt_any', {cnt_any}, '\n'"
f = open(metric_fp, "a", encoding="utf8")
f.write(mes)
f.close()

f = open(args.query_fp, "a", encoding="utf8")
f.write(json.dumps(chatter.query))
f.close()

n = args.task_end_index - args.task_start_index
mes2 = f"cnt_avg / n: {cnt_avg / n}, cnt_any / n: {cnt_any / n}"
mes3 = f"'usage_so_far', {gpt_usage(args.backend)}"
f = open(metric_fp, "a", encoding="utf8")
f.write(mes2)
f.write(mes3)
f.close()


llm_backend_choices = list(llm_config.keys())


def parse_args():
args = argparse.ArgumentParser()
args.add_argument("--backend", type=str, choices=llm_backend_choices, default="llama-2-7b-chat")
args.add_argument("--temperature", type=float, default=0.6)

args.add_argument("--task", type=str, required=True, choices=["game24", "text", "crosswords"])
args.add_argument("--task_start_index", type=int, default=900)
args.add_argument("--task_end_index", type=int, default=1000)

args.add_argument("--naive_run", action="store_true")
args.add_argument(
"--prompt_sample", type=str, choices=["standard", "cot"]
) # only used when method_generate = sample, or naive_run

args.add_argument("--method_generate", type=str, choices=["sample", "propose"])
args.add_argument("--method_evaluate", type=str, choices=["value", "vote"])
args.add_argument("--method_select", type=str, choices=["sample", "greedy"], default="greedy")
args.add_argument("--n_generate_sample", type=int, default=1) # only thing needed if naive_run
args.add_argument("--n_evaluate_sample", type=int, default=1)
args.add_argument("--n_select_sample", type=int, default=1)

args.add_argument("--query_fp", type=str, default=f"./logs/default/query_{int(time.time())}.log")

args = args.parse_args()
return args


if __name__ == "__main__":
args = parse_args()
if args.backend in llm_backend_choices:
chatter = llamaChatCompletion(args.backend)
elif args.backend in Ernie_llm_list:
chatter = Ernie(model=args.backend)
run(args, chatter=chatter)
Loading
Loading