fix: generate answer node

This commit is contained in:
Marco Vinciguerra 2024-10-13 10:25:36 +02:00
parent 54b37bbda5
commit 431b2093be

View File

@ -57,7 +57,7 @@ class GenerateAnswerNode(BaseNode):
self.is_md_scraper = node_config.get("is_md_scraper", False)
self.additional_info = node_config.get("additional_info")
async def execute(self, state: dict) -> dict:
def execute(self, state: dict) -> dict:
"""
Executes the GenerateAnswerNode.
@ -123,7 +123,7 @@ class GenerateAnswerNode(BaseNode):
chain = prompt | self.llm_model
if output_parser:
chain = chain | output_parser
answer = await chain.ainvoke({"question": user_prompt})
answer = chain.invoke({"question": user_prompt})
state.update({self.output[0]: answer})
return state
@ -143,7 +143,7 @@ class GenerateAnswerNode(BaseNode):
chains_dict[chain_name] = chains_dict[chain_name] | output_parser
async_runner = RunnableParallel(**chains_dict)
batch_results = await async_runner.ainvoke({"question": user_prompt})
batch_results = async_runner.invoke({"question": user_prompt})
merge_prompt = PromptTemplate(
template=template_merge_prompt,
@ -154,7 +154,7 @@ class GenerateAnswerNode(BaseNode):
merge_chain = merge_prompt | self.llm_model
if output_parser:
merge_chain = merge_chain | output_parser
answer = await merge_chain.ainvoke({"context": batch_results, "question": user_prompt})
answer = merge_chain.invoke({"context": batch_results, "question": user_prompt})
state.update({self.output[0]: answer})
return state