diff --git a/scrapegraphai/nodes/generate_answer_node.py b/scrapegraphai/nodes/generate_answer_node.py index eb440a75..9cd5dce5 100644 --- a/scrapegraphai/nodes/generate_answer_node.py +++ b/scrapegraphai/nodes/generate_answer_node.py @@ -1,13 +1,12 @@ """ GenerateAnswerNode Module """ - +import asyncio from typing import List, Optional from langchain.prompts import PromptTemplate from langchain_core.output_parsers import JsonOutputParser from langchain_core.runnables import RunnableParallel from tqdm import tqdm -import asyncio from ..utils.merge_results import merge_results from ..utils.logging import get_logger from ..models import Ollama, OpenAI @@ -136,21 +135,18 @@ class GenerateAnswerNode(BaseNode): chain_name = f"chunk{i+1}" chains_dict[chain_name] = prompt | self.llm_model | output_parser + async_runner = RunnableParallel(**chains_dict) - async def process_chains(): - async_runner = RunnableParallel() - for chain_name, chain in chains_dict.items(): - async_runner.add(chain.ainvoke([{"question": user_prompt}] * len(doc))) - - batch_results = await async_runner.run() - return batch_results + batch_results = async_runner.invoke({"question": user_prompt}) - loop = asyncio.get_event_loop() - batch_answers = loop.run_until_complete(process_chains()) + merge_prompt = PromptTemplate( + template = template_merge_prompt, + input_variables=["context", "question"], + partial_variables={"format_instructions": format_instructions}, + ) - # Merge batch results (assuming same structure) - merged_answer = merge_results(batch_answers) - answers = merged_answer + merge_chain = merge_prompt | self.llm_model | output_parser + answer = merge_chain.invoke({"context": batch_results, "question": user_prompt}) - state.update({self.output[0]: answers}) + state.update({self.output[0]: answer}) return state