mirror of
https://github.com/VinciGit00/Scrapegraph-ai.git
synced 2026-06-28 21:01:55 +08:00
fix: correctly parsing output when using structured_output
This commit is contained in:
parent
5e990719cf
commit
8e74ac55a1
@ -6,11 +6,13 @@ from typing import List, Optional
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain_core.output_parsers import JsonOutputParser
|
||||
from langchain_core.runnables import RunnableParallel
|
||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_mistralai import ChatMistralAI
|
||||
from tqdm import tqdm
|
||||
from ..utils.logging import get_logger
|
||||
from .base_node import BaseNode
|
||||
from ..prompts.generate_answer_node_csv_prompts import (TEMPLATE_CHUKS_CSV,
|
||||
TEMPLATE_NO_CHUKS_CSV, TEMPLATE_MERGE_CSV)
|
||||
from ..prompts import TEMPLATE_CHUKS_CSV, TEMPLATE_NO_CHUKS_CSV, TEMPLATE_MERGE_CSV
|
||||
|
||||
class GenerateAnswerCSVNode(BaseNode):
|
||||
"""
|
||||
@ -92,9 +94,24 @@ class GenerateAnswerCSVNode(BaseNode):
|
||||
|
||||
# Initialize the output parser
|
||||
if self.node_config.get("schema", None) is not None:
|
||||
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])
|
||||
|
||||
if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI)):
|
||||
self.llm_model = self.llm_model.with_structured_output(
|
||||
schema = self.node_config["schema"],
|
||||
method="function_calling") # json schema works only on specific models
|
||||
|
||||
# default parser to empty lambda function
|
||||
output_parser = lambda x: x
|
||||
if is_basemodel_subclass(self.node_config["schema"]):
|
||||
output_parser = dict
|
||||
format_instructions = "NA"
|
||||
else:
|
||||
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])
|
||||
format_instructions = output_parser.get_format_instructions()
|
||||
|
||||
else:
|
||||
output_parser = JsonOutputParser()
|
||||
format_instructions = output_parser.get_format_instructions()
|
||||
|
||||
TEMPLATE_NO_CHUKS_CSV_PROMPT = TEMPLATE_NO_CHUKS_CSV
|
||||
TEMPLATE_CHUKS_CSV_PROMPT = TEMPLATE_CHUKS_CSV
|
||||
@ -105,8 +122,6 @@ class GenerateAnswerCSVNode(BaseNode):
|
||||
TEMPLATE_CHUKS_CSV_PROMPT = self.additional_info + TEMPLATE_CHUKS_CSV
|
||||
TEMPLATE_MERGE_CSV_PROMPT = self.additional_info + TEMPLATE_MERGE_CSV
|
||||
|
||||
format_instructions = output_parser.get_format_instructions()
|
||||
|
||||
chains_dict = {}
|
||||
|
||||
if len(doc) == 1:
|
||||
|
||||
@ -1,16 +1,15 @@
|
||||
"""
|
||||
GenerateAnswerNode Module
|
||||
"""
|
||||
from sys import modules
|
||||
from typing import List, Optional
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain_core.output_parsers import JsonOutputParser
|
||||
from langchain_core.runnables import RunnableParallel
|
||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
from langchain_openai import ChatOpenAI, AzureChatOpenAI
|
||||
from langchain_mistralai import ChatMistralAI
|
||||
from langchain_community.chat_models import ChatOllama
|
||||
from tqdm import tqdm
|
||||
from ..utils.logging import get_logger
|
||||
from .base_node import BaseNode
|
||||
from ..prompts import TEMPLATE_CHUNKS, TEMPLATE_NO_CHUNKS, TEMPLATE_MERGE, TEMPLATE_CHUNKS_MD, TEMPLATE_NO_CHUNKS_MD, TEMPLATE_MERGE_MD
|
||||
|
||||
@ -91,14 +90,20 @@ class GenerateAnswerNode(BaseNode):
|
||||
if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI)):
|
||||
self.llm_model = self.llm_model.with_structured_output(
|
||||
schema = self.node_config["schema"],
|
||||
method="json_schema")
|
||||
method="function_calling") # json schema works only on specific models
|
||||
|
||||
# default parser to empty lambda function
|
||||
output_parser = lambda x: x
|
||||
if is_basemodel_subclass(self.node_config["schema"]):
|
||||
output_parser = dict
|
||||
format_instructions = "NA"
|
||||
else:
|
||||
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])
|
||||
format_instructions = output_parser.get_format_instructions()
|
||||
|
||||
else:
|
||||
output_parser = JsonOutputParser()
|
||||
|
||||
format_instructions = output_parser.get_format_instructions()
|
||||
format_instructions = output_parser.get_format_instructions()
|
||||
|
||||
if isinstance(self.llm_model, (ChatOpenAI, AzureChatOpenAI)) and not self.script_creator or self.force and not self.script_creator or self.is_md_scraper:
|
||||
template_no_chunks_prompt = TEMPLATE_NO_CHUNKS_MD
|
||||
|
||||
@ -5,6 +5,9 @@ from typing import List, Optional
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain_core.output_parsers import JsonOutputParser
|
||||
from langchain_core.runnables import RunnableParallel
|
||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_mistralai import ChatMistralAI
|
||||
from tqdm import tqdm
|
||||
from langchain_community.chat_models import ChatOllama
|
||||
from .base_node import BaseNode
|
||||
@ -78,9 +81,25 @@ class GenerateAnswerOmniNode(BaseNode):
|
||||
|
||||
# Initialize the output parser
|
||||
if self.node_config.get("schema", None) is not None:
|
||||
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])
|
||||
|
||||
if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI)):
|
||||
self.llm_model = self.llm_model.with_structured_output(
|
||||
schema = self.node_config["schema"],
|
||||
method="function_calling") # json schema works only on specific models
|
||||
|
||||
# default parser to empty lambda function
|
||||
output_parser = lambda x: x
|
||||
if is_basemodel_subclass(self.node_config["schema"]):
|
||||
output_parser = dict
|
||||
format_instructions = "NA"
|
||||
else:
|
||||
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])
|
||||
format_instructions = output_parser.get_format_instructions()
|
||||
|
||||
else:
|
||||
output_parser = JsonOutputParser()
|
||||
format_instructions = output_parser.get_format_instructions()
|
||||
|
||||
TEMPLATE_NO_CHUNKS_OMNI_prompt = TEMPLATE_NO_CHUNKS_OMNI
|
||||
TEMPLATE_CHUNKS_OMNI_prompt = TEMPLATE_CHUNKS_OMNI
|
||||
TEMPLATE_MERGE_OMNI_prompt= TEMPLATE_MERGE_OMNI
|
||||
@ -90,7 +109,6 @@ class GenerateAnswerOmniNode(BaseNode):
|
||||
TEMPLATE_CHUNKS_OMNI_prompt = self.additional_info + TEMPLATE_CHUNKS_OMNI_prompt
|
||||
TEMPLATE_MERGE_OMNI_prompt = self.additional_info + TEMPLATE_MERGE_OMNI_prompt
|
||||
|
||||
format_instructions = output_parser.get_format_instructions()
|
||||
|
||||
|
||||
chains_dict = {}
|
||||
|
||||
@ -5,6 +5,9 @@ from typing import List, Optional
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain_core.output_parsers import JsonOutputParser
|
||||
from langchain_core.runnables import RunnableParallel
|
||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_mistralai import ChatMistralAI
|
||||
from tqdm import tqdm
|
||||
from langchain_community.chat_models import ChatOllama
|
||||
from ..utils.logging import get_logger
|
||||
@ -93,9 +96,25 @@ class GenerateAnswerPDFNode(BaseNode):
|
||||
|
||||
# Initialize the output parser
|
||||
if self.node_config.get("schema", None) is not None:
|
||||
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])
|
||||
|
||||
if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI)):
|
||||
self.llm_model = self.llm_model.with_structured_output(
|
||||
schema = self.node_config["schema"],
|
||||
method="function_calling") # json schema works only on specific models
|
||||
|
||||
# default parser to empty lambda function
|
||||
output_parser = lambda x: x
|
||||
if is_basemodel_subclass(self.node_config["schema"]):
|
||||
output_parser = dict
|
||||
format_instructions = "NA"
|
||||
else:
|
||||
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])
|
||||
format_instructions = output_parser.get_format_instructions()
|
||||
|
||||
else:
|
||||
output_parser = JsonOutputParser()
|
||||
format_instructions = output_parser.get_format_instructions()
|
||||
|
||||
TEMPLATE_NO_CHUNKS_PDF_prompt = TEMPLATE_NO_CHUNKS_PDF
|
||||
TEMPLATE_CHUNKS_PDF_prompt = TEMPLATE_CHUNKS_PDF
|
||||
TEMPLATE_MERGE_PDF_prompt = TEMPLATE_MERGE_PDF
|
||||
@ -105,8 +124,6 @@ class GenerateAnswerPDFNode(BaseNode):
|
||||
TEMPLATE_CHUNKS_PDF_prompt = self.additional_info + TEMPLATE_CHUNKS_PDF_prompt
|
||||
TEMPLATE_MERGE_PDF_prompt = self.additional_info + TEMPLATE_MERGE_PDF_prompt
|
||||
|
||||
format_instructions = output_parser.get_format_instructions()
|
||||
|
||||
if len(doc) == 1:
|
||||
prompt = PromptTemplate(
|
||||
template=TEMPLATE_NO_CHUNKS_PDF_prompt,
|
||||
|
||||
@ -4,6 +4,9 @@ MergeAnswersNode Module
|
||||
from typing import List, Optional
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain_core.output_parsers import JsonOutputParser
|
||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_mistralai import ChatMistralAI
|
||||
from ..utils.logging import get_logger
|
||||
from .base_node import BaseNode
|
||||
from ..prompts import TEMPLATE_COMBINED
|
||||
@ -68,11 +71,24 @@ class MergeAnswersNode(BaseNode):
|
||||
answers_str += f"CONTENT WEBSITE {i+1}: {answer}\n"
|
||||
|
||||
if self.node_config.get("schema", None) is not None:
|
||||
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])
|
||||
|
||||
if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI)):
|
||||
self.llm_model = self.llm_model.with_structured_output(
|
||||
schema = self.node_config["schema"],
|
||||
method="function_calling") # json schema works only on specific models
|
||||
|
||||
# default parser to empty lambda function
|
||||
output_parser = lambda x: x
|
||||
if is_basemodel_subclass(self.node_config["schema"]):
|
||||
output_parser = dict
|
||||
format_instructions = "NA"
|
||||
else:
|
||||
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])
|
||||
format_instructions = output_parser.get_format_instructions()
|
||||
|
||||
else:
|
||||
output_parser = JsonOutputParser()
|
||||
|
||||
format_instructions = output_parser.get_format_instructions()
|
||||
format_instructions = output_parser.get_format_instructions()
|
||||
|
||||
prompt_template = PromptTemplate(
|
||||
template=TEMPLATE_COMBINED,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user