@@ -32,6 +32,7 @@ class FakeExtractionModel(SimpleChatModel):
32
32
backup_responses : List [AIMessage ] = []
33
33
i : int = 0
34
34
bound_count : int = 0
35
+ bound : Optional ["FakeExtractionModel" ] = None
35
36
tools : list = []
36
37
37
38
def _call (
@@ -78,6 +79,7 @@ def bind_tools(self, tools: list, **kwargs: Any) -> "FakeExtractionModel": # ty
78
79
backup_responses = backup_responses ,
79
80
tools = tools ,
80
81
i = self .i ,
82
+ bound = self ,
81
83
** kwargs ,
82
84
)
83
85
@@ -651,3 +653,59 @@ class MyRecognizedSchema(BaseModel):
651
653
assert len (recognized_responses ) == 1
652
654
recognized_item = recognized_responses [0 ]
653
655
assert recognized_item .notes == "updated notes"
656
+
657
+
658
+ @pytest .mark .asyncio
659
+ @pytest .mark .parametrize ("enable_inserts" , [True , False ])
660
+ async def test_enable_deletes_flow (enable_inserts : bool ) -> None :
661
+ class MySchema (BaseModel ):
662
+ """Schema for recognized docs."""
663
+
664
+ data : str
665
+
666
+ existing_docs = [
667
+ ("Doc1" , "MySchema" , {"data" : "contents of doc1" }),
668
+ ("Doc2" , "MySchema" , {"data" : "contents of doc2" }),
669
+ ]
670
+
671
+ remove_doc_call_id = str (uuid .uuid4 ())
672
+ remove_message = AIMessage (
673
+ content = "I want to remove Doc1" ,
674
+ tool_calls = [
675
+ {
676
+ "id" : remove_doc_call_id ,
677
+ "name" : "RemoveDoc" , # This is recognized only if enable_deletes=True
678
+ "args" : {"json_doc_id" : "Doc1" },
679
+ }
680
+ ],
681
+ )
682
+
683
+ fake_llm = FakeExtractionModel (
684
+ responses = [remove_message ], backup_responses = [remove_message ] * 3
685
+ )
686
+
687
+ extractor = create_extractor (
688
+ llm = fake_llm ,
689
+ tools = [MySchema ],
690
+ enable_inserts = enable_inserts ,
691
+ enable_deletes = True ,
692
+ )
693
+
694
+ # Invoke the pipeline with some dummy "system" prompt and existing docs
695
+ result = await extractor .ainvoke (
696
+ {
697
+ "messages" : [("system" , "System instructions: handle doc removal." )],
698
+ "existing" : existing_docs ,
699
+ }
700
+ )
701
+
702
+ # The pipeline always returns final "messages" in result["messages"].
703
+ # Because "RemoveDoc" isn't a recognized schema in the final output,
704
+ # we won't see it among result["responses"] either way.
705
+ assert len (result ["messages" ]) == 1
706
+ final_ai_msg = result ["messages" ][0 ]
707
+ assert isinstance (final_ai_msg , AIMessage )
708
+
709
+ assert len (final_ai_msg .tool_calls ) == 1
710
+ assert len (result ["responses" ]) == 1
711
+ assert result ["responses" ][0 ].__repr_name__ () == "RemoveDoc" # type: ignore
0 commit comments