mirror of
https://github.com/VinciGit00/Scrapegraph-ai.git
synced 2026-07-04 21:00:36 +08:00
feat: add generate_answer node paralellization
Some checks are pending
/ build (push) Waiting to run
Some checks are pending
/ build (push) Waiting to run
Co-Authored-By: Federico Minutoli <40361744+DiTo97@users.noreply.github.com>
This commit is contained in:
parent
2ae19aee56
commit
0c4b2908d9
@ -1,13 +1,12 @@
|
||||
"""
|
||||
GenerateAnswerNode Module
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import List, Optional
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain_core.output_parsers import JsonOutputParser
|
||||
from langchain_core.runnables import RunnableParallel
|
||||
from tqdm import tqdm
|
||||
import asyncio
|
||||
from ..utils.merge_results import merge_results
|
||||
from ..utils.logging import get_logger
|
||||
from ..models import Ollama, OpenAI
|
||||
@ -136,21 +135,18 @@ class GenerateAnswerNode(BaseNode):
|
||||
chain_name = f"chunk{i+1}"
|
||||
chains_dict[chain_name] = prompt | self.llm_model | output_parser
|
||||
|
||||
async_runner = RunnableParallel(**chains_dict)
|
||||
|
||||
async def process_chains():
|
||||
async_runner = RunnableParallel()
|
||||
for chain_name, chain in chains_dict.items():
|
||||
async_runner.add(chain.ainvoke([{"question": user_prompt}] * len(doc)))
|
||||
|
||||
batch_results = await async_runner.run()
|
||||
return batch_results
|
||||
batch_results = async_runner.invoke({"question": user_prompt})
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
batch_answers = loop.run_until_complete(process_chains())
|
||||
merge_prompt = PromptTemplate(
|
||||
template = template_merge_prompt,
|
||||
input_variables=["context", "question"],
|
||||
partial_variables={"format_instructions": format_instructions},
|
||||
)
|
||||
|
||||
# Merge batch results (assuming same structure)
|
||||
merged_answer = merge_results(batch_answers)
|
||||
answers = merged_answer
|
||||
merge_chain = merge_prompt | self.llm_model | output_parser
|
||||
answer = merge_chain.invoke({"context": batch_results, "question": user_prompt})
|
||||
|
||||
state.update({self.output[0]: answers})
|
||||
state.update({self.output[0]: answer})
|
||||
return state
|
||||
|
||||
Loading…
Reference in New Issue
Block a user