diff --git a/scrapegraphai/nodes/generate_answer_csv_node.py b/scrapegraphai/nodes/generate_answer_csv_node.py index de127f47..160be9ce 100644 --- a/scrapegraphai/nodes/generate_answer_csv_node.py +++ b/scrapegraphai/nodes/generate_answer_csv_node.py @@ -3,6 +3,7 @@ Module for generating the answer node """ from typing import List, Optional +from pydantic.v1 import BaseModel as BaseModelV1 from langchain.prompts import PromptTemplate from langchain_core.output_parsers import JsonOutputParser from langchain_core.runnables import RunnableParallel @@ -12,6 +13,7 @@ from langchain_mistralai import ChatMistralAI from tqdm import tqdm from ..utils.logging import get_logger from .base_node import BaseNode +from ..utils.llm_output_parser import typed_dict_output_parser, base_model_v2_output_parser, base_model_v1_output_parser from ..prompts import TEMPLATE_CHUKS_CSV, TEMPLATE_NO_CHUKS_CSV, TEMPLATE_MERGE_CSV class GenerateAnswerCSVNode(BaseNode): @@ -97,13 +99,13 @@ class GenerateAnswerCSVNode(BaseNode): if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI)): self.llm_model = self.llm_model.with_structured_output( - schema = self.node_config["schema"], - method="function_calling") # json schema works only on specific models - - # default parser to empty lambda function - output_parser = lambda x: x + schema = self.node_config["schema"]) # json schema works only on specific models + + output_parser = typed_dict_output_parser if is_basemodel_subclass(self.node_config["schema"]): - output_parser = dict + output_parser = base_model_v2_output_parser + if issubclass(self.node_config["schema"], BaseModelV1): + output_parser = base_model_v1_output_parser format_instructions = "NA" else: output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"]) diff --git a/scrapegraphai/nodes/generate_answer_node.py b/scrapegraphai/nodes/generate_answer_node.py index ae92f6c5..6e32d1b5 100644 --- a/scrapegraphai/nodes/generate_answer_node.py +++ b/scrapegraphai/nodes/generate_answer_node.py @@ -2,6 +2,7 @@ GenerateAnswerNode Module """ from typing import List, Optional +from pydantic.v1 import BaseModel as BaseModelV1 from langchain.prompts import PromptTemplate from langchain_core.output_parsers import JsonOutputParser from langchain_core.runnables import RunnableParallel @@ -11,6 +12,7 @@ from langchain_mistralai import ChatMistralAI from langchain_community.chat_models import ChatOllama from tqdm import tqdm from .base_node import BaseNode +from ..utils.llm_output_parser import base_model_v1_output_parser, base_model_v2_output_parser, typed_dict_output_parser from ..prompts import (TEMPLATE_CHUNKS, TEMPLATE_NO_CHUNKS, TEMPLATE_MERGE, TEMPLATE_CHUNKS_MD, TEMPLATE_NO_CHUNKS_MD, @@ -93,12 +95,12 @@ class GenerateAnswerNode(BaseNode): if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI)): self.llm_model = self.llm_model.with_structured_output( schema = self.node_config["schema"]) # json schema works only on specific models - - # default parser to empty lambda function - def output_parser(x): - return x + + output_parser = typed_dict_output_parser if is_basemodel_subclass(self.node_config["schema"]): - output_parser = dict + output_parser = base_model_v2_output_parser + if issubclass(self.node_config["schema"], BaseModelV1): + output_parser = base_model_v1_output_parser format_instructions = "NA" else: output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"]) diff --git a/scrapegraphai/nodes/generate_answer_omni_node.py b/scrapegraphai/nodes/generate_answer_omni_node.py index 32dfbff6..871a26e5 100644 --- a/scrapegraphai/nodes/generate_answer_omni_node.py +++ b/scrapegraphai/nodes/generate_answer_omni_node.py @@ -2,6 +2,7 @@ GenerateAnswerNode Module """ from typing import List, Optional +from pydantic.v1 import BaseModel as BaseModelV1 from langchain.prompts import PromptTemplate from langchain_core.output_parsers import JsonOutputParser from langchain_core.runnables import RunnableParallel @@ -11,6 +12,7 @@ from langchain_mistralai import ChatMistralAI from tqdm import tqdm from langchain_community.chat_models import ChatOllama from .base_node import BaseNode +from ..utils.llm_output_parser import typed_dict_output_parser, base_model_v2_output_parser, base_model_v1_output_parser from ..prompts.generate_answer_node_omni_prompts import (TEMPLATE_NO_CHUNKS_OMNI, TEMPLATE_CHUNKS_OMNI, TEMPLATE_MERGE_OMNI) @@ -86,13 +88,13 @@ class GenerateAnswerOmniNode(BaseNode): if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI)): self.llm_model = self.llm_model.with_structured_output( - schema = self.node_config["schema"], - method="function_calling") # json schema works only on specific models - - # default parser to empty lambda function - output_parser = lambda x: x + schema = self.node_config["schema"]) # json schema works only on specific models + + output_parser = typed_dict_output_parser if is_basemodel_subclass(self.node_config["schema"]): - output_parser = dict + output_parser = base_model_v2_output_parser + if issubclass(self.node_config["schema"], BaseModelV1): + output_parser = base_model_v1_output_parser format_instructions = "NA" else: output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"]) diff --git a/scrapegraphai/nodes/generate_answer_pdf_node.py b/scrapegraphai/nodes/generate_answer_pdf_node.py index 3f7daf73..4832d45e 100644 --- a/scrapegraphai/nodes/generate_answer_pdf_node.py +++ b/scrapegraphai/nodes/generate_answer_pdf_node.py @@ -2,6 +2,7 @@ Module for generating the answer node """ from typing import List, Optional +from pydantic.v1 import BaseModel as BaseModelV1 from langchain.prompts import PromptTemplate from langchain_core.output_parsers import JsonOutputParser from langchain_core.runnables import RunnableParallel @@ -12,6 +13,7 @@ from tqdm import tqdm from langchain_community.chat_models import ChatOllama from ..utils.logging import get_logger from .base_node import BaseNode +from ..utils.llm_output_parser import typed_dict_output_parser, base_model_v2_output_parser, base_model_v1_output_parser from ..prompts.generate_answer_node_pdf_prompts import (TEMPLATE_CHUNKS_PDF, TEMPLATE_NO_CHUNKS_PDF, TEMPLATE_MERGE_PDF) @@ -98,12 +100,13 @@ class GenerateAnswerPDFNode(BaseNode): if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI)): self.llm_model = self.llm_model.with_structured_output( - schema = self.node_config["schema"], - method="function_calling") # json schema works only on specific models - - output_parser = lambda x: x + schema = self.node_config["schema"]) # json schema works only on specific models + + output_parser = typed_dict_output_parser if is_basemodel_subclass(self.node_config["schema"]): - output_parser = dict + output_parser = base_model_v2_output_parser + if issubclass(self.node_config["schema"], BaseModelV1): + output_parser = base_model_v1_output_parser format_instructions = "NA" else: output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"]) diff --git a/scrapegraphai/nodes/merge_answers_node.py b/scrapegraphai/nodes/merge_answers_node.py index a269425f..64e1f149 100644 --- a/scrapegraphai/nodes/merge_answers_node.py +++ b/scrapegraphai/nodes/merge_answers_node.py @@ -2,6 +2,7 @@ MergeAnswersNode Module """ from typing import List, Optional +from pydantic.v1 import BaseModel as BaseModelV1 from langchain.prompts import PromptTemplate from langchain_core.output_parsers import JsonOutputParser from langchain_core.utils.pydantic import is_basemodel_subclass @@ -10,6 +11,7 @@ from langchain_mistralai import ChatMistralAI from ..utils.logging import get_logger from .base_node import BaseNode from ..prompts import TEMPLATE_COMBINED +from ..utils.llm_output_parser import base_model_v1_output_parser, base_model_v2_output_parser, typed_dict_output_parser class MergeAnswersNode(BaseNode): """ @@ -74,12 +76,13 @@ class MergeAnswersNode(BaseNode): if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI)): self.llm_model = self.llm_model.with_structured_output( - schema = self.node_config["schema"], - method="function_calling") # json schema works only on specific models - # default parser to empty lambda function - output_parser = lambda x: x + schema = self.node_config["schema"]) # json schema works only on specific models + + output_parser = typed_dict_output_parser if is_basemodel_subclass(self.node_config["schema"]): - output_parser = dict + output_parser = base_model_v2_output_parser + if issubclass(self.node_config["schema"], BaseModelV1): + output_parser = base_model_v1_output_parser format_instructions = "NA" else: output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"]) @@ -100,7 +103,7 @@ class MergeAnswersNode(BaseNode): merge_chain = prompt_template | self.llm_model | output_parser answer = merge_chain.invoke({"user_prompt": user_prompt}) - answer["sources"] = state.get("urls") + answer["sources"] = state.get("urls", []) state.update({self.output[0]: answer}) return state diff --git a/scrapegraphai/utils/llm_output_parser.py b/scrapegraphai/utils/llm_output_parser.py new file mode 100644 index 00000000..e6ac6e2d --- /dev/null +++ b/scrapegraphai/utils/llm_output_parser.py @@ -0,0 +1,53 @@ +""" +Custom output parser for the LLM model. +""" +from pydantic import BaseModel as BaseModelV2 +from pydantic.v1 import BaseModel as BaseModelV1 + +def base_model_v1_output_parser(x: BaseModelV1) -> dict: + """ + Parse the output of an LLM when the schema is a BaseModelv1 and `with_structured_output` is used. + + Args: + x (BaseModelV2 | BaseModelV1): The output from the LLM model. + + Returns: + dict: The parsed output. + """ + work_dict = x.dict() + + # recursive dict parser + def recursive_dict_parser(work_dict: dict) -> dict: + dict_keys = work_dict.keys() + for key in dict_keys: + if isinstance(work_dict[key], BaseModelV1): + work_dict[key] = work_dict[key].dict() + recursive_dict_parser(work_dict[key]) + return work_dict + + return recursive_dict_parser(work_dict) + + +def base_model_v2_output_parser(x: BaseModelV2) -> dict: + """ + Parse the output of an LLM when the schema is a BaseModelv2 and `with_structured_output` is used. + + Args: + x (BaseModelV2): The output from the LLM model. + + Returns: + dict: The parsed output. + """ + return x.model_dump() + +def typed_dict_output_parser(x: dict) -> dict: + """ + Parse the output of an LLM when the schema is a TypedDict and `with_structured_output` is used. + + Args: + x (dict): The output from the LLM model. + + Returns: + dict: The parsed output. + """ + return x