fix: generate answer node

This commit is contained in:
Marco Vinciguerra 2024-11-16 16:30:57 +01:00
parent 02550077f1
commit 49897c4d2e

View File

@ -3,6 +3,7 @@ GenerateAnswerNode Module
"""
from typing import List, Optional
from json.decoder import JSONDecodeError
import time
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.runnables import RunnableParallel
@ -12,14 +13,13 @@ from langchain_community.chat_models import ChatOllama
from tqdm import tqdm
from .base_node import BaseNode
from ..utils.output_parser import get_structured_output_parser, get_pydantic_output_parser
from requests.exceptions import Timeout
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks import get_openai_callback
from ..prompts import (
TEMPLATE_CHUNKS, TEMPLATE_NO_CHUNKS, TEMPLATE_MERGE,
TEMPLATE_CHUNKS_MD, TEMPLATE_NO_CHUNKS_MD, TEMPLATE_MERGE_MD
)
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks import get_openai_callback
from requests.exceptions import Timeout
import time
class GenerateAnswerNode(BaseNode):
"""
@ -82,11 +82,8 @@ class GenerateAnswerNode(BaseNode):
if self.node_config.get("schema", None) is not None:
if isinstance(self.llm_model, ChatOpenAI):
self.llm_model = self.llm_model.with_structured_output(
schema=self.node_config["schema"]
)
output_parser = get_structured_output_parser(self.node_config["schema"])
format_instructions = "NA"
output_parser = get_pydantic_output_parser(self.node_config["schema"])
format_instructions = output_parser.get_format_instructions()
else:
if not isinstance(self.llm_model, ChatBedrock):
output_parser = get_pydantic_output_parser(self.node_config["schema"])