mirror of
https://github.com/VinciGit00/Scrapegraph-ai.git
synced 2026-06-28 21:01:55 +08:00
fix: generate answer node timeout
This commit is contained in:
parent
86bf4f2402
commit
32ef5547f1
@ -52,9 +52,6 @@ class GenerateAnswerNode(BaseNode):
|
|||||||
super().__init__(node_name, "node", input, output, 2, node_config)
|
super().__init__(node_name, "node", input, output, 2, node_config)
|
||||||
self.llm_model = node_config["llm_model"]
|
self.llm_model = node_config["llm_model"]
|
||||||
|
|
||||||
if hasattr(self.llm_model, 'request_timeout'):
|
|
||||||
self.llm_model.request_timeout = node_config.get("timeout", 30)
|
|
||||||
|
|
||||||
if isinstance(node_config["llm_model"], ChatOllama):
|
if isinstance(node_config["llm_model"], ChatOllama):
|
||||||
self.llm_model.format = "json"
|
self.llm_model.format = "json"
|
||||||
|
|
||||||
@ -63,7 +60,22 @@ class GenerateAnswerNode(BaseNode):
|
|||||||
self.script_creator = node_config.get("script_creator", False)
|
self.script_creator = node_config.get("script_creator", False)
|
||||||
self.is_md_scraper = node_config.get("is_md_scraper", False)
|
self.is_md_scraper = node_config.get("is_md_scraper", False)
|
||||||
self.additional_info = node_config.get("additional_info")
|
self.additional_info = node_config.get("additional_info")
|
||||||
self.timeout = node_config.get("timeout", 30)
|
self.timeout = node_config.get("timeout", 120)
|
||||||
|
|
||||||
|
def invoke_with_timeout(self, chain, inputs, timeout):
|
||||||
|
"""Helper method to invoke chain with timeout"""
|
||||||
|
try:
|
||||||
|
start_time = time.time()
|
||||||
|
response = chain.invoke(inputs)
|
||||||
|
if time.time() - start_time > timeout:
|
||||||
|
raise Timeout(f"Response took longer than {timeout} seconds")
|
||||||
|
return response
|
||||||
|
except Timeout as e:
|
||||||
|
self.logger.error(f"Timeout error: {str(e)}")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Error during chain execution: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
def execute(self, state: dict) -> dict:
|
def execute(self, state: dict) -> dict:
|
||||||
"""
|
"""
|
||||||
@ -119,21 +131,6 @@ class GenerateAnswerNode(BaseNode):
|
|||||||
template_chunks_prompt = self.additional_info + template_chunks_prompt
|
template_chunks_prompt = self.additional_info + template_chunks_prompt
|
||||||
template_merge_prompt = self.additional_info + template_merge_prompt
|
template_merge_prompt = self.additional_info + template_merge_prompt
|
||||||
|
|
||||||
def invoke_with_timeout(chain, inputs, timeout):
|
|
||||||
try:
|
|
||||||
with get_openai_callback() as cb:
|
|
||||||
start_time = time.time()
|
|
||||||
response = chain.invoke(inputs)
|
|
||||||
if time.time() - start_time > timeout:
|
|
||||||
raise Timeout(f"Response took longer than {timeout} seconds")
|
|
||||||
return response
|
|
||||||
except Timeout as e:
|
|
||||||
self.logger.error(f"Timeout error: {str(e)}")
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"Error during chain execution: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
if len(doc) == 1:
|
if len(doc) == 1:
|
||||||
prompt = PromptTemplate(
|
prompt = PromptTemplate(
|
||||||
template=template_no_chunks_prompt,
|
template=template_no_chunks_prompt,
|
||||||
@ -141,17 +138,15 @@ class GenerateAnswerNode(BaseNode):
|
|||||||
partial_variables={"context": doc, "format_instructions": format_instructions}
|
partial_variables={"context": doc, "format_instructions": format_instructions}
|
||||||
)
|
)
|
||||||
chain = prompt | self.llm_model
|
chain = prompt | self.llm_model
|
||||||
|
if output_parser:
|
||||||
|
chain = chain | output_parser
|
||||||
|
|
||||||
try:
|
try:
|
||||||
raw_response = invoke_with_timeout(chain, {"question": user_prompt}, self.timeout)
|
answer = self.invoke_with_timeout(chain, {"question": user_prompt}, self.timeout)
|
||||||
except Timeout:
|
except Timeout:
|
||||||
state.update({self.output[0]: {"error": "Response timeout exceeded"}})
|
state.update({self.output[0]: {"error": "Response timeout exceeded"}})
|
||||||
return state
|
return state
|
||||||
|
|
||||||
if output_parser:
|
|
||||||
chain = chain | output_parser
|
|
||||||
|
|
||||||
answer = chain.invoke({"question": user_prompt})
|
|
||||||
state.update({self.output[0]: answer})
|
state.update({self.output[0]: answer})
|
||||||
return state
|
return state
|
||||||
|
|
||||||
@ -171,9 +166,9 @@ class GenerateAnswerNode(BaseNode):
|
|||||||
|
|
||||||
async_runner = RunnableParallel(**chains_dict)
|
async_runner = RunnableParallel(**chains_dict)
|
||||||
try:
|
try:
|
||||||
batch_results = invoke_with_timeout(
|
batch_results = self.invoke_with_timeout(
|
||||||
async_runner,
|
async_runner,
|
||||||
{"question": user_prompt},
|
{"question": user_prompt},
|
||||||
self.timeout
|
self.timeout
|
||||||
)
|
)
|
||||||
except Timeout:
|
except Timeout:
|
||||||
@ -190,7 +185,7 @@ class GenerateAnswerNode(BaseNode):
|
|||||||
if output_parser:
|
if output_parser:
|
||||||
merge_chain = merge_chain | output_parser
|
merge_chain = merge_chain | output_parser
|
||||||
try:
|
try:
|
||||||
answer = invoke_with_timeout(
|
answer = self.invoke_with_timeout(
|
||||||
merge_chain,
|
merge_chain,
|
||||||
{"context": batch_results, "question": user_prompt},
|
{"context": batch_results, "question": user_prompt},
|
||||||
self.timeout
|
self.timeout
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user