feat: add finalize_node()

This commit is contained in:
EURAC\marperini 2024-04-29 10:05:04 +02:00
parent e778d27169
commit 6e7283ed8f
9 changed files with 241 additions and 30434 deletions

View File

@ -0,0 +1,155 @@
"""
Example of custom graph using existing nodes
"""
import os
from dotenv import load_dotenv
from scrapegraphai.models import OpenAI
from scrapegraphai.graphs import BaseGraph
from scrapegraphai.nodes import FetchNode, GenerateAnswerNode
load_dotenv()
# ************************************************
# Define the configuration for the graph
# ************************************************
openai_key = os.getenv("OPENAI_APIKEY")
graph_config = {
"llm": {
"api_key": openai_key,
"model": "gpt-3.5-turbo",
"temperature": 0,
"streaming": True
},
}
# ************************************************
# Define the graph nodes
# ************************************************
llm_model = OpenAI(graph_config["llm"])
# define the nodes for the graph
fetch_node = FetchNode(
input="url | local_dir",
output=["doc"],
)
generate_answer_node = GenerateAnswerNode(
input="user_prompt & (relevant_chunks | parsed_doc | doc)",
output=["answer"],
node_config={"llm": llm_model},
)
# ************************************************
# Create the graph by defining the connections
# ************************************************
graph = BaseGraph(
nodes={
fetch_node,
generate_answer_node,
},
edges={
(fetch_node, generate_answer_node)
},
entry_point=fetch_node
)
# ************************************************
# Execute the graph
# ************************************************
subtree_text = '''
div>div -> "This is a paragraph" \n
div>ul>li>a>span -> "This is a list item 1" \n
div>ul>li>a>span -> "This is a list item 2" \n
div>ul>li>a>span -> "This is a list item 3"
'''
subtree_simplified_html = '''
<div>
<div>This is a paragraph</div>
<ul>
<li>
<span>This is a list item 1</span>
</li>
<li>
<span>This is a list item 2</span>
</li>
<li>
<span>This is a list item 3</span>
</li>
</ul>
</div>
'''
subtree_dict_simple = {
"div": {
"text": {
"content": "This is a paragraph",
"path_to_fork": "div>div",
},
"ul": {
"path_to_fork": "div>ul",
"texts": [
{
"content": "This is a list item 1",
"path_to_fork": "ul>li>a>span",
},
{
"content": "This is a list item 2",
"path_to_fork": "ul>li>a>span",
},
{
"content": "This is a list item 3",
"path_to_fork": "ul>li>a>span",
}
]
}
}
}
subtree_dict_complex = {
"div": {
"text": {
"content": "This is a paragraph",
"path_to_fork": "div>div",
"attributes": {
"classes": ["paragraph"],
"ids": ["paragraph"],
"hrefs": ["https://www.example.com"]
}
},
"ul": {
"text1":{
"content": "This is a list item 1",
"path_to_fork": "ul>li>a>span",
"attributes": {
"classes": ["list-item", "item-1"],
"ids": ["item-1"],
"hrefs": ["https://www.example.com"]
}
},
"text2":{
"content": "This is a list item 2",
"path_to_fork": "ul>li>a>span",
"attributes": {
"classes": ["list-item", "item-2"],
"ids": ["item-2"],
"hrefs": ["https://www.example.com"]
}
}
}
}
}
result, execution_info = graph.execute({
"user_prompt": "How many list items are there in the document?",
"local_dir": str(subtree_dict_simple)
})
# get the answer from the result
result = result.get("answer", "No answer found.")
print(result)

View File

@ -46,21 +46,34 @@ def print_matches_side_by_side(matches):
# Usage example:
# *********************************************************************************************************************
loader = AsyncHtmlLoader('https://www.wired.com/category/science/')
loader = AsyncHtmlLoader('https://perinim.github.io/projects/')
document = loader.load()
html_content = document[0].page_content
curr_time = time.time()
# Instantiate a DOMTree with HTML content
dom_tree = DOMTree(html_content)
nodes, metadatas = dom_tree.collect_text_nodes() # Collect text nodes for analysis
for node, metadata in zip(nodes, metadatas):
print("Text:", node)
print("Metadata:", metadata)
# nodes, metadatas = dom_tree.collect_text_nodes() # Collect text nodes for analysis
# for node, metadata in zip(nodes, metadatas):
# print("Text:", node)
# print("Metadata:", metadata)
# sub_list = dom_tree.generate_subtree_dicts() # Generate subtree dictionaries for analysis
# print(sub_list)
# graph = dom_tree.visualize(exclude_tags=['script', 'style', 'meta', 'link'])
# subtrees = dom_tree.get_subtrees() # Retrieve subtrees rooted at fork nodes
subtrees = dom_tree.get_subtrees() # Retrieve subtrees rooted at fork nodes
print("Number of subtrees found:", len(subtrees))
# remove trees whos root node does not lead to any text
text_subtrees = [subtree for subtree in subtrees if subtree.root.leads_to_text]
print("Number of subtrees that lead to text:", len(text_subtrees))
direct_leaf_subtrees = [subtree for subtree in text_subtrees if subtree.root.has_direct_leaves]
print("Number of subtrees with direct leaves beneath fork nodes:", len(direct_leaf_subtrees))
for subtree in direct_leaf_subtrees:
print("Subtree rooted at:", subtree.root.value)
subtree.traverse(lambda node: print(node))
# Index subtrees by structure and content
# structure_index, content_index = index_subtrees(subtrees)
@ -83,4 +96,4 @@ print(f"Time taken to build DOM tree: {time.time() - curr_time:.2f} seconds")
# print("Subtree rooted at:", subtree.root.value)
# subtree.traverse(lambda node: print(node))
# Traverse the DOMTree and print each node
# dom_tree.traverse(lambda node: print(node))
# dom_tree.traverse(lambda node: print(node))

View File

@ -15,7 +15,9 @@ class DOMTree(Tree):
elif isinstance(child, NavigableString):
text = child.strip()
if text:
tree_node.add_child(TreeNode(value='text', attributes={'content': text}))
new_node = TreeNode(value='text', attributes={'content': text})
tree_node.add_child(new_node)
new_node.finalize_node()
elif isinstance(child, Tag):
new_node = TreeNode(value=child.name, attributes=child.attrs)
tree_node.add_child(new_node)

View File

@ -16,6 +16,42 @@ class Tree:
# Retrieves all subtrees rooted at fork nodes
return self.root.get_subtrees() if self.root else []
def generate_subtree_dicts(self):
subtree_dicts = []
def aggregate_text_under_fork(fork_node):
text_aggregate = {
"content": [],
"path_to_fork": ""
}
for child in fork_node.children:
if child.value == 'text':
text_aggregate["content"].append(child.attributes['content'])
elif child.is_fork:
continue
else:
for sub_child in child.children:
text_aggregate["content"].append(sub_child.attributes)
text_aggregate["path_to_fork"] = fork_node.closest_fork_path
return text_aggregate
def process_node(node):
if node.is_fork:
texts = aggregate_text_under_fork(node)
if texts["content"]: # Only add if there's text content
subtree_dicts.append({
node.value: {
"text": texts,
"path_to_fork": texts["path_to_fork"],
}
})
for child in node.children:
process_node(child)
process_node(self.root)
return subtree_dicts
def visualize(self, exclude_tags = ['script']):
def add_nodes_edges(tree_node, graph):
if tree_node:
@ -49,7 +85,7 @@ class Tree:
# Initialize Digraph, set graph and node attributes
graph = Digraph()
graph.attr(size='10,10', dpi='300') # Set higher DPI for better image resolution
# graph.attr(size='10,10', dpi='300') # Set higher DPI for better image resolution
graph.attr('node', style='filled', fontname='Helvetica')
graph.attr('edge', fontname='Helvetica')

View File

@ -7,7 +7,10 @@ class TreeNode:
self.children = children if children is not None else []
self.parent = parent
self.depth = depth
# Flag to track if the subtree leads to text
self.leads_to_text = False
# Flags to track if the subtree has a direct leaf node
self.has_direct_leaves = False
self.root_path = self._compute_root_path()
self.closest_fork_path = self._compute_fork_path()
self.structure_hash = None
@ -54,14 +57,26 @@ class TreeNode:
current = current.parent
path.append(current.value) # Add the fork or root node
return '>'.join(reversed(path))
def get_subtrees(self):
def finalize_node(self):
if self.is_text and self.is_leaf:
self.update_direct_leaves_flag()
def update_direct_leaves_flag(self):
ancestor = self.parent
while ancestor and len(ancestor.children) == 1:
ancestor = ancestor.parent
if ancestor and ancestor.is_fork:
ancestor.has_direct_leaves = True
def get_subtrees(self, direct_leaves=False):
# This method finds and returns subtrees rooted at this node and all descendant forks
# Optionally filters to include only those with direct leaves beneath fork nodes
subtrees = []
if self.is_fork:
if self.is_fork and (not direct_leaves or self.has_direct_leaves):
subtrees.append(Tree(root=self))
for child in self.children:
subtrees.extend(child.get_subtrees())
subtrees.extend(child.get_subtrees(direct_leaves=direct_leaves))
return subtrees
def hash_subtree_structure(self, node):
@ -84,7 +99,7 @@ class TreeNode:
return text
def __repr__(self):
return f"TreeNode(value={self.value}, leads_to_text={self.leads_to_text}, depth={self.depth}, root_path={self.root_path}, closest_fork_path={self.closest_fork_path})"
return f"TreeNode(value={self.value}, leads_to_text={self.leads_to_text}, is_fork={self.is_fork})"
@property
def is_fork(self):
@ -92,4 +107,8 @@ class TreeNode:
@property
def is_leaf(self):
return len(self.children) == 0
return len(self.children) == 0
@property
def is_text(self):
return self.value == 'text'

View File

@ -72,7 +72,7 @@ class FetchNode(BaseNode):
# if it is a local directory
if not source.startswith("http"):
compressedDocument = [Document(page_content=remover(source), metadata={
compressedDocument = [Document(page_content=source, metadata={
"source": "local_dir"
})]

File diff suppressed because it is too large Load Diff

Binary file not shown.

Before

Width:  |  Height:  |  Size: 79 KiB

File diff suppressed because it is too large Load Diff

Before

Width:  |  Height:  |  Size: 1.3 MiB