diff --git a/scrapegraphai/nodes/generate_answer_node.py b/scrapegraphai/nodes/generate_answer_node.py index fdaacbfe..970a6790 100644 --- a/scrapegraphai/nodes/generate_answer_node.py +++ b/scrapegraphai/nodes/generate_answer_node.py @@ -5,7 +5,12 @@ from typing import List, Optional from langchain.prompts import PromptTemplate from langchain_core.output_parsers import JsonOutputParser from langchain_core.runnables import RunnableParallel -from langchain_openai import ChatOpenAI +from langchain_openai import ChatOpenAI, AzureChatOpenAI +from langchain_mistralai import ChatMistralAI +from langchain_anthropic import ChatAnthropic +from langchain_groq import ChatGroq +from langchain_fireworks import ChatFireworks +from langchain_google_vertexai import ChatVertexAI from langchain_community.chat_models import ChatOllama from tqdm import tqdm from ..utils.logging import get_logger @@ -88,7 +93,9 @@ class GenerateAnswerNode(BaseNode): # Initialize the output parser if self.node_config.get("schema", None) is not None: output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"]) - if isinstance(self.llm_model, ChatOpenAI) and (self.llm_model.model_name=="gpt-4o-mini" or self.llm_model.model_name=="gpt-4o-2024-08-06"): + + # Use built-in structured output for providers that allow it + if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI, ChatAnthropic, ChatFireworks, ChatGroq, ChatVertexAI)): self.llm_model = self.llm_model.with_structured_output( schema = self.node_config["schema"], method="json_schema") @@ -98,7 +105,7 @@ class GenerateAnswerNode(BaseNode): format_instructions = output_parser.get_format_instructions() - if isinstance(self.llm_model, ChatOpenAI) and not self.script_creator or self.force and not self.script_creator or self.is_md_scraper: + if isinstance(self.llm_model, (ChatOpenAI, AzureChatOpenAI)) and not self.script_creator or self.force and not self.script_creator or self.is_md_scraper: template_no_chunks_prompt = TEMPLATE_NO_CHUNKS_MD template_chunks_prompt = TEMPLATE_CHUNKS_MD template_merge_prompt = TEMPLATE_MERGE_MD