feat: add generate_answer node paralellization
Some checks are pending
/ build (push) Waiting to run

Co-Authored-By: Federico Minutoli <40361744+DiTo97@users.noreply.github.com>
This commit is contained in:
Marco Vinciguerra 2024-07-22 19:58:33 +02:00
parent 2ae19aee56
commit 0c4b2908d9

View File

@ -1,13 +1,12 @@
""" """
GenerateAnswerNode Module GenerateAnswerNode Module
""" """
import asyncio
from typing import List, Optional from typing import List, Optional
from langchain.prompts import PromptTemplate from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser from langchain_core.output_parsers import JsonOutputParser
from langchain_core.runnables import RunnableParallel from langchain_core.runnables import RunnableParallel
from tqdm import tqdm from tqdm import tqdm
import asyncio
from ..utils.merge_results import merge_results from ..utils.merge_results import merge_results
from ..utils.logging import get_logger from ..utils.logging import get_logger
from ..models import Ollama, OpenAI from ..models import Ollama, OpenAI
@ -136,21 +135,18 @@ class GenerateAnswerNode(BaseNode):
chain_name = f"chunk{i+1}" chain_name = f"chunk{i+1}"
chains_dict[chain_name] = prompt | self.llm_model | output_parser chains_dict[chain_name] = prompt | self.llm_model | output_parser
async_runner = RunnableParallel(**chains_dict)
async def process_chains(): batch_results = async_runner.invoke({"question": user_prompt})
async_runner = RunnableParallel()
for chain_name, chain in chains_dict.items():
async_runner.add(chain.ainvoke([{"question": user_prompt}] * len(doc)))
batch_results = await async_runner.run()
return batch_results
loop = asyncio.get_event_loop() merge_prompt = PromptTemplate(
batch_answers = loop.run_until_complete(process_chains()) template = template_merge_prompt,
input_variables=["context", "question"],
partial_variables={"format_instructions": format_instructions},
)
# Merge batch results (assuming same structure) merge_chain = merge_prompt | self.llm_model | output_parser
merged_answer = merge_results(batch_answers) answer = merge_chain.invoke({"context": batch_results, "question": user_prompt})
answers = merged_answer
state.update({self.output[0]: answers}) state.update({self.output[0]: answer})
return state return state