refctoring of the code

This commit is contained in:
Marco Vinciguerra 2024-09-15 11:20:08 +02:00
parent dcef172e03
commit 438b8127db
8 changed files with 19 additions and 25 deletions

View File

@ -65,12 +65,10 @@ class GraphBuilder:
"temperature": 0,
"streaming": True
}
# Update defaults with any LLM parameters that were provided
llm_params = {**llm_defaults, **llm_config}
if "api_key" not in llm_params:
raise ValueError("LLM configuration must include an 'api_key'.")
# select the model based on the model name
if "gpt-" in llm_params["model"]:
return ChatOpenAI(llm_params)
elif "gemini" in llm_params["model"]:
@ -152,17 +150,13 @@ class GraphBuilder:
edges = graph_config.get('edges', [])
entry_point = graph_config.get('entry_point')
# Add nodes to the graph
for node in nodes:
# If this node is the entry point, use a double circle to denote it
if node['node_name'] == entry_point:
graph.node(node['node_name'], shape='doublecircle')
else:
graph.node(node['node_name'])
# Add edges to the graph
for edge in edges:
# An edge could potentially have multiple 'to' nodes if it's from a conditional node
if isinstance(edge['to'], list):
for to_node in edge['to']:
graph.edge(edge['from'], to_node)

View File

@ -252,8 +252,8 @@ class FetchNode(BaseNode):
if not self.cut:
parsed_content = cleanup_html(response, source)
if ((isinstance(self.llm_model, ChatOpenAI) or isinstance(self.llm_model, AzureChatOpenAI))
and not self.script_creator) or (self.force and not self.script_creator):
if isinstance(self.llm_model, (ChatOpenAI, AzureChatOpenAI)) \
and not self.script_creator) or (self.force and not self.script_creator):
parsed_content = convert_to_md(source, parsed_content)
compressed_document = [Document(page_content=parsed_content)]
@ -271,7 +271,8 @@ class FetchNode(BaseNode):
try:
from ..docloaders.browser_base import browser_base_fetch
except ImportError:
raise ImportError("The browserbase module is not installed. Please install it using `pip install browserbase`.")
raise ImportError("""The browserbase module is not installed.
Please install it using `pip install browserbase`.""")
data = browser_base_fetch(self.browser_base.get("api_key"),
self.browser_base.get("project_id"), [source])
@ -283,7 +284,8 @@ class FetchNode(BaseNode):
document = loader.load()
if not document or not document[0].page_content.strip():
raise ValueError("No HTML body content found in the document fetched by ChromiumLoader.")
raise ValueError("""No HTML body content found in
the document fetched by ChromiumLoader.""")
parsed_content = document[0].page_content
if (isinstance(self.llm_model, ChatOpenAI) or isinstance(self.llm_model, AzureChatOpenAI)) and not self.script_creator or self.force and not self.script_creator and not self.openai_md_enabled:

View File

@ -14,7 +14,6 @@ To disable sending telemetry there are three ways:
or:
export SCRAPEGRAPHAI_TELEMETRY_ENABLED=false
"""
import configparser
import functools
import importlib.metadata
@ -68,14 +67,16 @@ def _check_config_and_environ_for_telemetry_flag(
try:
telemetry_enabled = config_obj.getboolean("DEFAULT", "telemetry_enabled")
except ValueError as e:
logger.debug(f"Unable to parse value for `telemetry_enabled` from config. Encountered {e}")
logger.debug(f"""Unable to parse value for
`telemetry_enabled` from config. Encountered {e}""")
if os.environ.get("SCRAPEGRAPHAI_TELEMETRY_ENABLED") is not None:
env_value = os.environ.get("SCRAPEGRAPHAI_TELEMETRY_ENABLED")
config_obj["DEFAULT"]["telemetry_enabled"] = env_value
try:
telemetry_enabled = config_obj.getboolean("DEFAULT", "telemetry_enabled")
except ValueError as e:
logger.debug(f"Unable to parse value for `SCRAPEGRAPHAI_TELEMETRY_ENABLED` from environment. Encountered {e}")
logger.debug(f"""Unable to parse value for `SCRAPEGRAPHAI_TELEMETRY_ENABLED`
from environment. Encountered {e}""")
return telemetry_enabled
@ -94,7 +95,6 @@ BASE_PROPERTIES = {
"telemetry_version": "0.0.3",
}
def disable_telemetry():
"""
function for disabling the telemetries
@ -102,7 +102,6 @@ def disable_telemetry():
global g_telemetry_enabled
g_telemetry_enabled = False
def is_telemetry_enabled() -> bool:
"""
function for checking if a telemetry is enables
@ -122,7 +121,6 @@ def is_telemetry_enabled() -> bool:
else:
return False
def _send_event_json(event_json: dict):
headers = {
"Content-Type": "application/json",
@ -141,7 +139,6 @@ def _send_event_json(event_json: dict):
else:
logger.debug(f"Telemetry data sent: {data}")
def send_event_json(event_json: dict):
"""
fucntion for sending event json
@ -154,7 +151,6 @@ def send_event_json(event_json: dict):
except Exception as e:
logger.debug(f"Failed to send telemetry data in a thread: {e}")
def log_event(event: str, properties: Dict[str, any]):
"""
function for logging the events
@ -167,7 +163,6 @@ def log_event(event: str, properties: Dict[str, any]):
}
send_event_json(event_json)
def log_graph_execution(graph_name: str, source: str, prompt:str, schema:dict,
llm_model: str, embedder_model: str, source_type: str,
execution_time: float, content: str = None, response: dict = None,
@ -193,8 +188,10 @@ def log_graph_execution(graph_name: str, source: str, prompt:str, schema:dict,
}
log_event("graph_execution", properties)
def capture_function_usage(call_fn: Callable) -> Callable:
"""
function that captures the usage
"""
@functools.wraps(call_fn)
def wrapped_fn(*args, **kwargs):
try:

View File

@ -24,6 +24,7 @@ def convert_to_md(html: str, url: str = None) -> str:
h = html2text.HTML2Text()
h.ignore_links = False
h.body_width = 0
if url is not None:
parsed_url = urlparse(url)
domain = f"{parsed_url.scheme}://{parsed_url.netloc}"

View File

@ -1,3 +1,6 @@
"""
copy module
"""
import copy
from typing import Any

View File

@ -3,7 +3,6 @@ Parse_state_key module
"""
import re
def parse_expression(expression, state: dict) -> list:
"""
Parses a complex boolean expression involving state keys.
@ -22,7 +21,8 @@ def parse_expression(expression, state: dict) -> list:
Example:
>>> parse_expression("user_input & (relevant_chunks | parsed_document | document)",
{"user_input": None, "document": None, "parsed_document": None, "relevant_chunks": None})
{"user_input": None, "document": None,
"parsed_document": None, "relevant_chunks": None})
['user_input', 'relevant_chunks', 'parsed_document', 'document']
This function evaluates the expression to determine the
@ -69,7 +69,6 @@ def parse_expression(expression, state: dict) -> list:
return [elem.strip() for elem in and_segment if elem.strip() in state]
return []
# Helper function to evaluate expressions with parentheses
def evaluate_expression(expression):
while '(' in expression:
start = expression.rfind('(')

View File

@ -4,7 +4,6 @@ This utility function saves the byte response as an audio file.
from pathlib import Path
from typing import Union
def save_audio_from_bytes(byte_response: bytes, output_path: Union[str, Path]) -> None:
"""
Saves the byte response as an audio file to the specified path.

View File

@ -5,7 +5,6 @@ from typing import List
import tiktoken
from ..helpers.models_tokens import models_tokens
def truncate_text_tokens(text: str, model: str, encoding_name: str) -> List[str]:
"""
Truncates text into chunks that are small enough to be processed by specified llm models.