Scrapegraph-ai/scrapegraphai/nodes/base_node.py
2024-05-02 00:23:38 +02:00

205 lines
7.7 KiB
Python

"""
BaseNode Module
"""
from abc import ABC, abstractmethod
from typing import Optional, List
import re
class BaseNode(ABC):
"""
An abstract base class for nodes in a graph-based workflow, designed to perform specific actions when executed.
Attributes:
node_name (str): The unique identifier name for the node.
input (str): Boolean expression defining the input keys needed from the state.
output (List[str]): List of
min_input_len (int): Minimum required number of input keys.
node_config (Optional[dict]): Additional configuration for the node.
Args:
node_name (str): Name for identifying the node.
node_type (str): Type of the node; must be 'node' or 'conditional_node'.
input (str): Expression defining the input keys needed from the state.
output (List[str]): List of output keys to be updated in the state.
min_input_len (int, optional): Minimum required number of input keys; defaults to 1.
node_config (Optional[dict], optional): Additional configuration for the node; defaults to None.
Raises:
ValueError: If `node_type` is not one of the allowed types.
Example:
>>> class MyNode(BaseNode):
... def execute(self, state):
... # Implementation of node logic here
... return state
...
>>> my_node = MyNode("ExampleNode", "node", "input_spec", ["output_spec"])
>>> updated_state = my_node.execute({'key': 'value'})
{'key': 'value'}
"""
def __init__(self, node_name: str, node_type: str, input: str, output: List[str],
min_input_len: int = 1, node_config: Optional[dict] = None):
self.node_name = node_name
self.input = input
self.output = output
self.min_input_len = min_input_len
self.node_config = node_config
if node_type not in ["node", "conditional_node"]:
raise ValueError(
f"node_type must be 'node' or 'conditional_node', got '{node_type}'")
self.node_type = node_type
@abstractmethod
def execute(self, state: dict) -> dict:
"""
Execute the node's logic based on the current state and update it accordingly.
Args:
state (dict): The current state of the graph.
Returns:
dict: The updated state after executing the node's logic.
"""
pass
def get_input_keys(self, state: dict) -> List[str]:
"""
Determines the necessary state keys based on the input specification.
Args:
state (dict): The current state of the graph used to parse input keys.
Returns:
List[str]: A list of input keys required for node operation.
Raises:
ValueError: If error occurs in parsing input keys.
"""
try:
input_keys = self._parse_input_keys(state, self.input)
self._validate_input_keys(input_keys)
return input_keys
except ValueError as e:
raise ValueError(
f"Error parsing input keys for {self.node_name}: {str(e)}")
def _validate_input_keys(self, input_keys):
"""
Validates if the provided input keys meet the minimum length requirement.
Args:
input_keys (List[str]): The list of input keys to validate.
Raises:
ValueError: If the number of input keys is less than the minimum required.
"""
if len(input_keys) < self.min_input_len:
raise ValueError(
f"""{self.node_name} requires at least {self.min_input_len} input keys,
got {len(input_keys)}.""")
def _parse_input_keys(self, state: dict, expression: str) -> List[str]:
"""
Parses the input keys expression to extract relevant keys from the state based on logical conditions.
The expression can contain AND (&), OR (|), and parentheses to group conditions.
Args:
state (dict): The current state of the graph.
expression (str): The input keys expression to parse.
Returns:
List[str]: A list of key names that match the input keys expression logic.
Raises:
ValueError: If the expression is invalid or if no state keys match the expression.
"""
# Check for empty expression
if not expression:
raise ValueError("Empty expression.")
# Check for adjacent state keys without an operator between them
pattern = r'\b(' + '|'.join(re.escape(key) for key in state.keys()) + \
r')(\b\s*\b)(' + '|'.join(re.escape(key)
for key in state.keys()) + r')\b'
if re.search(pattern, expression):
raise ValueError(
"Adjacent state keys found without an operator between them.")
# Remove spaces
expression = expression.replace(" ", "")
# Check for operators with empty adjacent tokens or at the start/end
if expression[0] in '&|' or expression[-1] in '&|' \
or '&&' in expression or '||' in expression or \
'&|' in expression or '|&' in expression:
raise ValueError("Invalid operator usage.")
# Check for balanced parentheses and valid operator placement
open_parentheses = close_parentheses = 0
for i, char in enumerate(expression):
if char == '(':
open_parentheses += 1
elif char == ')':
close_parentheses += 1
# Check for invalid operator sequences
if char in "&|" and i + 1 < len(expression) and expression[i + 1] in "&|":
raise ValueError(
"Invalid operator placement: operators cannot be adjacent.")
# Check for missing or balanced parentheses
if open_parentheses != close_parentheses:
raise ValueError(
"Missing or unbalanced parentheses in expression.")
# Helper function to evaluate an expression without parentheses
def evaluate_simple_expression(exp: str) -> List[str]:
"""Evaluate an expression without parentheses."""
# Split the expression by the OR operator and process each segment
for or_segment in exp.split('|'):
# Check if all elements in an AND segment are in state
and_segment = or_segment.split('&')
if all(elem.strip() in state for elem in and_segment):
return [elem.strip() for elem in and_segment if elem.strip() in state]
return []
# Helper function to evaluate expressions with parentheses
def evaluate_expression(expression: str) -> List[str]:
"""Evaluate an expression with parentheses."""
while '(' in expression:
start = expression.rfind('(')
end = expression.find(')', start)
sub_exp = expression[start + 1:end]
# Replace the evaluated part with a placeholder and then evaluate it
sub_result = evaluate_simple_expression(sub_exp)
# For simplicity in handling, join sub-results with OR to reprocess them later
expression = expression[:start] + \
'|'.join(sub_result) + expression[end+1:]
return evaluate_simple_expression(expression)
result = evaluate_expression(expression)
if not result:
raise ValueError("No state keys matched the expression.")
# Remove redundant state keys from the result, without changing their order
final_result = []
for key in result:
if key not in final_result:
final_result.append(key)
return final_result