Skip to content

Commit 6769c0d

Browse files
test: Add coverage improvement test for tests/test_generate_answer_node.py
1 parent 71053bc commit 6769c0d

File tree

1 file changed

+270
-0
lines changed

1 file changed

+270
-0
lines changed

tests/test_generate_answer_node.py

+270
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,270 @@
1+
import json
2+
import pytest
3+
from langchain.prompts import (
4+
PromptTemplate,
5+
)
6+
from langchain_community.chat_models import (
7+
ChatOllama,
8+
)
9+
from langchain_core.runnables import (
10+
RunnableParallel,
11+
)
12+
from requests.exceptions import (
13+
Timeout,
14+
)
15+
from scrapegraphai.nodes.generate_answer_node import (
16+
GenerateAnswerNode,
17+
)
18+
19+
20+
class DummyLLM:
21+
22+
def __call__(self, *args, **kwargs):
23+
return "dummy response"
24+
25+
26+
class DummyLogger:
27+
28+
def info(self, msg):
29+
pass
30+
31+
def error(self, msg):
32+
pass
33+
34+
35+
@pytest.fixture
36+
def dummy_node():
37+
"""
38+
Fixture for a GenerateAnswerNode instance using DummyLLM.
39+
Uses a valid input keys string ("dummy_input & doc") to avoid parsing errors.
40+
"""
41+
node_config = {"llm_model": DummyLLM(), "verbose": False, "timeout": 1}
42+
node = GenerateAnswerNode("dummy_input & doc", ["output"], node_config=node_config)
43+
node.logger = DummyLogger()
44+
node.get_input_keys = lambda state: ["dummy_input", "doc"]
45+
return node
46+
47+
48+
def test_process_missing_content_and_user_prompt(dummy_node):
49+
"""
50+
Test that process() raises a ValueError when either the content or the user prompt is missing.
51+
"""
52+
state_missing_content = {"user_prompt": "What is the answer?"}
53+
with pytest.raises(ValueError) as excinfo1:
54+
dummy_node.process(state_missing_content)
55+
assert "No content found in state" in str(excinfo1.value)
56+
state_missing_prompt = {"content": "Some valid context content"}
57+
with pytest.raises(ValueError) as excinfo2:
58+
dummy_node.process(state_missing_prompt)
59+
assert "No user prompt found in state" in str(excinfo2.value)
60+
61+
62+
class DummyLLMWithPipe:
63+
"""DummyLLM that supports the pipe '|' operator.
64+
When used in a chain with a PromptTemplate, the pipe operator returns self,
65+
simulating chain composition."""
66+
67+
def __or__(self, other):
68+
return self
69+
70+
def __call__(self, *args, **kwargs):
71+
return {"content": "script single-chunk answer"}
72+
73+
74+
@pytest.fixture
75+
def dummy_node_with_pipe():
76+
"""
77+
Fixture for a GenerateAnswerNode instance using DummyLLMWithPipe.
78+
Uses a valid input keys string ("dummy_input & doc") to avoid parsing errors.
79+
"""
80+
node_config = {"llm_model": DummyLLMWithPipe(), "verbose": False, "timeout": 480}
81+
node = GenerateAnswerNode("dummy_input & doc", ["output"], node_config=node_config)
82+
node.logger = DummyLogger()
83+
node.get_input_keys = lambda state: ["dummy_input", "doc"]
84+
return node
85+
86+
87+
def test_execute_multiple_chunks(dummy_node_with_pipe):
88+
"""
89+
Test the execute() method for a scenario with multiple document chunks.
90+
It simulates parallel processing of chunks and then merges them.
91+
"""
92+
state = {
93+
"dummy_input": "What is the final answer?",
94+
"doc": ["Chunk text 1", "Chunk text 2"],
95+
}
96+
97+
def fake_invoke_with_timeout(chain, inputs, timeout):
98+
if isinstance(chain, RunnableParallel):
99+
return {
100+
"chunk1": {"content": "answer for chunk 1"},
101+
"chunk2": {"content": "answer for chunk 2"},
102+
}
103+
if "context" in inputs and "question" in inputs:
104+
return {"content": "merged final answer"}
105+
return {"content": "single answer"}
106+
107+
dummy_node_with_pipe.invoke_with_timeout = fake_invoke_with_timeout
108+
output_state = dummy_node_with_pipe.execute(state)
109+
assert output_state["output"] == {"content": "merged final answer"}
110+
111+
112+
def test_execute_single_chunk(dummy_node_with_pipe):
113+
"""
114+
Test the execute() method for a single document chunk.
115+
"""
116+
state = {"dummy_input": "What is the answer?", "doc": ["Only one chunk text"]}
117+
118+
def fake_invoke_with_timeout(chain, inputs, timeout):
119+
if "question" in inputs:
120+
return {"content": "single-chunk answer"}
121+
return {"content": "unexpected result"}
122+
123+
dummy_node_with_pipe.invoke_with_timeout = fake_invoke_with_timeout
124+
output_state = dummy_node_with_pipe.execute(state)
125+
assert output_state["output"] == {"content": "single-chunk answer"}
126+
127+
128+
def test_execute_merge_json_decode_error(dummy_node_with_pipe):
129+
"""
130+
Test that execute() handles a JSONDecodeError in the merge chain properly.
131+
"""
132+
state = {
133+
"dummy_input": "What is the final answer?",
134+
"doc": ["Chunk 1 text", "Chunk 2 text"],
135+
}
136+
137+
def fake_invoke_with_timeout(chain, inputs, timeout):
138+
if isinstance(chain, RunnableParallel):
139+
return {
140+
"chunk1": {"content": "answer for chunk 1"},
141+
"chunk2": {"content": "answer for chunk 2"},
142+
}
143+
if "context" in inputs and "question" in inputs:
144+
raise json.JSONDecodeError("Invalid JSON", "", 0)
145+
return {"content": "unexpected response"}
146+
147+
dummy_node_with_pipe.invoke_with_timeout = fake_invoke_with_timeout
148+
output_state = dummy_node_with_pipe.execute(state)
149+
assert "error" in output_state["output"]
150+
assert (
151+
"Invalid JSON response format during merge" in output_state["output"]["error"]
152+
)
153+
154+
155+
class DummyChain:
156+
"""A dummy chain for simulating a chain's invoke behavior.
157+
Returns a successful answer in the expected format."""
158+
159+
def invoke(self, inputs):
160+
return {"content": "successful answer"}
161+
162+
163+
@pytest.fixture
164+
def dummy_node_for_process():
165+
"""
166+
Fixture for creating a GenerateAnswerNode instance for testing the process() method success case.
167+
"""
168+
node_config = {"llm_model": DummyChain(), "verbose": False, "timeout": 1}
169+
node = GenerateAnswerNode(
170+
"user_prompt & content", ["output"], node_config=node_config
171+
)
172+
node.logger = DummyLogger()
173+
node.get_input_keys = lambda state: ["user_prompt", "content"]
174+
return node
175+
176+
177+
def test_process_success(dummy_node_for_process):
178+
"""
179+
Test that process() successfully generates an answer when both user prompt and content are provided.
180+
"""
181+
state = {
182+
"user_prompt": "What is the answer?",
183+
"content": "This is some valid context.",
184+
}
185+
dummy_node_for_process.chain = DummyChain()
186+
dummy_node_for_process.invoke_with_timeout = (
187+
lambda chain, inputs, timeout: chain.invoke(inputs)
188+
)
189+
new_state = dummy_node_for_process.process(state)
190+
assert new_state["output"] == {"content": "successful answer"}
191+
192+
193+
def test_execute_timeout_single_chunk(dummy_node_with_pipe):
194+
"""
195+
Test that execute() properly handles a Timeout exception in the single chunk branch.
196+
"""
197+
state = {"dummy_input": "What is the answer?", "doc": ["Only one chunk text"]}
198+
199+
def fake_invoke_timeout(chain, inputs, timeout):
200+
raise Timeout("Simulated timeout error")
201+
202+
dummy_node_with_pipe.invoke_with_timeout = fake_invoke_timeout
203+
output_state = dummy_node_with_pipe.execute(state)
204+
assert "error" in output_state["output"]
205+
assert "Response timeout exceeded" in output_state["output"]["error"]
206+
assert "Simulated timeout error" in output_state["output"]["raw_response"]
207+
208+
209+
def test_execute_script_creator_single_chunk():
210+
"""
211+
Test the execute() method for the scenario when script_creator mode is enabled.
212+
This verifies that the non-markdown prompt templates branch is executed and the expected answer is generated.
213+
"""
214+
node_config = {
215+
"llm_model": DummyLLMWithPipe(),
216+
"verbose": False,
217+
"timeout": 480,
218+
"script_creator": True,
219+
"force": False,
220+
"is_md_scraper": False,
221+
"additional_info": "TEST INFO: ",
222+
}
223+
node = GenerateAnswerNode("dummy_input & doc", ["output"], node_config=node_config)
224+
node.logger = DummyLogger()
225+
node.get_input_keys = lambda state: ["dummy_input", "doc"]
226+
state = {
227+
"dummy_input": "What is the script answer?",
228+
"doc": ["Only one chunk script"],
229+
}
230+
231+
def fake_invoke_with_timeout(chain, inputs, timeout):
232+
if "question" in inputs:
233+
return {"content": "script single-chunk answer"}
234+
return {"content": "unexpected response"}
235+
236+
node.invoke_with_timeout = fake_invoke_with_timeout
237+
output_state = node.execute(state)
238+
assert output_state["output"] == {"content": "script single-chunk answer"}
239+
240+
241+
class DummyChatOllama(ChatOllama):
242+
"""A dummy ChatOllama class to simulate ChatOllama behavior."""
243+
244+
245+
class DummySchema:
246+
"""A dummy schema class with a model_json_schema method."""
247+
248+
def model_json_schema(self):
249+
return "dummy_schema_json"
250+
251+
252+
def test_init_chat_ollama_format():
253+
"""
254+
Test that the __init__ method of GenerateAnswerNode sets the format attribute of a ChatOllama LLM correctly.
255+
"""
256+
dummy_llm = DummyChatOllama()
257+
node_config = {"llm_model": dummy_llm, "verbose": False, "timeout": 1}
258+
node = GenerateAnswerNode("dummy_input", ["output"], node_config=node_config)
259+
assert node.llm_model.format == "json"
260+
dummy_llm_with_schema = DummyChatOllama()
261+
node_config_with_schema = {
262+
"llm_model": dummy_llm_with_schema,
263+
"verbose": False,
264+
"timeout": 1,
265+
"schema": DummySchema(),
266+
}
267+
node2 = GenerateAnswerNode(
268+
"dummy_input", ["output"], node_config=node_config_with_schema
269+
)
270+
assert node2.llm_model.format == "dummy_schema_json"

0 commit comments

Comments
 (0)