|
3 | 3 | """
|
4 | 4 | from typing import List, Optional
|
5 | 5 | from json.decoder import JSONDecodeError
|
| 6 | +import time |
6 | 7 | from langchain.prompts import PromptTemplate
|
7 | 8 | from langchain_core.output_parsers import JsonOutputParser
|
8 | 9 | from langchain_core.runnables import RunnableParallel
|
|
12 | 13 | from tqdm import tqdm
|
13 | 14 | from .base_node import BaseNode
|
14 | 15 | from ..utils.output_parser import get_structured_output_parser, get_pydantic_output_parser
|
| 16 | +from requests.exceptions import Timeout |
| 17 | +from langchain.callbacks.manager import CallbackManager |
| 18 | +from langchain.callbacks import get_openai_callback |
15 | 19 | from ..prompts import (
|
16 | 20 | TEMPLATE_CHUNKS, TEMPLATE_NO_CHUNKS, TEMPLATE_MERGE,
|
17 | 21 | TEMPLATE_CHUNKS_MD, TEMPLATE_NO_CHUNKS_MD, TEMPLATE_MERGE_MD
|
18 | 22 | )
|
19 |
| -from langchain.callbacks.manager import CallbackManager |
20 |
| -from langchain.callbacks import get_openai_callback |
21 |
| -from requests.exceptions import Timeout |
22 |
| -import time |
23 | 23 |
|
24 | 24 | class GenerateAnswerNode(BaseNode):
|
25 | 25 | """
|
@@ -82,11 +82,8 @@ def execute(self, state: dict) -> dict:
|
82 | 82 |
|
83 | 83 | if self.node_config.get("schema", None) is not None:
|
84 | 84 | if isinstance(self.llm_model, ChatOpenAI):
|
85 |
| - self.llm_model = self.llm_model.with_structured_output( |
86 |
| - schema=self.node_config["schema"] |
87 |
| - ) |
88 |
| - output_parser = get_structured_output_parser(self.node_config["schema"]) |
89 |
| - format_instructions = "NA" |
| 85 | + output_parser = get_pydantic_output_parser(self.node_config["schema"]) |
| 86 | + format_instructions = output_parser.get_format_instructions() |
90 | 87 | else:
|
91 | 88 | if not isinstance(self.llm_model, ChatBedrock):
|
92 | 89 | output_parser = get_pydantic_output_parser(self.node_config["schema"])
|
|
0 commit comments