Merge branch 'pre/beta' into ligthweigthing_library

This commit is contained in:
Marco Vinciguerra 2024-08-23 12:22:31 +02:00 committed by GitHub
commit 26de5dd623
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 40 additions and 104 deletions

View File

@ -149,6 +149,7 @@ class AbstractGraph(ABC):
"ollama", "oneapi", "nvidia", "groq", "google_vertexai",
"bedrock", "mistralai", "hugging_face", "deepseek", "ernie", "fireworks"]
if llm_params["model"].split("/")[0] not in known_models and llm_params["model"].split("-")[0] not in known_models:
raise ValueError(f"Model '{llm_params['model']}' is not supported")

View File

@ -9,7 +9,8 @@ from langchain_core.runnables import RunnableParallel
from tqdm import tqdm
from ..utils.logging import get_logger
from .base_node import BaseNode
from ..prompts.generate_answer_node_csv_prompts import TEMPLATE_CHUKS_CSV, TEMPLATE_NO_CHUKS_CSV, TEMPLATE_MERGE_CSV
from ..prompts.generate_answer_node_csv_prompts import (TEMPLATE_CHUKS_CSV,
TEMPLATE_NO_CHUKS_CSV, TEMPLATE_MERGE_CSV)
class GenerateAnswerCSVNode(BaseNode):
"""
@ -95,14 +96,14 @@ class GenerateAnswerCSVNode(BaseNode):
else:
output_parser = JsonOutputParser()
TEMPLATE_NO_CHUKS_CSV_prompt = TEMPLATE_NO_CHUKS_CSV
TEMPLATE_CHUKS_CSV_prompt = TEMPLATE_CHUKS_CSV
TEMPLATE_MERGE_CSV_prompt = TEMPLATE_MERGE_CSV
TEMPLATE_NO_CHUKS_CSV_PROMPT = TEMPLATE_NO_CHUKS_CSV
TEMPLATE_CHUKS_CSV_PROMPT = TEMPLATE_CHUKS_CSV
TEMPLATE_MERGE_CSV_PROMPT = TEMPLATE_MERGE_CSV
if self.additional_info is not None:
TEMPLATE_NO_CHUKS_CSV_prompt = self.additional_info + TEMPLATE_NO_CHUKS_CSV
TEMPLATE_CHUKS_CSV_prompt = self.additional_info + TEMPLATE_CHUKS_CSV
TEMPLATE_MERGE_CSV_prompt = self.additional_info + TEMPLATE_MERGE_CSV
TEMPLATE_NO_CHUKS_CSV_PROMPT = self.additional_info + TEMPLATE_NO_CHUKS_CSV
TEMPLATE_CHUKS_CSV_PROMPT = self.additional_info + TEMPLATE_CHUKS_CSV
TEMPLATE_MERGE_CSV_PROMPT = self.additional_info + TEMPLATE_MERGE_CSV
format_instructions = output_parser.get_format_instructions()
@ -110,7 +111,7 @@ class GenerateAnswerCSVNode(BaseNode):
if len(doc) == 1:
prompt = PromptTemplate(
template=TEMPLATE_NO_CHUKS_CSV_prompt,
template=TEMPLATE_NO_CHUKS_CSV_PROMPT,
input_variables=["question"],
partial_variables={
"context": doc,
@ -127,7 +128,7 @@ class GenerateAnswerCSVNode(BaseNode):
tqdm(doc, desc="Processing chunks", disable=not self.verbose)
):
prompt = PromptTemplate(
template=TEMPLATE_CHUKS_CSV_prompt,
template=TEMPLATE_CHUKS_CSV_PROMPT,
input_variables=["question"],
partial_variables={
"context": chunk,
@ -144,7 +145,7 @@ class GenerateAnswerCSVNode(BaseNode):
batch_results = async_runner.invoke({"question": user_prompt})
merge_prompt = PromptTemplate(
template = TEMPLATE_MERGE_CSV_prompt,
template = TEMPLATE_MERGE_CSV_PROMPT,
input_variables=["context", "question"],
partial_variables={"format_instructions": format_instructions},
)
@ -153,4 +154,4 @@ class GenerateAnswerCSVNode(BaseNode):
answer = merge_chain.invoke({"context": batch_results, "question": user_prompt})
state.update({self.output[0]: answer})
return state
return state

View File

@ -67,10 +67,8 @@ class GenerateScraperNode(BaseNode):
self.logger.info(f"--- Executing {self.node_name} Node ---")
# Interpret input keys based on the provided input expression
input_keys = self.get_input_keys(state)
# Fetching data from the state based on the input keys
input_data = [state[key] for key in input_keys]
user_prompt = input_data[0]

View File

@ -58,10 +58,8 @@ class GetProbableTagsNode(BaseNode):
self.logger.info(f"--- Executing {self.node_name} Node ---")
# Interpret input keys based on the provided input expression
input_keys = self.get_input_keys(state)
# Fetching data from the state based on the input keys
input_data = [state[key] for key in input_keys]
user_prompt = input_data[0]
@ -88,10 +86,8 @@ class GetProbableTagsNode(BaseNode):
},
)
# Execute the chain to get probable tags
tag_answer = tag_prompt | self.llm_model | output_parser
probable_tags = tag_answer.invoke({"question": user_prompt})
# Update the dictionary with probable tags
state.update({self.output[0]: probable_tags})
return state

View File

@ -103,7 +103,6 @@ class GraphIteratorNode(BaseNode):
if graph_instance is None:
raise ValueError("graph instance is required for concurrent execution")
# Assign depth level to the graph
if "graph_depth" in graph_instance.config:
graph_instance.config["graph_depth"] += 1
else:
@ -113,14 +112,12 @@ class GraphIteratorNode(BaseNode):
participants = []
# semaphore to limit the number of concurrent tasks
semaphore = asyncio.Semaphore(batchsize)
async def _async_run(graph):
async with semaphore:
return await asyncio.to_thread(graph.run)
# creates a deepcopy of the graph instance for each endpoint
for url in urls:
instance = copy.copy(graph_instance)
instance.source = url

View File

@ -56,21 +56,17 @@ class MergeAnswersNode(BaseNode):
self.logger.info(f"--- Executing {self.node_name} Node ---")
# Interpret input keys based on the provided input expression
input_keys = self.get_input_keys(state)
# Fetching data from the state based on the input keys
input_data = [state[key] for key in input_keys]
user_prompt = input_data[0]
answers = input_data[1]
# merge the answers in one string
answers_str = ""
for i, answer in enumerate(answers):
answers_str += f"CONTENT WEBSITE {i+1}: {answer}\n"
# Initialize the output parser
if self.node_config.get("schema", None) is not None:
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])
else:
@ -90,6 +86,5 @@ class MergeAnswersNode(BaseNode):
merge_chain = prompt_template | self.llm_model | output_parser
answer = merge_chain.invoke({"user_prompt": user_prompt})
# Update the state with the generated answer
state.update({self.output[0]: answer})
return state

View File

@ -59,13 +59,11 @@ class ParseNode(BaseNode):
self.logger.info(f"--- Executing {self.node_name} Node ---")
# Interpret input keys based on the provided input expression
input_keys = self.get_input_keys(state)
# Fetching data from the state based on the input keys
input_data = [state[key] for key in input_keys]
# Parse the document
docs_transformed = input_data[0]
if self.parse_html:
docs_transformed = Html2TextTransformer().transform_documents(input_data[0])
docs_transformed = docs_transformed[0]
@ -77,7 +75,6 @@ class ParseNode(BaseNode):
else:
docs_transformed = docs_transformed[0]
# Adapt the chunk size, leaving room for the reply, the prompt and the schema
chunk_size = self.node_config.get("chunk_size", 4096)
chunk_size = min(chunk_size - 500, int(chunk_size * 0.9))

View File

@ -80,10 +80,8 @@ class RAGNode(BaseNode):
self.logger.info(f"--- Executing {self.node_name} Node ---")
# Interpret input keys based on the provided input expression
input_keys = self.get_input_keys(state)
# Fetching data from the state based on the input keys
input_data = [state[key] for key in input_keys]
user_prompt = input_data[0]
@ -102,7 +100,6 @@ class RAGNode(BaseNode):
self.logger.info("--- (updated chunks metadata) ---")
# check if embedder_model is provided, if not use llm_model
if self.embedder_model is not None:
embeddings = self.embedder_model
elif 'embeddings' in self.node_config:
@ -144,23 +141,17 @@ class RAGNode(BaseNode):
pipeline_compressor = DocumentCompressorPipeline(
transformers=[redundant_filter, relevant_filter]
)
# redundant + relevant filter compressor
compression_retriever = ContextualCompressionRetriever(
base_compressor=pipeline_compressor, base_retriever=retriever
)
# relevant filter compressor only
# compression_retriever = ContextualCompressionRetriever(
# base_compressor=relevant_filter, base_retriever=retriever
# )
compressed_docs = compression_retriever.invoke(user_prompt)
self.logger.info("--- (tokens compressed and vector stored) ---")
state.update({self.output[0]: compressed_docs})
return state
def _create_default_embedder(self, llm_config=None) -> object:
"""
@ -223,7 +214,6 @@ class RAGNode(BaseNode):
embedder_params = {**embedder_config}
if "model_instance" in embedder_config:
return embedder_params["model_instance"]
# Instantiate the embedding model based on the model name
if "openai" in embedder_params["model"]:
return OpenAIEmbeddings(api_key=embedder_params["api_key"])
if "azure" in embedder_params["model"]:

View File

@ -75,10 +75,8 @@ class RobotsNode(BaseNode):
self.logger.info(f"--- Executing {self.node_name} Node ---")
# Interpret input keys based on the provided input expression
input_keys = self.get_input_keys(state)
# Fetching data from the state based on the input keys
input_data = [state[key] for key in input_keys]
source = input_data[0]

View File

@ -67,7 +67,6 @@ class SearchInternetNode(BaseNode):
input_keys = self.get_input_keys(state)
# Fetching data from the state based on the input keys
input_data = [state[key] for key in input_keys]
user_prompt = input_data[0]
@ -79,10 +78,8 @@ class SearchInternetNode(BaseNode):
input_variables=["user_prompt"],
)
# Execute the chain to get the search query
search_answer = search_prompt | self.llm_model | output_parser
# Ollama: Use no json format when creating the search query
if isinstance(self.llm_model, ChatOllama) and self.llm_model.format == 'json':
self.llm_model.format = None
search_query = search_answer.invoke({"user_prompt": user_prompt})[0]
@ -96,9 +93,7 @@ class SearchInternetNode(BaseNode):
search_engine=self.search_engine)
if len(answer) == 0:
# raise an exception if no answer is found
raise ValueError("Zero results found for the search query.")
# Update the state with the generated answer
state.update({self.output[0]: answer})
return state

View File

@ -49,7 +49,6 @@ class SearchLinkNode(BaseNode):
self.filter_config = {**default_filters.filter_dict, **provided_filter_config}
self.filter_links = True
else:
# Skip filtering if not enabled
self.filter_config = None
self.filter_links = False
@ -58,29 +57,26 @@ class SearchLinkNode(BaseNode):
def _is_same_domain(self, url, domain):
if not self.filter_links or not self.filter_config.get("diff_domain_filter", True):
return True # Skip the domain filter if not enabled
return True
parsed_url = urlparse(url)
parsed_domain = urlparse(domain)
return parsed_url.netloc == parsed_domain.netloc
def _is_image_url(self, url):
if not self.filter_links:
return False # Skip image filtering if filtering is not enabled
return False
image_extensions = self.filter_config.get("img_exts", [])
return any(url.lower().endswith(ext) for ext in image_extensions)
def _is_language_url(self, url):
if not self.filter_links:
return False # Skip language filtering if filtering is not enabled
return False
lang_indicators = self.filter_config.get("lang_indicators", [])
parsed_url = urlparse(url)
query_params = parse_qs(parsed_url.query)
# Check if the URL path or query string indicates a language-specific version
return any(indicator in parsed_url.path.lower() or indicator in query_params for indicator in lang_indicators)
def _is_potentially_irrelevant(self, url):
if not self.filter_links:
return False # Skip irrelevant URL filtering if filtering is not enabled
@ -88,12 +84,11 @@ class SearchLinkNode(BaseNode):
irrelevant_keywords = self.filter_config.get("irrelevant_keywords", [])
return any(keyword in url.lower() for keyword in irrelevant_keywords)
def execute(self, state: dict) -> dict:
"""
Filter out relevant links from the webpage that are relavant to prompt. Out of the filtered links, also
ensure that all links are navigable.
Filter out relevant links from the webpage that are relavant to prompt.
Out of the filtered links, also ensure that all links are navigable.
Args:
state (dict): The current state of the graph. The input keys will be used to fetch the
correct data types from the state.
@ -108,7 +103,6 @@ class SearchLinkNode(BaseNode):
self.logger.info(f"--- Executing {self.node_name} Node ---")
parsed_content_chunks = state.get("doc")
source_url = state.get("url") or state.get("local_dir")
output_parser = JsonOutputParser()
@ -148,7 +142,7 @@ class SearchLinkNode(BaseNode):
except Exception as e:
# Fallback approach: Using the LLM to extract links
self.logger.error(f"Error extracting links: {e}. Falling back to LLM.")
merge_prompt = PromptTemplate(
template=TEMPLATE_RELEVANT_LINKS,
input_variables=["content", "user_prompt"],

View File

@ -58,10 +58,8 @@ class SearchLinksWithContext(BaseNode):
self.logger.info(f"--- Executing {self.node_name} Node ---")
# Interpret input keys based on the provided input expression
input_keys = self.get_input_keys(state)
# Fetching data from the state based on the input keys
input_data = [state[key] for key in input_keys]
doc = input_data[1]
@ -71,7 +69,6 @@ class SearchLinksWithContext(BaseNode):
result = []
# Use tqdm to add progress bar
for i, chunk in enumerate(
tqdm(doc, desc="Processing chunks", disable=not self.verbose)
):

View File

@ -43,7 +43,8 @@ class TextToSpeechNode(BaseNode):
correct data types from the state.
Returns:
dict: The updated state with the output key containing the audio generated from the text.
dict: The updated state with the output
key containing the audio generated from the text.
Raises:
KeyError: If the input keys are not found in the state, indicating that the
@ -52,15 +53,11 @@ class TextToSpeechNode(BaseNode):
self.logger.info(f"--- Executing {self.node_name} Node ---")
# Interpret input keys based on the provided input expression
input_keys = self.get_input_keys(state)
# Fetching data from the state based on the input keys
input_data = [state[key] for key in input_keys]
# get the text to translate
text2translate = str(next(iter(input_data[0].values())))
# text2translate = str(input_data[0])
audio = self.tts_model.run(text2translate)

View File

@ -28,35 +28,28 @@ def cleanup_html(html_content: str, base_url: str) -> str:
soup = BeautifulSoup(html_content, 'html.parser')
# Title Extraction
title_tag = soup.find('title')
title = title_tag.get_text() if title_tag else ""
# Script and Style Tag Removal
for tag in soup.find_all(['script', 'style']):
tag.extract()
# Links extraction
link_urls = [urljoin(base_url, link['href']) for link in soup.find_all('a', href=True)]
# Images extraction
images = soup.find_all('img')
image_urls = []
for image in images:
if 'src' in image.attrs:
# if http or https is not present in the image url, join it with the base url
if 'http' not in image['src']:
image_urls.append(urljoin(base_url, image['src']))
else:
image_urls.append(image['src'])
# Body Extraction (if it exists)
body_content = soup.find('body')
if body_content:
# Minify the HTML within the body tag
minimized_body = minify(str(body_content))
return title, minimized_body, link_urls, image_urls
else:
raise ValueError(f"""No HTML body content found, please try setting the 'headless'
raise ValueError(f"""No HTML body content found, please try setting the 'headless'
flag to False in the graph configuration. HTML content: {html_content}""")

View File

@ -29,9 +29,8 @@ def convert_to_csv(data: dict, filename: str, position: str = None) -> None:
"""
if ".csv" in filename:
filename = filename.replace(".csv", "") # Remove .csv extension
filename = filename.replace(".csv", "")
# Get the directory of the caller script if position is not provided
if position is None:
caller_dir = os.path.dirname(os.path.abspath(sys.argv[0]))
position = caller_dir
@ -40,7 +39,7 @@ def convert_to_csv(data: dict, filename: str, position: str = None) -> None:
if not isinstance(data, dict):
raise TypeError("Input data must be a dictionary")
os.makedirs(position, exist_ok=True) # Create directory if needed
os.makedirs(position, exist_ok=True)
df = pd.DataFrame.from_dict(data, orient='index')
df.to_csv(os.path.join(position, f"{filename}.csv"), index=False)
@ -52,4 +51,4 @@ def convert_to_csv(data: dict, filename: str, position: str = None) -> None:
raise PermissionError(
f"You don't have permission to write to '{position}'.") from pe
except Exception as e:
raise e # Re-raise other potential errors
raise e

View File

@ -28,15 +28,15 @@ def convert_to_json(data: dict, filename: str, position: str = None) -> None:
Saves a JSON file named 'output.json' at '/path/to/save'.
Notes:
This function automatically ensures the directory exists before attempting to write the file. If the directory does not exist, it will attempt to create it.
This function automatically ensures the directory exists before
attempting to write the file.
If the directory does not exist, it will attempt to create it.
"""
if ".json" in filename:
filename = filename.replace(".json", "") # Remove .json extension
# Get the directory of the caller script
if position is None:
# Get directory of the main script
caller_dir = os.path.dirname(os.path.abspath(sys.argv[0]))
position = caller_dir

View File

@ -18,7 +18,8 @@ def convert_to_md(html: str, url: str = None) -> str:
<h1>This is a heading.</h1></body></html>")
'This is a paragraph.\n\n# This is a heading.'
Note: All the styles and links are ignored during the conversion. """
Note: All the styles and links are ignored during the conversion.
"""
h = html2text.HTML2Text()
h.ignore_links = False

View File

@ -48,7 +48,6 @@ def _set_library_root_logger() -> None:
DEFAULT_HANDLER = logging.StreamHandler() # sys.stderr as stream
# https://github.com/pyinstaller/pyinstaller/issues/7334#issuecomment-1357447176
if sys.stderr is None:
sys.stderr = open(os.devnull, "w", encoding="utf-8")
@ -66,7 +65,8 @@ def get_logger(name: Optional[str] = None) -> logging.Logger:
If no name is provided, the root logger for the library is returned.
Args:
name (Optional[str]): The name of the logger. If None, the root logger for the library is returned.
name (Optional[str]): The name of the logger.
If None, the root logger for the library is returned.
Returns:
logging.Logger: The logger with the specified name.
@ -199,7 +199,8 @@ def warning_once(self, *args, **kwargs):
"""
Emit a warning log with the same message only once.
This function is added as a method to the logging.Logger class. It emits a warning log with the same message only once,
This function is added as a method to the logging.Logger class.
It emits a warning log with the same message only once,
even if it is called multiple times with the same message.
Args:

View File

@ -31,11 +31,9 @@ def parse_expression(expression, state: dict) -> list:
incorrect adjacency of operators, and empty expressions.
"""
# Check for empty expression
if not expression:
raise ValueError("Empty expression.")
# Check for adjacent state keys without an operator between them
pattern = r'\b(' + '|'.join(re.escape(key) for key in state.keys()) + \
r')(\b\s*\b)(' + '|'.join(re.escape(key)
for key in state.keys()) + r')\b'
@ -43,37 +41,29 @@ def parse_expression(expression, state: dict) -> list:
raise ValueError(
"Adjacent state keys found without an operator between them.")
# Remove spaces
expression = expression.replace(" ", "")
# Check for operators with empty adjacent tokens or at the start/end
if expression[0] in '&|' or expression[-1] in '&|' or \
'&&' in expression or '||' in expression or \
'&|' in expression or '|&' in expression:
raise ValueError("Invalid operator usage.")
# Check for balanced parentheses and valid operator placement
open_parentheses = close_parentheses = 0
for i, char in enumerate(expression):
if char == '(':
open_parentheses += 1
elif char == ')':
close_parentheses += 1
# Check for invalid operator sequences
if char in "&|" and i + 1 < len(expression) and expression[i + 1] in "&|":
raise ValueError(
"Invalid operator placement: operators cannot be adjacent.")
# Check for missing or balanced parentheses
if open_parentheses != close_parentheses:
raise ValueError("Missing or unbalanced parentheses in expression.")
# Helper function to evaluate an expression without parentheses
def evaluate_simple_expression(exp):
# Split the expression by the OR operator and process each segment
for or_segment in exp.split('|'):
# Check if all elements in an AND segment are in state
and_segment = or_segment.split('&')
if all(elem.strip() in state for elem in and_segment):
return [elem.strip() for elem in and_segment if elem.strip() in state]
@ -85,9 +75,7 @@ def parse_expression(expression, state: dict) -> list:
start = expression.rfind('(')
end = expression.find(')', start)
sub_exp = expression[start + 1:end]
# Replace the evaluated part with a placeholder and then evaluate it
sub_result = evaluate_simple_expression(sub_exp)
# For simplicity in handling, join sub-results with OR to reprocess them later
expression = expression[:start] + \
'|'.join(sub_result) + expression[end+1:]
return evaluate_simple_expression(expression)
@ -97,7 +85,6 @@ def parse_expression(expression, state: dict) -> list:
if not temp_result:
raise ValueError("No state keys matched the expression.")
# Remove redundant state keys from the result, without changing their order
final_result = []
for key in temp_result:
if key not in final_result:

View File

@ -6,7 +6,7 @@ source code inspired by https://gist.github.com/DiTo97/46f4b733396b8d7a8f1d4d22d
import sys
import typing
import importlib.util # noqa: F401
import importlib.util
if typing.TYPE_CHECKING:
import types
@ -36,7 +36,6 @@ def srcfile_import(modpath: str, modname: str) -> "types.ModuleType":
module = importlib.util.module_from_spec(spec)
# adds the module to the global scope
sys.modules[modname] = module
spec.loader.exec_module(module)
@ -56,7 +55,7 @@ def dynamic_import(modname: str, message: str = "") -> None:
"""
if modname not in sys.modules:
try:
import importlib # noqa: F401
import importlib
module = importlib.import_module(modname)
sys.modules[modname] = module