|
3 | 3 | """
|
4 | 4 | from langchain_core.output_parsers import JsonOutputParser
|
5 | 5 | from langchain.prompts import PromptTemplate
|
| 6 | +from langchain_core.runnables import RunnableParallel |
6 | 7 | from .base_node import BaseNode
|
7 | 8 |
|
8 | 9 |
|
@@ -78,22 +79,48 @@ def execute(self, state: dict) -> dict:
|
78 | 79 | output_parser = JsonOutputParser()
|
79 | 80 | format_instructions = output_parser.get_format_instructions()
|
80 | 81 |
|
81 |
| - template = """You are a website scraper and you have just scraped the |
| 82 | + template_chunks = """You are a website scraper and you have just scraped the |
82 | 83 | following content from a website.
|
83 |
| - You are now asked to answer a question about the content you have scraped.\n {format_instructions} \n The content is as follows: {context} |
| 84 | + You are now asked to answer a question about the content you have scraped.\n {format_instructions} \n |
| 85 | + The website is big so I am giving you one chunk at the time to be merged later with the other chunks.\n |
| 86 | + Content of {chunk_id}: {context} |
84 | 87 | Question: {question}
|
85 | 88 | """
|
| 89 | + |
| 90 | + template_merge = """You are a website scraper and you have just scraped the |
| 91 | + following content from a website. |
| 92 | + You are now asked to answer a question about the content you have scraped.\n {format_instructions} \n |
| 93 | + You have scraped many chunks since the website is big and now you are asked to merge them into a single answer without repetitions (if there are any).\n |
| 94 | + Content to merge: {context} |
| 95 | + Question: {question} |
| 96 | + """ |
| 97 | + |
| 98 | + chains_dict = {} |
| 99 | + |
| 100 | + for i, chunk in enumerate(context): |
| 101 | + prompt = PromptTemplate( |
| 102 | + template=template_chunks, |
| 103 | + input_variables=["question"], |
| 104 | + partial_variables={"context": chunk.page_content, "chunk_id": i + 1, "format_instructions": format_instructions}, |
| 105 | + ) |
| 106 | + # Dynamically name the chains based on their index |
| 107 | + chain_name = f"chunk{i+1}" |
| 108 | + chains_dict[chain_name] = prompt | self.llm | output_parser |
86 | 109 |
|
87 |
| - schema_prompt = PromptTemplate( |
88 |
| - template=template, |
| 110 | + # Use dictionary unpacking to pass the dynamically named chains to RunnableParallel |
| 111 | + map_chain = RunnableParallel(**chains_dict) |
| 112 | + # Chain |
| 113 | + answer_map = map_chain.invoke({"question": user_input}) |
| 114 | + |
| 115 | + # Merge the answers from the chunks |
| 116 | + merge_prompt = PromptTemplate( |
| 117 | + template=template_merge, |
89 | 118 | input_variables=["context", "question"],
|
90 | 119 | partial_variables={"format_instructions": format_instructions},
|
91 | 120 | )
|
92 |
| - |
93 |
| - # Chain |
94 |
| - schema_chain = schema_prompt | self.llm | output_parser |
95 |
| - answer = schema_chain.invoke( |
96 |
| - {"context": context, "question": user_input}) |
| 121 | + merge_chain = merge_prompt | self.llm | output_parser |
| 122 | + answer = merge_chain.invoke( |
| 123 | + {"context": answer_map, "question": user_input}) |
97 | 124 |
|
98 | 125 | # Update the state with the generated answer
|
99 | 126 | state.update({"answer": answer})
|
|
0 commit comments