mirror of
https://github.com/VinciGit00/Scrapegraph-ai.git
synced 2026-07-01 21:00:48 +08:00
fix: Added support for nested structure
This commit is contained in:
parent
039ba2e95a
commit
66ea166438
@ -3,6 +3,7 @@ Module for generating the answer node
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
from pydantic.v1 import BaseModel as BaseModelV1
|
||||||
from langchain.prompts import PromptTemplate
|
from langchain.prompts import PromptTemplate
|
||||||
from langchain_core.output_parsers import JsonOutputParser
|
from langchain_core.output_parsers import JsonOutputParser
|
||||||
from langchain_core.runnables import RunnableParallel
|
from langchain_core.runnables import RunnableParallel
|
||||||
@ -12,6 +13,7 @@ from langchain_mistralai import ChatMistralAI
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from ..utils.logging import get_logger
|
from ..utils.logging import get_logger
|
||||||
from .base_node import BaseNode
|
from .base_node import BaseNode
|
||||||
|
from ..utils.llm_output_parser import typed_dict_output_parser, base_model_v2_output_parser, base_model_v1_output_parser
|
||||||
from ..prompts import TEMPLATE_CHUKS_CSV, TEMPLATE_NO_CHUKS_CSV, TEMPLATE_MERGE_CSV
|
from ..prompts import TEMPLATE_CHUKS_CSV, TEMPLATE_NO_CHUKS_CSV, TEMPLATE_MERGE_CSV
|
||||||
|
|
||||||
class GenerateAnswerCSVNode(BaseNode):
|
class GenerateAnswerCSVNode(BaseNode):
|
||||||
@ -97,13 +99,13 @@ class GenerateAnswerCSVNode(BaseNode):
|
|||||||
|
|
||||||
if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI)):
|
if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI)):
|
||||||
self.llm_model = self.llm_model.with_structured_output(
|
self.llm_model = self.llm_model.with_structured_output(
|
||||||
schema = self.node_config["schema"],
|
schema = self.node_config["schema"]) # json schema works only on specific models
|
||||||
method="function_calling") # json schema works only on specific models
|
|
||||||
|
output_parser = typed_dict_output_parser
|
||||||
# default parser to empty lambda function
|
|
||||||
output_parser = lambda x: x
|
|
||||||
if is_basemodel_subclass(self.node_config["schema"]):
|
if is_basemodel_subclass(self.node_config["schema"]):
|
||||||
output_parser = dict
|
output_parser = base_model_v2_output_parser
|
||||||
|
if issubclass(self.node_config["schema"], BaseModelV1):
|
||||||
|
output_parser = base_model_v1_output_parser
|
||||||
format_instructions = "NA"
|
format_instructions = "NA"
|
||||||
else:
|
else:
|
||||||
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])
|
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])
|
||||||
|
|||||||
@ -2,6 +2,7 @@
|
|||||||
GenerateAnswerNode Module
|
GenerateAnswerNode Module
|
||||||
"""
|
"""
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
from pydantic.v1 import BaseModel as BaseModelV1
|
||||||
from langchain.prompts import PromptTemplate
|
from langchain.prompts import PromptTemplate
|
||||||
from langchain_core.output_parsers import JsonOutputParser
|
from langchain_core.output_parsers import JsonOutputParser
|
||||||
from langchain_core.runnables import RunnableParallel
|
from langchain_core.runnables import RunnableParallel
|
||||||
@ -11,6 +12,7 @@ from langchain_mistralai import ChatMistralAI
|
|||||||
from langchain_community.chat_models import ChatOllama
|
from langchain_community.chat_models import ChatOllama
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from .base_node import BaseNode
|
from .base_node import BaseNode
|
||||||
|
from ..utils.llm_output_parser import base_model_v1_output_parser, base_model_v2_output_parser, typed_dict_output_parser
|
||||||
from ..prompts import (TEMPLATE_CHUNKS,
|
from ..prompts import (TEMPLATE_CHUNKS,
|
||||||
TEMPLATE_NO_CHUNKS, TEMPLATE_MERGE,
|
TEMPLATE_NO_CHUNKS, TEMPLATE_MERGE,
|
||||||
TEMPLATE_CHUNKS_MD, TEMPLATE_NO_CHUNKS_MD,
|
TEMPLATE_CHUNKS_MD, TEMPLATE_NO_CHUNKS_MD,
|
||||||
@ -93,12 +95,12 @@ class GenerateAnswerNode(BaseNode):
|
|||||||
if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI)):
|
if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI)):
|
||||||
self.llm_model = self.llm_model.with_structured_output(
|
self.llm_model = self.llm_model.with_structured_output(
|
||||||
schema = self.node_config["schema"]) # json schema works only on specific models
|
schema = self.node_config["schema"]) # json schema works only on specific models
|
||||||
|
|
||||||
# default parser to empty lambda function
|
output_parser = typed_dict_output_parser
|
||||||
def output_parser(x):
|
|
||||||
return x
|
|
||||||
if is_basemodel_subclass(self.node_config["schema"]):
|
if is_basemodel_subclass(self.node_config["schema"]):
|
||||||
output_parser = dict
|
output_parser = base_model_v2_output_parser
|
||||||
|
if issubclass(self.node_config["schema"], BaseModelV1):
|
||||||
|
output_parser = base_model_v1_output_parser
|
||||||
format_instructions = "NA"
|
format_instructions = "NA"
|
||||||
else:
|
else:
|
||||||
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])
|
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])
|
||||||
|
|||||||
@ -2,6 +2,7 @@
|
|||||||
GenerateAnswerNode Module
|
GenerateAnswerNode Module
|
||||||
"""
|
"""
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
from pydantic.v1 import BaseModel as BaseModelV1
|
||||||
from langchain.prompts import PromptTemplate
|
from langchain.prompts import PromptTemplate
|
||||||
from langchain_core.output_parsers import JsonOutputParser
|
from langchain_core.output_parsers import JsonOutputParser
|
||||||
from langchain_core.runnables import RunnableParallel
|
from langchain_core.runnables import RunnableParallel
|
||||||
@ -11,6 +12,7 @@ from langchain_mistralai import ChatMistralAI
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from langchain_community.chat_models import ChatOllama
|
from langchain_community.chat_models import ChatOllama
|
||||||
from .base_node import BaseNode
|
from .base_node import BaseNode
|
||||||
|
from ..utils.llm_output_parser import typed_dict_output_parser, base_model_v2_output_parser, base_model_v1_output_parser
|
||||||
from ..prompts.generate_answer_node_omni_prompts import (TEMPLATE_NO_CHUNKS_OMNI,
|
from ..prompts.generate_answer_node_omni_prompts import (TEMPLATE_NO_CHUNKS_OMNI,
|
||||||
TEMPLATE_CHUNKS_OMNI,
|
TEMPLATE_CHUNKS_OMNI,
|
||||||
TEMPLATE_MERGE_OMNI)
|
TEMPLATE_MERGE_OMNI)
|
||||||
@ -86,13 +88,13 @@ class GenerateAnswerOmniNode(BaseNode):
|
|||||||
|
|
||||||
if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI)):
|
if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI)):
|
||||||
self.llm_model = self.llm_model.with_structured_output(
|
self.llm_model = self.llm_model.with_structured_output(
|
||||||
schema = self.node_config["schema"],
|
schema = self.node_config["schema"]) # json schema works only on specific models
|
||||||
method="function_calling") # json schema works only on specific models
|
|
||||||
|
output_parser = typed_dict_output_parser
|
||||||
# default parser to empty lambda function
|
|
||||||
output_parser = lambda x: x
|
|
||||||
if is_basemodel_subclass(self.node_config["schema"]):
|
if is_basemodel_subclass(self.node_config["schema"]):
|
||||||
output_parser = dict
|
output_parser = base_model_v2_output_parser
|
||||||
|
if issubclass(self.node_config["schema"], BaseModelV1):
|
||||||
|
output_parser = base_model_v1_output_parser
|
||||||
format_instructions = "NA"
|
format_instructions = "NA"
|
||||||
else:
|
else:
|
||||||
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])
|
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])
|
||||||
|
|||||||
@ -2,6 +2,7 @@
|
|||||||
Module for generating the answer node
|
Module for generating the answer node
|
||||||
"""
|
"""
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
from pydantic.v1 import BaseModel as BaseModelV1
|
||||||
from langchain.prompts import PromptTemplate
|
from langchain.prompts import PromptTemplate
|
||||||
from langchain_core.output_parsers import JsonOutputParser
|
from langchain_core.output_parsers import JsonOutputParser
|
||||||
from langchain_core.runnables import RunnableParallel
|
from langchain_core.runnables import RunnableParallel
|
||||||
@ -12,6 +13,7 @@ from tqdm import tqdm
|
|||||||
from langchain_community.chat_models import ChatOllama
|
from langchain_community.chat_models import ChatOllama
|
||||||
from ..utils.logging import get_logger
|
from ..utils.logging import get_logger
|
||||||
from .base_node import BaseNode
|
from .base_node import BaseNode
|
||||||
|
from ..utils.llm_output_parser import typed_dict_output_parser, base_model_v2_output_parser, base_model_v1_output_parser
|
||||||
from ..prompts.generate_answer_node_pdf_prompts import (TEMPLATE_CHUNKS_PDF,
|
from ..prompts.generate_answer_node_pdf_prompts import (TEMPLATE_CHUNKS_PDF,
|
||||||
TEMPLATE_NO_CHUNKS_PDF,
|
TEMPLATE_NO_CHUNKS_PDF,
|
||||||
TEMPLATE_MERGE_PDF)
|
TEMPLATE_MERGE_PDF)
|
||||||
@ -98,12 +100,13 @@ class GenerateAnswerPDFNode(BaseNode):
|
|||||||
|
|
||||||
if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI)):
|
if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI)):
|
||||||
self.llm_model = self.llm_model.with_structured_output(
|
self.llm_model = self.llm_model.with_structured_output(
|
||||||
schema = self.node_config["schema"],
|
schema = self.node_config["schema"]) # json schema works only on specific models
|
||||||
method="function_calling") # json schema works only on specific models
|
|
||||||
|
output_parser = typed_dict_output_parser
|
||||||
output_parser = lambda x: x
|
|
||||||
if is_basemodel_subclass(self.node_config["schema"]):
|
if is_basemodel_subclass(self.node_config["schema"]):
|
||||||
output_parser = dict
|
output_parser = base_model_v2_output_parser
|
||||||
|
if issubclass(self.node_config["schema"], BaseModelV1):
|
||||||
|
output_parser = base_model_v1_output_parser
|
||||||
format_instructions = "NA"
|
format_instructions = "NA"
|
||||||
else:
|
else:
|
||||||
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])
|
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])
|
||||||
|
|||||||
@ -2,6 +2,7 @@
|
|||||||
MergeAnswersNode Module
|
MergeAnswersNode Module
|
||||||
"""
|
"""
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
from pydantic.v1 import BaseModel as BaseModelV1
|
||||||
from langchain.prompts import PromptTemplate
|
from langchain.prompts import PromptTemplate
|
||||||
from langchain_core.output_parsers import JsonOutputParser
|
from langchain_core.output_parsers import JsonOutputParser
|
||||||
from langchain_core.utils.pydantic import is_basemodel_subclass
|
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||||
@ -10,6 +11,7 @@ from langchain_mistralai import ChatMistralAI
|
|||||||
from ..utils.logging import get_logger
|
from ..utils.logging import get_logger
|
||||||
from .base_node import BaseNode
|
from .base_node import BaseNode
|
||||||
from ..prompts import TEMPLATE_COMBINED
|
from ..prompts import TEMPLATE_COMBINED
|
||||||
|
from ..utils.llm_output_parser import base_model_v1_output_parser, base_model_v2_output_parser, typed_dict_output_parser
|
||||||
|
|
||||||
class MergeAnswersNode(BaseNode):
|
class MergeAnswersNode(BaseNode):
|
||||||
"""
|
"""
|
||||||
@ -74,12 +76,13 @@ class MergeAnswersNode(BaseNode):
|
|||||||
|
|
||||||
if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI)):
|
if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI)):
|
||||||
self.llm_model = self.llm_model.with_structured_output(
|
self.llm_model = self.llm_model.with_structured_output(
|
||||||
schema = self.node_config["schema"],
|
schema = self.node_config["schema"]) # json schema works only on specific models
|
||||||
method="function_calling") # json schema works only on specific models
|
|
||||||
# default parser to empty lambda function
|
output_parser = typed_dict_output_parser
|
||||||
output_parser = lambda x: x
|
|
||||||
if is_basemodel_subclass(self.node_config["schema"]):
|
if is_basemodel_subclass(self.node_config["schema"]):
|
||||||
output_parser = dict
|
output_parser = base_model_v2_output_parser
|
||||||
|
if issubclass(self.node_config["schema"], BaseModelV1):
|
||||||
|
output_parser = base_model_v1_output_parser
|
||||||
format_instructions = "NA"
|
format_instructions = "NA"
|
||||||
else:
|
else:
|
||||||
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])
|
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])
|
||||||
@ -100,7 +103,7 @@ class MergeAnswersNode(BaseNode):
|
|||||||
|
|
||||||
merge_chain = prompt_template | self.llm_model | output_parser
|
merge_chain = prompt_template | self.llm_model | output_parser
|
||||||
answer = merge_chain.invoke({"user_prompt": user_prompt})
|
answer = merge_chain.invoke({"user_prompt": user_prompt})
|
||||||
answer["sources"] = state.get("urls")
|
answer["sources"] = state.get("urls", [])
|
||||||
|
|
||||||
state.update({self.output[0]: answer})
|
state.update({self.output[0]: answer})
|
||||||
return state
|
return state
|
||||||
|
|||||||
53
scrapegraphai/utils/llm_output_parser.py
Normal file
53
scrapegraphai/utils/llm_output_parser.py
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
"""
|
||||||
|
Custom output parser for the LLM model.
|
||||||
|
"""
|
||||||
|
from pydantic import BaseModel as BaseModelV2
|
||||||
|
from pydantic.v1 import BaseModel as BaseModelV1
|
||||||
|
|
||||||
|
def base_model_v1_output_parser(x: BaseModelV1) -> dict:
|
||||||
|
"""
|
||||||
|
Parse the output of an LLM when the schema is a BaseModelv1 and `with_structured_output` is used.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (BaseModelV2 | BaseModelV1): The output from the LLM model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: The parsed output.
|
||||||
|
"""
|
||||||
|
work_dict = x.dict()
|
||||||
|
|
||||||
|
# recursive dict parser
|
||||||
|
def recursive_dict_parser(work_dict: dict) -> dict:
|
||||||
|
dict_keys = work_dict.keys()
|
||||||
|
for key in dict_keys:
|
||||||
|
if isinstance(work_dict[key], BaseModelV1):
|
||||||
|
work_dict[key] = work_dict[key].dict()
|
||||||
|
recursive_dict_parser(work_dict[key])
|
||||||
|
return work_dict
|
||||||
|
|
||||||
|
return recursive_dict_parser(work_dict)
|
||||||
|
|
||||||
|
|
||||||
|
def base_model_v2_output_parser(x: BaseModelV2) -> dict:
|
||||||
|
"""
|
||||||
|
Parse the output of an LLM when the schema is a BaseModelv2 and `with_structured_output` is used.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (BaseModelV2): The output from the LLM model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: The parsed output.
|
||||||
|
"""
|
||||||
|
return x.model_dump()
|
||||||
|
|
||||||
|
def typed_dict_output_parser(x: dict) -> dict:
|
||||||
|
"""
|
||||||
|
Parse the output of an LLM when the schema is a TypedDict and `with_structured_output` is used.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (dict): The output from the LLM model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: The parsed output.
|
||||||
|
"""
|
||||||
|
return x
|
||||||
Loading…
Reference in New Issue
Block a user