-
Notifications
You must be signed in to change notification settings - Fork 41
/
Copy pathagent.py
171 lines (152 loc) · 7.08 KB
/
agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import logging
import re
from typing import List, Dict, Union
from langchain import PromptTemplate
from langchain.tools import BaseTool
from gentopia.agent.base_agent import BaseAgent
from gentopia.agent.rewoo.nodes.Planner import Planner
from gentopia.agent.rewoo.nodes.Solver import Solver
from gentopia.llm.base_llm import BaseLLM
from gentopia.model.agent_model import AgentType
from gentopia.util.cost_helpers import *
from gentopia.util.text_helpers import *
class RewooAgent(BaseAgent):
name: str = "RewooAgent"
type: AgentType = AgentType.REWOO
version: str
description: str
target_tasks: list[str]
llm: Union[BaseLLM, Dict[str, BaseLLM]] # {"Planner": xxx, "Solver": xxx}
prompt_template: Dict[str, PromptTemplate] # {"Planner": xxx, "Solver": xxx}
plugins: List[Union[BaseTool, BaseAgent]]
examples: Dict[str, Union[str, List[str]]] = None
# logger = logging.getLogger('application')
def _get_llms(self):
if isinstance(self.llm, BaseLLM):
return {"Planner": self.llm, "Solver": self.llm}
elif isinstance(self.llm, dict) and "Planner" in self.llm and "Solver" in self.llm:
return {"Planner": self.llm["Planner"], "Solver": self.llm["Solver"]}
else:
raise ValueError("llm must be a BaseLLM or a dict with Planner and Solver.")
def _parse_plan_map(self, planner_response: str) -> List[dict[str, List[str]]]:
"""
Parse planner output. It should be an n-to-n mapping from Plans to *Es.
This is because sometimes LLM cannot follow the strict output format.
Example:
*Plan1
*E1
*E2
should result in: {"Plan1": ["*E1", "*E2"]}
Or:
*Plan1
*Plan2
*E1
should result in: {"*Plan1": [], "*Plan2": ["*E1"]}
This function should also return a plan map.
"""
valid_chunk = [line for line in planner_response.splitlines()
if line.startswith("*Plan") or line.startswith("*E")]
plan_to_es = dict()
plans = dict()
for line in valid_chunk:
if line.startswith("*Plan"):
plan = line.split(":", 1)[0].strip()
plans[plan] = line.split(":", 1)[1].strip()
plan_to_es[plan] = []
elif line.startswith("*E"):
plan_to_es[plan].append(line.split(":", 1)[0].strip())
return plan_to_es, plans
def _parse_planner_evidences(self, planner_response: str) -> (dict[str, str], List[List[str]]):
"""
Parse planner output. This should return a mapping from *E to tool call.
It should also identify the level of each *E in dependency map.
Example:
{"*E1": "Tool1", "*E2": "Tool2", "*E3": "Tool3", "*E4": "Tool4"}, [[*E1, *E2], [*E3, *E4]]
"""
evidences, dependence = dict(), dict()
num = 0
for line in planner_response.splitlines():
if line.startswith("*E") and line[2].isdigit():
e, tool_call = line.split(":", 1)
e, tool_call = e.strip(), tool_call.strip()
if len(e) == 3:
dependence[e] = []
num += 1
evidences[e] = tool_call
for var in re.findall(r"\*E\d+", tool_call):
if var in evidences:
dependence[e].append(var)
else:
evidences[e] = "No evidence found"
level = []
while num > 0:
level.append([])
for i in dependence:
if dependence[i] is None:
continue
if len(dependence[i]) == 0:
level[-1].append(i)
num -= 1
for j in dependence:
if j is not None and i in dependence[j]:
dependence[j].remove(i)
if len(dependence[j]) == 0:
dependence[j] = None
return evidences, level
def _get_worker_evidence(self, planner_evidences, evidences_level):
worker_evidences = dict()
for level in evidences_level:
# TODO: Run simultaneously
for e in level:
tool_call = planner_evidences[e]
if "[" not in tool_call:
worker_evidences[e] = tool_call
continue
tool, tool_input = tool_call.split("[", 1)
tool_input = tool_input[:-1]
# find variables in input and replace with previous evidences
for var in re.findall(r"\*E\d+", tool_input):
if var in worker_evidences:
tool_input = tool_input.replace(var, "[" + worker_evidences.get(var, "") + "]")
try:
worker_evidences[e] = get_plugin_response_content(self._find_plugin(tool).run(tool_input))
except:
worker_evidences[e] = "No evidence found."
return worker_evidences
def _find_plugin(self, name: str):
for p in self.plugins:
if p.name == name:
return p
def run(self, instruction: str) -> AgentOutput:
logging.info(f"Running {self.name + ':' + self.version} with instruction: {instruction}")
total_cost = 0.0
total_token = 0
planner_llm = self._get_llms()["Planner"]
solver_llm = self._get_llms()["Solver"]
planner = Planner(model=planner_llm,
workers=self.plugins,
prompt_template=self.prompt_template.get("Planner", None),
examples=self.examples.get("Planner", None))
solver = Solver(model=solver_llm,
prompt_template=self.prompt_template.get("Solver", None),
examples=self.examples.get("Solver", None))
# Plan
planner_output = planner.run(instruction)
total_cost += calculate_cost(planner_llm.model_name, planner_output.prompt_token,
planner_output.completion_token)
total_token += planner_output.prompt_token + planner_output.completion_token
plan_to_es, plans = self._parse_plan_map(planner_output.content)
planner_evidences, evidence_level = self._parse_planner_evidences(planner_output.content)
# Work
worker_evidences = self._get_worker_evidence(planner_evidences, evidence_level)
worker_log = ""
for plan in plan_to_es:
worker_log += f"{plan}: {plans[plan]}\n"
for e in plan_to_es[plan]:
worker_log += f"{e}: {worker_evidences[e]}\n"
# Solve
solver_output = solver.run(instruction, worker_log)
total_cost += calculate_cost(solver_llm.model_name, solver_output.prompt_token,
solver_output.completion_token)
total_token += solver_output.prompt_token + solver_output.completion_token
return AgentOutput(output=solver_output.content, cost=total_cost, token_usage=total_token)