feat: refactoring of the code

This commit is contained in:
Marco Vinciguerra 2024-08-02 12:00:00 +02:00
parent 3e07f6273f
commit 9355507a2d
25 changed files with 65 additions and 109 deletions

View File

@ -86,7 +86,8 @@ class BaseNode(ABC):
Args:
param (dict): The dictionary to update node_config with.
overwrite (bool): Flag indicating if the values of node_config should be overwritten if their value is not None.
overwrite (bool): Flag indicating if the values of node_config
should be overwritten if their value is not None.
"""
for key, val in params.items():
if hasattr(self, key) and not overwrite:
@ -133,7 +134,8 @@ class BaseNode(ABC):
def _parse_input_keys(self, state: dict, expression: str) -> List[str]:
"""
Parses the input keys expression to extract relevant keys from the state based on logical conditions.
Parses the input keys expression to extract
relevant keys from the state based on logical conditions.
The expression can contain AND (&), OR (|), and parentheses to group conditions.
Args:

View File

@ -133,7 +133,7 @@ class FetchNode(BaseNode):
state.update({self.output[0]: compressed_document})
return state
elif input_keys[0] == "json":
f = open(source)
f = open(source, encoding="utf-8")
compressed_document = [
Document(page_content=str(json.load(f)), metadata={"source": "json"})
]
@ -181,12 +181,11 @@ class FetchNode(BaseNode):
if not response.text.strip():
raise ValueError("No HTML body content found in the response.")
parsed_content = response
if not self.cut:
parsed_content = cleanup_html(response, source)
if (isinstance(self.llm_model, ChatOpenAI) and not self.script_creator) or (self.force and not self.script_creator):
if (isinstance(self.llm_model, ChatOpenAI)
and not self.script_creator) or (self.force and not self.script_creator):
parsed_content = convert_to_md(source, input_data[0])
compressed_document = [Document(page_content=parsed_content)]
else:
@ -205,7 +204,8 @@ class FetchNode(BaseNode):
data = browser_base_fetch(self.browser_base.get("api_key"),
self.browser_base.get("project_id"), [source])
document = [Document(page_content=content, metadata={"source": source}) for content in data]
document = [Document(page_content=content,
metadata={"source": source}) for content in data]
else:
loader = ChromiumLoader([source], headless=self.headless, **loader_kwargs)
document = loader.load()
@ -215,10 +215,8 @@ class FetchNode(BaseNode):
parsed_content = document[0].page_content
if isinstance(self.llm_model, ChatOpenAI) and not self.script_creator or self.force and not self.script_creator and not self.openai_md_enabled:
parsed_content = convert_to_md(document[0].page_content, input_data[0])
compressed_document = [
Document(page_content=parsed_content, metadata={"source": "html file"})
]

View File

@ -3,18 +3,12 @@ gg
Module for generating the answer node
"""
# Imports from standard library
from typing import List, Optional
# Imports from Langchain
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.runnables import RunnableParallel
from tqdm import tqdm
from ..utils.logging import get_logger
# Imports from the library
from .base_node import BaseNode
from ..helpers.generate_answer_node_csv_prompts import template_chunks_csv, template_no_chunks_csv, template_merge_csv

View File

@ -1,7 +1,6 @@
"""
GenerateAnswerNode Module
"""
import asyncio
from typing import List, Optional
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser
@ -9,7 +8,6 @@ from langchain_core.runnables import RunnableParallel
from langchain_openai import ChatOpenAI
from langchain_community.chat_models import ChatOllama
from tqdm import tqdm
from langchain_openai import ChatOpenAI
from ..utils.logging import get_logger
from .base_node import BaseNode
from ..helpers import template_chunks, template_no_chunks, template_merge, template_chunks_md, template_no_chunks_md, template_merge_md
@ -130,7 +128,6 @@ class GenerateAnswerNode(BaseNode):
partial_variables={"context": chunk,
"chunk_id": i + 1,
"format_instructions": format_instructions})
# Add chain to dictionary with dynamic name
chain_name = f"chunk{i+1}"
chains_dict[chain_name] = prompt | self.llm_model | output_parser

View File

@ -113,7 +113,7 @@ class GenerateAnswerOmniNode(BaseNode):
chain = prompt | self.llm_model | output_parser
answer = chain.invoke({"question": user_prompt})
state.update({self.output[0]: answer})
return state
@ -148,4 +148,4 @@ class GenerateAnswerOmniNode(BaseNode):
answer = merge_chain.invoke({"context": batch_results, "question": user_prompt})
state.update({self.output[0]: answer})
return state
return state

View File

@ -2,18 +2,13 @@
Module for generating the answer node
"""
# Imports from standard library
from typing import List, Optional
# Imports from Langchain
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.runnables import RunnableParallel
from tqdm import tqdm
from langchain_community.chat_models import ChatOllama
from ..utils.logging import get_logger
# Imports from the library
from .base_node import BaseNode
from ..helpers.generate_answer_node_pdf_prompts import template_chunks_pdf, template_no_chunks_pdf, template_merge_pdf

View File

@ -83,7 +83,6 @@ class GenerateScraperNode(BaseNode):
user_prompt = input_data[0]
doc = input_data[1]
# schema to be used for output parsing
if self.node_config.get("schema", None) is not None:
output_schema = JsonOutputParser(pydantic_object=self.node_config["schema"])
else:
@ -130,7 +129,6 @@ class GenerateScraperNode(BaseNode):
)
map_chain = prompt | self.llm_model | StrOutputParser()
# Chain
answer = map_chain.invoke({"question": user_prompt})
state.update({self.output[0]: answer})

View File

@ -1,7 +1,6 @@
"""
GetProbableTagsNode Module
"""
from typing import List, Optional
from langchain.output_parsers import CommaSeparatedListOutputParser
from langchain.prompts import PromptTemplate

View File

@ -5,13 +5,11 @@ GraphIterator Module
import asyncio
import copy
from typing import List, Optional
from tqdm.asyncio import tqdm
from ..utils.logging import get_logger
from .base_node import BaseNode
_default_batchsize = 16
DEFAULT_BATCHSIZE = 16
class GraphIteratorNode(BaseNode):
@ -51,13 +49,15 @@ class GraphIteratorNode(BaseNode):
the correct data from the state.
Returns:
dict: The updated state with the output key containing the results of the graph instances.
dict: The updated state with the output key c
ontaining the results of the graph instances.
Raises:
KeyError: If the input keys are not found in the state, indicating that the
necessary information for running the graph instances is missing.
KeyError: If the input keys are not found in the state,
indicating that thenecessary information for running
the graph instances is missing.
"""
batchsize = self.node_config.get("batchsize", _default_batchsize)
batchsize = self.node_config.get("batchsize", DEFAULT_BATCHSIZE)
self.logger.info(
f"--- Executing {self.node_name} Node with batchsize {batchsize} ---"

View File

@ -3,14 +3,14 @@ ImageToTextNode Module
"""
from typing import List, Optional
from ..utils.logging import get_logger
from .base_node import BaseNode
class ImageToTextNode(BaseNode):
"""
Retrieve images from a list of URLs and return a description of the images using an image-to-text model.
Retrieve images from a list of URLs and return a description of
the images using an image-to-text model.
Attributes:
llm_model: An instance of the language model client used for image-to-text conversion.

View File

@ -2,18 +2,10 @@
MergeAnswersNode Module
"""
# Imports from standard library
from typing import List, Optional
from tqdm import tqdm
# Imports from Langchain
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from tqdm import tqdm
from ..utils.logging import get_logger
# Imports from the library
from .base_node import BaseNode

View File

@ -5,15 +5,9 @@ MergeAnswersNode Module
# Imports from standard library
from typing import List, Optional
from tqdm import tqdm
# Imports from Langchain
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser, StrOutputParser
from tqdm import tqdm
from ..utils.logging import get_logger
# Imports from the library
from .base_node import BaseNode

View File

@ -75,23 +75,23 @@ class ParseNode(BaseNode):
chunks = chunk(text=docs_transformed.page_content,
chunk_size= self.node_config.get("chunk_size", 4096)-250,
token_counter=lambda x: len(x),
token_counter= lambda x: len(x),
memoize=False)
else:
docs_transformed = docs_transformed[0]
if type(docs_transformed) == Document:
if isinstance(docs_transformed, Document):
chunks = chunk(text=docs_transformed.page_content,
chunk_size= self.node_config.get("chunk_size", 4096)-250,
token_counter=lambda x: len(x),
token_counter= lambda x: len(x),
memoize=False)
else:
chunks = chunk(text=docs_transformed,
chunk_size= self.node_config.get("chunk_size", 4096)-250,
token_counter=lambda x: len(x),
token_counter= lambda x: len(x),
memoize=False)
state.update({self.output[0]: chunks})
return state

View File

@ -4,15 +4,9 @@ RobotsNode Module
from typing import List, Optional
from urllib.parse import urlparse
from langchain_community.document_loaders import AsyncChromiumLoader
from langchain.prompts import PromptTemplate
from langchain.output_parsers import CommaSeparatedListOutputParser
from langchain.output_parsers import CommaSeparatedListOutputParser
from langchain.prompts import PromptTemplate
from langchain_community.document_loaders import AsyncChromiumLoader
from ..helpers import robots_dictionary
from ..utils.logging import get_logger
from .base_node import BaseNode
@ -146,4 +140,4 @@ class RobotsNode(BaseNode):
self.logger.warning("\033[32m(Scraping this website is allowed)\033[0m")
state.update({self.output[0]: is_scrapable})
return state
return state

View File

@ -1,9 +1,7 @@
"""
SearchInternetNode Module
"""
from typing import List, Optional
from langchain.output_parsers import CommaSeparatedListOutputParser
from langchain.prompts import PromptTemplate
from langchain_community.chat_models import ChatOllama

View File

@ -2,19 +2,13 @@
SearchLinkNode Module
"""
# Imports from standard library
from typing import List, Optional
import re
from tqdm import tqdm
# Imports from Langchain
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.runnables import RunnableParallel
from ..utils.logging import get_logger
# Imports from the library
from .base_node import BaseNode

View File

@ -67,7 +67,6 @@ class SearchLinksWithContext(BaseNode):
# Fetching data from the state based on the input keys
input_data = [state[key] for key in input_keys]
user_prompt = input_data[0]
doc = input_data[1]
output_parser = CommaSeparatedListOutputParser()

View File

@ -1,13 +1,10 @@
"""
TextToSpeechNode Module
"""
from typing import List, Optional
from ..utils.logging import get_logger
from .base_node import BaseNode
class TextToSpeechNode(BaseNode):
"""
Converts text to speech using the specified text-to-speech model.

View File

@ -1,8 +1,8 @@
"""
convert_to_md modul
"""
import html2text
from urllib.parse import urlparse
import html2text
def convert_to_md(html: str, url: str = None) -> str:
""" Convert HTML to Markdown.

View File

@ -12,7 +12,7 @@ from typing import Optional
_library_name = __name__.split(".", maxsplit=1)[0]
_default_handler = None
DEFAULT_HANDLER = None
_default_logging_level = logging.WARNING
_semaphore = threading.Lock()
@ -23,22 +23,22 @@ def _get_library_root_logger() -> logging.Logger:
def _set_library_root_logger() -> None:
global _default_handler
global DEFAULT_HANDLER
with _semaphore:
if _default_handler:
if DEFAULT_HANDLER:
return
_default_handler = logging.StreamHandler() # sys.stderr as stream
DEFAULT_HANDLER = logging.StreamHandler() # sys.stderr as stream
# https://github.com/pyinstaller/pyinstaller/issues/7334#issuecomment-1357447176
if sys.stderr is None:
sys.stderr = open(os.devnull, "w")
sys.stderr = open(os.devnull, "w", encoding="utf-8")
_default_handler.flush = sys.stderr.flush
DEFAULT_HANDLER.flush = sys.stderr.flush
library_root_logger = _get_library_root_logger()
library_root_logger.addHandler(_default_handler)
library_root_logger.addHandler(DEFAULT_HANDLER)
library_root_logger.setLevel(_default_logging_level)
library_root_logger.propagate = False
@ -86,8 +86,8 @@ def set_handler(handler: logging.Handler) -> None:
_get_library_root_logger().addHandler(handler)
def set_default_handler() -> None:
set_handler(_default_handler)
def setDEFAULT_HANDLER() -> None:
set_handler(DEFAULT_HANDLER)
def unset_handler(handler: logging.Handler) -> None:
@ -98,8 +98,8 @@ def unset_handler(handler: logging.Handler) -> None:
_get_library_root_logger().removeHandler(handler)
def unset_default_handler() -> None:
unset_handler(_default_handler)
def unsetDEFAULT_HANDLER() -> None:
unset_handler(DEFAULT_HANDLER)
def set_propagation() -> None:

View File

@ -13,19 +13,22 @@ def parse_expression(expression, state: dict) -> list:
state (dict): Dictionary of state keys used to evaluate the expression.
Raises:
ValueError: If the expression is empty, has adjacent state keys without operators, invalid operator usage,
unbalanced parentheses, or if no state keys match the expression.
ValueError: If the expression is empty, has adjacent state keys without operators,
invalid operator usage, unbalanced parentheses, or if no state keys match the expression.
Returns:
list: A list of state keys that match the boolean expression, ensuring each key appears only once.
list: A list of state keys that match the boolean expression,
ensuring each key appears only once.
Example:
>>> parse_expression("user_input & (relevant_chunks | parsed_document | document)",
{"user_input": None, "document": None, "parsed_document": None, "relevant_chunks": None})
['user_input', 'relevant_chunks', 'parsed_document', 'document']
This function evaluates the expression to determine the logical inclusion of state keys based on provided boolean logic.
It checks for syntax errors such as unbalanced parentheses, incorrect adjacency of operators, and empty expressions.
This function evaluates the expression to determine the
logical inclusion of state keys based on provided boolean logic.
It checks for syntax errors such as unbalanced parentheses,
incorrect adjacency of operators, and empty expressions.
"""
# Check for empty expression

View File

@ -6,7 +6,6 @@ import ipaddress
import random
import re
from typing import List, Optional, Set, TypedDict
import requests
from fp.errors import FreeProxyException
from fp.fp import FreeProxy

View File

@ -1,3 +1,6 @@
"""
Research_web module
"""
import re
from typing import List
from langchain_community.tools import DuckDuckGoSearchResults
@ -5,13 +8,15 @@ from googlesearch import search as google_search
import requests
from bs4 import BeautifulSoup
def search_on_web(query: str, search_engine: str = "Google", max_results: int = 10, port: int = 8080) -> List[str]:
def search_on_web(query: str, search_engine: str = "Google",
max_results: int = 10, port: int = 8080) -> List[str]:
"""
Searches the web for a given query using specified search engine options.
Args:
query (str): The search query to find on the internet.
search_engine (str, optional): Specifies the search engine to use, options include 'Google', 'DuckDuckGo', 'Bing', or 'SearXNG'. Default is 'Google'.
search_engine (str, optional): Specifies the search engine to use,
options include 'Google', 'DuckDuckGo', 'Bing', or 'SearXNG'. Default is 'Google'.
max_results (int, optional): The maximum number of search results to return.
port (int, optional): The port number to use when searching with 'SearXNG'. Default is 8080.
@ -25,19 +30,19 @@ def search_on_web(query: str, search_engine: str = "Google", max_results: int =
>>> search_on_web("example query", search_engine="Google", max_results=5)
['http://example.com', 'http://example.org', ...]
"""
if search_engine.lower() == "google":
res = []
for url in google_search(query, stop=max_results):
res.append(url)
return res
elif search_engine.lower() == "duckduckgo":
research = DuckDuckGoSearchResults(max_results=max_results)
res = research.run(query)
links = re.findall(r'https?://[^\s,\]]+', res)
return links
elif search_engine.lower() == "bing":
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
@ -46,24 +51,24 @@ def search_on_web(query: str, search_engine: str = "Google", max_results: int =
response = requests.get(search_url, headers=headers)
response.raise_for_status()
soup = BeautifulSoup(response.text, "html.parser")
search_results = []
for result in soup.find_all('li', class_='b_algo', limit=max_results):
link = result.find('a')['href']
search_results.append(link)
return search_results
elif search_engine.lower() == "searxng":
url = f"http://localhost:{port}"
params = {"q": query, "format": "json"}
# Send the GET request to the server
response = requests.get(url, params=params)
# Parse the response and limit to the specified max_results
data = response.json()
limited_results = data["results"][:max_results]
return limited_results
else:
raise ValueError("The only search engines available are DuckDuckGo, Google, Bing, or SearXNG")

View File

@ -5,7 +5,7 @@ source code inspired by https://gist.github.com/DiTo97/46f4b733396b8d7a8f1d4d22d
import sys
import typing
import importlib.util # noqa: F401
if typing.TYPE_CHECKING:
import types
@ -24,9 +24,6 @@ def srcfile_import(modpath: str, modname: str) -> "types.ModuleType":
Raises:
ImportError: If the module cannot be imported from the srcfile
"""
import importlib.util # noqa: F401
#
spec = importlib.util.spec_from_file_location(modname, modpath)
if spec is None:

View File

@ -22,7 +22,8 @@ def truncate_text_tokens(text: str, model: str, encoding_name: str) -> List[str]
>>> truncate_text_tokens("This is a sample text for truncation.", "GPT-3", "EMBEDDING_ENCODING")
["This is a sample text", "for truncation."]
This function ensures that each chunk of text can be tokenized by the specified model without exceeding the model's token limit.
This function ensures that each chunk of text can be tokenized
by the specified model without exceeding the model's token limit.
"""
encoding = tiktoken.get_encoding(encoding_name)