Scrapegraph-ai/scrapegraphai/nodes/base_node.py
2025-04-14 07:50:46 +00:00

236 lines
8.0 KiB
Python

"""
This module defines the base node class for the ScrapeGraphAI application.
"""
import re
from abc import ABC, abstractmethod
from typing import List, Optional
from ..utils import get_logger
class BaseNode(ABC):
"""
An abstract base class for nodes in a graph-based workflow,
designed to perform specific actions when executed.
Attributes:
node_name (str): The unique identifier name for the node.
input (str): Boolean expression defining the input keys needed from the state.
output (List[str]): List of
min_input_len (int): Minimum required number of input keys.
node_config (Optional[dict]): Additional configuration for the node.
logger (logging.Logger): The centralized root logger
Args:
node_name (str): Name for identifying the node.
node_type (str): Type of the node; must be 'node' or 'conditional_node'.
input (str): Expression defining the input keys needed from the state.
output (List[str]): List of output keys to be updated in the state.
min_input_len (int, optional): Minimum required number of input keys; defaults to 1.
node_config (Optional[dict], optional): Additional configuration
for the node; defaults to None.
Raises:
ValueError: If `node_type` is not one of the allowed types.
Example:
>>> class MyNode(BaseNode):
... def execute(self, state):
... # Implementation of node logic here
... return state
...
>>> my_node = MyNode("ExampleNode", "node", "input_spec", ["output_spec"])
>>> updated_state = my_node.execute({'key': 'value'})
{'key': 'value'}
"""
def __init__(
self,
node_name: str,
node_type: str,
input: str,
output: List[str],
min_input_len: int = 1,
node_config: Optional[dict] = None,
):
self.node_name = node_name
self.input = input
self.output = output
self.min_input_len = min_input_len
self.node_config = node_config
self.logger = get_logger()
if node_type not in ["node", "conditional_node"]:
raise ValueError(
f"node_type must be 'node' or 'conditional_node', got '{node_type}'"
)
self.node_type = node_type
@abstractmethod
def execute(self, state: dict) -> dict:
"""
Execute the node's logic based on the current state and update it accordingly.
Args:
state (dict): The current state of the graph.
Returns:
dict: The updated state after executing the node's logic.
"""
pass
def update_config(self, params: dict, overwrite: bool = False):
"""
Updates the node_config dictionary as well as attributes with same key.
Args:
param (dict): The dictionary to update node_config with.
overwrite (bool): Flag indicating if the values of node_config
should be overwritten if their value is not None.
"""
for key, val in params.items():
if hasattr(self, key) and not overwrite:
continue
setattr(self, key, val)
def get_input_keys(self, state: dict) -> List[str]:
"""
Determines the necessary state keys based on the input specification.
Args:
state (dict): The current state of the graph used to parse input keys.
Returns:
List[str]: A list of input keys required for node operation.
Raises:
ValueError: If error occurs in parsing input keys.
"""
try:
input_keys = self._parse_input_keys(state, self.input)
self._validate_input_keys(input_keys)
return input_keys
except ValueError as e:
raise ValueError(f"Error parsing input keys for {self.node_name}") from e
def _validate_input_keys(self, input_keys):
"""
Validates if the provided input keys meet the minimum length requirement.
Args:
input_keys (List[str]): The list of input keys to validate.
Raises:
ValueError: If the number of input keys is less than the minimum required.
"""
if len(input_keys) < self.min_input_len:
raise ValueError(
f"""{self.node_name} requires at least {self.min_input_len} input keys,
got {len(input_keys)}."""
)
def _parse_input_keys(self, state: dict, expression: str) -> List[str]:
"""
Parses the input keys expression to extract
relevant keys from the state based on logical conditions.
The expression can contain AND (&), OR (|), and parentheses to group conditions.
Args:
state (dict): The current state of the graph.
expression (str): The input keys expression to parse.
Returns:
List[str]: A list of key names that match the input keys expression logic.
Raises:
ValueError: If the expression is invalid or if no state keys match the expression.
"""
if not expression:
raise ValueError("Empty expression.")
pattern = (
r"\b("
+ "|".join(re.escape(key) for key in state.keys())
+ r")(\b\s*\b)("
+ "|".join(re.escape(key) for key in state.keys())
+ r")\b"
)
if re.search(pattern, expression):
raise ValueError(
"Adjacent state keys found without an operator between them."
)
expression = expression.replace(" ", "")
if (
expression[0] in "&|"
or expression[-1] in "&|"
or "&&" in expression
or "||" in expression
or "&|" in expression
or "|&" in expression
):
raise ValueError("Invalid operator usage.")
open_parentheses = close_parentheses = 0
for i, char in enumerate(expression):
if char == "(":
open_parentheses += 1
elif char == ")":
close_parentheses += 1
# Check for invalid operator sequences
if char in "&|" and i + 1 < len(expression) and expression[i + 1] in "&|":
raise ValueError(
"Invalid operator placement: operators cannot be adjacent."
)
if open_parentheses != close_parentheses:
raise ValueError("Missing or unbalanced parentheses in expression.")
def evaluate_simple_expression(exp: str) -> List[str]:
"""Evaluate an expression without parentheses."""
for or_segment in exp.split("|"):
and_segment = or_segment.split("&")
if all(elem.strip() in state for elem in and_segment):
return [
elem.strip() for elem in and_segment if elem.strip() in state
]
return []
def evaluate_expression(expression: str) -> List[str]:
"""Evaluate an expression with parentheses."""
while "(" in expression:
start = expression.rfind("(")
end = expression.find(")", start)
sub_exp = expression[start + 1 : end]
sub_result = evaluate_simple_expression(sub_exp)
expression = (
expression[:start] + "|".join(sub_result) + expression[end + 1 :]
)
return evaluate_simple_expression(expression)
result = evaluate_expression(expression)
if not result:
raise ValueError(
f"""No state keys matched the expression.
Expression was {expression}.
State contains keys: {", ".join(state.keys())}"""
)
final_result = []
for key in result:
if key not in final_result:
final_result.append(key)
return final_result