From f7ba1f30de72824c38cc4b4cc8e22020d4ad7a7e Mon Sep 17 00:00:00 2001 From: Marco Vinciguerra Date: Fri, 23 Aug 2024 11:33:22 +0200 Subject: [PATCH] refactoring of the code --- scrapegraphai/graphs/abstract_graph.py | 9 ++++---- .../nodes/generate_answer_csv_node.py | 23 ++++++++++--------- scrapegraphai/nodes/generate_scraper_node.py | 2 -- scrapegraphai/nodes/get_probable_tags_node.py | 4 ---- scrapegraphai/nodes/graph_iterator_node.py | 3 --- scrapegraphai/nodes/merge_answers_node.py | 5 ---- scrapegraphai/nodes/parse_node.py | 5 +--- scrapegraphai/nodes/rag_node.py | 12 +--------- scrapegraphai/nodes/robots_node.py | 2 -- scrapegraphai/nodes/search_internet_node.py | 7 +----- scrapegraphai/nodes/search_link_node.py | 20 ++++++---------- .../nodes/search_node_with_context.py | 3 --- scrapegraphai/nodes/text_to_speech_node.py | 7 ++---- scrapegraphai/utils/cleanup_html.py | 9 +------- scrapegraphai/utils/convert_to_csv.py | 7 +++--- scrapegraphai/utils/convert_to_json.py | 6 ++--- scrapegraphai/utils/convert_to_md.py | 3 ++- scrapegraphai/utils/logging.py | 7 +++--- scrapegraphai/utils/parse_state_keys.py | 13 ----------- scrapegraphai/utils/sys_dynamic_import.py | 5 ++-- 20 files changed, 44 insertions(+), 108 deletions(-) diff --git a/scrapegraphai/graphs/abstract_graph.py b/scrapegraphai/graphs/abstract_graph.py index 8434a0b6..27b0e3e6 100644 --- a/scrapegraphai/graphs/abstract_graph.py +++ b/scrapegraphai/graphs/abstract_graph.py @@ -7,8 +7,6 @@ from typing import Optional import uuid import warnings from pydantic import BaseModel -from langchain_community.chat_models import ErnieBotChat -from langchain_nvidia_ai_endpoints import ChatNVIDIA from langchain.chat_models import init_chat_model from ..helpers import models_tokens from ..models import ( @@ -147,8 +145,7 @@ class AbstractGraph(ABC): warnings.simplefilter("ignore") return init_chat_model(**llm_params) - known_models = ["chatgpt","gpt","openai", "azure_openai", "google_genai", "ollama", "oneapi", "nvidia", "groq", "google_vertexai", "bedrock", "mistralai", "hugging_face", "deepseek", "ernie", "fireworks"] - + known_models = {"chatgpt","gpt","openai", "azure_openai", "google_genai", "ollama", "oneapi", "nvidia", "groq", "google_vertexai", "bedrock", "mistralai", "hugging_face", "deepseek", "ernie", "fireworks"} if llm_params["model"].split("/")[0] not in known_models and llm_params["model"].split("-")[0] not in known_models: raise ValueError(f"Model '{llm_params['model']}' is not supported") @@ -198,6 +195,8 @@ class AbstractGraph(ABC): return DeepSeek(llm_params) elif "ernie" in llm_params["model"]: + from langchain_community.chat_models import ErnieBotChat + try: self.model_token = models_tokens["ernie"][llm_params["model"]] except KeyError: @@ -215,6 +214,8 @@ class AbstractGraph(ABC): return OneApi(llm_params) elif "nvidia" in llm_params["model"]: + from langchain_nvidia_ai_endpoints import ChatNVIDIA + try: self.model_token = models_tokens["nvidia"][llm_params["model"].split("/")[-1]] llm_params["model"] = "/".join(llm_params["model"].split("/")[1:]) diff --git a/scrapegraphai/nodes/generate_answer_csv_node.py b/scrapegraphai/nodes/generate_answer_csv_node.py index b7d7471a..0907dfb9 100644 --- a/scrapegraphai/nodes/generate_answer_csv_node.py +++ b/scrapegraphai/nodes/generate_answer_csv_node.py @@ -9,7 +9,8 @@ from langchain_core.runnables import RunnableParallel from tqdm import tqdm from ..utils.logging import get_logger from .base_node import BaseNode -from ..prompts.generate_answer_node_csv_prompts import TEMPLATE_CHUKS_CSV, TEMPLATE_NO_CHUKS_CSV, TEMPLATE_MERGE_CSV +from ..prompts.generate_answer_node_csv_prompts import (TEMPLATE_CHUKS_CSV, + TEMPLATE_NO_CHUKS_CSV, TEMPLATE_MERGE_CSV) class GenerateAnswerCSVNode(BaseNode): """ @@ -95,14 +96,14 @@ class GenerateAnswerCSVNode(BaseNode): else: output_parser = JsonOutputParser() - TEMPLATE_NO_CHUKS_CSV_prompt = TEMPLATE_NO_CHUKS_CSV - TEMPLATE_CHUKS_CSV_prompt = TEMPLATE_CHUKS_CSV - TEMPLATE_MERGE_CSV_prompt = TEMPLATE_MERGE_CSV + TEMPLATE_NO_CHUKS_CSV_PROMPT = TEMPLATE_NO_CHUKS_CSV + TEMPLATE_CHUKS_CSV_PROMPT = TEMPLATE_CHUKS_CSV + TEMPLATE_MERGE_CSV_PROMPT = TEMPLATE_MERGE_CSV if self.additional_info is not None: - TEMPLATE_NO_CHUKS_CSV_prompt = self.additional_info + TEMPLATE_NO_CHUKS_CSV - TEMPLATE_CHUKS_CSV_prompt = self.additional_info + TEMPLATE_CHUKS_CSV - TEMPLATE_MERGE_CSV_prompt = self.additional_info + TEMPLATE_MERGE_CSV + TEMPLATE_NO_CHUKS_CSV_PROMPT = self.additional_info + TEMPLATE_NO_CHUKS_CSV + TEMPLATE_CHUKS_CSV_PROMPT = self.additional_info + TEMPLATE_CHUKS_CSV + TEMPLATE_MERGE_CSV_PROMPT = self.additional_info + TEMPLATE_MERGE_CSV format_instructions = output_parser.get_format_instructions() @@ -110,7 +111,7 @@ class GenerateAnswerCSVNode(BaseNode): if len(doc) == 1: prompt = PromptTemplate( - template=TEMPLATE_NO_CHUKS_CSV_prompt, + template=TEMPLATE_NO_CHUKS_CSV_PROMPT, input_variables=["question"], partial_variables={ "context": doc, @@ -127,7 +128,7 @@ class GenerateAnswerCSVNode(BaseNode): tqdm(doc, desc="Processing chunks", disable=not self.verbose) ): prompt = PromptTemplate( - template=TEMPLATE_CHUKS_CSV_prompt, + template=TEMPLATE_CHUKS_CSV_PROMPT, input_variables=["question"], partial_variables={ "context": chunk, @@ -144,7 +145,7 @@ class GenerateAnswerCSVNode(BaseNode): batch_results = async_runner.invoke({"question": user_prompt}) merge_prompt = PromptTemplate( - template = TEMPLATE_MERGE_CSV_prompt, + template = TEMPLATE_MERGE_CSV_PROMPT, input_variables=["context", "question"], partial_variables={"format_instructions": format_instructions}, ) @@ -153,4 +154,4 @@ class GenerateAnswerCSVNode(BaseNode): answer = merge_chain.invoke({"context": batch_results, "question": user_prompt}) state.update({self.output[0]: answer}) - return state \ No newline at end of file + return state diff --git a/scrapegraphai/nodes/generate_scraper_node.py b/scrapegraphai/nodes/generate_scraper_node.py index 4f077091..a7c5e5bb 100644 --- a/scrapegraphai/nodes/generate_scraper_node.py +++ b/scrapegraphai/nodes/generate_scraper_node.py @@ -67,10 +67,8 @@ class GenerateScraperNode(BaseNode): self.logger.info(f"--- Executing {self.node_name} Node ---") - # Interpret input keys based on the provided input expression input_keys = self.get_input_keys(state) - # Fetching data from the state based on the input keys input_data = [state[key] for key in input_keys] user_prompt = input_data[0] diff --git a/scrapegraphai/nodes/get_probable_tags_node.py b/scrapegraphai/nodes/get_probable_tags_node.py index 4d12b985..9ba38283 100644 --- a/scrapegraphai/nodes/get_probable_tags_node.py +++ b/scrapegraphai/nodes/get_probable_tags_node.py @@ -58,10 +58,8 @@ class GetProbableTagsNode(BaseNode): self.logger.info(f"--- Executing {self.node_name} Node ---") - # Interpret input keys based on the provided input expression input_keys = self.get_input_keys(state) - # Fetching data from the state based on the input keys input_data = [state[key] for key in input_keys] user_prompt = input_data[0] @@ -88,10 +86,8 @@ class GetProbableTagsNode(BaseNode): }, ) - # Execute the chain to get probable tags tag_answer = tag_prompt | self.llm_model | output_parser probable_tags = tag_answer.invoke({"question": user_prompt}) - # Update the dictionary with probable tags state.update({self.output[0]: probable_tags}) return state diff --git a/scrapegraphai/nodes/graph_iterator_node.py b/scrapegraphai/nodes/graph_iterator_node.py index cd71b8e1..a765da28 100644 --- a/scrapegraphai/nodes/graph_iterator_node.py +++ b/scrapegraphai/nodes/graph_iterator_node.py @@ -103,7 +103,6 @@ class GraphIteratorNode(BaseNode): if graph_instance is None: raise ValueError("graph instance is required for concurrent execution") - # Assign depth level to the graph if "graph_depth" in graph_instance.config: graph_instance.config["graph_depth"] += 1 else: @@ -113,14 +112,12 @@ class GraphIteratorNode(BaseNode): participants = [] - # semaphore to limit the number of concurrent tasks semaphore = asyncio.Semaphore(batchsize) async def _async_run(graph): async with semaphore: return await asyncio.to_thread(graph.run) - # creates a deepcopy of the graph instance for each endpoint for url in urls: instance = copy.copy(graph_instance) instance.source = url diff --git a/scrapegraphai/nodes/merge_answers_node.py b/scrapegraphai/nodes/merge_answers_node.py index f00880e9..f2559a09 100644 --- a/scrapegraphai/nodes/merge_answers_node.py +++ b/scrapegraphai/nodes/merge_answers_node.py @@ -56,21 +56,17 @@ class MergeAnswersNode(BaseNode): self.logger.info(f"--- Executing {self.node_name} Node ---") - # Interpret input keys based on the provided input expression input_keys = self.get_input_keys(state) - # Fetching data from the state based on the input keys input_data = [state[key] for key in input_keys] user_prompt = input_data[0] answers = input_data[1] - # merge the answers in one string answers_str = "" for i, answer in enumerate(answers): answers_str += f"CONTENT WEBSITE {i+1}: {answer}\n" - # Initialize the output parser if self.node_config.get("schema", None) is not None: output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"]) else: @@ -90,6 +86,5 @@ class MergeAnswersNode(BaseNode): merge_chain = prompt_template | self.llm_model | output_parser answer = merge_chain.invoke({"user_prompt": user_prompt}) - # Update the state with the generated answer state.update({self.output[0]: answer}) return state diff --git a/scrapegraphai/nodes/parse_node.py b/scrapegraphai/nodes/parse_node.py index dbecdbf9..3e8ed5ac 100644 --- a/scrapegraphai/nodes/parse_node.py +++ b/scrapegraphai/nodes/parse_node.py @@ -59,13 +59,11 @@ class ParseNode(BaseNode): self.logger.info(f"--- Executing {self.node_name} Node ---") - # Interpret input keys based on the provided input expression input_keys = self.get_input_keys(state) - # Fetching data from the state based on the input keys input_data = [state[key] for key in input_keys] - # Parse the document docs_transformed = input_data[0] + if self.parse_html: docs_transformed = Html2TextTransformer().transform_documents(input_data[0]) docs_transformed = docs_transformed[0] @@ -77,7 +75,6 @@ class ParseNode(BaseNode): else: docs_transformed = docs_transformed[0] - # Adapt the chunk size, leaving room for the reply, the prompt and the schema chunk_size = self.node_config.get("chunk_size", 4096) chunk_size = min(chunk_size - 500, int(chunk_size * 0.9)) diff --git a/scrapegraphai/nodes/rag_node.py b/scrapegraphai/nodes/rag_node.py index ea5efe7a..868044a0 100644 --- a/scrapegraphai/nodes/rag_node.py +++ b/scrapegraphai/nodes/rag_node.py @@ -80,10 +80,8 @@ class RAGNode(BaseNode): self.logger.info(f"--- Executing {self.node_name} Node ---") - # Interpret input keys based on the provided input expression input_keys = self.get_input_keys(state) - # Fetching data from the state based on the input keys input_data = [state[key] for key in input_keys] user_prompt = input_data[0] @@ -102,7 +100,6 @@ class RAGNode(BaseNode): self.logger.info("--- (updated chunks metadata) ---") - # check if embedder_model is provided, if not use llm_model if self.embedder_model is not None: embeddings = self.embedder_model elif 'embeddings' in self.node_config: @@ -144,23 +141,17 @@ class RAGNode(BaseNode): pipeline_compressor = DocumentCompressorPipeline( transformers=[redundant_filter, relevant_filter] ) - # redundant + relevant filter compressor compression_retriever = ContextualCompressionRetriever( base_compressor=pipeline_compressor, base_retriever=retriever ) - # relevant filter compressor only - # compression_retriever = ContextualCompressionRetriever( - # base_compressor=relevant_filter, base_retriever=retriever - # ) - compressed_docs = compression_retriever.invoke(user_prompt) self.logger.info("--- (tokens compressed and vector stored) ---") state.update({self.output[0]: compressed_docs}) return state - + def _create_default_embedder(self, llm_config=None) -> object: """ @@ -223,7 +214,6 @@ class RAGNode(BaseNode): embedder_params = {**embedder_config} if "model_instance" in embedder_config: return embedder_params["model_instance"] - # Instantiate the embedding model based on the model name if "openai" in embedder_params["model"]: return OpenAIEmbeddings(api_key=embedder_params["api_key"]) if "azure" in embedder_params["model"]: diff --git a/scrapegraphai/nodes/robots_node.py b/scrapegraphai/nodes/robots_node.py index b33d49c1..6f9bc352 100644 --- a/scrapegraphai/nodes/robots_node.py +++ b/scrapegraphai/nodes/robots_node.py @@ -75,10 +75,8 @@ class RobotsNode(BaseNode): self.logger.info(f"--- Executing {self.node_name} Node ---") - # Interpret input keys based on the provided input expression input_keys = self.get_input_keys(state) - # Fetching data from the state based on the input keys input_data = [state[key] for key in input_keys] source = input_data[0] diff --git a/scrapegraphai/nodes/search_internet_node.py b/scrapegraphai/nodes/search_internet_node.py index b23e8e8b..df1b6277 100644 --- a/scrapegraphai/nodes/search_internet_node.py +++ b/scrapegraphai/nodes/search_internet_node.py @@ -67,7 +67,6 @@ class SearchInternetNode(BaseNode): input_keys = self.get_input_keys(state) - # Fetching data from the state based on the input keys input_data = [state[key] for key in input_keys] user_prompt = input_data[0] @@ -79,10 +78,8 @@ class SearchInternetNode(BaseNode): input_variables=["user_prompt"], ) - # Execute the chain to get the search query search_answer = search_prompt | self.llm_model | output_parser - - # Ollama: Use no json format when creating the search query + if isinstance(self.llm_model, ChatOllama) and self.llm_model.format == 'json': self.llm_model.format = None search_query = search_answer.invoke({"user_prompt": user_prompt})[0] @@ -96,9 +93,7 @@ class SearchInternetNode(BaseNode): search_engine=self.search_engine) if len(answer) == 0: - # raise an exception if no answer is found raise ValueError("Zero results found for the search query.") - # Update the state with the generated answer state.update({self.output[0]: answer}) return state diff --git a/scrapegraphai/nodes/search_link_node.py b/scrapegraphai/nodes/search_link_node.py index c39c469d..60c3e1aa 100644 --- a/scrapegraphai/nodes/search_link_node.py +++ b/scrapegraphai/nodes/search_link_node.py @@ -49,7 +49,6 @@ class SearchLinkNode(BaseNode): self.filter_config = {**default_filters.filter_dict, **provided_filter_config} self.filter_links = True else: - # Skip filtering if not enabled self.filter_config = None self.filter_links = False @@ -58,29 +57,26 @@ class SearchLinkNode(BaseNode): def _is_same_domain(self, url, domain): if not self.filter_links or not self.filter_config.get("diff_domain_filter", True): - return True # Skip the domain filter if not enabled + return True parsed_url = urlparse(url) parsed_domain = urlparse(domain) return parsed_url.netloc == parsed_domain.netloc def _is_image_url(self, url): if not self.filter_links: - return False # Skip image filtering if filtering is not enabled - + return False image_extensions = self.filter_config.get("img_exts", []) return any(url.lower().endswith(ext) for ext in image_extensions) def _is_language_url(self, url): if not self.filter_links: - return False # Skip language filtering if filtering is not enabled + return False lang_indicators = self.filter_config.get("lang_indicators", []) parsed_url = urlparse(url) query_params = parse_qs(parsed_url.query) - # Check if the URL path or query string indicates a language-specific version return any(indicator in parsed_url.path.lower() or indicator in query_params for indicator in lang_indicators) - def _is_potentially_irrelevant(self, url): if not self.filter_links: return False # Skip irrelevant URL filtering if filtering is not enabled @@ -88,12 +84,11 @@ class SearchLinkNode(BaseNode): irrelevant_keywords = self.filter_config.get("irrelevant_keywords", []) return any(keyword in url.lower() for keyword in irrelevant_keywords) - + def execute(self, state: dict) -> dict: """ - Filter out relevant links from the webpage that are relavant to prompt. Out of the filtered links, also - ensure that all links are navigable. - + Filter out relevant links from the webpage that are relavant to prompt. + Out of the filtered links, also ensure that all links are navigable. Args: state (dict): The current state of the graph. The input keys will be used to fetch the correct data types from the state. @@ -108,7 +103,6 @@ class SearchLinkNode(BaseNode): self.logger.info(f"--- Executing {self.node_name} Node ---") - parsed_content_chunks = state.get("doc") source_url = state.get("url") or state.get("local_dir") output_parser = JsonOutputParser() @@ -148,7 +142,7 @@ class SearchLinkNode(BaseNode): except Exception as e: # Fallback approach: Using the LLM to extract links self.logger.error(f"Error extracting links: {e}. Falling back to LLM.") - + merge_prompt = PromptTemplate( template=TEMPLATE_RELEVANT_LINKS, input_variables=["content", "user_prompt"], diff --git a/scrapegraphai/nodes/search_node_with_context.py b/scrapegraphai/nodes/search_node_with_context.py index 37a05d0f..7343b64c 100644 --- a/scrapegraphai/nodes/search_node_with_context.py +++ b/scrapegraphai/nodes/search_node_with_context.py @@ -58,10 +58,8 @@ class SearchLinksWithContext(BaseNode): self.logger.info(f"--- Executing {self.node_name} Node ---") - # Interpret input keys based on the provided input expression input_keys = self.get_input_keys(state) - # Fetching data from the state based on the input keys input_data = [state[key] for key in input_keys] doc = input_data[1] @@ -71,7 +69,6 @@ class SearchLinksWithContext(BaseNode): result = [] - # Use tqdm to add progress bar for i, chunk in enumerate( tqdm(doc, desc="Processing chunks", disable=not self.verbose) ): diff --git a/scrapegraphai/nodes/text_to_speech_node.py b/scrapegraphai/nodes/text_to_speech_node.py index e8e43cb5..dfa3a64e 100644 --- a/scrapegraphai/nodes/text_to_speech_node.py +++ b/scrapegraphai/nodes/text_to_speech_node.py @@ -43,7 +43,8 @@ class TextToSpeechNode(BaseNode): correct data types from the state. Returns: - dict: The updated state with the output key containing the audio generated from the text. + dict: The updated state with the output + key containing the audio generated from the text. Raises: KeyError: If the input keys are not found in the state, indicating that the @@ -52,15 +53,11 @@ class TextToSpeechNode(BaseNode): self.logger.info(f"--- Executing {self.node_name} Node ---") - # Interpret input keys based on the provided input expression input_keys = self.get_input_keys(state) - # Fetching data from the state based on the input keys input_data = [state[key] for key in input_keys] - # get the text to translate text2translate = str(next(iter(input_data[0].values()))) - # text2translate = str(input_data[0]) audio = self.tts_model.run(text2translate) diff --git a/scrapegraphai/utils/cleanup_html.py b/scrapegraphai/utils/cleanup_html.py index 23c9f803..6c7c3c4c 100644 --- a/scrapegraphai/utils/cleanup_html.py +++ b/scrapegraphai/utils/cleanup_html.py @@ -28,35 +28,28 @@ def cleanup_html(html_content: str, base_url: str) -> str: soup = BeautifulSoup(html_content, 'html.parser') - # Title Extraction title_tag = soup.find('title') title = title_tag.get_text() if title_tag else "" - # Script and Style Tag Removal for tag in soup.find_all(['script', 'style']): tag.extract() - # Links extraction link_urls = [urljoin(base_url, link['href']) for link in soup.find_all('a', href=True)] - # Images extraction images = soup.find_all('img') image_urls = [] for image in images: if 'src' in image.attrs: - # if http or https is not present in the image url, join it with the base url if 'http' not in image['src']: image_urls.append(urljoin(base_url, image['src'])) else: image_urls.append(image['src']) - # Body Extraction (if it exists) body_content = soup.find('body') if body_content: - # Minify the HTML within the body tag minimized_body = minify(str(body_content)) return title, minimized_body, link_urls, image_urls else: - raise ValueError(f"""No HTML body content found, please try setting the 'headless' + raise ValueError(f"""No HTML body content found, please try setting the 'headless' flag to False in the graph configuration. HTML content: {html_content}""") diff --git a/scrapegraphai/utils/convert_to_csv.py b/scrapegraphai/utils/convert_to_csv.py index 44897c7c..850f9416 100644 --- a/scrapegraphai/utils/convert_to_csv.py +++ b/scrapegraphai/utils/convert_to_csv.py @@ -29,9 +29,8 @@ def convert_to_csv(data: dict, filename: str, position: str = None) -> None: """ if ".csv" in filename: - filename = filename.replace(".csv", "") # Remove .csv extension + filename = filename.replace(".csv", "") - # Get the directory of the caller script if position is not provided if position is None: caller_dir = os.path.dirname(os.path.abspath(sys.argv[0])) position = caller_dir @@ -40,7 +39,7 @@ def convert_to_csv(data: dict, filename: str, position: str = None) -> None: if not isinstance(data, dict): raise TypeError("Input data must be a dictionary") - os.makedirs(position, exist_ok=True) # Create directory if needed + os.makedirs(position, exist_ok=True) df = pd.DataFrame.from_dict(data, orient='index') df.to_csv(os.path.join(position, f"{filename}.csv"), index=False) @@ -52,4 +51,4 @@ def convert_to_csv(data: dict, filename: str, position: str = None) -> None: raise PermissionError( f"You don't have permission to write to '{position}'.") from pe except Exception as e: - raise e # Re-raise other potential errors + raise e diff --git a/scrapegraphai/utils/convert_to_json.py b/scrapegraphai/utils/convert_to_json.py index 45b1ea55..4e1711f1 100644 --- a/scrapegraphai/utils/convert_to_json.py +++ b/scrapegraphai/utils/convert_to_json.py @@ -28,15 +28,15 @@ def convert_to_json(data: dict, filename: str, position: str = None) -> None: Saves a JSON file named 'output.json' at '/path/to/save'. Notes: - This function automatically ensures the directory exists before attempting to write the file. If the directory does not exist, it will attempt to create it. + This function automatically ensures the directory exists before + attempting to write the file. + If the directory does not exist, it will attempt to create it. """ if ".json" in filename: filename = filename.replace(".json", "") # Remove .json extension - # Get the directory of the caller script if position is None: - # Get directory of the main script caller_dir = os.path.dirname(os.path.abspath(sys.argv[0])) position = caller_dir diff --git a/scrapegraphai/utils/convert_to_md.py b/scrapegraphai/utils/convert_to_md.py index 123f3457..ff0bbbd7 100644 --- a/scrapegraphai/utils/convert_to_md.py +++ b/scrapegraphai/utils/convert_to_md.py @@ -18,7 +18,8 @@ def convert_to_md(html: str, url: str = None) -> str:

This is a heading.

") 'This is a paragraph.\n\n# This is a heading.' - Note: All the styles and links are ignored during the conversion. """ + Note: All the styles and links are ignored during the conversion. + """ h = html2text.HTML2Text() h.ignore_links = False diff --git a/scrapegraphai/utils/logging.py b/scrapegraphai/utils/logging.py index 335bcbf1..44f40aff 100644 --- a/scrapegraphai/utils/logging.py +++ b/scrapegraphai/utils/logging.py @@ -48,7 +48,6 @@ def _set_library_root_logger() -> None: DEFAULT_HANDLER = logging.StreamHandler() # sys.stderr as stream - # https://github.com/pyinstaller/pyinstaller/issues/7334#issuecomment-1357447176 if sys.stderr is None: sys.stderr = open(os.devnull, "w", encoding="utf-8") @@ -66,7 +65,8 @@ def get_logger(name: Optional[str] = None) -> logging.Logger: If no name is provided, the root logger for the library is returned. Args: - name (Optional[str]): The name of the logger. If None, the root logger for the library is returned. + name (Optional[str]): The name of the logger. + If None, the root logger for the library is returned. Returns: logging.Logger: The logger with the specified name. @@ -199,7 +199,8 @@ def warning_once(self, *args, **kwargs): """ Emit a warning log with the same message only once. - This function is added as a method to the logging.Logger class. It emits a warning log with the same message only once, + This function is added as a method to the logging.Logger class. + It emits a warning log with the same message only once, even if it is called multiple times with the same message. Args: diff --git a/scrapegraphai/utils/parse_state_keys.py b/scrapegraphai/utils/parse_state_keys.py index 107397e9..f4bd2ea5 100644 --- a/scrapegraphai/utils/parse_state_keys.py +++ b/scrapegraphai/utils/parse_state_keys.py @@ -31,11 +31,9 @@ def parse_expression(expression, state: dict) -> list: incorrect adjacency of operators, and empty expressions. """ - # Check for empty expression if not expression: raise ValueError("Empty expression.") - # Check for adjacent state keys without an operator between them pattern = r'\b(' + '|'.join(re.escape(key) for key in state.keys()) + \ r')(\b\s*\b)(' + '|'.join(re.escape(key) for key in state.keys()) + r')\b' @@ -43,37 +41,29 @@ def parse_expression(expression, state: dict) -> list: raise ValueError( "Adjacent state keys found without an operator between them.") - # Remove spaces expression = expression.replace(" ", "") - # Check for operators with empty adjacent tokens or at the start/end if expression[0] in '&|' or expression[-1] in '&|' or \ '&&' in expression or '||' in expression or \ '&|' in expression or '|&' in expression: raise ValueError("Invalid operator usage.") - # Check for balanced parentheses and valid operator placement open_parentheses = close_parentheses = 0 for i, char in enumerate(expression): if char == '(': open_parentheses += 1 elif char == ')': close_parentheses += 1 - # Check for invalid operator sequences if char in "&|" and i + 1 < len(expression) and expression[i + 1] in "&|": raise ValueError( "Invalid operator placement: operators cannot be adjacent.") - # Check for missing or balanced parentheses if open_parentheses != close_parentheses: raise ValueError("Missing or unbalanced parentheses in expression.") - # Helper function to evaluate an expression without parentheses def evaluate_simple_expression(exp): - # Split the expression by the OR operator and process each segment for or_segment in exp.split('|'): - # Check if all elements in an AND segment are in state and_segment = or_segment.split('&') if all(elem.strip() in state for elem in and_segment): return [elem.strip() for elem in and_segment if elem.strip() in state] @@ -85,9 +75,7 @@ def parse_expression(expression, state: dict) -> list: start = expression.rfind('(') end = expression.find(')', start) sub_exp = expression[start + 1:end] - # Replace the evaluated part with a placeholder and then evaluate it sub_result = evaluate_simple_expression(sub_exp) - # For simplicity in handling, join sub-results with OR to reprocess them later expression = expression[:start] + \ '|'.join(sub_result) + expression[end+1:] return evaluate_simple_expression(expression) @@ -97,7 +85,6 @@ def parse_expression(expression, state: dict) -> list: if not temp_result: raise ValueError("No state keys matched the expression.") - # Remove redundant state keys from the result, without changing their order final_result = [] for key in temp_result: if key not in final_result: diff --git a/scrapegraphai/utils/sys_dynamic_import.py b/scrapegraphai/utils/sys_dynamic_import.py index 8905ed5f..14910b3f 100644 --- a/scrapegraphai/utils/sys_dynamic_import.py +++ b/scrapegraphai/utils/sys_dynamic_import.py @@ -6,7 +6,7 @@ source code inspired by https://gist.github.com/DiTo97/46f4b733396b8d7a8f1d4d22d import sys import typing -import importlib.util # noqa: F401 +import importlib.util if typing.TYPE_CHECKING: import types @@ -36,7 +36,6 @@ def srcfile_import(modpath: str, modname: str) -> "types.ModuleType": module = importlib.util.module_from_spec(spec) - # adds the module to the global scope sys.modules[modname] = module spec.loader.exec_module(module) @@ -56,7 +55,7 @@ def dynamic_import(modname: str, message: str = "") -> None: """ if modname not in sys.modules: try: - import importlib # noqa: F401 + import importlib module = importlib.import_module(modname) sys.modules[modname] = module