mirror of
https://github.com/VinciGit00/Scrapegraph-ai.git
synced 2026-07-01 21:00:48 +08:00
Merge branch 'pre/beta' into ligthweigthing_library
This commit is contained in:
commit
26de5dd623
@ -149,6 +149,7 @@ class AbstractGraph(ABC):
|
||||
"ollama", "oneapi", "nvidia", "groq", "google_vertexai",
|
||||
"bedrock", "mistralai", "hugging_face", "deepseek", "ernie", "fireworks"]
|
||||
|
||||
|
||||
if llm_params["model"].split("/")[0] not in known_models and llm_params["model"].split("-")[0] not in known_models:
|
||||
raise ValueError(f"Model '{llm_params['model']}' is not supported")
|
||||
|
||||
|
||||
@ -9,7 +9,8 @@ from langchain_core.runnables import RunnableParallel
|
||||
from tqdm import tqdm
|
||||
from ..utils.logging import get_logger
|
||||
from .base_node import BaseNode
|
||||
from ..prompts.generate_answer_node_csv_prompts import TEMPLATE_CHUKS_CSV, TEMPLATE_NO_CHUKS_CSV, TEMPLATE_MERGE_CSV
|
||||
from ..prompts.generate_answer_node_csv_prompts import (TEMPLATE_CHUKS_CSV,
|
||||
TEMPLATE_NO_CHUKS_CSV, TEMPLATE_MERGE_CSV)
|
||||
|
||||
class GenerateAnswerCSVNode(BaseNode):
|
||||
"""
|
||||
@ -95,14 +96,14 @@ class GenerateAnswerCSVNode(BaseNode):
|
||||
else:
|
||||
output_parser = JsonOutputParser()
|
||||
|
||||
TEMPLATE_NO_CHUKS_CSV_prompt = TEMPLATE_NO_CHUKS_CSV
|
||||
TEMPLATE_CHUKS_CSV_prompt = TEMPLATE_CHUKS_CSV
|
||||
TEMPLATE_MERGE_CSV_prompt = TEMPLATE_MERGE_CSV
|
||||
TEMPLATE_NO_CHUKS_CSV_PROMPT = TEMPLATE_NO_CHUKS_CSV
|
||||
TEMPLATE_CHUKS_CSV_PROMPT = TEMPLATE_CHUKS_CSV
|
||||
TEMPLATE_MERGE_CSV_PROMPT = TEMPLATE_MERGE_CSV
|
||||
|
||||
if self.additional_info is not None:
|
||||
TEMPLATE_NO_CHUKS_CSV_prompt = self.additional_info + TEMPLATE_NO_CHUKS_CSV
|
||||
TEMPLATE_CHUKS_CSV_prompt = self.additional_info + TEMPLATE_CHUKS_CSV
|
||||
TEMPLATE_MERGE_CSV_prompt = self.additional_info + TEMPLATE_MERGE_CSV
|
||||
TEMPLATE_NO_CHUKS_CSV_PROMPT = self.additional_info + TEMPLATE_NO_CHUKS_CSV
|
||||
TEMPLATE_CHUKS_CSV_PROMPT = self.additional_info + TEMPLATE_CHUKS_CSV
|
||||
TEMPLATE_MERGE_CSV_PROMPT = self.additional_info + TEMPLATE_MERGE_CSV
|
||||
|
||||
format_instructions = output_parser.get_format_instructions()
|
||||
|
||||
@ -110,7 +111,7 @@ class GenerateAnswerCSVNode(BaseNode):
|
||||
|
||||
if len(doc) == 1:
|
||||
prompt = PromptTemplate(
|
||||
template=TEMPLATE_NO_CHUKS_CSV_prompt,
|
||||
template=TEMPLATE_NO_CHUKS_CSV_PROMPT,
|
||||
input_variables=["question"],
|
||||
partial_variables={
|
||||
"context": doc,
|
||||
@ -127,7 +128,7 @@ class GenerateAnswerCSVNode(BaseNode):
|
||||
tqdm(doc, desc="Processing chunks", disable=not self.verbose)
|
||||
):
|
||||
prompt = PromptTemplate(
|
||||
template=TEMPLATE_CHUKS_CSV_prompt,
|
||||
template=TEMPLATE_CHUKS_CSV_PROMPT,
|
||||
input_variables=["question"],
|
||||
partial_variables={
|
||||
"context": chunk,
|
||||
@ -144,7 +145,7 @@ class GenerateAnswerCSVNode(BaseNode):
|
||||
batch_results = async_runner.invoke({"question": user_prompt})
|
||||
|
||||
merge_prompt = PromptTemplate(
|
||||
template = TEMPLATE_MERGE_CSV_prompt,
|
||||
template = TEMPLATE_MERGE_CSV_PROMPT,
|
||||
input_variables=["context", "question"],
|
||||
partial_variables={"format_instructions": format_instructions},
|
||||
)
|
||||
@ -153,4 +154,4 @@ class GenerateAnswerCSVNode(BaseNode):
|
||||
answer = merge_chain.invoke({"context": batch_results, "question": user_prompt})
|
||||
|
||||
state.update({self.output[0]: answer})
|
||||
return state
|
||||
return state
|
||||
|
||||
@ -67,10 +67,8 @@ class GenerateScraperNode(BaseNode):
|
||||
|
||||
self.logger.info(f"--- Executing {self.node_name} Node ---")
|
||||
|
||||
# Interpret input keys based on the provided input expression
|
||||
input_keys = self.get_input_keys(state)
|
||||
|
||||
# Fetching data from the state based on the input keys
|
||||
input_data = [state[key] for key in input_keys]
|
||||
|
||||
user_prompt = input_data[0]
|
||||
|
||||
@ -58,10 +58,8 @@ class GetProbableTagsNode(BaseNode):
|
||||
|
||||
self.logger.info(f"--- Executing {self.node_name} Node ---")
|
||||
|
||||
# Interpret input keys based on the provided input expression
|
||||
input_keys = self.get_input_keys(state)
|
||||
|
||||
# Fetching data from the state based on the input keys
|
||||
input_data = [state[key] for key in input_keys]
|
||||
|
||||
user_prompt = input_data[0]
|
||||
@ -88,10 +86,8 @@ class GetProbableTagsNode(BaseNode):
|
||||
},
|
||||
)
|
||||
|
||||
# Execute the chain to get probable tags
|
||||
tag_answer = tag_prompt | self.llm_model | output_parser
|
||||
probable_tags = tag_answer.invoke({"question": user_prompt})
|
||||
|
||||
# Update the dictionary with probable tags
|
||||
state.update({self.output[0]: probable_tags})
|
||||
return state
|
||||
|
||||
@ -103,7 +103,6 @@ class GraphIteratorNode(BaseNode):
|
||||
if graph_instance is None:
|
||||
raise ValueError("graph instance is required for concurrent execution")
|
||||
|
||||
# Assign depth level to the graph
|
||||
if "graph_depth" in graph_instance.config:
|
||||
graph_instance.config["graph_depth"] += 1
|
||||
else:
|
||||
@ -113,14 +112,12 @@ class GraphIteratorNode(BaseNode):
|
||||
|
||||
participants = []
|
||||
|
||||
# semaphore to limit the number of concurrent tasks
|
||||
semaphore = asyncio.Semaphore(batchsize)
|
||||
|
||||
async def _async_run(graph):
|
||||
async with semaphore:
|
||||
return await asyncio.to_thread(graph.run)
|
||||
|
||||
# creates a deepcopy of the graph instance for each endpoint
|
||||
for url in urls:
|
||||
instance = copy.copy(graph_instance)
|
||||
instance.source = url
|
||||
|
||||
@ -56,21 +56,17 @@ class MergeAnswersNode(BaseNode):
|
||||
|
||||
self.logger.info(f"--- Executing {self.node_name} Node ---")
|
||||
|
||||
# Interpret input keys based on the provided input expression
|
||||
input_keys = self.get_input_keys(state)
|
||||
|
||||
# Fetching data from the state based on the input keys
|
||||
input_data = [state[key] for key in input_keys]
|
||||
|
||||
user_prompt = input_data[0]
|
||||
answers = input_data[1]
|
||||
|
||||
# merge the answers in one string
|
||||
answers_str = ""
|
||||
for i, answer in enumerate(answers):
|
||||
answers_str += f"CONTENT WEBSITE {i+1}: {answer}\n"
|
||||
|
||||
# Initialize the output parser
|
||||
if self.node_config.get("schema", None) is not None:
|
||||
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])
|
||||
else:
|
||||
@ -90,6 +86,5 @@ class MergeAnswersNode(BaseNode):
|
||||
merge_chain = prompt_template | self.llm_model | output_parser
|
||||
answer = merge_chain.invoke({"user_prompt": user_prompt})
|
||||
|
||||
# Update the state with the generated answer
|
||||
state.update({self.output[0]: answer})
|
||||
return state
|
||||
|
||||
@ -59,13 +59,11 @@ class ParseNode(BaseNode):
|
||||
|
||||
self.logger.info(f"--- Executing {self.node_name} Node ---")
|
||||
|
||||
# Interpret input keys based on the provided input expression
|
||||
input_keys = self.get_input_keys(state)
|
||||
|
||||
# Fetching data from the state based on the input keys
|
||||
input_data = [state[key] for key in input_keys]
|
||||
# Parse the document
|
||||
docs_transformed = input_data[0]
|
||||
|
||||
if self.parse_html:
|
||||
docs_transformed = Html2TextTransformer().transform_documents(input_data[0])
|
||||
docs_transformed = docs_transformed[0]
|
||||
@ -77,7 +75,6 @@ class ParseNode(BaseNode):
|
||||
else:
|
||||
docs_transformed = docs_transformed[0]
|
||||
|
||||
# Adapt the chunk size, leaving room for the reply, the prompt and the schema
|
||||
chunk_size = self.node_config.get("chunk_size", 4096)
|
||||
chunk_size = min(chunk_size - 500, int(chunk_size * 0.9))
|
||||
|
||||
|
||||
@ -80,10 +80,8 @@ class RAGNode(BaseNode):
|
||||
|
||||
self.logger.info(f"--- Executing {self.node_name} Node ---")
|
||||
|
||||
# Interpret input keys based on the provided input expression
|
||||
input_keys = self.get_input_keys(state)
|
||||
|
||||
# Fetching data from the state based on the input keys
|
||||
input_data = [state[key] for key in input_keys]
|
||||
|
||||
user_prompt = input_data[0]
|
||||
@ -102,7 +100,6 @@ class RAGNode(BaseNode):
|
||||
|
||||
self.logger.info("--- (updated chunks metadata) ---")
|
||||
|
||||
# check if embedder_model is provided, if not use llm_model
|
||||
if self.embedder_model is not None:
|
||||
embeddings = self.embedder_model
|
||||
elif 'embeddings' in self.node_config:
|
||||
@ -144,23 +141,17 @@ class RAGNode(BaseNode):
|
||||
pipeline_compressor = DocumentCompressorPipeline(
|
||||
transformers=[redundant_filter, relevant_filter]
|
||||
)
|
||||
# redundant + relevant filter compressor
|
||||
compression_retriever = ContextualCompressionRetriever(
|
||||
base_compressor=pipeline_compressor, base_retriever=retriever
|
||||
)
|
||||
|
||||
# relevant filter compressor only
|
||||
# compression_retriever = ContextualCompressionRetriever(
|
||||
# base_compressor=relevant_filter, base_retriever=retriever
|
||||
# )
|
||||
|
||||
compressed_docs = compression_retriever.invoke(user_prompt)
|
||||
|
||||
self.logger.info("--- (tokens compressed and vector stored) ---")
|
||||
|
||||
state.update({self.output[0]: compressed_docs})
|
||||
return state
|
||||
|
||||
|
||||
|
||||
def _create_default_embedder(self, llm_config=None) -> object:
|
||||
"""
|
||||
@ -223,7 +214,6 @@ class RAGNode(BaseNode):
|
||||
embedder_params = {**embedder_config}
|
||||
if "model_instance" in embedder_config:
|
||||
return embedder_params["model_instance"]
|
||||
# Instantiate the embedding model based on the model name
|
||||
if "openai" in embedder_params["model"]:
|
||||
return OpenAIEmbeddings(api_key=embedder_params["api_key"])
|
||||
if "azure" in embedder_params["model"]:
|
||||
|
||||
@ -75,10 +75,8 @@ class RobotsNode(BaseNode):
|
||||
|
||||
self.logger.info(f"--- Executing {self.node_name} Node ---")
|
||||
|
||||
# Interpret input keys based on the provided input expression
|
||||
input_keys = self.get_input_keys(state)
|
||||
|
||||
# Fetching data from the state based on the input keys
|
||||
input_data = [state[key] for key in input_keys]
|
||||
|
||||
source = input_data[0]
|
||||
|
||||
@ -67,7 +67,6 @@ class SearchInternetNode(BaseNode):
|
||||
|
||||
input_keys = self.get_input_keys(state)
|
||||
|
||||
# Fetching data from the state based on the input keys
|
||||
input_data = [state[key] for key in input_keys]
|
||||
|
||||
user_prompt = input_data[0]
|
||||
@ -79,10 +78,8 @@ class SearchInternetNode(BaseNode):
|
||||
input_variables=["user_prompt"],
|
||||
)
|
||||
|
||||
# Execute the chain to get the search query
|
||||
search_answer = search_prompt | self.llm_model | output_parser
|
||||
|
||||
# Ollama: Use no json format when creating the search query
|
||||
|
||||
if isinstance(self.llm_model, ChatOllama) and self.llm_model.format == 'json':
|
||||
self.llm_model.format = None
|
||||
search_query = search_answer.invoke({"user_prompt": user_prompt})[0]
|
||||
@ -96,9 +93,7 @@ class SearchInternetNode(BaseNode):
|
||||
search_engine=self.search_engine)
|
||||
|
||||
if len(answer) == 0:
|
||||
# raise an exception if no answer is found
|
||||
raise ValueError("Zero results found for the search query.")
|
||||
|
||||
# Update the state with the generated answer
|
||||
state.update({self.output[0]: answer})
|
||||
return state
|
||||
|
||||
@ -49,7 +49,6 @@ class SearchLinkNode(BaseNode):
|
||||
self.filter_config = {**default_filters.filter_dict, **provided_filter_config}
|
||||
self.filter_links = True
|
||||
else:
|
||||
# Skip filtering if not enabled
|
||||
self.filter_config = None
|
||||
self.filter_links = False
|
||||
|
||||
@ -58,29 +57,26 @@ class SearchLinkNode(BaseNode):
|
||||
|
||||
def _is_same_domain(self, url, domain):
|
||||
if not self.filter_links or not self.filter_config.get("diff_domain_filter", True):
|
||||
return True # Skip the domain filter if not enabled
|
||||
return True
|
||||
parsed_url = urlparse(url)
|
||||
parsed_domain = urlparse(domain)
|
||||
return parsed_url.netloc == parsed_domain.netloc
|
||||
|
||||
def _is_image_url(self, url):
|
||||
if not self.filter_links:
|
||||
return False # Skip image filtering if filtering is not enabled
|
||||
|
||||
return False
|
||||
image_extensions = self.filter_config.get("img_exts", [])
|
||||
return any(url.lower().endswith(ext) for ext in image_extensions)
|
||||
|
||||
def _is_language_url(self, url):
|
||||
if not self.filter_links:
|
||||
return False # Skip language filtering if filtering is not enabled
|
||||
return False
|
||||
|
||||
lang_indicators = self.filter_config.get("lang_indicators", [])
|
||||
parsed_url = urlparse(url)
|
||||
query_params = parse_qs(parsed_url.query)
|
||||
|
||||
# Check if the URL path or query string indicates a language-specific version
|
||||
return any(indicator in parsed_url.path.lower() or indicator in query_params for indicator in lang_indicators)
|
||||
|
||||
def _is_potentially_irrelevant(self, url):
|
||||
if not self.filter_links:
|
||||
return False # Skip irrelevant URL filtering if filtering is not enabled
|
||||
@ -88,12 +84,11 @@ class SearchLinkNode(BaseNode):
|
||||
irrelevant_keywords = self.filter_config.get("irrelevant_keywords", [])
|
||||
return any(keyword in url.lower() for keyword in irrelevant_keywords)
|
||||
|
||||
|
||||
|
||||
def execute(self, state: dict) -> dict:
|
||||
"""
|
||||
Filter out relevant links from the webpage that are relavant to prompt. Out of the filtered links, also
|
||||
ensure that all links are navigable.
|
||||
|
||||
Filter out relevant links from the webpage that are relavant to prompt.
|
||||
Out of the filtered links, also ensure that all links are navigable.
|
||||
Args:
|
||||
state (dict): The current state of the graph. The input keys will be used to fetch the
|
||||
correct data types from the state.
|
||||
@ -108,7 +103,6 @@ class SearchLinkNode(BaseNode):
|
||||
|
||||
self.logger.info(f"--- Executing {self.node_name} Node ---")
|
||||
|
||||
|
||||
parsed_content_chunks = state.get("doc")
|
||||
source_url = state.get("url") or state.get("local_dir")
|
||||
output_parser = JsonOutputParser()
|
||||
@ -148,7 +142,7 @@ class SearchLinkNode(BaseNode):
|
||||
except Exception as e:
|
||||
# Fallback approach: Using the LLM to extract links
|
||||
self.logger.error(f"Error extracting links: {e}. Falling back to LLM.")
|
||||
|
||||
|
||||
merge_prompt = PromptTemplate(
|
||||
template=TEMPLATE_RELEVANT_LINKS,
|
||||
input_variables=["content", "user_prompt"],
|
||||
|
||||
@ -58,10 +58,8 @@ class SearchLinksWithContext(BaseNode):
|
||||
|
||||
self.logger.info(f"--- Executing {self.node_name} Node ---")
|
||||
|
||||
# Interpret input keys based on the provided input expression
|
||||
input_keys = self.get_input_keys(state)
|
||||
|
||||
# Fetching data from the state based on the input keys
|
||||
input_data = [state[key] for key in input_keys]
|
||||
|
||||
doc = input_data[1]
|
||||
@ -71,7 +69,6 @@ class SearchLinksWithContext(BaseNode):
|
||||
|
||||
result = []
|
||||
|
||||
# Use tqdm to add progress bar
|
||||
for i, chunk in enumerate(
|
||||
tqdm(doc, desc="Processing chunks", disable=not self.verbose)
|
||||
):
|
||||
|
||||
@ -43,7 +43,8 @@ class TextToSpeechNode(BaseNode):
|
||||
correct data types from the state.
|
||||
|
||||
Returns:
|
||||
dict: The updated state with the output key containing the audio generated from the text.
|
||||
dict: The updated state with the output
|
||||
key containing the audio generated from the text.
|
||||
|
||||
Raises:
|
||||
KeyError: If the input keys are not found in the state, indicating that the
|
||||
@ -52,15 +53,11 @@ class TextToSpeechNode(BaseNode):
|
||||
|
||||
self.logger.info(f"--- Executing {self.node_name} Node ---")
|
||||
|
||||
# Interpret input keys based on the provided input expression
|
||||
input_keys = self.get_input_keys(state)
|
||||
|
||||
# Fetching data from the state based on the input keys
|
||||
input_data = [state[key] for key in input_keys]
|
||||
|
||||
# get the text to translate
|
||||
text2translate = str(next(iter(input_data[0].values())))
|
||||
# text2translate = str(input_data[0])
|
||||
|
||||
audio = self.tts_model.run(text2translate)
|
||||
|
||||
|
||||
@ -28,35 +28,28 @@ def cleanup_html(html_content: str, base_url: str) -> str:
|
||||
|
||||
soup = BeautifulSoup(html_content, 'html.parser')
|
||||
|
||||
# Title Extraction
|
||||
title_tag = soup.find('title')
|
||||
title = title_tag.get_text() if title_tag else ""
|
||||
|
||||
# Script and Style Tag Removal
|
||||
for tag in soup.find_all(['script', 'style']):
|
||||
tag.extract()
|
||||
|
||||
# Links extraction
|
||||
link_urls = [urljoin(base_url, link['href']) for link in soup.find_all('a', href=True)]
|
||||
|
||||
# Images extraction
|
||||
images = soup.find_all('img')
|
||||
image_urls = []
|
||||
for image in images:
|
||||
if 'src' in image.attrs:
|
||||
# if http or https is not present in the image url, join it with the base url
|
||||
if 'http' not in image['src']:
|
||||
image_urls.append(urljoin(base_url, image['src']))
|
||||
else:
|
||||
image_urls.append(image['src'])
|
||||
|
||||
# Body Extraction (if it exists)
|
||||
body_content = soup.find('body')
|
||||
if body_content:
|
||||
# Minify the HTML within the body tag
|
||||
minimized_body = minify(str(body_content))
|
||||
return title, minimized_body, link_urls, image_urls
|
||||
|
||||
else:
|
||||
raise ValueError(f"""No HTML body content found, please try setting the 'headless'
|
||||
raise ValueError(f"""No HTML body content found, please try setting the 'headless'
|
||||
flag to False in the graph configuration. HTML content: {html_content}""")
|
||||
|
||||
@ -29,9 +29,8 @@ def convert_to_csv(data: dict, filename: str, position: str = None) -> None:
|
||||
"""
|
||||
|
||||
if ".csv" in filename:
|
||||
filename = filename.replace(".csv", "") # Remove .csv extension
|
||||
filename = filename.replace(".csv", "")
|
||||
|
||||
# Get the directory of the caller script if position is not provided
|
||||
if position is None:
|
||||
caller_dir = os.path.dirname(os.path.abspath(sys.argv[0]))
|
||||
position = caller_dir
|
||||
@ -40,7 +39,7 @@ def convert_to_csv(data: dict, filename: str, position: str = None) -> None:
|
||||
if not isinstance(data, dict):
|
||||
raise TypeError("Input data must be a dictionary")
|
||||
|
||||
os.makedirs(position, exist_ok=True) # Create directory if needed
|
||||
os.makedirs(position, exist_ok=True)
|
||||
|
||||
df = pd.DataFrame.from_dict(data, orient='index')
|
||||
df.to_csv(os.path.join(position, f"{filename}.csv"), index=False)
|
||||
@ -52,4 +51,4 @@ def convert_to_csv(data: dict, filename: str, position: str = None) -> None:
|
||||
raise PermissionError(
|
||||
f"You don't have permission to write to '{position}'.") from pe
|
||||
except Exception as e:
|
||||
raise e # Re-raise other potential errors
|
||||
raise e
|
||||
|
||||
@ -28,15 +28,15 @@ def convert_to_json(data: dict, filename: str, position: str = None) -> None:
|
||||
Saves a JSON file named 'output.json' at '/path/to/save'.
|
||||
|
||||
Notes:
|
||||
This function automatically ensures the directory exists before attempting to write the file. If the directory does not exist, it will attempt to create it.
|
||||
This function automatically ensures the directory exists before
|
||||
attempting to write the file.
|
||||
If the directory does not exist, it will attempt to create it.
|
||||
"""
|
||||
|
||||
if ".json" in filename:
|
||||
filename = filename.replace(".json", "") # Remove .json extension
|
||||
|
||||
# Get the directory of the caller script
|
||||
if position is None:
|
||||
# Get directory of the main script
|
||||
caller_dir = os.path.dirname(os.path.abspath(sys.argv[0]))
|
||||
position = caller_dir
|
||||
|
||||
|
||||
@ -18,7 +18,8 @@ def convert_to_md(html: str, url: str = None) -> str:
|
||||
<h1>This is a heading.</h1></body></html>")
|
||||
'This is a paragraph.\n\n# This is a heading.'
|
||||
|
||||
Note: All the styles and links are ignored during the conversion. """
|
||||
Note: All the styles and links are ignored during the conversion.
|
||||
"""
|
||||
|
||||
h = html2text.HTML2Text()
|
||||
h.ignore_links = False
|
||||
|
||||
@ -48,7 +48,6 @@ def _set_library_root_logger() -> None:
|
||||
|
||||
DEFAULT_HANDLER = logging.StreamHandler() # sys.stderr as stream
|
||||
|
||||
# https://github.com/pyinstaller/pyinstaller/issues/7334#issuecomment-1357447176
|
||||
if sys.stderr is None:
|
||||
sys.stderr = open(os.devnull, "w", encoding="utf-8")
|
||||
|
||||
@ -66,7 +65,8 @@ def get_logger(name: Optional[str] = None) -> logging.Logger:
|
||||
If no name is provided, the root logger for the library is returned.
|
||||
|
||||
Args:
|
||||
name (Optional[str]): The name of the logger. If None, the root logger for the library is returned.
|
||||
name (Optional[str]): The name of the logger.
|
||||
If None, the root logger for the library is returned.
|
||||
|
||||
Returns:
|
||||
logging.Logger: The logger with the specified name.
|
||||
@ -199,7 +199,8 @@ def warning_once(self, *args, **kwargs):
|
||||
"""
|
||||
Emit a warning log with the same message only once.
|
||||
|
||||
This function is added as a method to the logging.Logger class. It emits a warning log with the same message only once,
|
||||
This function is added as a method to the logging.Logger class.
|
||||
It emits a warning log with the same message only once,
|
||||
even if it is called multiple times with the same message.
|
||||
|
||||
Args:
|
||||
|
||||
@ -31,11 +31,9 @@ def parse_expression(expression, state: dict) -> list:
|
||||
incorrect adjacency of operators, and empty expressions.
|
||||
"""
|
||||
|
||||
# Check for empty expression
|
||||
if not expression:
|
||||
raise ValueError("Empty expression.")
|
||||
|
||||
# Check for adjacent state keys without an operator between them
|
||||
pattern = r'\b(' + '|'.join(re.escape(key) for key in state.keys()) + \
|
||||
r')(\b\s*\b)(' + '|'.join(re.escape(key)
|
||||
for key in state.keys()) + r')\b'
|
||||
@ -43,37 +41,29 @@ def parse_expression(expression, state: dict) -> list:
|
||||
raise ValueError(
|
||||
"Adjacent state keys found without an operator between them.")
|
||||
|
||||
# Remove spaces
|
||||
expression = expression.replace(" ", "")
|
||||
|
||||
# Check for operators with empty adjacent tokens or at the start/end
|
||||
if expression[0] in '&|' or expression[-1] in '&|' or \
|
||||
'&&' in expression or '||' in expression or \
|
||||
'&|' in expression or '|&' in expression:
|
||||
|
||||
raise ValueError("Invalid operator usage.")
|
||||
|
||||
# Check for balanced parentheses and valid operator placement
|
||||
open_parentheses = close_parentheses = 0
|
||||
for i, char in enumerate(expression):
|
||||
if char == '(':
|
||||
open_parentheses += 1
|
||||
elif char == ')':
|
||||
close_parentheses += 1
|
||||
# Check for invalid operator sequences
|
||||
if char in "&|" and i + 1 < len(expression) and expression[i + 1] in "&|":
|
||||
raise ValueError(
|
||||
"Invalid operator placement: operators cannot be adjacent.")
|
||||
|
||||
# Check for missing or balanced parentheses
|
||||
if open_parentheses != close_parentheses:
|
||||
raise ValueError("Missing or unbalanced parentheses in expression.")
|
||||
|
||||
# Helper function to evaluate an expression without parentheses
|
||||
def evaluate_simple_expression(exp):
|
||||
# Split the expression by the OR operator and process each segment
|
||||
for or_segment in exp.split('|'):
|
||||
# Check if all elements in an AND segment are in state
|
||||
and_segment = or_segment.split('&')
|
||||
if all(elem.strip() in state for elem in and_segment):
|
||||
return [elem.strip() for elem in and_segment if elem.strip() in state]
|
||||
@ -85,9 +75,7 @@ def parse_expression(expression, state: dict) -> list:
|
||||
start = expression.rfind('(')
|
||||
end = expression.find(')', start)
|
||||
sub_exp = expression[start + 1:end]
|
||||
# Replace the evaluated part with a placeholder and then evaluate it
|
||||
sub_result = evaluate_simple_expression(sub_exp)
|
||||
# For simplicity in handling, join sub-results with OR to reprocess them later
|
||||
expression = expression[:start] + \
|
||||
'|'.join(sub_result) + expression[end+1:]
|
||||
return evaluate_simple_expression(expression)
|
||||
@ -97,7 +85,6 @@ def parse_expression(expression, state: dict) -> list:
|
||||
if not temp_result:
|
||||
raise ValueError("No state keys matched the expression.")
|
||||
|
||||
# Remove redundant state keys from the result, without changing their order
|
||||
final_result = []
|
||||
for key in temp_result:
|
||||
if key not in final_result:
|
||||
|
||||
@ -6,7 +6,7 @@ source code inspired by https://gist.github.com/DiTo97/46f4b733396b8d7a8f1d4d22d
|
||||
|
||||
import sys
|
||||
import typing
|
||||
import importlib.util # noqa: F401
|
||||
import importlib.util
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
import types
|
||||
@ -36,7 +36,6 @@ def srcfile_import(modpath: str, modname: str) -> "types.ModuleType":
|
||||
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
|
||||
# adds the module to the global scope
|
||||
sys.modules[modname] = module
|
||||
|
||||
spec.loader.exec_module(module)
|
||||
@ -56,7 +55,7 @@ def dynamic_import(modname: str, message: str = "") -> None:
|
||||
"""
|
||||
if modname not in sys.modules:
|
||||
try:
|
||||
import importlib # noqa: F401
|
||||
import importlib
|
||||
|
||||
module = importlib.import_module(modname)
|
||||
sys.modules[modname] = module
|
||||
|
||||
Loading…
Reference in New Issue
Block a user