Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preserve Non-AI Messages in Message Operations #27

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 113 additions & 1 deletion tests/unit_tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from pydantic import ValidationError
from typing_extensions import Annotated, TypedDict

from trustcall._base import _convert_any_typed_dicts_to_pydantic
from trustcall._base import _convert_any_typed_dicts_to_pydantic, _apply_message_ops, MessageOp
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage


def test_convert_any_typed_dicts_to_pydantic():
Expand Down Expand Up @@ -89,3 +90,114 @@ class RecursiveType(TypedDict):
cyclic["next"] = cyclic
with pytest.raises(ValueError): # or RecursionError, depending on implementation
model(**cyclic)


def test_message_ops_update_tool_name():
"""Test various scenarios for updating tool names in messages."""

# Test case 1: Mixed message types
messages = [
SystemMessage(content="system message"),
HumanMessage(content="user message"),
AIMessage(
content="",
tool_calls=[{
"id": "tool1",
"name": "old_name",
"args": {"arg1": "value1"}
}]
),
ToolMessage(
content="tool response",
tool_call_id="tool1",
name="old_name"
)
]

message_ops = [
MessageOp(
op="update_tool_name",
target={
"id": "tool1",
"name": "new_name"
}
)
]

result = _apply_message_ops(messages, message_ops)

# Verify message count and types
assert len(result) == 4, "All messages should be preserved"
assert isinstance(result[0], SystemMessage)
assert isinstance(result[1], HumanMessage)
assert isinstance(result[2], AIMessage)
assert isinstance(result[3], ToolMessage)

# Verify content preservation
assert result[0].content == "system message"
assert result[1].content == "user message"
assert result[2].tool_calls[0]["name"] == "new_name"
assert result[3].content == "tool response"

# Test case 2: Multiple tool calls in single AIMessage
messages = [
AIMessage(
content="",
tool_calls=[
{
"id": "tool1",
"name": "old_name1",
"args": {"arg1": "value1"}
},
{
"id": "tool2",
"name": "old_name2",
"args": {"arg2": "value2"}
}
]
)
]

message_ops = [
MessageOp(
op="update_tool_name",
target={
"id": "tool1",
"name": "new_name1"
}
)
]

result = _apply_message_ops(messages, message_ops)

# Verify selective update
assert len(result) == 1
assert result[0].tool_calls[0]["name"] == "new_name1" # Updated
assert result[0].tool_calls[1]["name"] == "old_name2" # Unchanged

# Test case 3: No matching tool_id
messages = [
AIMessage(
content="",
tool_calls=[{
"id": "tool1",
"name": "old_name",
"args": {"arg1": "value1"}
}]
)
]

message_ops = [
MessageOp(
op="update_tool_name",
target={
"id": "non_existent_tool",
"name": "new_name"
}
)
]

result = _apply_message_ops(messages, message_ops)

# Verify no changes for non-matching tool_id
assert result[0].tool_calls[0]["name"] == "old_name"
2 changes: 2 additions & 0 deletions trustcall/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1205,6 +1205,8 @@ def _apply_message_ops(
m = m.copy()
m.tool_calls = new
messages_.append(m)
else:
messages_.append(m)
messages = messages_

else:
Expand Down