mirror of
https://github.com/VinciGit00/Scrapegraph-ai.git
synced 2026-07-04 21:00:36 +08:00
feat: add generate_answer node paralellization
Some checks are pending
/ build (push) Waiting to run
Some checks are pending
/ build (push) Waiting to run
Co-Authored-By: Federico Minutoli <40361744+DiTo97@users.noreply.github.com>
This commit is contained in:
parent
2ae19aee56
commit
0c4b2908d9
@ -1,13 +1,12 @@
|
|||||||
"""
|
"""
|
||||||
GenerateAnswerNode Module
|
GenerateAnswerNode Module
|
||||||
"""
|
"""
|
||||||
|
import asyncio
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
from langchain.prompts import PromptTemplate
|
from langchain.prompts import PromptTemplate
|
||||||
from langchain_core.output_parsers import JsonOutputParser
|
from langchain_core.output_parsers import JsonOutputParser
|
||||||
from langchain_core.runnables import RunnableParallel
|
from langchain_core.runnables import RunnableParallel
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import asyncio
|
|
||||||
from ..utils.merge_results import merge_results
|
from ..utils.merge_results import merge_results
|
||||||
from ..utils.logging import get_logger
|
from ..utils.logging import get_logger
|
||||||
from ..models import Ollama, OpenAI
|
from ..models import Ollama, OpenAI
|
||||||
@ -136,21 +135,18 @@ class GenerateAnswerNode(BaseNode):
|
|||||||
chain_name = f"chunk{i+1}"
|
chain_name = f"chunk{i+1}"
|
||||||
chains_dict[chain_name] = prompt | self.llm_model | output_parser
|
chains_dict[chain_name] = prompt | self.llm_model | output_parser
|
||||||
|
|
||||||
|
async_runner = RunnableParallel(**chains_dict)
|
||||||
|
|
||||||
async def process_chains():
|
batch_results = async_runner.invoke({"question": user_prompt})
|
||||||
async_runner = RunnableParallel()
|
|
||||||
for chain_name, chain in chains_dict.items():
|
|
||||||
async_runner.add(chain.ainvoke([{"question": user_prompt}] * len(doc)))
|
|
||||||
|
|
||||||
batch_results = await async_runner.run()
|
|
||||||
return batch_results
|
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
merge_prompt = PromptTemplate(
|
||||||
batch_answers = loop.run_until_complete(process_chains())
|
template = template_merge_prompt,
|
||||||
|
input_variables=["context", "question"],
|
||||||
|
partial_variables={"format_instructions": format_instructions},
|
||||||
|
)
|
||||||
|
|
||||||
# Merge batch results (assuming same structure)
|
merge_chain = merge_prompt | self.llm_model | output_parser
|
||||||
merged_answer = merge_results(batch_answers)
|
answer = merge_chain.invoke({"context": batch_results, "question": user_prompt})
|
||||||
answers = merged_answer
|
|
||||||
|
|
||||||
state.update({self.output[0]: answers})
|
state.update({self.output[0]: answer})
|
||||||
return state
|
return state
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user