|
1 | 1 | import json
|
2 | 2 | import logging
|
3 |
| -import re |
4 | 3 |
|
5 | 4 | from kani.ai_function import AIFunction
|
6 |
| -from kani.engines import Completion, WrapperEngine |
7 |
| -from kani.models import ChatMessage, ChatRole, FunctionCall, ToolCall |
| 5 | +from kani.models import ChatMessage, ChatRole, ToolCall |
8 | 6 | from kani.prompts import ApplyContext, PromptPipeline
|
9 | 7 |
|
10 | 8 | log = logging.getLogger(__name__)
|
@@ -186,93 +184,7 @@ def ensure_available_tools(msgs: list[ChatMessage], functions: list[AIFunction])
|
186 | 184 |
|
187 | 185 |
|
188 | 186 | # ==== function call parsing ====
|
189 |
| -# [TOOL_CALLS][{'name': 'get_current_weather', 'arguments': {'location': 'Paris, France', 'format': 'celsius'}}]</s> |
190 |
| -class MixtralFunctionCallingAdapter(WrapperEngine): |
191 |
| - """Common Mixtral-8x22B function calling parsing wrapper.""" |
| 187 | +# implemented in tool_adapters/mistral - here for back-compat |
| 188 | +from kani.tool_parsers.mistral import MistralToolCallParser as MistralFunctionCallingAdapter # noqa E402 |
192 | 189 |
|
193 |
| - def __init__(self, *args, tool_call_token="[TOOL_CALLS]", eos_token="</s>", **kwargs): |
194 |
| - super().__init__(*args, **kwargs) |
195 |
| - self.tool_call_token = tool_call_token |
196 |
| - self.eos_token = eos_token |
197 |
| - |
198 |
| - def _parse_tool_calls(self, content: str) -> tuple[str, list[ToolCall]]: |
199 |
| - tool_json = re.search( |
200 |
| - rf"{re.escape(self.tool_call_token)}\s*(.+?)\s*({re.escape(self.eos_token)})?$", |
201 |
| - content, |
202 |
| - re.IGNORECASE | re.DOTALL, |
203 |
| - ) |
204 |
| - if tool_json is None: |
205 |
| - return content, [] |
206 |
| - log.debug(f"Found tool JSON while parsing: {tool_json.group(1)}") |
207 |
| - actions = json.loads(tool_json.group(1)) |
208 |
| - |
209 |
| - # translate back to kani spec |
210 |
| - tool_calls = [] |
211 |
| - for action in actions: |
212 |
| - tool_name = action["name"] |
213 |
| - tool_args = json.dumps(action["arguments"]) |
214 |
| - tool_id = action.get("id") |
215 |
| - tool_call = ToolCall.from_function_call(FunctionCall(name=tool_name, arguments=tool_args), call_id_=tool_id) |
216 |
| - tool_calls.append(tool_call) |
217 |
| - |
218 |
| - # return trimmed content and tool calls |
219 |
| - return content[: tool_json.start()], tool_calls |
220 |
| - |
221 |
| - async def predict(self, messages: list[ChatMessage], functions: list[AIFunction] | None = None, **hyperparams): |
222 |
| - hyperparams.setdefault("decode_kwargs", dict(skip_special_tokens=False)) |
223 |
| - completion = await super().predict(messages, functions, **hyperparams) |
224 |
| - |
225 |
| - # if we have tools, parse |
226 |
| - if functions: |
227 |
| - completion.message.content, completion.message.tool_calls = self._parse_tool_calls(completion.message.text) |
228 |
| - completion.message.content = completion.message.content.removesuffix(self.eos_token).strip() |
229 |
| - |
230 |
| - return completion |
231 |
| - |
232 |
| - async def stream(self, messages: list[ChatMessage], functions: list[AIFunction] | None = None, **hyperparams): |
233 |
| - content_parts = [] |
234 |
| - in_tool_call = False |
235 |
| - inner_completion = None |
236 |
| - hyperparams.setdefault("decode_kwargs", dict(skip_special_tokens=False)) |
237 |
| - |
238 |
| - # consume from the inner iterator, yielding as normal until we see a tool call or a completion |
239 |
| - async for elem in super().stream(messages, functions, **hyperparams): |
240 |
| - log.debug(f"Got stream element: {elem!r}") |
241 |
| - if isinstance(elem, str): |
242 |
| - content_parts.append(elem) |
243 |
| - # if we see the start of a tool call, stop yielding and start buffering |
244 |
| - if self.tool_call_token in elem: |
245 |
| - yield elem[: elem.index(self.tool_call_token)] |
246 |
| - in_tool_call = True |
247 |
| - # otherwise yield the string |
248 |
| - if not in_tool_call: |
249 |
| - yield elem.removesuffix(self.eos_token) |
250 |
| - else: |
251 |
| - # save the inner completion |
252 |
| - inner_completion = elem |
253 |
| - |
254 |
| - # we have consumed all the elements - construct a new completion |
255 |
| - # if we don't have a tool call we can just yield the inner completion |
256 |
| - if not in_tool_call and inner_completion: |
257 |
| - yield inner_completion |
258 |
| - # otherwise, parse tool calls from the content (preserving inner tool calls if necessary) |
259 |
| - else: |
260 |
| - content = "".join(content_parts) |
261 |
| - log.debug(f"Content before parsing tool calls: {content!r}") |
262 |
| - content, tool_calls = self._parse_tool_calls(content) |
263 |
| - if inner_completion: |
264 |
| - tool_calls = (inner_completion.message.tool_calls or []) + tool_calls |
265 |
| - prompt_tokens = inner_completion.prompt_tokens |
266 |
| - completion_tokens = inner_completion.completion_tokens |
267 |
| - else: |
268 |
| - prompt_tokens = None |
269 |
| - completion_tokens = None |
270 |
| - clean_content = content.removesuffix(self.eos_token).strip() |
271 |
| - yield Completion( |
272 |
| - ChatMessage.assistant(clean_content, tool_calls=tool_calls), |
273 |
| - prompt_tokens=prompt_tokens, |
274 |
| - completion_tokens=completion_tokens, |
275 |
| - ) |
276 |
| - |
277 |
| - |
278 |
| -MistralFunctionCallingAdapter = MixtralFunctionCallingAdapter |
| 190 | +MixtralFunctionCallingAdapter = MistralFunctionCallingAdapter |
0 commit comments