Merge pull request #612 from ScrapeGraphAI/598-1140+-pydantic-validationerror-with-smartscrapergraph

fix: update generate answernode
This commit is contained in:
Federico Aguzzi 2024-08-31 14:51:04 +02:00 committed by GitHub
commit 9c2aefaa34
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -80,30 +80,20 @@ class GenerateAnswerNode(BaseNode):
self.logger.info(f"--- Executing {self.node_name} Node ---")
# Interpret input keys based on the provided input expression
input_keys = self.get_input_keys(state)
# Fetching data from the state based on the input keys
input_data = [state[key] for key in input_keys]
user_prompt = input_data[0]
doc = input_data[1]
# Initialize the output parser
if self.node_config.get("schema", None) is not None:
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])
# Use built-in structured output for providers that allow it
optional_modules = {"langchain_anthropic", "langchain_fireworks", "langchain_groq", "langchain_google_vertexai"}
if all(key in modules for key in optional_modules):
if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI, ChatAnthropic, ChatFireworks, ChatGroq, ChatVertexAI)):
self.llm_model = self.llm_model.with_structured_output(
schema = self.node_config["schema"],
method="json_schema")
if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI)):
self.llm_model = self.llm_model.with_structured_output(
schema = self.node_config["schema"],
method="json_schema")
else:
if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI)):
self.llm_model = self.llm_model.with_structured_output(
schema = self.node_config["schema"],
method="json_schema")
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])
else:
output_parser = JsonOutputParser()