feat(GenerateAnswerNode): built-in structured output through LangChain
Some checks failed
/ build (push) Has been cancelled

Co-Authored-By: Marco Vinciguerra <88108002+VinciGit00@users.noreply.github.com>
This commit is contained in:
Federico Aguzzi 2024-08-19 13:45:37 +02:00
parent b48ee825ee
commit d29338b7c2

View File

@ -5,7 +5,12 @@ from typing import List, Optional
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.runnables import RunnableParallel
from langchain_openai import ChatOpenAI
from langchain_openai import ChatOpenAI, AzureChatOpenAI
from langchain_mistralai import ChatMistralAI
from langchain_anthropic import ChatAnthropic
from langchain_groq import ChatGroq
from langchain_fireworks import ChatFireworks
from langchain_google_vertexai import ChatVertexAI
from langchain_community.chat_models import ChatOllama
from tqdm import tqdm
from ..utils.logging import get_logger
@ -88,7 +93,9 @@ class GenerateAnswerNode(BaseNode):
# Initialize the output parser
if self.node_config.get("schema", None) is not None:
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])
if isinstance(self.llm_model, ChatOpenAI) and (self.llm_model.model_name=="gpt-4o-mini" or self.llm_model.model_name=="gpt-4o-2024-08-06"):
# Use built-in structured output for providers that allow it
if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI, ChatAnthropic, ChatFireworks, ChatGroq, ChatVertexAI)):
self.llm_model = self.llm_model.with_structured_output(
schema = self.node_config["schema"],
method="json_schema")
@ -98,7 +105,7 @@ class GenerateAnswerNode(BaseNode):
format_instructions = output_parser.get_format_instructions()
if isinstance(self.llm_model, ChatOpenAI) and not self.script_creator or self.force and not self.script_creator or self.is_md_scraper:
if isinstance(self.llm_model, (ChatOpenAI, AzureChatOpenAI)) and not self.script_creator or self.force and not self.script_creator or self.is_md_scraper:
template_no_chunks_prompt = TEMPLATE_NO_CHUNKS_MD
template_chunks_prompt = TEMPLATE_CHUNKS_MD
template_merge_prompt = TEMPLATE_MERGE_MD