Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

some refactor #16

Merged
merged 2 commits into from
Jun 19, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 12 additions & 14 deletions config.yaml
Original file line number Diff line number Diff line change
@@ -4,21 +4,19 @@ type: rewoo
version: 0.0.1
description: example
prompt_template:
- Planner:
prompt: !prompt ZeroShotPlannerPrompt
- Solver:
prompt: !prompt ZeroShotSolverPrompt
Planner: !prompt ZeroShotPlannerPrompt
Solver: !prompt ZeroShotSolverPrompt
llm:
- Planner:
model_name: gpt-4
params:
temperature: 0.7
- Solver:
model_name: gpt-4
params:
max_tokens: 1024
temperature: 0.0
frequency_penalty: 0.0
Planner:
model_name: gpt-4
params:
temperature: 0.7
Solver:
model_name: gpt-4
params:
max_tokens: 1024
temperature: 0.0
frequency_penalty: 0.0
target_tasks:
- print
- find
12 changes: 6 additions & 6 deletions gentopia/agent/react/agent.py
Original file line number Diff line number Diff line change
@@ -4,12 +4,13 @@

from langchain import PromptTemplate
from langchain.schema import AgentFinish
from langchain.tools import BaseTool
from gentopia.tools import BaseTool
from pydantic import create_model, BaseModel

from gentopia.agent.base_agent import BaseAgent
from gentopia.assembler.task import AgentAction
from gentopia.llm.client.openai import OpenAIGPTClient
from gentopia.llm.base_llm import BaseLLM
from gentopia.model.agent_model import AgentType, AgentOutput
from gentopia.util.cost_helpers import calculate_cost

@@ -22,7 +23,7 @@ class ReactAgent(BaseAgent):
version: str
description: str
target_tasks: list[str]
llm: OpenAIGPTClient
llm: BaseLLM
prompt_template: PromptTemplate
plugins: List[Union[BaseTool, BaseAgent]]
examples: Union[str, List[str]] = None
@@ -43,7 +44,7 @@ def _compose_plugin_description(self) -> str:
for plugin in self.plugins:
prompt += f"{plugin.name}[input]: {plugin.description}\n"
except Exception:
raise ValueError("Worker must have a name and description.")
raise ValueError("Plugin must have a name and description.")
return prompt

def _construct_scratchpad(
@@ -121,10 +122,9 @@ def run(self, instruction):
logging.info(f"Prompt: {prompt}")
response = self.llm.completion(prompt)
if response.state == "error":
logging.error("Planner failed to retrieve response from LLM")
raise ValueError("Planner failed to retrieve response from LLM")
logging.error("Failed to retrieve response from LLM")
raise ValueError("Failed to retrieve response from LLM")

logging.info(f"Planner run successful.")
total_cost += calculate_cost(self.llm.model_name, response.prompt_token,
response.completion_token)
total_token += response.prompt_token + response.completion_token
51 changes: 48 additions & 3 deletions gentopia/agent/vanilla/agent.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,52 @@
from typing import List, Union

from gentopia.agent.base_agent import BaseAgent
from gentopia.llm.base_llm import BaseLLM
from gentopia.model.agent_model import AgentType
from gentopia.prompt.vanilla import *
from gentopia.util.cost_helpers import *
from gentopia.util.text_helpers import *


# TODO: Implement this class
class VanillaAgent(BaseAgent):
def run(self, instruction):
pass
name: str = "VanillaAgent"
type: AgentType = AgentType.Vanilla
version: str
description: str
target_tasks: list[str]
llm: BaseLLM
prompt_template: PromptTemplate = None
examples: Union[str, List[str]] = None

def _compose_fewshot_prompt(self) -> str:
if self.examples is None:
return ""
if isinstance(self.examples, str):
return self.examples
else:
return "\n\n".join([e.strip("\n") for e in self.examples])

def _compose_prompt(self, instruction: str) -> str:
fewshot = self._compose_fewshot_prompt()
if self.prompt_template is not None:
if "fewshot" in self.prompt_template.input_variables:
return self.prompt_template.format(fewshot=fewshot, instruction=instruction)
else:
return self.prompt_template.format(instruction=instruction)
else:
if self.examples is None:
return ZeroShotVanillaPrompt.format(instruction=instruction)
else:
return FewShotVanillaPrompt.format(fewshot=fewshot, instruction=instruction)

def run(self, instruction: str) -> AgentOutput:
prompt = self._compose_prompt(instruction)
response = self.llm.completion(prompt)
total_cost = calculate_cost(self.llm.model_name, response.prompt_token,
response.completion_token)
total_token = response.prompt_token + response.completion_token

return AgentOutput(
output=response.content,
cost=total_cost,
token_usage=total_token)
35 changes: 21 additions & 14 deletions gentopia/assembler/agent_assembler.py
Original file line number Diff line number Diff line change
@@ -54,42 +54,49 @@ def get_agent(self, config=None):
return agent

def _get_llm(self, obj) -> Union[BaseLLM, Dict[str, BaseLLM]]:
assert isinstance(obj, dict) or isinstance(obj, list)
if isinstance(obj, list):
assert isinstance(obj, dict) or isinstance(obj, str)
if isinstance(obj, dict) and ("model_name" not in obj):
return {
list(item.keys())[0]: self._parse_llm(list(item.values())[0]) for item in obj
#list(item.keys())[0]: self._parse_llm(list(item.values())[0]) for item in obj
key: self._parse_llm(obj[key]) for key in obj
}
else:
return self._parse_llm(obj)

def _parse_llm(self, obj) -> BaseLLM:
name = obj['model_name']
model_param = obj.get('params', dict())
if isinstance(obj, str):
name = obj
model_param = dict()
else:
print(obj)
name = obj['model_name']
model_param = obj.get('params', dict())
if TYPES.get(name, None) == "OpenAI":
key = obj.get('key', None)
#key = obj.get('key', None)
params = OpenAIParamModel(**model_param)
return OpenAIGPTClient(model_name=name, params=params, api_key=key)
return OpenAIGPTClient(model_name=name, params=params)
elif TYPES.get(name, None) == "Huggingface":
print(obj)
device = obj.get('device', 'gpu' if torch.cuda.is_available() else 'cpu')
params = HuggingfaceParamModel(**model_param)
return HuggingfaceLLMClient(model_name=name, params=params, device=device)
else:
raise ValueError(f"LLM {name} is not supported currently.")

def _get_prompt_template(self, obj):
assert isinstance(obj, dict) or isinstance(obj, list)
if isinstance(obj, list):
assert isinstance(obj, dict) or isinstance(obj, PromptTemplate)
if isinstance(obj, dict):
return {
list(item.keys())[0]: self._parse_prompt_template(list(item.values())[0]) for item in obj
key: self._parse_prompt_template(obj[key]) for key in obj
}
else:
ans = self._parse_prompt_template(obj)
return ans

def _parse_prompt_template(self, obj):
assert isinstance(obj, dict)
if 'prompt' in obj:
return obj['prompt']
assert isinstance(obj, dict) or isinstance(obj, PromptTemplate)
if isinstance(obj, PromptTemplate):
return obj
input_variables = obj['input_variables']
template = obj['template']
template_format = obj.get('template_format', 'f-string')
@@ -102,7 +109,7 @@ def _parse_plugins(self, obj):
result = []
for i in obj:
# If referring to a tool class then directly load it
if issubclass(i, BaseTool):
if isinstance(i, BaseTool):
result.append(i)
continue

12 changes: 6 additions & 6 deletions gentopia/assembler/loader.py
Original file line number Diff line number Diff line change
@@ -22,12 +22,12 @@ def include(self, node: yaml.Node) -> Any:

def prompt(self, node: yaml.Node) -> Any:
prompt = self.construct_scalar(node)
prompt_cls = eval(prompt)
assert issubclass(prompt_cls, PromptTemplate)
return prompt_cls
prompt_ins = eval(prompt)
assert isinstance(prompt_ins, PromptTemplate)
return prompt_ins

def tool(self, node: yaml.Node) -> Any:
tool = self.construct_scalar(node)
tool_cls = eval(tool)
assert issubclass(tool_cls, BaseTool)
return tool_cls
tool_ins = eval(tool)
assert isinstance(tool_ins, BaseTool)
return tool_ins
14 changes: 6 additions & 8 deletions gentopia/llm/client/openai.py
Original file line number Diff line number Diff line change
@@ -9,15 +9,13 @@
from gentopia.model.param_model import *


class OpenAIGPTClient(BaseLLM):
def __init__(self, model_name: str, params: OpenAIParamModel, api_key: str = None):
assert TYPES.get(model_name, None) == "OpenAI"
self.api_key = api_key
self.params = params
self.model_name = model_name
class OpenAIGPTClient(BaseLLM, BaseModel):
model_name: str
params: OpenAIParamModel

def __init__(self, **data):
super().__init__(**data)
openai.api_key = os.environ.get("OPENAI_API_KEY", "")
if api_key is not None:
openai.api_key = api_key

def get_model_name(self) -> str:
return self.model_name
4 changes: 2 additions & 2 deletions gentopia/llm/llm_info.py
Original file line number Diff line number Diff line change
@@ -44,6 +44,6 @@
}

COSTS = {
"gpt-3.5-turbo": {"prompt": 0.002, "completion": 0.002},
"gpt-4": {"prompt": 0.03, "completion": 0.06},
"gpt-3.5-turbo": {"prompt": 0.002/1000, "completion": 0.002/1000},
"gpt-4": {"prompt": 0.03/1000, "completion": 0.06/1000},
}
2 changes: 1 addition & 1 deletion gentopia/llm/loaders/airoboros.py
Original file line number Diff line number Diff line change
@@ -10,5 +10,5 @@ def load_model(loader_model: HuggingfaceLoaderModel):
model = AutoModelForCausalLM.from_pretrained(*args, **kwargs)
if loader_model.device == "gpu":
model.half()
model = BetterTransformer.transform(model)
#model = BetterTransformer.transform(model)
return model, tokenizer
6 changes: 3 additions & 3 deletions gentopia/model/agent_model.py
Original file line number Diff line number Diff line change
@@ -2,7 +2,7 @@
from enum import Enum
from pydantic import BaseModel

from gentopia.agent.base_agent import BaseAgent
#from gentopia.agent.base_agent import BaseAgent


class AgentType(Enum):
@@ -11,7 +11,7 @@ class AgentType(Enum):
"""
REACT = "react"
REWOO = "rewoo"
DIRECT = "direct"
Vanilla = "vanilla"

@staticmethod
def get_agent_class(_type: AgentType):
@@ -26,7 +26,7 @@ def get_agent_class(_type: AgentType):
elif _type == AgentType.REWOO:
from gentopia.agent.rewoo import RewooAgent
return RewooAgent
elif _type == AgentType.DIRECT:
elif _type == AgentType.Vanilla:
from gentopia.agent.vanilla import VanillaAgent
return VanillaAgent
else:
3 changes: 2 additions & 1 deletion gentopia/prompt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .fewshots import *
from .rewoo import *
# from .react import ZeroShotReactPrompt
from .vanilla import *
from .react import *
2 changes: 1 addition & 1 deletion gentopia/prompt/react.py
Original file line number Diff line number Diff line change
@@ -16,7 +16,7 @@
Final Answer: the final answer to the original input question
Begin!

Question: {input}
Question: {instruction}
Thought:{agent_scratchpad}
"""
)
13 changes: 13 additions & 0 deletions gentopia/prompt/vanilla.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from langchain import PromptTemplate

ZeroShotVanillaPrompt = PromptTemplate(
input_variables=["instruction"],
template="""{instruction}"""
)

FewShotVanillaPrompt = PromptTemplate(
input_variables=["instruction", "fewshot"],
template="""{instruction}

{fewshot}"""
)
70 changes: 0 additions & 70 deletions gentopia/prompt/wiki_prompt.py

This file was deleted.

Loading