mirror of
https://github.com/VinciGit00/Scrapegraph-ai.git
synced 2026-07-01 21:00:48 +08:00
feat(GenerateAnswerNode): built-in structured output through LangChain
Some checks failed
/ build (push) Has been cancelled
Some checks failed
/ build (push) Has been cancelled
Co-Authored-By: Marco Vinciguerra <88108002+VinciGit00@users.noreply.github.com>
This commit is contained in:
parent
b48ee825ee
commit
d29338b7c2
@ -5,7 +5,12 @@ from typing import List, Optional
|
|||||||
from langchain.prompts import PromptTemplate
|
from langchain.prompts import PromptTemplate
|
||||||
from langchain_core.output_parsers import JsonOutputParser
|
from langchain_core.output_parsers import JsonOutputParser
|
||||||
from langchain_core.runnables import RunnableParallel
|
from langchain_core.runnables import RunnableParallel
|
||||||
from langchain_openai import ChatOpenAI
|
from langchain_openai import ChatOpenAI, AzureChatOpenAI
|
||||||
|
from langchain_mistralai import ChatMistralAI
|
||||||
|
from langchain_anthropic import ChatAnthropic
|
||||||
|
from langchain_groq import ChatGroq
|
||||||
|
from langchain_fireworks import ChatFireworks
|
||||||
|
from langchain_google_vertexai import ChatVertexAI
|
||||||
from langchain_community.chat_models import ChatOllama
|
from langchain_community.chat_models import ChatOllama
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from ..utils.logging import get_logger
|
from ..utils.logging import get_logger
|
||||||
@ -88,7 +93,9 @@ class GenerateAnswerNode(BaseNode):
|
|||||||
# Initialize the output parser
|
# Initialize the output parser
|
||||||
if self.node_config.get("schema", None) is not None:
|
if self.node_config.get("schema", None) is not None:
|
||||||
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])
|
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])
|
||||||
if isinstance(self.llm_model, ChatOpenAI) and (self.llm_model.model_name=="gpt-4o-mini" or self.llm_model.model_name=="gpt-4o-2024-08-06"):
|
|
||||||
|
# Use built-in structured output for providers that allow it
|
||||||
|
if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI, ChatAnthropic, ChatFireworks, ChatGroq, ChatVertexAI)):
|
||||||
self.llm_model = self.llm_model.with_structured_output(
|
self.llm_model = self.llm_model.with_structured_output(
|
||||||
schema = self.node_config["schema"],
|
schema = self.node_config["schema"],
|
||||||
method="json_schema")
|
method="json_schema")
|
||||||
@ -98,7 +105,7 @@ class GenerateAnswerNode(BaseNode):
|
|||||||
|
|
||||||
format_instructions = output_parser.get_format_instructions()
|
format_instructions = output_parser.get_format_instructions()
|
||||||
|
|
||||||
if isinstance(self.llm_model, ChatOpenAI) and not self.script_creator or self.force and not self.script_creator or self.is_md_scraper:
|
if isinstance(self.llm_model, (ChatOpenAI, AzureChatOpenAI)) and not self.script_creator or self.force and not self.script_creator or self.is_md_scraper:
|
||||||
template_no_chunks_prompt = TEMPLATE_NO_CHUNKS_MD
|
template_no_chunks_prompt = TEMPLATE_NO_CHUNKS_MD
|
||||||
template_chunks_prompt = TEMPLATE_CHUNKS_MD
|
template_chunks_prompt = TEMPLATE_CHUNKS_MD
|
||||||
template_merge_prompt = TEMPLATE_MERGE_MD
|
template_merge_prompt = TEMPLATE_MERGE_MD
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user