Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hackathon 5th No.73] ToT #7660

Merged
merged 54 commits into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
522a7af
Hackathon TASK73 ToT
ErnestinaQiu Dec 14, 2023
a68dacb
update readme tutorial
ErnestinaQiu Dec 14, 2023
c65870b
modify according to Lint
ErnestinaQiu Dec 14, 2023
315cc3e
modify according Link
ErnestinaQiu Dec 14, 2023
b001719
Delete LICENSE
ErnestinaQiu Dec 14, 2023
8932aca
Update LICENSE
ErnestinaQiu Dec 14, 2023
da288af
black format
ErnestinaQiu Dec 14, 2023
916d70c
isort format
ErnestinaQiu Dec 14, 2023
e64c51e
Update search_crosswords-dfs.ipynb
ErnestinaQiu Dec 14, 2023
ef5cfa6
update files formats
ErnestinaQiu Dec 14, 2023
96a6d35
Update LICENSE
ErnestinaQiu Dec 14, 2023
6c95517
Update LICENSE
ErnestinaQiu Dec 14, 2023
1728255
Update LICENSE
ErnestinaQiu Dec 14, 2023
bd35dde
Update LICENSE
ErnestinaQiu Dec 14, 2023
f5ff4df
delete test data
ErnestinaQiu Dec 15, 2023
5621975
delete some unnecessary files
ErnestinaQiu Dec 21, 2023
e7d3ba6
add paddlenlp-llama2
ErnestinaQiu Dec 21, 2023
84ee4d0
fix one bug
ErnestinaQiu Dec 22, 2023
effc87b
fix outputs bug
ErnestinaQiu Dec 22, 2023
c514ca9
delete meta/llama2
ErnestinaQiu Dec 22, 2023
402fa97
modify according to comments
ErnestinaQiu Dec 22, 2023
ae8c242
change according to comments
ErnestinaQiu Dec 22, 2023
c7979ed
Delete .gitignore
ErnestinaQiu Dec 22, 2023
c8f79e2
Create .gitignore
ErnestinaQiu Dec 22, 2023
994286c
Move directory
Dec 22, 2023
1f9499a
Add tree of thoughts scripts
Dec 22, 2023
065a1e9
add first dir
ErnestinaQiu Dec 22, 2023
3fd243d
Merge branch 'develop' of https://github.com/ErnestinaQiu/tot into de…
ErnestinaQiu Dec 22, 2023
24982e4
add note
ErnestinaQiu Dec 22, 2023
cebe49e
Update README.md
ErnestinaQiu Jan 25, 2024
987117e
Update requirements.txt
ErnestinaQiu Jan 25, 2024
c74e478
Update demo.py
ErnestinaQiu Jan 25, 2024
26179dc
Update .gitignore
ErnestinaQiu Jan 25, 2024
e1fdd67
Update run.py
ErnestinaQiu Jan 25, 2024
e793463
Update __init__.py
ErnestinaQiu Jan 25, 2024
5e4dcf1
chat templates
ErnestinaQiu Jan 25, 2024
671c6b8
add Ernie
ErnestinaQiu Jan 25, 2024
1face8b
Update llama.py
ErnestinaQiu Jan 25, 2024
e9b2100
Update bfs.py
ErnestinaQiu Jan 25, 2024
94d7a82
Update models.py
ErnestinaQiu Jan 25, 2024
4472dd6
Update run.py
ErnestinaQiu Jan 25, 2024
780ed41
format style
ErnestinaQiu Jan 25, 2024
8e90744
format style
ErnestinaQiu Jan 25, 2024
8fc80a3
format style
ErnestinaQiu Jan 25, 2024
e53ad12
format style
ErnestinaQiu Jan 25, 2024
c29e61c
format style
ErnestinaQiu Jan 25, 2024
1e9c384
format style
ErnestinaQiu Jan 25, 2024
94153a9
format style
ErnestinaQiu Jan 25, 2024
0592931
format style
ErnestinaQiu Jan 25, 2024
b1a65a9
删掉重复的“测试结果”
ErnestinaQiu Jan 26, 2024
fac1c02
删除Ernie的token,设置环境变量解决
ErnestinaQiu Jan 26, 2024
fc122c3
format style
ErnestinaQiu Jan 26, 2024
df25595
format style
ErnestinaQiu Jan 26, 2024
34c0953
删除注释掉的代码
ErnestinaQiu Jan 26, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions pipelines/examples/agents/tree-of-thought-llm/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
/*__pycache__/
dist/
src/tree_of_thoughts_llm.egg-info/
.env
*.pyc
*.DS_Store
spark/test.py
4 changes: 4 additions & 0 deletions pipelines/examples/agents/tree-of-thought-llm/MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
include src/tot/data/24/24.csv
include src/tot/data/crosswords/mini0505_0_100_5.json
include src/tot/data/crosswords/mini0505.json
include src/tot/data/text/data_100_random_text.txt
85 changes: 85 additions & 0 deletions pipelines/examples/agents/tree-of-thought-llm/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Tree of Thoughts (ToT)

![teaser](pics/teaser.png)

Official implementation for paper [Tree of Thoughts: Deliberate Problem Solving with Large Language Models](https://arxiv.org/abs/2305.10601) with code, prompts, model outputs.
Also check [its tweet thread](https://twitter.com/ShunyuYao12/status/1659357547474681857) in 1min.





## Setup
1. Set up OpenAI API key and store in environment variable ``OPENAI_API_KEY`` (see [here](https://help.openai.com/en/articles/5112595-best-practices-for-api-key-safety)).

2. Install `tot` package in two ways:
- Option 1: Install from PyPI
```bash
pip install tree-of-thoughts-llm
```
- Option 2: Install from source
```bash
git clone https://github.com/PaddlePaddle/PaddleNLP.git
cd pipelines/examples/agents
pip install -r requirements.txt
pip install -e . # install `tot` package
```
3. Intall meta/llama2 according to facebook tutorial. And then modify the model path in the llm_config.yaml

## Quick Start
The following minimal script will attempt to solve the game of 24 with `4 5 6 10` (might be a bit slow as it's using llama-7b-chat):


run in pipelines/examples/agents/tree-of-thought-llm

```
python demo.py
```
the detail code is the following

```python
import argparse
from tot.methods.bfs import solve
from tot.tasks.game24 import Game24Task

args = argparse.Namespace(backend='llama-2-7b-chat', temperature=0.6, task='game24', naive_run=False, prompt_sample=None, method_generate='propose', method_evaluate='value', method_select='greedy', n_generate_sample=1, n_evaluate_sample=3, n_select_sample=5)


task = Game24Task()
ys, infos = solve(args, task, 900)
print(ys[0])
```

And the output would be something like (note it's not deterministic, and sometimes the output can be wrong):
```
10 - 4 = 6 (left: 5 6 6)
5 * 6 = 30 (left: 6 30)
30 - 6 = 24 (left: 24)
Answer: (5 * (10 - 4)) - 6 = 24
```

## Paper Experiments

Run experiments via ``sh scripts/{game24, text, crosswords}/{standard_sampling, cot_sampling, bfs}.sh``, except in crosswords we use a DFS algorithm for ToT, which can be run via ``scripts/crosswords/search_crosswords-dfs.ipynb``.

The very simple ``run.py`` implements the ToT + BFS algorithm, as well as the naive IO/CoT sampling. Some key arguments:

- ``--naive_run``: if True, run naive IO/CoT sampling instead of ToT + BFS.
- ``--prompt_sample`` (choices=[``standard``, ``cot``]): sampling prompt
- ``--method_generate`` (choices=[``sample``, ``propose``]): thought generator, whether to sample independent thoughts (used in Creative Writing) or propose sequential thoughts (used in Game of 24)
- ``--method_evaluate`` (choices=[``value``, ``vote``]): state evaluator, whether to use the value states independently (used in Game of 24) or vote on states together (used in Creative Writing)
- ``--n_generate_sample``: number of times to prompt for thought generation
- ``--n_evaluate_sample``: number of times to prompt for state evaluation
- ``--n_select_sample``: number of states to keep from each step (i.e. ``b`` in the paper's ToT + BFS algorithm)



## Paper Trajectories
``logs/`` contains all the trajectories from the paper's experiments, except for ``logs/game24/gpt-4_0.7_propose1_value3_greedy5_start900_end1000.json`` which was reproduced after the paper (as the original experiment was done in a notebook) and achieved a 69\% score instead of the original 74\% score due to randomness in GPT decoding. We hope to aggregate multiple runs in the future to account for sampling randomness and update the paper, but this shouldn't affect the main conclusions of the paper.

## How to Add A New Task
Setting up a new task is easy, and mainly involves two steps.
* Set up a new task class in ``tot/tasks/`` and task files in ``tot/data/``. See ``tot/tasks/game24.py`` for an example. Add the task to ``tot/tasks/__init__.py``.
* Set up task-specific prompts in ``tot/prompts/``. See ``tot/prompts/game24.py`` for an example. Depending on the nature of the task, choose ``--method_generate`` (choices=[``sample``, ``propose``]) and ``--method_evaluate`` (choices=[``value``, ``vote``]) and their corresponding prompts.

If there are any questions, please contact ErnestinaQiu by ernestinaqiu@gmail.com
23 changes: 23 additions & 0 deletions pipelines/examples/agents/tree-of-thought-llm/demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import argparse

from src.tot.methods.bfs import solve
from src.tot.tasks.game24 import Game24Task

args = argparse.Namespace(
backend="llama-2-7b-chat",
temperature=0.6,
task="game24",
naive_run=False,
prompt_sample=None,
method_generate="propose",
method_evaluate="value",
method_select="greedy",
n_generate_sample=1,
n_evaluate_sample=3,
n_select_sample=5,
)

task = Game24Task()
ys, infos = solve(args, task, 900)
print(ys[0])
print(infos)
104 changes: 104 additions & 0 deletions pipelines/examples/agents/tree-of-thought-llm/llama2/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
"""
author: Ernestina
des: 1) set configure 2) initiate llama2
"""
import os
import time
from typing import List, Optional

import yaml
from llama2.llama.llama import Dialog, Llama

os.environ["WORLD_SIZE"] = "1"
os.environ["RANK"] = "0"
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "8020"

llm_config_path = os.path.join(os.getcwd(), "llm_config.yml")
with open(llm_config_path, "r") as f:
log_config = yaml.full_load(f.read())


class ChatCompletion:
global log_config
global max_seq_len
global max_batch_size

def __init__(self, model="llama-2-7b-chat") -> None:
ckpt_dir = log_config[model]["ckpt_dir"]
tokenizer_path = log_config[model]["tokenizer_path"]
# ckpt_dir = f"/mnt/e/study/dl/llama2/{model}/"
# tokenizer_path = "/mnt/e/study/dl/llama2/tokenizer.model"
max_seq_len = 1000
max_batch_size = 6
self.generator = Llama.build(
ckpt_dir=ckpt_dir,
tokenizer_path=tokenizer_path,
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
)

# @staticmethod
def create(
self,
messages: List[Dialog],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
):
"""
Entry point of the program for generating text using a pretrained model.

Args:
messages (list): There are two roles including "system" and "user".
--Example [[{"role": "user", "content": "what is the recipe of mayonnaise?"}, {"role": "system", "content": "Always answer with Haiku"}]]
ckpt_dir (str): The directory containing checkpoint files for the pretrained model.
tokenizer_path (str): The path to the tokenizer model used for text encoding/decoding.
temperature (float, optional): The temperature value for controlling randomness in generation.
Defaults to 0.6.
top_p (float, optional): The top-p sampling parameter for controlling diversity in generation.
Defaults to 0.9.
max_seq_len (int, optional): The maximum sequence length for input prompts. Defaults to 512.
max_batch_size (int, optional): The maximum batch size for generating sequences. Defaults to 8.
max_gen_len (int, optional): The maximum length of generated sequences. If None, it will be
set to the model's max sequence length. Defaults to None.
"""
results = self.generator.chat_completion(
messages, # type: ignore
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
)

completion = {
"choices": [],
"created": time.time(),
"id": "llama2_{}".format(int(time.time())),
"model": "llama-2-7b-chat",
"object": "chat.completion",
"usage": {"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 0},
}

assert len(messages) == len(results)
for i in range(len(results)):
dialog = messages[i]
print(f"dialog: \n {dialog}")
result = results[i]
if i == len(results) - 1:
finish_reason = "stop"
else:
finish_reason = "length"
tmp = {
"finish_reason": finish_reason,
"index": i,
"message": {"content": "", "role": ""},
}
tmp["message"]["role"] = result["generation"]["role"]
tmp["message"]["content"] = result["generation"]["content"].replace(
"\n", ""
)

completion["choices"].append(tmp)
print(f"\n result: \n {result}")

return completion
160 changes: 160 additions & 0 deletions pipelines/examples/agents/tree-of-thought-llm/llama2/llama/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
.pybuilder/
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version

# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock

# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock

# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml

# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/

# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
Loading