5
5
from modelscope_agent import Agent
6
6
from modelscope_agent .agent_env_util import AgentEnvMixin
7
7
from modelscope_agent .llm .base import BaseChatModel
8
+ from modelscope_agent .tools .base import BaseTool
8
9
from modelscope_agent .utils .tokenization_utils import count_tokens
9
10
from modelscope_agent .utils .utils import check_and_limit_input_length
10
11
107
108
'en' : '. you can use tools: [{tool_names}]' ,
108
109
}
109
110
111
+ SPECIAL_PREFIX_TEMPLATE_TOOL_FOR_CHAT = {
112
+ 'zh' : '。你必须使用工具中的一个或多个:[{tool_names}]' ,
113
+ 'en' : '. you must use one or more tools: [{tool_names}]' ,
114
+ }
115
+
110
116
SPECIAL_PREFIX_TEMPLATE_KNOWLEDGE = {
111
117
'zh' : '。请查看前面的知识库' ,
112
118
'en' : '. Please read the knowledge base at the beginning' ,
@@ -146,10 +152,26 @@ def _run(self,
146
152
lang : str = 'zh' ,
147
153
** kwargs ):
148
154
149
- self .tool_descs = '\n \n ' .join (tool .function_plain_text
150
- for tool in self .function_map .values ())
151
- self .tool_names = ',' .join (tool .name
152
- for tool in self .function_map .values ())
155
+ chat_mode = kwargs .get ('chat_mode' , False )
156
+ tools = kwargs .get ('tools' , None )
157
+ tool_choice = kwargs .get ('tool_choice' , 'auto' )
158
+
159
+ if tools is not None :
160
+ self .tool_descs = BaseTool .parser_function (tools )
161
+ tool_name_list = []
162
+ for tool in tools :
163
+ func_info = tool .get ('function' , {})
164
+ if func_info == {}:
165
+ continue
166
+ if 'name' in func_info :
167
+ tool_name_list .append (func_info ['name' ])
168
+ self .tool_names = ',' .join (tool_name_list )
169
+ else :
170
+ self .tool_descs = '\n \n ' .join (
171
+ tool .function_plain_text
172
+ for tool in self .function_map .values ())
173
+ self .tool_names = ',' .join (tool .name
174
+ for tool in self .function_map .values ())
153
175
154
176
self .system_prompt = ''
155
177
self .query_prefix = ''
@@ -172,7 +194,7 @@ def _run(self,
172
194
'knowledge' ] = SPECIAL_PREFIX_TEMPLATE_KNOWLEDGE [lang ]
173
195
174
196
# concat tools information
175
- if self .function_map and not self .llm .support_function_calling ():
197
+ if self .tool_descs and not self .llm .support_function_calling ():
176
198
self .system_prompt += TOOL_TEMPLATE [lang ].format (
177
199
tool_descs = self .tool_descs , tool_names = self .tool_names )
178
200
self .query_prefix_dict ['tool' ] = SPECIAL_PREFIX_TEMPLATE_TOOL [
@@ -215,10 +237,18 @@ def _run(self,
215
237
messages .extend (history )
216
238
217
239
# concat the new messages
218
- messages .append ({
219
- 'role' : 'user' ,
220
- 'content' : self .query_prefix + user_request
221
- })
240
+ if chat_mode and tool_choice == 'required' :
241
+ required_prefix = SPECIAL_PREFIX_TEMPLATE_TOOL_FOR_CHAT [
242
+ lang ].format (tool_names = self .tool_names )
243
+ messages .append ({
244
+ 'role' : 'user' ,
245
+ 'content' : required_prefix + user_request
246
+ })
247
+ else :
248
+ messages .append ({
249
+ 'role' : 'user' ,
250
+ 'content' : self .query_prefix + user_request
251
+ })
222
252
223
253
planning_prompt = ''
224
254
if self .llm .support_raw_prompt () and hasattr (self .llm ,
@@ -265,6 +295,12 @@ def _run(self,
265
295
else :
266
296
assert 'llm_result must be an instance of dict or str'
267
297
298
+ if chat_mode :
299
+ if use_tool and tool_choice != 'none' :
300
+ return f'Action: { action } \n Action Input: { action_input } \n Result: { output } '
301
+ else :
302
+ return f'Result: { output } '
303
+
268
304
# yield output
269
305
if use_tool :
270
306
if self .llm .support_function_calling ():
0 commit comments