Skip to content

Commit 0c27fe5

Browse files
committed
Add support for deletions (#30)
1 parent d13151f commit 0c27fe5

File tree

6 files changed

+246
-43
lines changed

6 files changed

+246
-43
lines changed

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ dependencies = [
1010
"jsonpatch<2.0,>=1.33",
1111
]
1212
name = "trustcall"
13-
version = "0.0.30"
13+
version = "0.0.32"
1414
description = "Tenacious & trustworthy tool calling built on LangGraph."
1515
readme = "README.md"
1616

tests/evals/test_evals.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def __iter__(self):
191191
"model_name",
192192
[
193193
"gpt-4o",
194-
"gpt-4o-mini",
194+
# "gpt-4o-mini",
195195
# "gpt-3.5-turbo",
196196
"claude-3-5-sonnet-20240620",
197197
# "accounts/fireworks/models/firefunction-v2",

tests/unit_tests/test_extraction.py

+58
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class FakeExtractionModel(SimpleChatModel):
3232
backup_responses: List[AIMessage] = []
3333
i: int = 0
3434
bound_count: int = 0
35+
bound: Optional["FakeExtractionModel"] = None
3536
tools: list = []
3637

3738
def _call(
@@ -78,6 +79,7 @@ def bind_tools(self, tools: list, **kwargs: Any) -> "FakeExtractionModel": # ty
7879
backup_responses=backup_responses,
7980
tools=tools,
8081
i=self.i,
82+
bound=self,
8183
**kwargs,
8284
)
8385

@@ -651,3 +653,59 @@ class MyRecognizedSchema(BaseModel):
651653
assert len(recognized_responses) == 1
652654
recognized_item = recognized_responses[0]
653655
assert recognized_item.notes == "updated notes"
656+
657+
658+
@pytest.mark.asyncio
659+
@pytest.mark.parametrize("enable_inserts", [True, False])
660+
async def test_enable_deletes_flow(enable_inserts: bool) -> None:
661+
class MySchema(BaseModel):
662+
"""Schema for recognized docs."""
663+
664+
data: str
665+
666+
existing_docs = [
667+
("Doc1", "MySchema", {"data": "contents of doc1"}),
668+
("Doc2", "MySchema", {"data": "contents of doc2"}),
669+
]
670+
671+
remove_doc_call_id = str(uuid.uuid4())
672+
remove_message = AIMessage(
673+
content="I want to remove Doc1",
674+
tool_calls=[
675+
{
676+
"id": remove_doc_call_id,
677+
"name": "RemoveDoc", # This is recognized only if enable_deletes=True
678+
"args": {"json_doc_id": "Doc1"},
679+
}
680+
],
681+
)
682+
683+
fake_llm = FakeExtractionModel(
684+
responses=[remove_message], backup_responses=[remove_message] * 3
685+
)
686+
687+
extractor = create_extractor(
688+
llm=fake_llm,
689+
tools=[MySchema],
690+
enable_inserts=enable_inserts,
691+
enable_deletes=True,
692+
)
693+
694+
# Invoke the pipeline with some dummy "system" prompt and existing docs
695+
result = await extractor.ainvoke(
696+
{
697+
"messages": [("system", "System instructions: handle doc removal.")],
698+
"existing": existing_docs,
699+
}
700+
)
701+
702+
# The pipeline always returns final "messages" in result["messages"].
703+
# Because "RemoveDoc" isn't a recognized schema in the final output,
704+
# we won't see it among result["responses"] either way.
705+
assert len(result["messages"]) == 1
706+
final_ai_msg = result["messages"][0]
707+
assert isinstance(final_ai_msg, AIMessage)
708+
709+
assert len(final_ai_msg.tool_calls) == 1
710+
assert len(result["responses"]) == 1
711+
assert result["responses"][0].__repr_name__() == "RemoveDoc" # type: ignore

tests/unit_tests/test_strict_existing.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,8 @@ def test_validate_existing_strictness(
308308
if isinstance(coerced, dict):
309309
assert all(k in tools or k == "__any__" for k in coerced)
310310
elif isinstance(coerced, list):
311-
assert all(s.schema_name in tools or s.schema_name == "__any__" for s in coerced)
311+
assert all(
312+
s.schema_name in tools or s.schema_name == "__any__" for s in coerced
313+
)
312314
elif existing_schema_policy is False:
313-
pass
315+
pass

0 commit comments

Comments
 (0)