refactor: Output parser code

This commit is contained in:
Lorenzo Paleari 2024-09-17 23:07:16 +02:00
parent 4f8b55d747
commit 28b85a3b16
No known key found for this signature in database
GPG Key ID: 010F47E3CB681DED
7 changed files with 101 additions and 102 deletions

View File

@ -3,17 +3,14 @@ Module for generating the answer node
"""
from typing import List, Optional
from pydantic.v1 import BaseModel as BaseModelV1
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.runnables import RunnableParallel
from langchain_core.utils.pydantic import is_basemodel_subclass
from langchain_openai import ChatOpenAI
from langchain_mistralai import ChatMistralAI
from tqdm import tqdm
from ..utils.logging import get_logger
from .base_node import BaseNode
from ..utils.llm_output_parser import typed_dict_output_parser, base_model_v2_output_parser, base_model_v1_output_parser
from ..utils.output_parser import get_structured_output_parser, get_pydantic_output_parser
from ..prompts import TEMPLATE_CHUKS_CSV, TEMPLATE_NO_CHUKS_CSV, TEMPLATE_MERGE_CSV
class GenerateAnswerCSVNode(BaseNode):
@ -101,14 +98,10 @@ class GenerateAnswerCSVNode(BaseNode):
self.llm_model = self.llm_model.with_structured_output(
schema = self.node_config["schema"]) # json schema works only on specific models
output_parser = typed_dict_output_parser
if is_basemodel_subclass(self.node_config["schema"]):
output_parser = base_model_v2_output_parser
if issubclass(self.node_config["schema"], BaseModelV1):
output_parser = base_model_v1_output_parser
output_parser = get_structured_output_parser(self.node_config["schema"])
format_instructions = "NA"
else:
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])
output_parser = get_pydantic_output_parser(self.node_config["schema"])
format_instructions = output_parser.get_format_instructions()
else:

View File

@ -2,17 +2,15 @@
GenerateAnswerNode Module
"""
from typing import List, Optional
from pydantic.v1 import BaseModel as BaseModelV1
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.runnables import RunnableParallel
from langchain_core.utils.pydantic import is_basemodel_subclass
from langchain_openai import ChatOpenAI, AzureChatOpenAI
from langchain_mistralai import ChatMistralAI
from langchain_community.chat_models import ChatOllama
from tqdm import tqdm
from .base_node import BaseNode
from ..utils.llm_output_parser import base_model_v1_output_parser, base_model_v2_output_parser, typed_dict_output_parser
from ..utils.output_parser import get_structured_output_parser, get_pydantic_output_parser
from ..prompts import (TEMPLATE_CHUNKS,
TEMPLATE_NO_CHUNKS, TEMPLATE_MERGE,
TEMPLATE_CHUNKS_MD, TEMPLATE_NO_CHUNKS_MD,
@ -95,15 +93,11 @@ class GenerateAnswerNode(BaseNode):
if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI)):
self.llm_model = self.llm_model.with_structured_output(
schema = self.node_config["schema"]) # json schema works only on specific models
output_parser = typed_dict_output_parser
if is_basemodel_subclass(self.node_config["schema"]):
output_parser = base_model_v2_output_parser
if issubclass(self.node_config["schema"], BaseModelV1):
output_parser = base_model_v1_output_parser
output_parser = get_structured_output_parser(self.node_config["schema"])
format_instructions = "NA"
else:
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])
output_parser = get_pydantic_output_parser(self.node_config["schema"])
format_instructions = output_parser.get_format_instructions()
else:

View File

@ -2,17 +2,15 @@
GenerateAnswerNode Module
"""
from typing import List, Optional
from pydantic.v1 import BaseModel as BaseModelV1
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.runnables import RunnableParallel
from langchain_core.utils.pydantic import is_basemodel_subclass
from langchain_openai import ChatOpenAI
from langchain_mistralai import ChatMistralAI
from tqdm import tqdm
from langchain_community.chat_models import ChatOllama
from .base_node import BaseNode
from ..utils.llm_output_parser import typed_dict_output_parser, base_model_v2_output_parser, base_model_v1_output_parser
from ..utils.output_parser import get_structured_output_parser, get_pydantic_output_parser
from ..prompts.generate_answer_node_omni_prompts import (TEMPLATE_NO_CHUNKS_OMNI,
TEMPLATE_CHUNKS_OMNI,
TEMPLATE_MERGE_OMNI)
@ -90,14 +88,10 @@ class GenerateAnswerOmniNode(BaseNode):
self.llm_model = self.llm_model.with_structured_output(
schema = self.node_config["schema"]) # json schema works only on specific models
output_parser = typed_dict_output_parser
if is_basemodel_subclass(self.node_config["schema"]):
output_parser = base_model_v2_output_parser
if issubclass(self.node_config["schema"], BaseModelV1):
output_parser = base_model_v1_output_parser
output_parser = get_structured_output_parser(self.node_config["schema"])
format_instructions = "NA"
else:
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])
output_parser = get_pydantic_output_parser(self.node_config["schema"])
format_instructions = output_parser.get_format_instructions()
else:

View File

@ -2,18 +2,15 @@
Module for generating the answer node
"""
from typing import List, Optional
from pydantic.v1 import BaseModel as BaseModelV1
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.runnables import RunnableParallel
from langchain_core.utils.pydantic import is_basemodel_subclass
from langchain_openai import ChatOpenAI
from langchain_mistralai import ChatMistralAI
from tqdm import tqdm
from langchain_community.chat_models import ChatOllama
from ..utils.logging import get_logger
from .base_node import BaseNode
from ..utils.llm_output_parser import typed_dict_output_parser, base_model_v2_output_parser, base_model_v1_output_parser
from ..utils.output_parser import get_structured_output_parser, get_pydantic_output_parser
from ..prompts.generate_answer_node_pdf_prompts import (TEMPLATE_CHUNKS_PDF,
TEMPLATE_NO_CHUNKS_PDF,
TEMPLATE_MERGE_PDF)
@ -102,14 +99,10 @@ class GenerateAnswerPDFNode(BaseNode):
self.llm_model = self.llm_model.with_structured_output(
schema = self.node_config["schema"]) # json schema works only on specific models
output_parser = typed_dict_output_parser
if is_basemodel_subclass(self.node_config["schema"]):
output_parser = base_model_v2_output_parser
if issubclass(self.node_config["schema"], BaseModelV1):
output_parser = base_model_v1_output_parser
output_parser = get_structured_output_parser(self.node_config["schema"])
format_instructions = "NA"
else:
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])
output_parser = get_pydantic_output_parser(self.node_config["schema"])
format_instructions = output_parser.get_format_instructions()
else:

View File

@ -2,16 +2,13 @@
MergeAnswersNode Module
"""
from typing import List, Optional
from pydantic.v1 import BaseModel as BaseModelV1
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.utils.pydantic import is_basemodel_subclass
from langchain_openai import ChatOpenAI
from langchain_mistralai import ChatMistralAI
from ..utils.logging import get_logger
from .base_node import BaseNode
from ..prompts import TEMPLATE_COMBINED
from ..utils.llm_output_parser import base_model_v1_output_parser, base_model_v2_output_parser, typed_dict_output_parser
from ..utils.output_parser import get_structured_output_parser, get_pydantic_output_parser
class MergeAnswersNode(BaseNode):
"""
@ -78,14 +75,10 @@ class MergeAnswersNode(BaseNode):
self.llm_model = self.llm_model.with_structured_output(
schema = self.node_config["schema"]) # json schema works only on specific models
output_parser = typed_dict_output_parser
if is_basemodel_subclass(self.node_config["schema"]):
output_parser = base_model_v2_output_parser
if issubclass(self.node_config["schema"], BaseModelV1):
output_parser = base_model_v1_output_parser
output_parser = get_structured_output_parser(self.node_config["schema"])
format_instructions = "NA"
else:
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])
output_parser = get_pydantic_output_parser(self.node_config["schema"])
format_instructions = output_parser.get_format_instructions()
else:

View File

@ -1,53 +0,0 @@
"""
Custom output parser for the LLM model.
"""
from pydantic import BaseModel as BaseModelV2
from pydantic.v1 import BaseModel as BaseModelV1
def base_model_v1_output_parser(x: BaseModelV1) -> dict:
"""
Parse the output of an LLM when the schema is a BaseModelv1 and `with_structured_output` is used.
Args:
x (BaseModelV2 | BaseModelV1): The output from the LLM model.
Returns:
dict: The parsed output.
"""
work_dict = x.dict()
# recursive dict parser
def recursive_dict_parser(work_dict: dict) -> dict:
dict_keys = work_dict.keys()
for key in dict_keys:
if isinstance(work_dict[key], BaseModelV1):
work_dict[key] = work_dict[key].dict()
recursive_dict_parser(work_dict[key])
return work_dict
return recursive_dict_parser(work_dict)
def base_model_v2_output_parser(x: BaseModelV2) -> dict:
"""
Parse the output of an LLM when the schema is a BaseModelv2 and `with_structured_output` is used.
Args:
x (BaseModelV2): The output from the LLM model.
Returns:
dict: The parsed output.
"""
return x.model_dump()
def typed_dict_output_parser(x: dict) -> dict:
"""
Parse the output of an LLM when the schema is a TypedDict and `with_structured_output` is used.
Args:
x (dict): The output from the LLM model.
Returns:
dict: The parsed output.
"""
return x

View File

@ -0,0 +1,85 @@
"""
Functions to retrieve the correct output parser and format instructions for the LLM model.
"""
from pydantic import BaseModel as BaseModelV2
from pydantic.v1 import BaseModel as BaseModelV1
from typing import Union, Dict, Any, Type, Callable
from langchain_core.output_parsers import JsonOutputParser
def get_structured_output_parser(schema: Union[Dict[str, Any], Type[BaseModelV1 | BaseModelV2], Type]) -> Callable:
"""
Get the correct output parser for the LLM model.
Returns:
Callable: The output parser function.
"""
if issubclass(schema, BaseModelV1):
return _base_model_v1_output_parser
if issubclass(schema, BaseModelV2):
return _base_model_v2_output_parser
return _dict_output_parser
def get_pydantic_output_parser(schema: Union[Dict[str, Any], Type[BaseModelV1 | BaseModelV2], Type]) -> JsonOutputParser:
"""
Get the correct output parser for the LLM model.
Returns:
JsonOutputParser: The output parser object.
"""
if issubclass(schema, BaseModelV1):
raise ValueError("pydantic.v1 and langchain_core.pydantic_v1 are not supported with this LLM model. Please use pydantic v2 instead.")
if issubclass(schema, BaseModelV2):
return JsonOutputParser(pydantic_object=schema)
raise ValueError("The schema is not a pydantic subclass. With this LLM model you must use a pydantic schemas.")
def _base_model_v1_output_parser(x: BaseModelV1) -> dict:
"""
Parse the output of an LLM when the schema is BaseModelv1.
Args:
x (BaseModelV1): The output from the LLM model.
Returns:
dict: The parsed output.
"""
work_dict = x.dict()
# recursive dict parser
def recursive_dict_parser(work_dict: dict) -> dict:
dict_keys = work_dict.keys()
for key in dict_keys:
if isinstance(work_dict[key], BaseModelV1):
work_dict[key] = work_dict[key].dict()
recursive_dict_parser(work_dict[key])
return work_dict
return recursive_dict_parser(work_dict)
def _base_model_v2_output_parser(x: BaseModelV2) -> dict:
"""
Parse the output of an LLM when the schema is BaseModelv2.
Args:
x (BaseModelV2): The output from the LLM model.
Returns:
dict: The parsed output.
"""
return x.model_dump()
def _dict_output_parser(x: dict) -> dict:
"""
Parse the output of an LLM when the schema is TypedDict or JsonSchema.
Args:
x (dict): The output from the LLM model.
Returns:
dict: The parsed output.
"""
return x