fix: Added support for nested structure

This commit is contained in:
Lorenzo Paleari 2024-09-13 04:18:53 +02:00
parent 039ba2e95a
commit 66ea166438
No known key found for this signature in database
GPG Key ID: 010F47E3CB681DED
6 changed files with 93 additions and 28 deletions

View File

@ -3,6 +3,7 @@ Module for generating the answer node
""" """
from typing import List, Optional from typing import List, Optional
from pydantic.v1 import BaseModel as BaseModelV1
from langchain.prompts import PromptTemplate from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser from langchain_core.output_parsers import JsonOutputParser
from langchain_core.runnables import RunnableParallel from langchain_core.runnables import RunnableParallel
@ -12,6 +13,7 @@ from langchain_mistralai import ChatMistralAI
from tqdm import tqdm from tqdm import tqdm
from ..utils.logging import get_logger from ..utils.logging import get_logger
from .base_node import BaseNode from .base_node import BaseNode
from ..utils.llm_output_parser import typed_dict_output_parser, base_model_v2_output_parser, base_model_v1_output_parser
from ..prompts import TEMPLATE_CHUKS_CSV, TEMPLATE_NO_CHUKS_CSV, TEMPLATE_MERGE_CSV from ..prompts import TEMPLATE_CHUKS_CSV, TEMPLATE_NO_CHUKS_CSV, TEMPLATE_MERGE_CSV
class GenerateAnswerCSVNode(BaseNode): class GenerateAnswerCSVNode(BaseNode):
@ -97,13 +99,13 @@ class GenerateAnswerCSVNode(BaseNode):
if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI)): if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI)):
self.llm_model = self.llm_model.with_structured_output( self.llm_model = self.llm_model.with_structured_output(
schema = self.node_config["schema"], schema = self.node_config["schema"]) # json schema works only on specific models
method="function_calling") # json schema works only on specific models
# default parser to empty lambda function output_parser = typed_dict_output_parser
output_parser = lambda x: x
if is_basemodel_subclass(self.node_config["schema"]): if is_basemodel_subclass(self.node_config["schema"]):
output_parser = dict output_parser = base_model_v2_output_parser
if issubclass(self.node_config["schema"], BaseModelV1):
output_parser = base_model_v1_output_parser
format_instructions = "NA" format_instructions = "NA"
else: else:
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"]) output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])

View File

@ -2,6 +2,7 @@
GenerateAnswerNode Module GenerateAnswerNode Module
""" """
from typing import List, Optional from typing import List, Optional
from pydantic.v1 import BaseModel as BaseModelV1
from langchain.prompts import PromptTemplate from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser from langchain_core.output_parsers import JsonOutputParser
from langchain_core.runnables import RunnableParallel from langchain_core.runnables import RunnableParallel
@ -11,6 +12,7 @@ from langchain_mistralai import ChatMistralAI
from langchain_community.chat_models import ChatOllama from langchain_community.chat_models import ChatOllama
from tqdm import tqdm from tqdm import tqdm
from .base_node import BaseNode from .base_node import BaseNode
from ..utils.llm_output_parser import base_model_v1_output_parser, base_model_v2_output_parser, typed_dict_output_parser
from ..prompts import (TEMPLATE_CHUNKS, from ..prompts import (TEMPLATE_CHUNKS,
TEMPLATE_NO_CHUNKS, TEMPLATE_MERGE, TEMPLATE_NO_CHUNKS, TEMPLATE_MERGE,
TEMPLATE_CHUNKS_MD, TEMPLATE_NO_CHUNKS_MD, TEMPLATE_CHUNKS_MD, TEMPLATE_NO_CHUNKS_MD,
@ -94,11 +96,11 @@ class GenerateAnswerNode(BaseNode):
self.llm_model = self.llm_model.with_structured_output( self.llm_model = self.llm_model.with_structured_output(
schema = self.node_config["schema"]) # json schema works only on specific models schema = self.node_config["schema"]) # json schema works only on specific models
# default parser to empty lambda function output_parser = typed_dict_output_parser
def output_parser(x):
return x
if is_basemodel_subclass(self.node_config["schema"]): if is_basemodel_subclass(self.node_config["schema"]):
output_parser = dict output_parser = base_model_v2_output_parser
if issubclass(self.node_config["schema"], BaseModelV1):
output_parser = base_model_v1_output_parser
format_instructions = "NA" format_instructions = "NA"
else: else:
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"]) output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])

View File

@ -2,6 +2,7 @@
GenerateAnswerNode Module GenerateAnswerNode Module
""" """
from typing import List, Optional from typing import List, Optional
from pydantic.v1 import BaseModel as BaseModelV1
from langchain.prompts import PromptTemplate from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser from langchain_core.output_parsers import JsonOutputParser
from langchain_core.runnables import RunnableParallel from langchain_core.runnables import RunnableParallel
@ -11,6 +12,7 @@ from langchain_mistralai import ChatMistralAI
from tqdm import tqdm from tqdm import tqdm
from langchain_community.chat_models import ChatOllama from langchain_community.chat_models import ChatOllama
from .base_node import BaseNode from .base_node import BaseNode
from ..utils.llm_output_parser import typed_dict_output_parser, base_model_v2_output_parser, base_model_v1_output_parser
from ..prompts.generate_answer_node_omni_prompts import (TEMPLATE_NO_CHUNKS_OMNI, from ..prompts.generate_answer_node_omni_prompts import (TEMPLATE_NO_CHUNKS_OMNI,
TEMPLATE_CHUNKS_OMNI, TEMPLATE_CHUNKS_OMNI,
TEMPLATE_MERGE_OMNI) TEMPLATE_MERGE_OMNI)
@ -86,13 +88,13 @@ class GenerateAnswerOmniNode(BaseNode):
if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI)): if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI)):
self.llm_model = self.llm_model.with_structured_output( self.llm_model = self.llm_model.with_structured_output(
schema = self.node_config["schema"], schema = self.node_config["schema"]) # json schema works only on specific models
method="function_calling") # json schema works only on specific models
# default parser to empty lambda function output_parser = typed_dict_output_parser
output_parser = lambda x: x
if is_basemodel_subclass(self.node_config["schema"]): if is_basemodel_subclass(self.node_config["schema"]):
output_parser = dict output_parser = base_model_v2_output_parser
if issubclass(self.node_config["schema"], BaseModelV1):
output_parser = base_model_v1_output_parser
format_instructions = "NA" format_instructions = "NA"
else: else:
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"]) output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])

View File

@ -2,6 +2,7 @@
Module for generating the answer node Module for generating the answer node
""" """
from typing import List, Optional from typing import List, Optional
from pydantic.v1 import BaseModel as BaseModelV1
from langchain.prompts import PromptTemplate from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser from langchain_core.output_parsers import JsonOutputParser
from langchain_core.runnables import RunnableParallel from langchain_core.runnables import RunnableParallel
@ -12,6 +13,7 @@ from tqdm import tqdm
from langchain_community.chat_models import ChatOllama from langchain_community.chat_models import ChatOllama
from ..utils.logging import get_logger from ..utils.logging import get_logger
from .base_node import BaseNode from .base_node import BaseNode
from ..utils.llm_output_parser import typed_dict_output_parser, base_model_v2_output_parser, base_model_v1_output_parser
from ..prompts.generate_answer_node_pdf_prompts import (TEMPLATE_CHUNKS_PDF, from ..prompts.generate_answer_node_pdf_prompts import (TEMPLATE_CHUNKS_PDF,
TEMPLATE_NO_CHUNKS_PDF, TEMPLATE_NO_CHUNKS_PDF,
TEMPLATE_MERGE_PDF) TEMPLATE_MERGE_PDF)
@ -98,12 +100,13 @@ class GenerateAnswerPDFNode(BaseNode):
if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI)): if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI)):
self.llm_model = self.llm_model.with_structured_output( self.llm_model = self.llm_model.with_structured_output(
schema = self.node_config["schema"], schema = self.node_config["schema"]) # json schema works only on specific models
method="function_calling") # json schema works only on specific models
output_parser = lambda x: x output_parser = typed_dict_output_parser
if is_basemodel_subclass(self.node_config["schema"]): if is_basemodel_subclass(self.node_config["schema"]):
output_parser = dict output_parser = base_model_v2_output_parser
if issubclass(self.node_config["schema"], BaseModelV1):
output_parser = base_model_v1_output_parser
format_instructions = "NA" format_instructions = "NA"
else: else:
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"]) output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])

View File

@ -2,6 +2,7 @@
MergeAnswersNode Module MergeAnswersNode Module
""" """
from typing import List, Optional from typing import List, Optional
from pydantic.v1 import BaseModel as BaseModelV1
from langchain.prompts import PromptTemplate from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser from langchain_core.output_parsers import JsonOutputParser
from langchain_core.utils.pydantic import is_basemodel_subclass from langchain_core.utils.pydantic import is_basemodel_subclass
@ -10,6 +11,7 @@ from langchain_mistralai import ChatMistralAI
from ..utils.logging import get_logger from ..utils.logging import get_logger
from .base_node import BaseNode from .base_node import BaseNode
from ..prompts import TEMPLATE_COMBINED from ..prompts import TEMPLATE_COMBINED
from ..utils.llm_output_parser import base_model_v1_output_parser, base_model_v2_output_parser, typed_dict_output_parser
class MergeAnswersNode(BaseNode): class MergeAnswersNode(BaseNode):
""" """
@ -74,12 +76,13 @@ class MergeAnswersNode(BaseNode):
if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI)): if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI)):
self.llm_model = self.llm_model.with_structured_output( self.llm_model = self.llm_model.with_structured_output(
schema = self.node_config["schema"], schema = self.node_config["schema"]) # json schema works only on specific models
method="function_calling") # json schema works only on specific models
# default parser to empty lambda function output_parser = typed_dict_output_parser
output_parser = lambda x: x
if is_basemodel_subclass(self.node_config["schema"]): if is_basemodel_subclass(self.node_config["schema"]):
output_parser = dict output_parser = base_model_v2_output_parser
if issubclass(self.node_config["schema"], BaseModelV1):
output_parser = base_model_v1_output_parser
format_instructions = "NA" format_instructions = "NA"
else: else:
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"]) output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])
@ -100,7 +103,7 @@ class MergeAnswersNode(BaseNode):
merge_chain = prompt_template | self.llm_model | output_parser merge_chain = prompt_template | self.llm_model | output_parser
answer = merge_chain.invoke({"user_prompt": user_prompt}) answer = merge_chain.invoke({"user_prompt": user_prompt})
answer["sources"] = state.get("urls") answer["sources"] = state.get("urls", [])
state.update({self.output[0]: answer}) state.update({self.output[0]: answer})
return state return state

View File

@ -0,0 +1,53 @@
"""
Custom output parser for the LLM model.
"""
from pydantic import BaseModel as BaseModelV2
from pydantic.v1 import BaseModel as BaseModelV1
def base_model_v1_output_parser(x: BaseModelV1) -> dict:
"""
Parse the output of an LLM when the schema is a BaseModelv1 and `with_structured_output` is used.
Args:
x (BaseModelV2 | BaseModelV1): The output from the LLM model.
Returns:
dict: The parsed output.
"""
work_dict = x.dict()
# recursive dict parser
def recursive_dict_parser(work_dict: dict) -> dict:
dict_keys = work_dict.keys()
for key in dict_keys:
if isinstance(work_dict[key], BaseModelV1):
work_dict[key] = work_dict[key].dict()
recursive_dict_parser(work_dict[key])
return work_dict
return recursive_dict_parser(work_dict)
def base_model_v2_output_parser(x: BaseModelV2) -> dict:
"""
Parse the output of an LLM when the schema is a BaseModelv2 and `with_structured_output` is used.
Args:
x (BaseModelV2): The output from the LLM model.
Returns:
dict: The parsed output.
"""
return x.model_dump()
def typed_dict_output_parser(x: dict) -> dict:
"""
Parse the output of an LLM when the schema is a TypedDict and `with_structured_output` is used.
Args:
x (dict): The output from the LLM model.
Returns:
dict: The parsed output.
"""
return x