Skip to content

Commit d18b2d3

Browse files
committed
refactor: tool parsers for HF models
1 parent ea4822c commit d18b2d3

23 files changed

+532
-126
lines changed

docs/api_reference.rst

+14
Original file line numberDiff line numberDiff line change
@@ -102,3 +102,17 @@ Message Formatters
102102
^^^^^^^^^^^^^^^^^^
103103
.. automodule:: kani.utils.message_formatters
104104
:members:
105+
106+
.. _tool-parsers:
107+
108+
Tool Parsers
109+
^^^^^^^^^^^^
110+
Tool parsers are used when you have an LLM's text output, which may contain tool calls in their raw format (e.g., JSON).
111+
They translate the raw text format into Kani's tool calling specification.
112+
113+
.. autoclass:: kani.tool_parsers.BaseToolCallParser
114+
:members:
115+
116+
.. autoclass:: kani.tool_parsers.NaiveJSONToolCallParser
117+
118+
.. autoclass:: kani.tool_parsers.MistralToolCallParser

docs/engines/huggingface.rst

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ If your language model backend is available on HuggingFace or is compatible with
1313
This means you can safely ignore this section of the documentation for most use cases! Just use:
1414

1515
.. code-block:: python
16+
1617
from kani.engines.huggingface import HuggingEngine
1718
engine = HuggingEngine(model_id="your-org/your-model-id")
1819

docs/engines/implementing.rst

+6
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ Implementing an Engine
77
prompt format.
88

99
.. code-block:: python
10+
1011
from kani.engines.huggingface import HuggingEngine
1112
engine = HuggingEngine(model_id="your-org/your-model-id")
1213
@@ -42,6 +43,11 @@ the underlying model, and kani needs to know about the extra tokens added by thi
4243

4344
Adding Function Calling
4445
-----------------------
46+
47+
.. important::
48+
Already have a way to build function calling prompts but just need a way to parse the outputs? Check out the list
49+
of :ref:`tool-parsers`.
50+
4551
If you're writing an engine for a model with function calling, there are a couple additional steps you need to take.
4652

4753
Generally, to use function calling, you need to do the following:

examples/4_engines_zoo.py

+25-30
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,26 @@
1010

1111
# ==== OpenAI (GPT) ====
1212
from kani.engines.openai import OpenAIEngine
13-
engine = OpenAIEngine(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4")
13+
engine = OpenAIEngine(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o-mini")
1414

1515
# ==== Anthropic (Claude) ====
1616
# see https://docs.anthropic.com/claude/docs/models-overview for a list of model IDs
1717
from kani.engines.anthropic import AnthropicEngine
18-
engine = AnthropicEngine(api_key=os.getenv("ANTHROPIC_API_KEY"), model="claude-3-opus-20240229")
18+
engine = AnthropicEngine(api_key=os.getenv("ANTHROPIC_API_KEY"), model="claude-3-5-sonnet-latest")
1919

2020
# ========== Hugging Face ==========
2121
# ---- Any Model (Chat Templates) ----
2222
from kani.engines.huggingface import HuggingEngine
2323
engine = HuggingEngine(model_id="org-id/model-id")
2424

25+
# ---- DeepSeek R1 (Hugging Face) ----
26+
from kani.engines.huggingface import HuggingEngine
27+
from kani.tool_parsers.deepseek import DeepSeekR1ToolCallParser
28+
# this method is the same for all distills of R1 as well - simply replace the model ID!
29+
model = HuggingEngine(model_id="deepseek-ai/DeepSeek-R1")
30+
engine = DeepSeekR1ToolCallParser(model)
31+
32+
2533
# ---- LLaMA v3 (Hugging Face) ----
2634
import torch
2735
from kani.engines.huggingface import HuggingEngine
@@ -37,44 +45,31 @@
3745
# NOTE: If you're running transformers<4.40 and LLaMA 3 continues generating after the <|eot_id|> token,
3846
# add `eos_token_id=[128001, 128009]` or upgrade transformers
3947

40-
# ---- LLaMA v2 (Hugging Face) ----
41-
from kani.engines.huggingface.llama2 import LlamaEngine
42-
engine = LlamaEngine(model_id="meta-llama/Llama-2-7b-chat-hf", use_auth_token=True) # log in with huggingface-cli
43-
4448
# ---- Mistral Small/Large (Hugging Face) ----
4549
from kani.engines.huggingface import HuggingEngine
46-
from kani.prompts.impl.mistral import MISTRAL_V3_PIPELINE, MistralFunctionCallingAdapter
50+
from kani.prompts.impl.mistral import MISTRAL_V3_PIPELINE
51+
from kani.tool_parsers.mistral import MistralToolCallParser
4752
# small (22B): mistralai/Mistral-Small-Instruct-2409
4853
# large (123B): mistralai/Mistral-Large-Instruct-2407
4954
model = HuggingEngine(model_id="mistralai/Mistral-Small-Instruct-2409", prompt_pipeline=MISTRAL_V3_PIPELINE)
50-
engine = MistralFunctionCallingAdapter(model)
51-
52-
# ---- Mistral-7B (Hugging Face) ----
53-
# v0.3 (supports function calling):
54-
from kani.engines.huggingface import HuggingEngine
55-
from kani.prompts.impl.mistral import MISTRAL_V3_PIPELINE, MistralFunctionCallingAdapter
56-
model = HuggingEngine(model_id="mistralai/Mistral-7B-Instruct-v0.3", prompt_pipeline=MISTRAL_V3_PIPELINE)
57-
engine = MistralFunctionCallingAdapter(model)
58-
59-
# v0.2:
60-
from kani.engines.huggingface import HuggingEngine
61-
from kani.prompts.impl import MISTRAL_V1_PIPELINE
62-
engine = HuggingEngine(model_id="mistralai/Mistral-7B-Instruct-v0.2", prompt_pipeline=MISTRAL_V1_PIPELINE)
63-
64-
# Also use the MISTRAL_V1_PIPELINE for Mixtral-8x7B (i.e. mistralai/Mixtral-8x7B-Instruct-v0.1).
55+
engine = MistralToolCallParser(model)
6556

6657
# ---- Command R (Hugging Face) ----
6758
from kani.engines.huggingface.cohere import CommandREngine
68-
engine = CommandREngine(model_id="CohereForAI/c4ai-command-r-v01")
59+
engine = CommandREngine(model_id="CohereForAI/c4ai-command-r-08-2024")
6960

70-
# ---- Gemma (Hugging Face) ----
71-
from kani.engines.huggingface import HuggingEngine
72-
from kani.prompts.impl import GEMMA_PIPELINE
73-
engine = HuggingEngine(model_id="google/gemma-1.1-7b-it", prompt_pipeline=GEMMA_PIPELINE, use_auth_token=True)
61+
# --------- older models ----------
62+
# ---- LLaMA v2 (Hugging Face) ----
63+
from kani.engines.huggingface.llama2 import LlamaEngine
64+
engine = LlamaEngine(model_id="meta-llama/Llama-2-7b-chat-hf", use_auth_token=True) # log in with huggingface-cli
7465

75-
# ---- Vicuna v1.3 (Hugging Face) ----
76-
from kani.engines.huggingface.vicuna import VicunaEngine
77-
engine = VicunaEngine(model_id="lmsys/vicuna-7b-v1.3")
66+
# ---- Mistral-7B (Hugging Face) ----
67+
# v0.3 (supports function calling):
68+
from kani.engines.huggingface import HuggingEngine
69+
from kani.prompts.impl.mistral import MISTRAL_V3_PIPELINE
70+
from kani.tool_parsers.mistral import MistralToolCallParser
71+
model = HuggingEngine(model_id="mistralai/Mistral-7B-Instruct-v0.3", prompt_pipeline=MISTRAL_V3_PIPELINE)
72+
engine = MistralToolCallParser(model)
7873

7974
# ========== llama.cpp ==========
8075
# ---- LLaMA v2 (llama.cpp) ----
File renamed without changes.
File renamed without changes.

kani/prompts/impl/README.md

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
This directory contains concrete implementations of prompting pipelines for some models. It is now deprecated - use
2+
HuggingEngine to automatically load chat templates for these models instead.
3+
4+
See `tool_adapters` for tool calling adapters for popular models.

kani/prompts/impl/mistral.py

+4-92
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
import json
22
import logging
3-
import re
43

54
from kani.ai_function import AIFunction
6-
from kani.engines import Completion, WrapperEngine
7-
from kani.models import ChatMessage, ChatRole, FunctionCall, ToolCall
5+
from kani.models import ChatMessage, ChatRole, ToolCall
86
from kani.prompts import ApplyContext, PromptPipeline
97

108
log = logging.getLogger(__name__)
@@ -186,93 +184,7 @@ def ensure_available_tools(msgs: list[ChatMessage], functions: list[AIFunction])
186184

187185

188186
# ==== function call parsing ====
189-
# [TOOL_CALLS][{'name': 'get_current_weather', 'arguments': {'location': 'Paris, France', 'format': 'celsius'}}]</s>
190-
class MixtralFunctionCallingAdapter(WrapperEngine):
191-
"""Common Mixtral-8x22B function calling parsing wrapper."""
187+
# implemented in tool_adapters/mistral - here for back-compat
188+
from kani.tool_parsers.mistral import MistralToolCallParser as MistralFunctionCallingAdapter # noqa E402
192189

193-
def __init__(self, *args, tool_call_token="[TOOL_CALLS]", eos_token="</s>", **kwargs):
194-
super().__init__(*args, **kwargs)
195-
self.tool_call_token = tool_call_token
196-
self.eos_token = eos_token
197-
198-
def _parse_tool_calls(self, content: str) -> tuple[str, list[ToolCall]]:
199-
tool_json = re.search(
200-
rf"{re.escape(self.tool_call_token)}\s*(.+?)\s*({re.escape(self.eos_token)})?$",
201-
content,
202-
re.IGNORECASE | re.DOTALL,
203-
)
204-
if tool_json is None:
205-
return content, []
206-
log.debug(f"Found tool JSON while parsing: {tool_json.group(1)}")
207-
actions = json.loads(tool_json.group(1))
208-
209-
# translate back to kani spec
210-
tool_calls = []
211-
for action in actions:
212-
tool_name = action["name"]
213-
tool_args = json.dumps(action["arguments"])
214-
tool_id = action.get("id")
215-
tool_call = ToolCall.from_function_call(FunctionCall(name=tool_name, arguments=tool_args), call_id_=tool_id)
216-
tool_calls.append(tool_call)
217-
218-
# return trimmed content and tool calls
219-
return content[: tool_json.start()], tool_calls
220-
221-
async def predict(self, messages: list[ChatMessage], functions: list[AIFunction] | None = None, **hyperparams):
222-
hyperparams.setdefault("decode_kwargs", dict(skip_special_tokens=False))
223-
completion = await super().predict(messages, functions, **hyperparams)
224-
225-
# if we have tools, parse
226-
if functions:
227-
completion.message.content, completion.message.tool_calls = self._parse_tool_calls(completion.message.text)
228-
completion.message.content = completion.message.content.removesuffix(self.eos_token).strip()
229-
230-
return completion
231-
232-
async def stream(self, messages: list[ChatMessage], functions: list[AIFunction] | None = None, **hyperparams):
233-
content_parts = []
234-
in_tool_call = False
235-
inner_completion = None
236-
hyperparams.setdefault("decode_kwargs", dict(skip_special_tokens=False))
237-
238-
# consume from the inner iterator, yielding as normal until we see a tool call or a completion
239-
async for elem in super().stream(messages, functions, **hyperparams):
240-
log.debug(f"Got stream element: {elem!r}")
241-
if isinstance(elem, str):
242-
content_parts.append(elem)
243-
# if we see the start of a tool call, stop yielding and start buffering
244-
if self.tool_call_token in elem:
245-
yield elem[: elem.index(self.tool_call_token)]
246-
in_tool_call = True
247-
# otherwise yield the string
248-
if not in_tool_call:
249-
yield elem.removesuffix(self.eos_token)
250-
else:
251-
# save the inner completion
252-
inner_completion = elem
253-
254-
# we have consumed all the elements - construct a new completion
255-
# if we don't have a tool call we can just yield the inner completion
256-
if not in_tool_call and inner_completion:
257-
yield inner_completion
258-
# otherwise, parse tool calls from the content (preserving inner tool calls if necessary)
259-
else:
260-
content = "".join(content_parts)
261-
log.debug(f"Content before parsing tool calls: {content!r}")
262-
content, tool_calls = self._parse_tool_calls(content)
263-
if inner_completion:
264-
tool_calls = (inner_completion.message.tool_calls or []) + tool_calls
265-
prompt_tokens = inner_completion.prompt_tokens
266-
completion_tokens = inner_completion.completion_tokens
267-
else:
268-
prompt_tokens = None
269-
completion_tokens = None
270-
clean_content = content.removesuffix(self.eos_token).strip()
271-
yield Completion(
272-
ChatMessage.assistant(clean_content, tool_calls=tool_calls),
273-
prompt_tokens=prompt_tokens,
274-
completion_tokens=completion_tokens,
275-
)
276-
277-
278-
MistralFunctionCallingAdapter = MixtralFunctionCallingAdapter
190+
MixtralFunctionCallingAdapter = MistralFunctionCallingAdapter

kani/tool_parsers/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .base import BaseToolCallParser
2+
from .deepseek import DeepSeekR1ToolCallParser
3+
from .json import NaiveJSONToolCallParser
4+
from .mistral import MistralToolCallParser

kani/tool_parsers/base.py

+89
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import logging
2+
from abc import ABC
3+
4+
from kani.engines import Completion, WrapperEngine
5+
from kani.engines.base import BaseCompletion
6+
from kani.models import ChatMessage, ToolCall
7+
8+
log = logging.getLogger(__name__)
9+
10+
11+
class BaseToolCallParser(WrapperEngine, ABC):
12+
"""
13+
Abstract base class for tool call parsers.
14+
15+
To implement your own tool call parser, subclass this class and:
16+
17+
* implement ``parse_tool_calls(self, content: str) -> tuple[str, list[ToolCall]]``
18+
* pass default values of ``tool_call_start_token`` and ``tool_call_end_token`` to ``super().__init__(...)``
19+
20+
This class will handle calling the parser and interrupting streams when tool calls are detected.
21+
"""
22+
23+
def __init__(self, *args, tool_call_start_token: str, tool_call_end_token: str, **kwargs):
24+
super().__init__(*args, **kwargs)
25+
self.tool_call_start_token = tool_call_start_token
26+
self.tool_call_end_token = tool_call_end_token
27+
28+
def parse_tool_calls(self, content: str) -> tuple[str, list[ToolCall]]:
29+
"""Given the string completion of the model, return the content without tool calls and the parsed tool calls."""
30+
raise NotImplementedError
31+
32+
async def predict(self, messages, functions=None, **hyperparams) -> BaseCompletion:
33+
completion = await super().predict(messages, functions, **hyperparams)
34+
35+
# if we have tools, parse them
36+
if functions:
37+
completion.message.content, completion.message.tool_calls = self.parse_tool_calls(completion.message.text)
38+
39+
return completion
40+
41+
async def stream(self, messages, functions=None, **hyperparams):
42+
content_parts = []
43+
in_tool_call = False
44+
inner_completion = None
45+
46+
# consume from the inner iterator, yielding as normal until we see a tool call or a completion
47+
async for elem in super().stream(messages, functions, **hyperparams):
48+
log.debug(f"Got stream element: {elem!r}")
49+
if isinstance(elem, str):
50+
content_parts.append(elem)
51+
# if we see the start of a tool call, stop yielding and start buffering
52+
if self.tool_call_start_token in elem:
53+
if len(elem) > len(self.tool_call_start_token):
54+
yield elem[: elem.index(self.tool_call_start_token)]
55+
in_tool_call = True
56+
# if we see the end of a tool call, start yielding and stop buffering
57+
if self.tool_call_end_token in elem:
58+
if len(elem) > len(self.tool_call_end_token):
59+
yield elem[elem.index(self.tool_call_end_token) + len(self.tool_call_end_token) :]
60+
in_tool_call = False
61+
# otherwise yield the string
62+
if not in_tool_call:
63+
yield elem
64+
else:
65+
# save the inner completion
66+
inner_completion = elem
67+
68+
# we have consumed all the elements - construct a new completion
69+
# if we don't have a tool call we can just yield the inner completion
70+
if not in_tool_call and inner_completion:
71+
yield inner_completion
72+
# otherwise, parse tool calls from the content (preserving inner tool calls if necessary)
73+
else:
74+
content = "".join(content_parts)
75+
log.debug(f"Content before parsing tool calls: {content!r}")
76+
content, tool_calls = self.parse_tool_calls(content)
77+
if inner_completion:
78+
tool_calls = (inner_completion.message.tool_calls or []) + tool_calls
79+
prompt_tokens = inner_completion.prompt_tokens
80+
completion_tokens = inner_completion.completion_tokens
81+
else:
82+
prompt_tokens = None
83+
completion_tokens = None
84+
clean_content = content.strip()
85+
yield Completion(
86+
ChatMessage.assistant(clean_content, tool_calls=tool_calls),
87+
prompt_tokens=prompt_tokens,
88+
completion_tokens=completion_tokens,
89+
)

0 commit comments

Comments
 (0)