From 28b85a3b16e0f07fce41b0ed27f8e337a5537c3c Mon Sep 17 00:00:00 2001 From: Lorenzo Paleari <100212108+LorenzoPaleari@users.noreply.github.com> Date: Tue, 17 Sep 2024 23:07:16 +0200 Subject: [PATCH] refactor: Output parser code --- .../nodes/generate_answer_csv_node.py | 13 +-- scrapegraphai/nodes/generate_answer_node.py | 14 +-- .../nodes/generate_answer_omni_node.py | 12 +-- .../nodes/generate_answer_pdf_node.py | 13 +-- scrapegraphai/nodes/merge_answers_node.py | 13 +-- scrapegraphai/utils/llm_output_parser.py | 53 ------------ scrapegraphai/utils/output_parser.py | 85 +++++++++++++++++++ 7 files changed, 101 insertions(+), 102 deletions(-) delete mode 100644 scrapegraphai/utils/llm_output_parser.py create mode 100644 scrapegraphai/utils/output_parser.py diff --git a/scrapegraphai/nodes/generate_answer_csv_node.py b/scrapegraphai/nodes/generate_answer_csv_node.py index 160be9ce..85593cfa 100644 --- a/scrapegraphai/nodes/generate_answer_csv_node.py +++ b/scrapegraphai/nodes/generate_answer_csv_node.py @@ -3,17 +3,14 @@ Module for generating the answer node """ from typing import List, Optional -from pydantic.v1 import BaseModel as BaseModelV1 from langchain.prompts import PromptTemplate from langchain_core.output_parsers import JsonOutputParser from langchain_core.runnables import RunnableParallel -from langchain_core.utils.pydantic import is_basemodel_subclass from langchain_openai import ChatOpenAI from langchain_mistralai import ChatMistralAI from tqdm import tqdm -from ..utils.logging import get_logger from .base_node import BaseNode -from ..utils.llm_output_parser import typed_dict_output_parser, base_model_v2_output_parser, base_model_v1_output_parser +from ..utils.output_parser import get_structured_output_parser, get_pydantic_output_parser from ..prompts import TEMPLATE_CHUKS_CSV, TEMPLATE_NO_CHUKS_CSV, TEMPLATE_MERGE_CSV class GenerateAnswerCSVNode(BaseNode): @@ -101,14 +98,10 @@ class GenerateAnswerCSVNode(BaseNode): self.llm_model = self.llm_model.with_structured_output( schema = self.node_config["schema"]) # json schema works only on specific models - output_parser = typed_dict_output_parser - if is_basemodel_subclass(self.node_config["schema"]): - output_parser = base_model_v2_output_parser - if issubclass(self.node_config["schema"], BaseModelV1): - output_parser = base_model_v1_output_parser + output_parser = get_structured_output_parser(self.node_config["schema"]) format_instructions = "NA" else: - output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"]) + output_parser = get_pydantic_output_parser(self.node_config["schema"]) format_instructions = output_parser.get_format_instructions() else: diff --git a/scrapegraphai/nodes/generate_answer_node.py b/scrapegraphai/nodes/generate_answer_node.py index 6e32d1b5..b0c102e1 100644 --- a/scrapegraphai/nodes/generate_answer_node.py +++ b/scrapegraphai/nodes/generate_answer_node.py @@ -2,17 +2,15 @@ GenerateAnswerNode Module """ from typing import List, Optional -from pydantic.v1 import BaseModel as BaseModelV1 from langchain.prompts import PromptTemplate from langchain_core.output_parsers import JsonOutputParser from langchain_core.runnables import RunnableParallel -from langchain_core.utils.pydantic import is_basemodel_subclass from langchain_openai import ChatOpenAI, AzureChatOpenAI from langchain_mistralai import ChatMistralAI from langchain_community.chat_models import ChatOllama from tqdm import tqdm from .base_node import BaseNode -from ..utils.llm_output_parser import base_model_v1_output_parser, base_model_v2_output_parser, typed_dict_output_parser +from ..utils.output_parser import get_structured_output_parser, get_pydantic_output_parser from ..prompts import (TEMPLATE_CHUNKS, TEMPLATE_NO_CHUNKS, TEMPLATE_MERGE, TEMPLATE_CHUNKS_MD, TEMPLATE_NO_CHUNKS_MD, @@ -95,15 +93,11 @@ class GenerateAnswerNode(BaseNode): if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI)): self.llm_model = self.llm_model.with_structured_output( schema = self.node_config["schema"]) # json schema works only on specific models - - output_parser = typed_dict_output_parser - if is_basemodel_subclass(self.node_config["schema"]): - output_parser = base_model_v2_output_parser - if issubclass(self.node_config["schema"], BaseModelV1): - output_parser = base_model_v1_output_parser + + output_parser = get_structured_output_parser(self.node_config["schema"]) format_instructions = "NA" else: - output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"]) + output_parser = get_pydantic_output_parser(self.node_config["schema"]) format_instructions = output_parser.get_format_instructions() else: diff --git a/scrapegraphai/nodes/generate_answer_omni_node.py b/scrapegraphai/nodes/generate_answer_omni_node.py index 871a26e5..2824a573 100644 --- a/scrapegraphai/nodes/generate_answer_omni_node.py +++ b/scrapegraphai/nodes/generate_answer_omni_node.py @@ -2,17 +2,15 @@ GenerateAnswerNode Module """ from typing import List, Optional -from pydantic.v1 import BaseModel as BaseModelV1 from langchain.prompts import PromptTemplate from langchain_core.output_parsers import JsonOutputParser from langchain_core.runnables import RunnableParallel -from langchain_core.utils.pydantic import is_basemodel_subclass from langchain_openai import ChatOpenAI from langchain_mistralai import ChatMistralAI from tqdm import tqdm from langchain_community.chat_models import ChatOllama from .base_node import BaseNode -from ..utils.llm_output_parser import typed_dict_output_parser, base_model_v2_output_parser, base_model_v1_output_parser +from ..utils.output_parser import get_structured_output_parser, get_pydantic_output_parser from ..prompts.generate_answer_node_omni_prompts import (TEMPLATE_NO_CHUNKS_OMNI, TEMPLATE_CHUNKS_OMNI, TEMPLATE_MERGE_OMNI) @@ -90,14 +88,10 @@ class GenerateAnswerOmniNode(BaseNode): self.llm_model = self.llm_model.with_structured_output( schema = self.node_config["schema"]) # json schema works only on specific models - output_parser = typed_dict_output_parser - if is_basemodel_subclass(self.node_config["schema"]): - output_parser = base_model_v2_output_parser - if issubclass(self.node_config["schema"], BaseModelV1): - output_parser = base_model_v1_output_parser + output_parser = get_structured_output_parser(self.node_config["schema"]) format_instructions = "NA" else: - output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"]) + output_parser = get_pydantic_output_parser(self.node_config["schema"]) format_instructions = output_parser.get_format_instructions() else: diff --git a/scrapegraphai/nodes/generate_answer_pdf_node.py b/scrapegraphai/nodes/generate_answer_pdf_node.py index 4832d45e..544184b4 100644 --- a/scrapegraphai/nodes/generate_answer_pdf_node.py +++ b/scrapegraphai/nodes/generate_answer_pdf_node.py @@ -2,18 +2,15 @@ Module for generating the answer node """ from typing import List, Optional -from pydantic.v1 import BaseModel as BaseModelV1 from langchain.prompts import PromptTemplate from langchain_core.output_parsers import JsonOutputParser from langchain_core.runnables import RunnableParallel -from langchain_core.utils.pydantic import is_basemodel_subclass from langchain_openai import ChatOpenAI from langchain_mistralai import ChatMistralAI from tqdm import tqdm from langchain_community.chat_models import ChatOllama -from ..utils.logging import get_logger from .base_node import BaseNode -from ..utils.llm_output_parser import typed_dict_output_parser, base_model_v2_output_parser, base_model_v1_output_parser +from ..utils.output_parser import get_structured_output_parser, get_pydantic_output_parser from ..prompts.generate_answer_node_pdf_prompts import (TEMPLATE_CHUNKS_PDF, TEMPLATE_NO_CHUNKS_PDF, TEMPLATE_MERGE_PDF) @@ -102,14 +99,10 @@ class GenerateAnswerPDFNode(BaseNode): self.llm_model = self.llm_model.with_structured_output( schema = self.node_config["schema"]) # json schema works only on specific models - output_parser = typed_dict_output_parser - if is_basemodel_subclass(self.node_config["schema"]): - output_parser = base_model_v2_output_parser - if issubclass(self.node_config["schema"], BaseModelV1): - output_parser = base_model_v1_output_parser + output_parser = get_structured_output_parser(self.node_config["schema"]) format_instructions = "NA" else: - output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"]) + output_parser = get_pydantic_output_parser(self.node_config["schema"]) format_instructions = output_parser.get_format_instructions() else: diff --git a/scrapegraphai/nodes/merge_answers_node.py b/scrapegraphai/nodes/merge_answers_node.py index 64e1f149..9f9a356c 100644 --- a/scrapegraphai/nodes/merge_answers_node.py +++ b/scrapegraphai/nodes/merge_answers_node.py @@ -2,16 +2,13 @@ MergeAnswersNode Module """ from typing import List, Optional -from pydantic.v1 import BaseModel as BaseModelV1 from langchain.prompts import PromptTemplate from langchain_core.output_parsers import JsonOutputParser -from langchain_core.utils.pydantic import is_basemodel_subclass from langchain_openai import ChatOpenAI from langchain_mistralai import ChatMistralAI -from ..utils.logging import get_logger from .base_node import BaseNode from ..prompts import TEMPLATE_COMBINED -from ..utils.llm_output_parser import base_model_v1_output_parser, base_model_v2_output_parser, typed_dict_output_parser +from ..utils.output_parser import get_structured_output_parser, get_pydantic_output_parser class MergeAnswersNode(BaseNode): """ @@ -78,14 +75,10 @@ class MergeAnswersNode(BaseNode): self.llm_model = self.llm_model.with_structured_output( schema = self.node_config["schema"]) # json schema works only on specific models - output_parser = typed_dict_output_parser - if is_basemodel_subclass(self.node_config["schema"]): - output_parser = base_model_v2_output_parser - if issubclass(self.node_config["schema"], BaseModelV1): - output_parser = base_model_v1_output_parser + output_parser = get_structured_output_parser(self.node_config["schema"]) format_instructions = "NA" else: - output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"]) + output_parser = get_pydantic_output_parser(self.node_config["schema"]) format_instructions = output_parser.get_format_instructions() else: diff --git a/scrapegraphai/utils/llm_output_parser.py b/scrapegraphai/utils/llm_output_parser.py deleted file mode 100644 index e6ac6e2d..00000000 --- a/scrapegraphai/utils/llm_output_parser.py +++ /dev/null @@ -1,53 +0,0 @@ -""" -Custom output parser for the LLM model. -""" -from pydantic import BaseModel as BaseModelV2 -from pydantic.v1 import BaseModel as BaseModelV1 - -def base_model_v1_output_parser(x: BaseModelV1) -> dict: - """ - Parse the output of an LLM when the schema is a BaseModelv1 and `with_structured_output` is used. - - Args: - x (BaseModelV2 | BaseModelV1): The output from the LLM model. - - Returns: - dict: The parsed output. - """ - work_dict = x.dict() - - # recursive dict parser - def recursive_dict_parser(work_dict: dict) -> dict: - dict_keys = work_dict.keys() - for key in dict_keys: - if isinstance(work_dict[key], BaseModelV1): - work_dict[key] = work_dict[key].dict() - recursive_dict_parser(work_dict[key]) - return work_dict - - return recursive_dict_parser(work_dict) - - -def base_model_v2_output_parser(x: BaseModelV2) -> dict: - """ - Parse the output of an LLM when the schema is a BaseModelv2 and `with_structured_output` is used. - - Args: - x (BaseModelV2): The output from the LLM model. - - Returns: - dict: The parsed output. - """ - return x.model_dump() - -def typed_dict_output_parser(x: dict) -> dict: - """ - Parse the output of an LLM when the schema is a TypedDict and `with_structured_output` is used. - - Args: - x (dict): The output from the LLM model. - - Returns: - dict: The parsed output. - """ - return x diff --git a/scrapegraphai/utils/output_parser.py b/scrapegraphai/utils/output_parser.py new file mode 100644 index 00000000..39ae092e --- /dev/null +++ b/scrapegraphai/utils/output_parser.py @@ -0,0 +1,85 @@ +""" +Functions to retrieve the correct output parser and format instructions for the LLM model. +""" +from pydantic import BaseModel as BaseModelV2 +from pydantic.v1 import BaseModel as BaseModelV1 +from typing import Union, Dict, Any, Type, Callable +from langchain_core.output_parsers import JsonOutputParser + +def get_structured_output_parser(schema: Union[Dict[str, Any], Type[BaseModelV1 | BaseModelV2], Type]) -> Callable: + """ + Get the correct output parser for the LLM model. + + Returns: + Callable: The output parser function. + """ + if issubclass(schema, BaseModelV1): + return _base_model_v1_output_parser + + if issubclass(schema, BaseModelV2): + return _base_model_v2_output_parser + + return _dict_output_parser + +def get_pydantic_output_parser(schema: Union[Dict[str, Any], Type[BaseModelV1 | BaseModelV2], Type]) -> JsonOutputParser: + """ + Get the correct output parser for the LLM model. + + Returns: + JsonOutputParser: The output parser object. + """ + if issubclass(schema, BaseModelV1): + raise ValueError("pydantic.v1 and langchain_core.pydantic_v1 are not supported with this LLM model. Please use pydantic v2 instead.") + + if issubclass(schema, BaseModelV2): + return JsonOutputParser(pydantic_object=schema) + + raise ValueError("The schema is not a pydantic subclass. With this LLM model you must use a pydantic schemas.") + +def _base_model_v1_output_parser(x: BaseModelV1) -> dict: + """ + Parse the output of an LLM when the schema is BaseModelv1. + + Args: + x (BaseModelV1): The output from the LLM model. + + Returns: + dict: The parsed output. + """ + work_dict = x.dict() + + # recursive dict parser + def recursive_dict_parser(work_dict: dict) -> dict: + dict_keys = work_dict.keys() + for key in dict_keys: + if isinstance(work_dict[key], BaseModelV1): + work_dict[key] = work_dict[key].dict() + recursive_dict_parser(work_dict[key]) + return work_dict + + return recursive_dict_parser(work_dict) + + +def _base_model_v2_output_parser(x: BaseModelV2) -> dict: + """ + Parse the output of an LLM when the schema is BaseModelv2. + + Args: + x (BaseModelV2): The output from the LLM model. + + Returns: + dict: The parsed output. + """ + return x.model_dump() + +def _dict_output_parser(x: dict) -> dict: + """ + Parse the output of an LLM when the schema is TypedDict or JsonSchema. + + Args: + x (dict): The output from the LLM model. + + Returns: + dict: The parsed output. + """ + return x