Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: ChatAgent #1424

Merged
merged 39 commits into from
Feb 17, 2025
Merged
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
e36211f
[WIP] refactor tools in ChatAgent
WHALEEYE Dec 27, 2024
a533339
openai refactored
WHALEEYE Dec 28, 2024
c0e6a01
move functions into utils.py
WHALEEYE Dec 29, 2024
e04368d
Merge branch 'master' into refactor/chatagent
WHALEEYE Dec 29, 2024
47f2c40
small fix
WHALEEYE Dec 30, 2024
22ce90a
merge master
WHALEEYE Jan 1, 2025
80f6587
openai worked
WHALEEYE Jan 1, 2025
7544b5c
enable async in some modelbackend
liuxukun2000 Jan 3, 2025
ffe13c9
tool refactored
WHALEEYE Jan 3, 2025
2e0cf7c
aadd async run
liuxukun2000 Jan 6, 2025
036092b
Merge branch 'refactor/chatagent' of https://github.com/camel-ai/came…
liuxukun2000 Jan 6, 2025
b397824
add _async run in models
liuxukun2000 Jan 9, 2025
2ce8758
add async run in chat agent
liuxukun2000 Jan 9, 2025
40007c8
precommit fix
liuxukun2000 Jan 9, 2025
f6e4041
sort out codes in ChatAgent
WHALEEYE Jan 10, 2025
b9eb5be
extract functions and types
WHALEEYE Jan 10, 2025
00d1bf0
delete None values in config
WHALEEYE Jan 10, 2025
5cd8eca
merge master
WHALEEYE Jan 11, 2025
e8c8148
fix incompatible types
WHALEEYE Jan 11, 2025
520c29a
add function to qwen
WHALEEYE Jan 13, 2025
8faf857
add function to qwen
WHALEEYE Jan 13, 2025
2c96e48
add qwen support
WHALEEYE Jan 13, 2025
995bbb3
add support for mistral
WHALEEYE Jan 17, 2025
859216b
add response format support
WHALEEYE Jan 23, 2025
72ff026
add response format support
WHALEEYE Jan 23, 2025
ccbaee0
Merge branch 'master' into refactor/chatagent
WHALEEYE Jan 24, 2025
52744b9
merge master
WHALEEYE Jan 25, 2025
0c28cf1
Merge branch 'master' into refactor/chatagent
WHALEEYE Jan 25, 2025
03e494b
Update camel/models/gemini_model.py
WHALEEYE Feb 4, 2025
5667bb4
update fixes
WHALEEYE Feb 4, 2025
720cec2
Merge branch 'master' into refactor/chatagent
WHALEEYE Feb 10, 2025
03fc264
add aiml tool calling support
WHALEEYE Feb 11, 2025
255ae77
Merge branch 'master' into refactor/chatagent
Wendong-Fan Feb 14, 2025
d001a10
fix: Issues based on review comment for ChatAgent refactor (#1602)
Wendong-Fan Feb 14, 2025
c0e43c1
fix according to reviews
WHALEEYE Feb 14, 2025
567bc05
Merge branch 'master' into refactor/chatagent
Wendong-Fan Feb 16, 2025
1ab516a
update models
WHALEEYE Feb 16, 2025
08935d0
unit test and mypy fix
Wendong-Fan Feb 17, 2025
6bc08b6
Merge branch 'master' into refactor/chatagent
Wendong-Fan Feb 17, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -431,3 +431,5 @@ benchmark/gaia/results.jsonl

# Secret files for docker
.container/.env

examples/datagen/star/outputs/
41 changes: 41 additions & 0 deletions camel/agents/_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from typing import Any, Dict, List, Optional, Union

from openai import AsyncStream, Stream
from pydantic import BaseModel, ConfigDict

from camel.messages import BaseMessage
from camel.types import ChatCompletion


class ToolCallRequest(BaseModel):
r"""The request for tool calling."""

func_name: str
args: Dict[str, Any]
tool_call_id: str


class ModelResponse(BaseModel):
r"""The response from the model."""

model_config = ConfigDict(arbitrary_types_allowed=True)

response: Union[ChatCompletion, Stream, AsyncStream]
tool_call_request: Optional[ToolCallRequest]
output_messages: List[BaseMessage]
finish_reasons: List[str]
usage_dict: Dict[str, Any]
response_id: str
230 changes: 230 additions & 0 deletions camel/agents/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import json
import logging
import re
import textwrap
import uuid
from typing import Any, Callable, Dict, List, Optional, Union

from openai.types.chat.chat_completion_message_tool_call import (
ChatCompletionMessageToolCall,
Function,
)

from camel.toolkits import FunctionTool
from camel.types import Choice
from camel.types.agents import ToolCallingRecord

logger = logging.getLogger(__name__)


def generate_tool_prompt(tool_schema_list: List[Dict[str, Any]]) -> str:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where is this function used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not yet used, they are supposed to be used by models that does not natively support tool callings, but we havent added support for them yet

r"""Generates a tool prompt based on the provided tool schema list.

Returns:
str: A string representing the tool prompt.
"""
tool_prompts = []

for tool in tool_schema_list:
tool_info = tool["function"]
tool_name = tool_info["name"]
tool_description = tool_info["description"]
tool_json = json.dumps(tool_info, indent=4)

prompt = (
f"Use the function '{tool_name}' to '{tool_description}':\n"
f"{tool_json}\n"
)
tool_prompts.append(prompt)

tool_prompt_str = "\n".join(tool_prompts)

final_prompt = textwrap.dedent(
f"""\
You have access to the following functions:

{tool_prompt_str}

If you choose to call a function ONLY reply in the following format with no prefix or suffix:

<function=example_function_name>{{"example_name": "example_value"}}</function>

Reminder:
- Function calls MUST follow the specified format, start with <function= and end with </function>
- Required parameters MUST be specified
- Only call one function at a time
- Put the entire function call reply on one line
- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls.
""" # noqa: E501
)
return final_prompt


def _parse_tool_response(response: str) -> Optional[Dict[str, Any]]:
r"""Parses the tool response to extract the function name and
arguments.

Args:
response (str): The response from the model containing the
function call.

Returns:
Optional[Dict[str, Any]]: The parsed function name and arguments
if found, otherwise :obj:`None`.
"""
function_regex = r"<function=(\w+)>(.*?)</function>"
match = re.search(function_regex, response)

if match:
function_name, args_string = match.groups()
try:
args = json.loads(args_string)
return {"function": function_name, "arguments": args}
except json.JSONDecodeError as error:
logger.error(f"Error parsing function arguments: {error}")
return None
return None


def extract_tool_call(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the argument is self and do we use it anywhere?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not used, refer to: #1621

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved this from inside ChatAgent and forgot to remove this. I will update this one in the follow up PRs

self, response: Any
) -> Optional[ChatCompletionMessageToolCall]:
r"""Extract the tool call from the model response, if present.

Args:
response (Any): The model's response object.

Returns:
Optional[ChatCompletionMessageToolCall]: The parsed tool call if
present, otherwise None.
"""
# Check if the response contains tool calls
if (
self.has_tools
and not self.model_type.support_native_tool_calling
and "</function>" in response.choices[0].message.content
):
parsed_content = _parse_tool_response(
response.choices[0].message.content
)
if parsed_content:
return ChatCompletionMessageToolCall(
id=str(uuid.uuid4()),
function=Function(
arguments=str(parsed_content["arguments"]).replace(
"'", '"'
),
name=str(parsed_content["function"]),
),
type="function",
)
elif (
self.has_tools
and self.model_type.support_native_tool_calling
and response.choices[0].message.tool_calls
):
return response.choices[0].message.tool_calls[0]

# No tool call found
return None


def safe_model_dump(obj) -> Dict[str, Any]:
r"""Safely dump a Pydantic model to a dictionary.

This method attempts to use the `model_dump` method if available,
otherwise it falls back to the `dict` method.
"""
# Check if the `model_dump` method exists (Pydantic v2)
if hasattr(obj, "model_dump"):
return obj.model_dump()
# Fallback to `dict()` method (Pydantic v1)
elif hasattr(obj, "dict"):
return obj.dict()
else:
raise TypeError("The object is not a Pydantic model")


def convert_to_function_tool(
tool: Union[FunctionTool, Callable],
) -> FunctionTool:
r"""Convert a tool to a FunctionTool from Callable."""
return tool if isinstance(tool, FunctionTool) else FunctionTool(tool)


def convert_to_schema(
tool: Union[FunctionTool, Callable, Dict[str, Any]],
) -> Dict[str, Any]:
r"""Convert a tool to a schema from Callable or FunctionTool."""
if isinstance(tool, FunctionTool):
return tool.get_openai_tool_schema()
elif callable(tool):
return FunctionTool(tool).get_openai_tool_schema()
else:
return tool


def get_info_dict(
session_id: Optional[str],
usage: Optional[Dict[str, int]],
termination_reasons: List[str],
num_tokens: int,
tool_calls: List[ToolCallingRecord],
) -> Dict[str, Any]:
r"""Returns a dictionary containing information about the chat session.

Args:
session_id (str, optional): The ID of the chat session.
usage (Dict[str, int], optional): Information about the usage of
the LLM.
termination_reasons (List[str]): The reasons for the termination
of the chat session.
num_tokens (int): The number of tokens used in the chat session.
tool_calls (List[ToolCallingRecord]): The list of function
calling records, containing the information of called tools.

Returns:
Dict[str, Any]: The chat session information.
"""
return {
"id": session_id,
"usage": usage,
"termination_reasons": termination_reasons,
"num_tokens": num_tokens,
"tool_calls": tool_calls,
}


def handle_logprobs(choice: Choice) -> Optional[List[Dict[str, Any]]]:
if choice.logprobs is None:
return None

tokens_logprobs = choice.logprobs.content

if tokens_logprobs is None:
return None

return [
{
"token": token_logprob.token,
"logprob": token_logprob.logprob,
"top_logprobs": [
(top_logprob.token, top_logprob.logprob)
for top_logprob in token_logprob.top_logprobs
],
}
for token_logprob in tokens_logprobs
]
Loading
Loading