Skip to content

Commit 55702b2

Browse files
authored
Merge pull request #39 from VinciGit00/fix-bug-merge
Fix bug merge
2 parents f264a27 + aeff434 commit 55702b2

13 files changed

+95
-156
lines changed
File renamed without changes.

examples/graph_examples/custom_graph_example.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from dotenv import load_dotenv
77
from scrapegraphai.models import OpenAI
88
from scrapegraphai.graphs import BaseGraph
9-
from scrapegraphai.nodes import FetchHTMLNode, ParseHTMLNode, GenerateAnswerNode
9+
from scrapegraphai.nodes import FetchHTMLNode, ParseNode, RAGNode, GenerateAnswerNode
1010

1111
load_dotenv()
1212

@@ -22,26 +22,29 @@
2222

2323
# define the nodes for the graph
2424
fetch_html_node = FetchHTMLNode("fetch_html")
25-
parse_document_node = ParseHTMLNode("parse_document")
25+
parse_document_node = ParseNode(doc_type="html", chunks_size=4000, node_name="parse_document")
26+
rag_node = RAGNode(model, "rag")
2627
generate_answer_node = GenerateAnswerNode(model, "generate_answer")
2728

2829
# create the graph
2930
graph = BaseGraph(
3031
nodes={
3132
fetch_html_node,
3233
parse_document_node,
34+
rag_node,
3335
generate_answer_node
3436
},
3537
edges={
3638
(fetch_html_node, parse_document_node),
37-
(parse_document_node, generate_answer_node)
39+
(parse_document_node, rag_node),
40+
(rag_node, generate_answer_node)
3841
},
3942
entry_point=fetch_html_node
4043
)
4144

4245
# execute the graph
43-
inputs = {"user_input": "Give me the news",
44-
"url": "https://www.ansa.it/sito/notizie/topnews/index.shtml"}
46+
inputs = {"user_input": "List me the projects with their description",
47+
"url": "https://perinim.github.io/projects/"}
4548
result = graph.execute(inputs)
4649

4750
# get the answer from the result

examples/graph_examples/smart_scraper_example.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
}
1717

1818
# Define URL and PROMPT
19-
URL = "https://www.google.com/search?client=safari&rls=en&q=ristoranti+trento&ie=UTF-8&oe=UTF-8"
20-
PROMPT = "List me all the https inside the page"
19+
URL = "https://www.ansa.it/veneto/"
20+
PROMPT = "List me all the news with their description."
2121

2222
# Create the SmartScraperGraph instance
2323
smart_scraper_graph = SmartScraperGraph(PROMPT, URL, llm_config)

scrapegraphai/graphs/smart_scraper_graph.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from .base_graph import BaseGraph
66
from ..nodes import (
77
FetchHTMLNode,
8+
ParseNode,
89
RAGNode,
910
GenerateAnswerNode
1011
)
@@ -73,18 +74,22 @@ def _create_graph(self):
7374
Returns:
7475
BaseGraph: An instance of the BaseGraph class.
7576
"""
77+
# define the nodes for the graph
7678
fetch_html_node = FetchHTMLNode("fetch_html")
79+
parse_document_node = ParseNode(doc_type="html", chunks_size=4000, node_name="parse_document")
7780
rag_node = RAGNode(self.llm, "rag")
7881
generate_answer_node = GenerateAnswerNode(self.llm, "generate_answer")
7982

8083
return BaseGraph(
8184
nodes={
8285
fetch_html_node,
86+
parse_document_node,
8387
rag_node,
8488
generate_answer_node,
8589
},
8690
edges={
87-
(fetch_html_node, rag_node),
91+
(fetch_html_node, parse_document_node),
92+
(parse_document_node, rag_node),
8893
(rag_node, generate_answer_node)
8994
},
9095
entry_point=fetch_html_node

scrapegraphai/graphs/speech_summary_graph.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from .base_graph import BaseGraph
77
from ..nodes import (
88
FetchHTMLNode,
9+
ParseNode,
910
RAGNode,
1011
GenerateAnswerNode,
1112
TextToSpeechNode,
@@ -79,6 +80,7 @@ def _create_graph(self):
7980
BaseGraph: An instance of the BaseGraph class.
8081
"""
8182
fetch_html_node = FetchHTMLNode("fetch_html")
83+
parse_document_node = ParseNode(doc_type="html", chunks_size=4000, node_name="parse_document")
8284
rag_node = RAGNode(self.llm, "rag")
8385
generate_answer_node = GenerateAnswerNode(self.llm, "generate_answer")
8486
text_to_speech_node = TextToSpeechNode(
@@ -87,12 +89,14 @@ def _create_graph(self):
8789
return BaseGraph(
8890
nodes={
8991
fetch_html_node,
92+
parse_document_node,
9093
rag_node,
9194
generate_answer_node,
9295
text_to_speech_node
9396
},
9497
edges={
95-
(fetch_html_node, rag_node),
98+
(fetch_html_node, parse_document_node),
99+
(parse_document_node, rag_node),
96100
(rag_node, generate_answer_node),
97101
(generate_answer_node, text_to_speech_node)
98102
},

scrapegraphai/helpers/nodes_metadata.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,12 @@
2020
},
2121
"returns": "Updated state with probable HTML tags under 'tags' key."
2222
},
23-
"ParseHTMLNode": {
24-
"description": "Parses HTML content to extract specific data.",
23+
"ParseNode": {
24+
"description": "Parses document content to extract specific data.",
2525
"type": "node",
2626
"args": {
27-
"document": "HTML content as a string.",
28-
"tags": "List of HTML tags to focus on during parsing."
27+
"doc_type": "Type of the input document. Default is 'html'.",
28+
"document": "The document content to be parsed.",
2929
},
3030
"returns": "Updated state with extracted data under 'parsed_document' key."
3131
},
@@ -38,7 +38,7 @@
3838
"type": "node",
3939
"args": {
4040
"user_input": "The user's query or question guiding the retrieval.",
41-
"document": "The HTML content to be processed and compressed."
41+
"document": "The document content to be processed and compressed."
4242
},
4343
"returns": """Updated state with 'relevant_chunks' key containing
4444
the most relevant text chunks."""
@@ -48,7 +48,7 @@
4848
"type": "node",
4949
"args": {
5050
"user_input": "User's query or question.",
51-
"parsed_document": "Data extracted from the HTML document."
51+
"parsed_document": "Data extracted from the input document."
5252
},
5353
"returns": "Updated state with the answer under 'answer' key."
5454
},

scrapegraphai/nodes/__init__.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,8 @@
55
from .conditional_node import ConditionalNode
66
from .get_probable_tags_node import GetProbableTagsNode
77
from .generate_answer_node import GenerateAnswerNode
8-
from .parse_html_node import ParseHTMLNode
8+
from .parse_node import ParseNode
99
from .rag_node import RAGNode
1010
from .text_to_speech_node import TextToSpeechNode
1111
from .image_to_text_node import ImageToTextNode
12-
from .fetch_text_node import FetchTextNode
13-
from .parse_text_node import ParseTextNode
12+
from .fetch_text_node import FetchTextNode

scrapegraphai/nodes/fetch_html_node.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,11 @@ def execute(self, state: dict) -> dict:
8181

8282
loader = AsyncHtmlLoader(url)
8383
document = loader.load()
84-
metadata = document[0].metadata
85-
document = remover(str(document[0]))
84+
# metadata = document[0].metadata
85+
# document = remover(str(document[0]))
8686

87-
state["document"] = [
88-
Document(page_content=document, metadata=metadata)]
87+
# state["document"] = [
88+
# Document(page_content=document, metadata=metadata)]
89+
state["document"] = document
8990

9091
return state

scrapegraphai/nodes/generate_answer_node.py

+13-20
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111

1212
# Imports from the library
1313
from .base_node import BaseNode
14-
from langchain.text_splitter import RecursiveCharacterTextSplitter
1514

1615

1716
class GenerateAnswerNode(BaseNode):
@@ -71,7 +70,7 @@ def execute(self, state: dict) -> dict:
7170
print("---GENERATING ANSWER---")
7271
try:
7372
user_input = state["user_input"]
74-
document = state["document_chunks"]
73+
document = state["document"]
7574
except KeyError as e:
7675
print(f"Error: {e} not found in state.")
7776
raise
@@ -111,34 +110,28 @@ def execute(self, state: dict) -> dict:
111110
prompt = PromptTemplate(
112111
template=template_chunks,
113112
input_variables=["question"],
114-
partial_variables={"context": chunk,
113+
partial_variables={"context": chunk.page_content,
115114
"chunk_id": i + 1, "format_instructions": format_instructions},
116115
)
117116
# Dynamically name the chains based on their index
118-
chains_dict[f"chunk{i+1}"] = prompt | self.llm | output_parser
117+
chain_name = f"chunk{i+1}"
118+
chains_dict[chain_name] = prompt | self.llm | output_parser
119119

120-
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
121-
chunk_size=4000,
122-
chunk_overlap=0,
123-
)
124-
125-
chunks = text_splitter.split_text(str(chains_dict))
120+
# Use dictionary unpacking to pass the dynamically named chains to RunnableParallel
121+
map_chain = RunnableParallel(**chains_dict)
122+
# Chain
123+
answer_map = map_chain.invoke({"question": user_input})
126124

125+
# Merge the answers from the chunks
127126
merge_prompt = PromptTemplate(
128127
template=template_merge,
129128
input_variables=["context", "question"],
130129
partial_variables={"format_instructions": format_instructions},
131130
)
132131
merge_chain = merge_prompt | self.llm | output_parser
132+
answer = merge_chain.invoke(
133+
{"context": answer_map, "question": user_input})
133134

134-
answer_lines = []
135-
for chunk in chunks:
136-
answer_temp = merge_chain.invoke(
137-
{"context": chunk, "question": user_input})
138-
answer_lines.append(answer_temp)
139-
140-
unique_answer_lines = list(set(answer_lines))
141-
answer = '\n'.join(unique_answer_lines)
142-
135+
# Update the state with the generated answer
143136
state.update({"answer": answer})
144-
return state
137+
return state

scrapegraphai/nodes/parse_html_node.py scrapegraphai/nodes/parse_node.py

+18-10
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66
from .base_node import BaseNode
77

88

9-
class ParseHTMLNode(BaseNode):
9+
class ParseNode(BaseNode):
1010
"""
11-
A node responsible for parsing HTML content from a document using specified tags.
11+
A node responsible for parsing HTML content from a document.
1212
It uses BeautifulSoupTransformer for parsing, providing flexibility in extracting
13-
specific parts of an HTML document based on the tags provided in the state.
13+
specific parts of an HTML document.
1414
1515
This node enhances the scraping workflow by allowing for targeted extraction of
1616
content, thereby optimizing the processing of large HTML documents.
@@ -28,14 +28,18 @@ class ParseHTMLNode(BaseNode):
2828
the specified tags, if provided, and updates the state with the parsed content.
2929
"""
3030

31-
def __init__(self, node_name: str):
31+
def __init__(self, doc_type: str = "html", chunks_size: int = 4000, node_name: str = "ParseHTMLNode"):
3232
"""
3333
Initializes the ParseHTMLNode with a node name.
3434
Args:
35+
doc_type (str): type of the input document
36+
chunks_size (int): size of the chunks to split the document
3537
node_name (str): name of the node
3638
node_type (str, optional): type of the node
3739
"""
3840
super().__init__(node_name, "node")
41+
self.doc_type = doc_type
42+
self.chunks_size = chunks_size
3943

4044
def execute(self, state):
4145
"""
@@ -57,23 +61,27 @@ def execute(self, state):
5761
information for parsing is missing.
5862
"""
5963

60-
print("---PARSING HTML DOCUMENT---")
64+
print("---PARSING DOCUMENT---")
6165
try:
6266
document = state["document"]
6367
except KeyError as e:
6468
print(f"Error: {e} not found in state.")
6569
raise
66-
70+
6771
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
68-
chunk_size=4000,
72+
chunk_size=self.chunks_size,
6973
chunk_overlap=0,
7074
)
7175

72-
docs_transformed = Html2TextTransformer(
73-
).transform_documents(document)[0]
76+
# Parse the document based on the specified doc_type
77+
if self.doc_type == "html":
78+
docs_transformed = Html2TextTransformer(
79+
).transform_documents(document)[0]
80+
elif self.doc_type == "text":
81+
docs_transformed = document
7482

7583
chunks = text_splitter.split_text(docs_transformed.page_content)
7684

77-
state.update({"document_chunks": chunks})
85+
state.update({"parsed_document": chunks})
7886

7987
return state

scrapegraphai/nodes/parse_text_node.py

-76
This file was deleted.

0 commit comments

Comments
 (0)