Skip to content

Commit d875099

Browse files
liuxukun2000Satan
and
Satan
authored
Allow loading of user-defined tools, parallelize tasks using threads (#33)
* Allow loading of user-defined tools * Allow loading of user-defined tools, parallelize tasks using threads --------- Co-authored-by: Satan <liuxk2019@mail.sustech.edu.cn>
1 parent 44b896e commit d875099

File tree

6 files changed

+70
-1895
lines changed

6 files changed

+70
-1895
lines changed

.gitignore

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#Custom
22
.idea/
3-
test.ipynb
3+
examples/test.ipynb
44
test.py
55
.chroma
66

configs/mathria.yaml

+3
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,6 @@ target_tasks:
3333
plugins:
3434
- name: wolfram_alpha
3535

36+
auth:
37+
OPENAI_API_KEY: !file /home/api.key
38+
WOLFRAM_ALPHA_APPID: !file /home/wolfram.key

gentopia/agent/rewoo/agent.py

+58-45
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import logging
2+
import os
23
import re
34
from typing import List, Dict, Union, Optional
45

56
from langchain import PromptTemplate
6-
7+
from concurrent.futures import ThreadPoolExecutor
78
from gentopia.agent.base_agent import BaseAgent
89
from gentopia.agent.rewoo.nodes.Planner import Planner
910
from gentopia.agent.rewoo.nodes.Solver import Solver
@@ -74,67 +75,79 @@ def _parse_planner_evidences(self, planner_response: str) -> (dict[str, str], Li
7475
{"#E1": "Tool1", "#E2": "Tool2", "#E3": "Tool3", "#E4": "Tool4"}, [[#E1, #E2], [#E3, #E4]]
7576
"""
7677
evidences, dependence = dict(), dict()
77-
num = 0
7878
for line in planner_response.splitlines():
7979
if line.startswith("#E") and line[2].isdigit():
8080
e, tool_call = line.split(":", 1)
8181
e, tool_call = e.strip(), tool_call.strip()
8282
if len(e) == 3:
8383
dependence[e] = []
84-
num += 1
8584
evidences[e] = tool_call
8685
for var in re.findall(r"#E\d+", tool_call):
8786
if var in evidences:
8887
dependence[e].append(var)
8988
else:
9089
evidences[e] = "No evidence found"
91-
level = [list(evidences.keys())]
92-
#TODO: Fix this
93-
94-
# while num > 0:
95-
# level.append([])
96-
# print(dependence)
97-
# for i in dependence:
98-
# if dependence[i] is None:
99-
# continue
100-
# if len(dependence[i]) == 0:
101-
# level[-1].append(i)
102-
# num -= 1
103-
# for j in dependence:
104-
# if dependence[j] is not None and i in dependence[j]:
105-
# dependence[j].remove(i)
106-
# if len(dependence[j]) == 0:
107-
# dependence[j] = None
108-
# print(level)
90+
level = []
91+
while dependence:
92+
select = [i for i in dependence if not dependence[i]]
93+
if len(select) == 0:
94+
raise ValueError("Circular dependency detected.")
95+
level.append(select)
96+
for item in select:
97+
dependence.pop(item)
98+
for item in dependence:
99+
for i in select:
100+
if i in dependence[item]:
101+
dependence[item].remove(i)
102+
109103
return evidences, level
110104

105+
106+
def _run_plugin(self, e, planner_evidences, worker_evidences, output=BaseOutput()):
107+
result = dict(e=e, plugin_cost=0, plugin_token=0, evidence="")
108+
tool_call = planner_evidences[e]
109+
if "[" not in tool_call:
110+
result['evidence'] = tool_call
111+
else:
112+
tool, tool_input = tool_call.split("[", 1)
113+
tool_input = tool_input[:-1]
114+
# find variables in input and replace with previous evidences
115+
for var in re.findall(r"#E\d+", tool_input):
116+
if var in worker_evidences:
117+
tool_input = tool_input.replace(var, "[" + worker_evidences.get(var, "") + "]")
118+
try:
119+
tool_response = self._find_plugin(tool).run(tool_input)
120+
# cumulate agent-as-plugin costs and tokens.
121+
if isinstance(tool_response, AgentOutput):
122+
result['plugin_cost'] = tool_response.cost
123+
result['plugin_token'] = tool_response.token_usage
124+
result['evidence'] = get_plugin_response_content(tool_response)
125+
except:
126+
result['evidence'] = "No evidence found."
127+
finally:
128+
output.panel_print(result['evidence'], f"[green] Function Response of [blue]{tool}: ")
129+
return result
130+
131+
111132
def _get_worker_evidence(self, planner_evidences, evidences_level, output=BaseOutput()):
112133
worker_evidences = dict()
113134
plugin_cost, plugin_token = 0.0, 0.0
114-
for level in evidences_level:
115-
# TODO: Run simultaneously
116-
for e in level:
117-
tool_call = planner_evidences[e]
118-
if "[" not in tool_call:
119-
worker_evidences[e] = tool_call
120-
continue
121-
tool, tool_input = tool_call.split("[", 1)
122-
tool_input = tool_input[:-1]
123-
# find variables in input and replace with previous evidences
124-
for var in re.findall(r"#E\d+", tool_input):
125-
if var in worker_evidences:
126-
tool_input = tool_input.replace(var, "[" + worker_evidences.get(var, "") + "]")
127-
try:
128-
tool_response = self._find_plugin(tool).run(tool_input)
129-
# cumulate agent-as-plugin costs and tokens.
130-
if isinstance(tool_response, AgentOutput):
131-
plugin_cost += tool_response.cost
132-
plugin_token += tool_response.token_usage
133-
worker_evidences[e] = get_plugin_response_content(tool_response)
134-
except:
135-
worker_evidences[e] = "No evidence found."
136-
finally:
137-
output.panel_print(worker_evidences[e], f"[green] Function Response of [blue]{tool}: ")
135+
with ThreadPoolExecutor(max_workers=2) as pool:
136+
for level in evidences_level:
137+
results = []
138+
for e in level:
139+
results.append(pool.submit(self._run_plugin, e, planner_evidences, worker_evidences, output))
140+
if len(results) > 1:
141+
output.update_status(f"Running tasks {level} in parallel.")
142+
else:
143+
output.update_status(f"Running task {level[0]}.")
144+
for r in results:
145+
resp = r.result()
146+
plugin_cost += resp['plugin_cost']
147+
plugin_token += resp['plugin_token']
148+
worker_evidences[resp['e']] = resp['evidence']
149+
output.done()
150+
138151
return worker_evidences, plugin_cost, plugin_token
139152

140153
def _find_plugin(self, name: str):

gentopia/assembler/loader.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,12 @@ def prompt(self, node: yaml.Node) -> Any:
3838

3939
def tool(self, node: yaml.Node) -> Any:
4040
tool = self.construct_scalar(node)
41-
tool_cls = eval(tool)
41+
if '.' in tool:
42+
_path = tool.split('.')
43+
module = importlib.import_module('.'.join(_path[:-1]))
44+
tool_cls = getattr(module, _path[-1])
45+
else:
46+
tool_cls = eval(tool)
4247
assert issubclass(tool_cls, BaseTool)
4348
return tool_cls
4449

0 commit comments

Comments
 (0)