diff --git a/scrapegraphai/nodes/generate_answer_node.py b/scrapegraphai/nodes/generate_answer_node.py index 1fa38995..d1a42965 100644 --- a/scrapegraphai/nodes/generate_answer_node.py +++ b/scrapegraphai/nodes/generate_answer_node.py @@ -52,9 +52,6 @@ class GenerateAnswerNode(BaseNode): super().__init__(node_name, "node", input, output, 2, node_config) self.llm_model = node_config["llm_model"] - if hasattr(self.llm_model, 'request_timeout'): - self.llm_model.request_timeout = node_config.get("timeout", 30) - if isinstance(node_config["llm_model"], ChatOllama): self.llm_model.format = "json" @@ -63,7 +60,22 @@ class GenerateAnswerNode(BaseNode): self.script_creator = node_config.get("script_creator", False) self.is_md_scraper = node_config.get("is_md_scraper", False) self.additional_info = node_config.get("additional_info") - self.timeout = node_config.get("timeout", 30) + self.timeout = node_config.get("timeout", 120) + + def invoke_with_timeout(self, chain, inputs, timeout): + """Helper method to invoke chain with timeout""" + try: + start_time = time.time() + response = chain.invoke(inputs) + if time.time() - start_time > timeout: + raise Timeout(f"Response took longer than {timeout} seconds") + return response + except Timeout as e: + self.logger.error(f"Timeout error: {str(e)}") + raise + except Exception as e: + self.logger.error(f"Error during chain execution: {str(e)}") + raise def execute(self, state: dict) -> dict: """ @@ -119,21 +131,6 @@ class GenerateAnswerNode(BaseNode): template_chunks_prompt = self.additional_info + template_chunks_prompt template_merge_prompt = self.additional_info + template_merge_prompt - def invoke_with_timeout(chain, inputs, timeout): - try: - with get_openai_callback() as cb: - start_time = time.time() - response = chain.invoke(inputs) - if time.time() - start_time > timeout: - raise Timeout(f"Response took longer than {timeout} seconds") - return response - except Timeout as e: - self.logger.error(f"Timeout error: {str(e)}") - raise - except Exception as e: - self.logger.error(f"Error during chain execution: {str(e)}") - raise - if len(doc) == 1: prompt = PromptTemplate( template=template_no_chunks_prompt, @@ -141,17 +138,15 @@ class GenerateAnswerNode(BaseNode): partial_variables={"context": doc, "format_instructions": format_instructions} ) chain = prompt | self.llm_model + if output_parser: + chain = chain | output_parser try: - raw_response = invoke_with_timeout(chain, {"question": user_prompt}, self.timeout) + answer = self.invoke_with_timeout(chain, {"question": user_prompt}, self.timeout) except Timeout: state.update({self.output[0]: {"error": "Response timeout exceeded"}}) return state - if output_parser: - chain = chain | output_parser - - answer = chain.invoke({"question": user_prompt}) state.update({self.output[0]: answer}) return state @@ -171,9 +166,9 @@ class GenerateAnswerNode(BaseNode): async_runner = RunnableParallel(**chains_dict) try: - batch_results = invoke_with_timeout( - async_runner, - {"question": user_prompt}, + batch_results = self.invoke_with_timeout( + async_runner, + {"question": user_prompt}, self.timeout ) except Timeout: @@ -190,7 +185,7 @@ class GenerateAnswerNode(BaseNode): if output_parser: merge_chain = merge_chain | output_parser try: - answer = invoke_with_timeout( + answer = self.invoke_with_timeout( merge_chain, {"context": batch_results, "question": user_prompt}, self.timeout