fix(logger): set up centralized root logger in base node

This commit is contained in:
Federico Minutoli 2024-05-24 01:09:26 +02:00
parent c807695720
commit 4348d4f4db

View File

@ -2,9 +2,11 @@
BaseNode Module BaseNode Module
""" """
from abc import ABC, abstractmethod
from typing import Optional, List
import re import re
from abc import ABC, abstractmethod
from typing import List, Optional
from ..utils import get_logger
class BaseNode(ABC): class BaseNode(ABC):
@ -14,10 +16,11 @@ class BaseNode(ABC):
Attributes: Attributes:
node_name (str): The unique identifier name for the node. node_name (str): The unique identifier name for the node.
input (str): Boolean expression defining the input keys needed from the state. input (str): Boolean expression defining the input keys needed from the state.
output (List[str]): List of output (List[str]): List of
min_input_len (int): Minimum required number of input keys. min_input_len (int): Minimum required number of input keys.
node_config (Optional[dict]): Additional configuration for the node. node_config (Optional[dict]): Additional configuration for the node.
logger (logging.Logger): The centralized root logger
Args: Args:
node_name (str): Name for identifying the node. node_name (str): Name for identifying the node.
node_type (str): Type of the node; must be 'node' or 'conditional_node'. node_type (str): Type of the node; must be 'node' or 'conditional_node'.
@ -28,7 +31,7 @@ class BaseNode(ABC):
Raises: Raises:
ValueError: If `node_type` is not one of the allowed types. ValueError: If `node_type` is not one of the allowed types.
Example: Example:
>>> class MyNode(BaseNode): >>> class MyNode(BaseNode):
... def execute(self, state): ... def execute(self, state):
@ -40,18 +43,27 @@ class BaseNode(ABC):
{'key': 'value'} {'key': 'value'}
""" """
def __init__(self, node_name: str, node_type: str, input: str, output: List[str], def __init__(
min_input_len: int = 1, node_config: Optional[dict] = None): self,
node_name: str,
node_type: str,
input: str,
output: List[str],
min_input_len: int = 1,
node_config: Optional[dict] = None,
):
self.node_name = node_name self.node_name = node_name
self.input = input self.input = input
self.output = output self.output = output
self.min_input_len = min_input_len self.min_input_len = min_input_len
self.node_config = node_config self.node_config = node_config
self.logger = get_logger()
if node_type not in ["node", "conditional_node"]: if node_type not in ["node", "conditional_node"]:
raise ValueError( raise ValueError(
f"node_type must be 'node' or 'conditional_node', got '{node_type}'") f"node_type must be 'node' or 'conditional_node', got '{node_type}'"
)
self.node_type = node_type self.node_type = node_type
@abstractmethod @abstractmethod
@ -102,8 +114,7 @@ class BaseNode(ABC):
self._validate_input_keys(input_keys) self._validate_input_keys(input_keys)
return input_keys return input_keys
except ValueError as e: except ValueError as e:
raise ValueError( raise ValueError(f"Error parsing input keys for {self.node_name}: {str(e)}")
f"Error parsing input keys for {self.node_name}: {str(e)}")
def _validate_input_keys(self, input_keys): def _validate_input_keys(self, input_keys):
""" """
@ -119,7 +130,8 @@ class BaseNode(ABC):
if len(input_keys) < self.min_input_len: if len(input_keys) < self.min_input_len:
raise ValueError( raise ValueError(
f"""{self.node_name} requires at least {self.min_input_len} input keys, f"""{self.node_name} requires at least {self.min_input_len} input keys,
got {len(input_keys)}.""") got {len(input_keys)}."""
)
def _parse_input_keys(self, state: dict, expression: str) -> List[str]: def _parse_input_keys(self, state: dict, expression: str) -> List[str]:
""" """
@ -142,67 +154,80 @@ class BaseNode(ABC):
raise ValueError("Empty expression.") raise ValueError("Empty expression.")
# Check for adjacent state keys without an operator between them # Check for adjacent state keys without an operator between them
pattern = r'\b(' + '|'.join(re.escape(key) for key in state.keys()) + \ pattern = (
r')(\b\s*\b)(' + '|'.join(re.escape(key) r"\b("
for key in state.keys()) + r')\b' + "|".join(re.escape(key) for key in state.keys())
+ r")(\b\s*\b)("
+ "|".join(re.escape(key) for key in state.keys())
+ r")\b"
)
if re.search(pattern, expression): if re.search(pattern, expression):
raise ValueError( raise ValueError(
"Adjacent state keys found without an operator between them.") "Adjacent state keys found without an operator between them."
)
# Remove spaces # Remove spaces
expression = expression.replace(" ", "") expression = expression.replace(" ", "")
# Check for operators with empty adjacent tokens or at the start/end # Check for operators with empty adjacent tokens or at the start/end
if expression[0] in '&|' or expression[-1] in '&|' \ if (
or '&&' in expression or '||' in expression or \ expression[0] in "&|"
'&|' in expression or '|&' in expression: or expression[-1] in "&|"
or "&&" in expression
or "||" in expression
or "&|" in expression
or "|&" in expression
):
raise ValueError("Invalid operator usage.") raise ValueError("Invalid operator usage.")
# Check for balanced parentheses and valid operator placement # Check for balanced parentheses and valid operator placement
open_parentheses = close_parentheses = 0 open_parentheses = close_parentheses = 0
for i, char in enumerate(expression): for i, char in enumerate(expression):
if char == '(': if char == "(":
open_parentheses += 1 open_parentheses += 1
elif char == ')': elif char == ")":
close_parentheses += 1 close_parentheses += 1
# Check for invalid operator sequences # Check for invalid operator sequences
if char in "&|" and i + 1 < len(expression) and expression[i + 1] in "&|": if char in "&|" and i + 1 < len(expression) and expression[i + 1] in "&|":
raise ValueError( raise ValueError(
"Invalid operator placement: operators cannot be adjacent.") "Invalid operator placement: operators cannot be adjacent."
)
# Check for missing or balanced parentheses # Check for missing or balanced parentheses
if open_parentheses != close_parentheses: if open_parentheses != close_parentheses:
raise ValueError( raise ValueError("Missing or unbalanced parentheses in expression.")
"Missing or unbalanced parentheses in expression.")
# Helper function to evaluate an expression without parentheses # Helper function to evaluate an expression without parentheses
def evaluate_simple_expression(exp: str) -> List[str]: def evaluate_simple_expression(exp: str) -> List[str]:
"""Evaluate an expression without parentheses.""" """Evaluate an expression without parentheses."""
# Split the expression by the OR operator and process each segment # Split the expression by the OR operator and process each segment
for or_segment in exp.split('|'): for or_segment in exp.split("|"):
# Check if all elements in an AND segment are in state # Check if all elements in an AND segment are in state
and_segment = or_segment.split('&') and_segment = or_segment.split("&")
if all(elem.strip() in state for elem in and_segment): if all(elem.strip() in state for elem in and_segment):
return [elem.strip() for elem in and_segment if elem.strip() in state] return [
elem.strip() for elem in and_segment if elem.strip() in state
]
return [] return []
# Helper function to evaluate expressions with parentheses # Helper function to evaluate expressions with parentheses
def evaluate_expression(expression: str) -> List[str]: def evaluate_expression(expression: str) -> List[str]:
"""Evaluate an expression with parentheses.""" """Evaluate an expression with parentheses."""
while '(' in expression: while "(" in expression:
start = expression.rfind('(') start = expression.rfind("(")
end = expression.find(')', start) end = expression.find(")", start)
sub_exp = expression[start + 1:end] sub_exp = expression[start + 1 : end]
# Replace the evaluated part with a placeholder and then evaluate it # Replace the evaluated part with a placeholder and then evaluate it
sub_result = evaluate_simple_expression(sub_exp) sub_result = evaluate_simple_expression(sub_exp)
# For simplicity in handling, join sub-results with OR to reprocess them later # For simplicity in handling, join sub-results with OR to reprocess them later
expression = expression[:start] + \ expression = (
'|'.join(sub_result) + expression[end+1:] expression[:start] + "|".join(sub_result) + expression[end + 1 :]
)
return evaluate_simple_expression(expression) return evaluate_simple_expression(expression)
result = evaluate_expression(expression) result = evaluate_expression(expression)