-
Notifications
You must be signed in to change notification settings - Fork 41
/
Copy pathagent_assembler.py
executable file
·238 lines (208 loc) · 9.35 KB
/
agent_assembler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
import os
from typing import Union, Dict, Optional
import torch.cuda
from gentopia.prompt import PromptTemplate
from gentopia.agent.base_agent import BaseAgent
from gentopia.assembler.config import Config
from gentopia.llm import HuggingfaceLLMClient, OpenAIGPTClient
from gentopia.llm.base_llm import BaseLLM
from gentopia.llm.llm_info import TYPES
from gentopia.manager.base_llm_manager import BaseLLMManager
from gentopia.memory.api import MemoryWrapper
from gentopia.memory.api import create_memory
from gentopia.model.agent_model import AgentType
from gentopia.model.param_model import OpenAIParamModel, HuggingfaceParamModel
from gentopia.tools import *
from gentopia.tools import BaseTool
from gentopia.tools.basetool import ToolMetaclass
class AgentAssembler:
"""
This class is responsible for assembling an agent instance from a configuration file or dictionary and its dependencies.
:param file: A path to a configuration file.
:type file: str, optional
:param config: A configuration dictionary.
:type config: dict, optional
"""
def __init__(self, file=None, config=None):
"""
Constructor method.
Initializes an instance of the AgentAssembler class.
:param file: A path to a configuration file.
:type file: str, optional
:param config: A configuration dictionary.
:type config: dict, optional
"""
if file is not None:
self.config = Config.from_file(file)
elif config is not None:
self.config = Config.from_dict(config)
self.plugins: Dict[str, Union[BaseAgent, BaseTool]] = dict()
self.manager: Optional[BaseLLMManager] = None
def get_agent(self, config=None):
"""
This method returns an agent instance based on the provided configuration.
:param config: A configuration dictionary.
:type config: dict, optional
:raises AssertionError: If the configuration is None.
:return: An agent instance.
:rtype: BaseAgent
"""
if config is None:
config = self.config
assert config is not None
auth = config.get('auth', {})
self._set_auth_env(auth)
# Agent config
name = config.get('name')
_type = AgentType(config.get('type'))
version = config.get('version', "")
description = config.get('description', "")
AgentClass = AgentType.get_agent_class(_type)
prompt_template = self._get_prompt_template(config.get('prompt_template'))
agent = AgentClass(
name=name,
type=_type,
version=version,
description=description,
target_tasks=config.get('target_tasks', []),
llm=self._get_llm(config['llm']),
prompt_template=prompt_template,
plugins=self._parse_plugins(config.get('plugins', [])),
memory=self._parse_memory(config.get('memory', [])) # initialize memory
)
return agent
def _parse_memory(self, obj) -> MemoryWrapper:
"""
This method parses the memory configuration and returns a memory wrapper instance.
:param obj: A configuration dictionary containing memory parameters.
:type obj: dict
:return: A memory wrapper instance.
:rtype: MemoryWrapper
"""
if obj == []:
return None
memory_type = obj["memory_type"] # memory_type: ["pinecone"]
return create_memory(memory_type, obj['conversation_threshold'], obj['reasoning_threshold'], **obj["params"]) # params of memory. Different memories may have different params
def _get_llm(self, obj) -> Union[BaseLLM, Dict[str, BaseLLM]]:
"""
This method returns a language model manager (LLM) instance based on the provided configuration.
:param obj: A configuration dictionary or string.
:type obj: dict or str
:raises AssertionError: If the configuration is not a dictionary or string.
:raises ValueError: If the specified LLM is not supported.
:return: An LLM instance or dictionary of LLM instances.
:rtype: Union[BaseLLM, Dict[str, BaseLLM]]
"""
assert isinstance(obj, dict) or isinstance(obj, str)
if isinstance(obj, dict) and ("model_name" not in obj):
return {
# list(item.keys())[0]: self._parse_llm(list(item.values())[0]) for item in obj
key: self._parse_llm(obj[key]) for key in obj
}
else:
return self._parse_llm(obj)
def _parse_llm(self, obj) -> BaseLLM:
"""
This method parses the Language Model Manager (LLM) configuration and returns an LLM instance.
:param obj: A configuration dictionary or string.
:type obj: dict or str
:raises ValueError: If the specified LLM is not supported.
:return: An LLM instance.
:rtype: BaseLLM
"""
if isinstance(obj, str):
name = obj
model_param = dict()
else:
name = obj['model_name']
model_param = obj.get('params', dict())
llm = None
if TYPES.get(name, None) == "OpenAI":
# key = obj.get('key', None)
params = OpenAIParamModel(**model_param)
llm = OpenAIGPTClient(model_name=name, params=params)
elif TYPES.get(name, None) == "Huggingface":
device = obj.get('device', 'gpu' if torch.cuda.is_available() else 'cpu')
params = HuggingfaceParamModel(**model_param)
llm = HuggingfaceLLMClient(model_name=name, params=params, device=device)
if llm is None:
raise ValueError(f"LLM {name} is not supported currently.")
if self.manager is None:
return llm
return self.manager.get_llm(name, params, cls=HuggingfaceLLMClient, device=device)
def _get_prompt_template(self, obj):
"""
This method returns a prompt template instance based on the provided configuration.
:param obj: A configuration dictionary or prompt template instance.
:type obj: dict or PromptTemplate
:raises AssertionError: If the configuration is not a dictionary or prompt template instance.
:return: A prompt template instance.
:rtype: PromptTemplate
"""
assert isinstance(obj, dict) or isinstance(obj, PromptTemplate)
if isinstance(obj, dict):
return {
key: self._parse_prompt_template(obj[key]) for key in obj
}
else:
ans = self._parse_prompt_template(obj)
return ans
def _parse_prompt_template(self, obj):
"""
This method parses the prompt template configuration and returns a prompt template instance.
:param obj: A configuration dictionary or prompt template instance.
:type obj: dict or PromptTemplate
:raises AssertionError: If the configuration is not a dictionary or prompt template instance.
:return: A prompt template instance.
:rtype: PromptTemplate
"""
assert isinstance(obj, dict) or isinstance(obj, PromptTemplate)
if isinstance(obj, PromptTemplate):
return obj
input_variables = obj['input_variables']
template = obj['template']
template_format = obj.get('template_format', 'f-string')
validate_template = bool(obj.get('validate_template', True))
return PromptTemplate(input_variables=input_variables, template=template, template_format=template_format,
validate_template=validate_template)
def _parse_plugins(self, obj):
"""
This method parses the plugin configuration and returns a list of plugin instances.
:param obj: A list of plugin configuration dictionaries.
:type obj: list
:raises AssertionError: If the configuration is not a list.
:return: A list of plugin instances.
:rtype: list
"""
assert isinstance(obj, list)
result = []
for plugin in obj:
# If referring to a tool class then directly load it
if issubclass(plugin.__class__, ToolMetaclass):
result.append(plugin)
continue
# Directly invoke already loaded plugin
if plugin['name'] in self.plugins:
_plugin = self.plugins[plugin['name']]
result.append(_plugin)
continue
# Agent as plugin
if plugin.get('type', "") in AgentType.__members__:
agent = self.get_agent(plugin)
result.append(agent)
self.plugins[plugin['name']] = agent
# Tool as plugin
else:
params = plugin.get('params', dict())
tool = load_tools(plugin['name'])(**params)
result.append(tool)
self.plugins[plugin['name']] = tool
return result
def _set_auth_env(self, obj):
"""
This method sets environment variables for authentication.
:param obj: A dictionary containing authentication information.
:type obj: dict
"""
for key in obj:
os.environ[key] = obj.get(key)