diff --git a/scrapegraphai/nodes/generate_answer_node.py b/scrapegraphai/nodes/generate_answer_node.py index 81812598..5022b16f 100644 --- a/scrapegraphai/nodes/generate_answer_node.py +++ b/scrapegraphai/nodes/generate_answer_node.py @@ -6,9 +6,10 @@ from typing import List, Optional from langchain.prompts import PromptTemplate from langchain_core.output_parsers import JsonOutputParser from langchain_core.runnables import RunnableParallel +from langchain_openai import ChatOpenAI +from langchain_community.chat_models import ChatOllama from tqdm import tqdm from ..utils.logging import get_logger -from ..models import Ollama, OpenAI from .base_node import BaseNode from ..helpers import template_chunks, template_no_chunks, template_merge, template_chunks_md, template_no_chunks_md, template_merge_md @@ -41,7 +42,7 @@ class GenerateAnswerNode(BaseNode): self.llm_model = node_config["llm_model"] - if isinstance(node_config["llm_model"], Ollama): + if isinstance(node_config["llm_model"], ChatOllama): self.llm_model.format="json" self.verbose = ( @@ -93,7 +94,7 @@ class GenerateAnswerNode(BaseNode): format_instructions = output_parser.get_format_instructions() - if isinstance(self.llm_model, OpenAI) and not self.script_creator or self.force and not self.script_creator or self.is_md_scraper: + if isinstance(self.llm_model, ChatOpenAI) and not self.script_creator or self.force and not self.script_creator or self.is_md_scraper: template_no_chunks_prompt = template_no_chunks_md template_chunks_prompt = template_chunks_md template_merge_prompt = template_merge_md