feat: add generate_answer node paralellization
Some checks are pending
/ build (push) Waiting to run

Co-Authored-By: Federico Minutoli <40361744+DiTo97@users.noreply.github.com>
This commit is contained in:
Marco Vinciguerra 2024-07-22 19:58:33 +02:00
parent 2ae19aee56
commit 0c4b2908d9

View File

@ -1,13 +1,12 @@
"""
GenerateAnswerNode Module
"""
import asyncio
from typing import List, Optional
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.runnables import RunnableParallel
from tqdm import tqdm
import asyncio
from ..utils.merge_results import merge_results
from ..utils.logging import get_logger
from ..models import Ollama, OpenAI
@ -136,21 +135,18 @@ class GenerateAnswerNode(BaseNode):
chain_name = f"chunk{i+1}"
chains_dict[chain_name] = prompt | self.llm_model | output_parser
async_runner = RunnableParallel(**chains_dict)
async def process_chains():
async_runner = RunnableParallel()
for chain_name, chain in chains_dict.items():
async_runner.add(chain.ainvoke([{"question": user_prompt}] * len(doc)))
batch_results = await async_runner.run()
return batch_results
batch_results = async_runner.invoke({"question": user_prompt})
loop = asyncio.get_event_loop()
batch_answers = loop.run_until_complete(process_chains())
merge_prompt = PromptTemplate(
template = template_merge_prompt,
input_variables=["context", "question"],
partial_variables={"format_instructions": format_instructions},
)
# Merge batch results (assuming same structure)
merged_answer = merge_results(batch_answers)
answers = merged_answer
merge_chain = merge_prompt | self.llm_model | output_parser
answer = merge_chain.invoke({"context": batch_results, "question": user_prompt})
state.update({self.output[0]: answers})
state.update({self.output[0]: answer})
return state