mirror of
https://github.com/VinciGit00/Scrapegraph-ai.git
synced 2026-06-28 21:01:55 +08:00
feat: refactoring of generate answer node
This commit is contained in:
parent
3e8c043473
commit
1f465e636d
@ -16,6 +16,10 @@ from ..prompts import (
|
||||
TEMPLATE_CHUNKS, TEMPLATE_NO_CHUNKS, TEMPLATE_MERGE,
|
||||
TEMPLATE_CHUNKS_MD, TEMPLATE_NO_CHUNKS_MD, TEMPLATE_MERGE_MD
|
||||
)
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
from langchain.callbacks import get_openai_callback
|
||||
from requests.exceptions import Timeout
|
||||
import time
|
||||
|
||||
class GenerateAnswerNode(BaseNode):
|
||||
"""
|
||||
@ -56,6 +60,7 @@ class GenerateAnswerNode(BaseNode):
|
||||
self.script_creator = node_config.get("script_creator", False)
|
||||
self.is_md_scraper = node_config.get("is_md_scraper", False)
|
||||
self.additional_info = node_config.get("additional_info")
|
||||
self.timeout = node_config.get("timeout", 30)
|
||||
|
||||
def execute(self, state: dict) -> dict:
|
||||
"""
|
||||
@ -114,6 +119,21 @@ class GenerateAnswerNode(BaseNode):
|
||||
template_chunks_prompt = self.additional_info + template_chunks_prompt
|
||||
template_merge_prompt = self.additional_info + template_merge_prompt
|
||||
|
||||
def invoke_with_timeout(chain, inputs, timeout):
|
||||
try:
|
||||
with get_openai_callback() as cb:
|
||||
start_time = time.time()
|
||||
response = chain.invoke(inputs)
|
||||
if time.time() - start_time > timeout:
|
||||
raise Timeout(f"Response took longer than {timeout} seconds")
|
||||
return response
|
||||
except Timeout as e:
|
||||
self.logger.error(f"Timeout error: {str(e)}")
|
||||
raise
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error during chain execution: {str(e)}")
|
||||
raise
|
||||
|
||||
if len(doc) == 1:
|
||||
prompt = PromptTemplate(
|
||||
template=template_no_chunks_prompt,
|
||||
@ -121,7 +141,11 @@ class GenerateAnswerNode(BaseNode):
|
||||
partial_variables={"context": doc, "format_instructions": format_instructions}
|
||||
)
|
||||
chain = prompt | self.llm_model
|
||||
raw_response = chain.invoke({"question": user_prompt})
|
||||
try:
|
||||
raw_response = invoke_with_timeout(chain, {"question": user_prompt}, self.timeout)
|
||||
except Timeout:
|
||||
state.update({self.output[0]: {"error": "Response timeout exceeded"}})
|
||||
return state
|
||||
|
||||
if output_parser:
|
||||
try:
|
||||
@ -155,7 +179,15 @@ class GenerateAnswerNode(BaseNode):
|
||||
chains_dict[chain_name] = chains_dict[chain_name] | output_parser
|
||||
|
||||
async_runner = RunnableParallel(**chains_dict)
|
||||
batch_results = async_runner.invoke({"question": user_prompt})
|
||||
try:
|
||||
batch_results = invoke_with_timeout(
|
||||
async_runner,
|
||||
{"question": user_prompt},
|
||||
self.timeout
|
||||
)
|
||||
except Timeout:
|
||||
state.update({self.output[0]: {"error": "Response timeout exceeded during chunk processing"}})
|
||||
return state
|
||||
|
||||
merge_prompt = PromptTemplate(
|
||||
template=template_merge_prompt,
|
||||
@ -166,7 +198,15 @@ class GenerateAnswerNode(BaseNode):
|
||||
merge_chain = merge_prompt | self.llm_model
|
||||
if output_parser:
|
||||
merge_chain = merge_chain | output_parser
|
||||
answer = merge_chain.invoke({"context": batch_results, "question": user_prompt})
|
||||
try:
|
||||
answer = invoke_with_timeout(
|
||||
merge_chain,
|
||||
{"context": batch_results, "question": user_prompt},
|
||||
self.timeout
|
||||
)
|
||||
except Timeout:
|
||||
state.update({self.output[0]: {"error": "Response timeout exceeded during merge"}})
|
||||
return state
|
||||
|
||||
state.update({self.output[0]: answer})
|
||||
return state
|
||||
|
||||
Loading…
Reference in New Issue
Block a user