fix: temporary fix for parse_node

This commit is contained in:
Marco Vinciguerra 2024-09-09 11:42:33 +02:00
parent fc738cacac
commit f2bb22d8e9
2 changed files with 6 additions and 76 deletions

View File

@ -1,14 +1,11 @@
""" """
ParseNode Module ParseNode Module
""" """
from typing import Tuple, List, Optional from typing import List, Optional
from urllib.parse import urljoin
import re
from semchunk import chunk from semchunk import chunk
from langchain_community.document_transformers import Html2TextTransformer from langchain_community.document_transformers import Html2TextTransformer
from langchain_core.documents import Document from langchain_core.documents import Document
from .base_node import BaseNode from .base_node import BaseNode
from ..helpers import default_filters
class ParseNode(BaseNode): class ParseNode(BaseNode):
""" """
@ -43,60 +40,6 @@ class ParseNode(BaseNode):
self.parse_html = ( self.parse_html = (
True if node_config is None else node_config.get("parse_html", True) True if node_config is None else node_config.get("parse_html", True)
) )
self.llm_model = node_config['llm_model']
self.parse_urls = (
False if node_config is None else node_config.get("parse_urls", False)
)
def _clean_urls(self, urls: List[str]) -> List[str]:
"""
Cleans the URLs extracted from the text.
Args:
urls (List[str]): The list of URLs to clean.
Returns:
List[str]: The cleaned URLs.
"""
cleaned_urls = []
for url in urls:
url = re.sub(r'.*?\]\(', '', url)
url = url.rstrip(').')
cleaned_urls.append(url)
return cleaned_urls
def extract_urls(self, text: str, source: str) -> Tuple[List[str], List[str]]:
"""
Extracts URLs from the given text.
Args:
text (str): The text to extract URLs from.
Returns:
Tuple[List[str], List[str]]: A tuple containing the extracted link URLs and image URLs.
"""
if not self.parse_urls:
return [], []
image_extensions = default_filters.filter_dict["img_exts"]
image_extension_seq = '|'.join(image_extensions).replace('.','')
url_pattern = re.compile(r'(https?://[^\s]+|\S+\.(?:' + image_extension_seq + '))')
all_urls = url_pattern.findall(text)
all_urls = self._clean_urls(all_urls)
if not source.startswith("http"):
all_urls = [url for url in all_urls if url.startswith("http")]
else:
all_urls = [urljoin(source, url) for url in all_urls]
images = [url for url in all_urls if any(url.endswith(ext) for ext in image_extensions)]
links = [url for url in all_urls if url not in images]
return links, images
def execute(self, state: dict) -> dict: def execute(self, state: dict) -> dict:
""" """
@ -119,46 +62,33 @@ class ParseNode(BaseNode):
input_keys = self.get_input_keys(state) input_keys = self.get_input_keys(state)
input_data = [state[key] for key in input_keys] input_data = [state[key] for key in input_keys]
docs_transformed = input_data[0] docs_transformed = input_data[0]
source = input_data[1] if self.parse_urls else None
def count_tokens(text):
from ..utils import token_count
return token_count(text, self.llm_model.model_name)
if self.parse_html: if self.parse_html:
docs_transformed = Html2TextTransformer(ignore_links=False).transform_documents(input_data[0]) docs_transformed = Html2TextTransformer().transform_documents(input_data[0])
docs_transformed = docs_transformed[0] docs_transformed = docs_transformed[0]
link_urls, img_urls = self.extract_urls(docs_transformed.page_content, source)
chunks = chunk(text=docs_transformed.page_content, chunks = chunk(text=docs_transformed.page_content,
chunk_size=self.node_config.get("chunk_size", 4096)-250, chunk_size=self.node_config.get("chunk_size", 4096)-250,
token_counter=count_tokens, token_counter=lambda text: len(text.split()),
memoize=False) memoize=False)
else: else:
docs_transformed = docs_transformed[0] docs_transformed = docs_transformed[0]
link_urls, img_urls = self.extract_urls(docs_transformed.page_content, source)
chunk_size = self.node_config.get("chunk_size", 4096) chunk_size = self.node_config.get("chunk_size", 4096)
chunk_size = min(chunk_size - 500, int(chunk_size * 0.9)) chunk_size = min(chunk_size - 500, int(chunk_size * 0.9))
if isinstance(docs_transformed, Document): if isinstance(docs_transformed, Document):
chunks = chunk(text=docs_transformed.page_content, chunks = chunk(text=docs_transformed.page_content,
chunk_size=chunk_size, chunk_size=chunk_size,
token_counter=count_tokens, token_counter=lambda text: len(text.split()),
memoize=False) memoize=False)
else: else:
chunks = chunk(text=docs_transformed, chunks = chunk(text=docs_transformed,
chunk_size=chunk_size, chunk_size=chunk_size,
token_counter=count_tokens, token_counter=lambda text: len(text.split()),
memoize=False) memoize=False)
state.update({self.output[0]: chunks}) state.update({self.output[0]: chunks})
if self.parse_urls:
state.update({self.output[1]: link_urls})
state.update({self.output[2]: img_urls})
return state return state