Merge pull request #39 from VinciGit00/fix-bug-merge

Fix bug merge
This commit is contained in:
Marco Vinciguerra 2024-03-12 11:57:47 +01:00 committed by GitHub
commit 55702b28a0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 95 additions and 156 deletions

View File

@ -6,7 +6,7 @@ import os
from dotenv import load_dotenv
from scrapegraphai.models import OpenAI
from scrapegraphai.graphs import BaseGraph
from scrapegraphai.nodes import FetchHTMLNode, ParseHTMLNode, GenerateAnswerNode
from scrapegraphai.nodes import FetchHTMLNode, ParseNode, RAGNode, GenerateAnswerNode
load_dotenv()
@ -22,7 +22,8 @@ model = OpenAI(llm_config)
# define the nodes for the graph
fetch_html_node = FetchHTMLNode("fetch_html")
parse_document_node = ParseHTMLNode("parse_document")
parse_document_node = ParseNode(doc_type="html", chunks_size=4000, node_name="parse_document")
rag_node = RAGNode(model, "rag")
generate_answer_node = GenerateAnswerNode(model, "generate_answer")
# create the graph
@ -30,18 +31,20 @@ graph = BaseGraph(
nodes={
fetch_html_node,
parse_document_node,
rag_node,
generate_answer_node
},
edges={
(fetch_html_node, parse_document_node),
(parse_document_node, generate_answer_node)
(parse_document_node, rag_node),
(rag_node, generate_answer_node)
},
entry_point=fetch_html_node
)
# execute the graph
inputs = {"user_input": "Give me the news",
"url": "https://www.ansa.it/sito/notizie/topnews/index.shtml"}
inputs = {"user_input": "List me the projects with their description",
"url": "https://perinim.github.io/projects/"}
result = graph.execute(inputs)
# get the answer from the result

View File

@ -16,8 +16,8 @@ llm_config = {
}
# Define URL and PROMPT
URL = "https://www.google.com/search?client=safari&rls=en&q=ristoranti+trento&ie=UTF-8&oe=UTF-8"
PROMPT = "List me all the https inside the page"
URL = "https://www.ansa.it/veneto/"
PROMPT = "List me all the news with their description."
# Create the SmartScraperGraph instance
smart_scraper_graph = SmartScraperGraph(PROMPT, URL, llm_config)

View File

@ -5,6 +5,7 @@ from ..models import OpenAI
from .base_graph import BaseGraph
from ..nodes import (
FetchHTMLNode,
ParseNode,
RAGNode,
GenerateAnswerNode
)
@ -73,18 +74,22 @@ class SmartScraperGraph:
Returns:
BaseGraph: An instance of the BaseGraph class.
"""
# define the nodes for the graph
fetch_html_node = FetchHTMLNode("fetch_html")
parse_document_node = ParseNode(doc_type="html", chunks_size=4000, node_name="parse_document")
rag_node = RAGNode(self.llm, "rag")
generate_answer_node = GenerateAnswerNode(self.llm, "generate_answer")
return BaseGraph(
nodes={
fetch_html_node,
parse_document_node,
rag_node,
generate_answer_node,
},
edges={
(fetch_html_node, rag_node),
(fetch_html_node, parse_document_node),
(parse_document_node, rag_node),
(rag_node, generate_answer_node)
},
entry_point=fetch_html_node

View File

@ -6,6 +6,7 @@ from ..models import OpenAI, OpenAITextToSpeech
from .base_graph import BaseGraph
from ..nodes import (
FetchHTMLNode,
ParseNode,
RAGNode,
GenerateAnswerNode,
TextToSpeechNode,
@ -79,6 +80,7 @@ class SpeechSummaryGraph:
BaseGraph: An instance of the BaseGraph class.
"""
fetch_html_node = FetchHTMLNode("fetch_html")
parse_document_node = ParseNode(doc_type="html", chunks_size=4000, node_name="parse_document")
rag_node = RAGNode(self.llm, "rag")
generate_answer_node = GenerateAnswerNode(self.llm, "generate_answer")
text_to_speech_node = TextToSpeechNode(
@ -87,12 +89,14 @@ class SpeechSummaryGraph:
return BaseGraph(
nodes={
fetch_html_node,
parse_document_node,
rag_node,
generate_answer_node,
text_to_speech_node
},
edges={
(fetch_html_node, rag_node),
(fetch_html_node, parse_document_node),
(parse_document_node, rag_node),
(rag_node, generate_answer_node),
(generate_answer_node, text_to_speech_node)
},

View File

@ -20,12 +20,12 @@ nodes_metadata = {
},
"returns": "Updated state with probable HTML tags under 'tags' key."
},
"ParseHTMLNode": {
"description": "Parses HTML content to extract specific data.",
"ParseNode": {
"description": "Parses document content to extract specific data.",
"type": "node",
"args": {
"document": "HTML content as a string.",
"tags": "List of HTML tags to focus on during parsing."
"doc_type": "Type of the input document. Default is 'html'.",
"document": "The document content to be parsed.",
},
"returns": "Updated state with extracted data under 'parsed_document' key."
},
@ -38,7 +38,7 @@ nodes_metadata = {
"type": "node",
"args": {
"user_input": "The user's query or question guiding the retrieval.",
"document": "The HTML content to be processed and compressed."
"document": "The document content to be processed and compressed."
},
"returns": """Updated state with 'relevant_chunks' key containing
the most relevant text chunks."""
@ -48,7 +48,7 @@ nodes_metadata = {
"type": "node",
"args": {
"user_input": "User's query or question.",
"parsed_document": "Data extracted from the HTML document."
"parsed_document": "Data extracted from the input document."
},
"returns": "Updated state with the answer under 'answer' key."
},

View File

@ -5,9 +5,8 @@ from .fetch_html_node import FetchHTMLNode
from .conditional_node import ConditionalNode
from .get_probable_tags_node import GetProbableTagsNode
from .generate_answer_node import GenerateAnswerNode
from .parse_html_node import ParseHTMLNode
from .parse_node import ParseNode
from .rag_node import RAGNode
from .text_to_speech_node import TextToSpeechNode
from .image_to_text_node import ImageToTextNode
from .fetch_text_node import FetchTextNode
from .parse_text_node import ParseTextNode
from .fetch_text_node import FetchTextNode

View File

@ -81,10 +81,11 @@ class FetchHTMLNode(BaseNode):
loader = AsyncHtmlLoader(url)
document = loader.load()
metadata = document[0].metadata
document = remover(str(document[0]))
# metadata = document[0].metadata
# document = remover(str(document[0]))
state["document"] = [
Document(page_content=document, metadata=metadata)]
# state["document"] = [
# Document(page_content=document, metadata=metadata)]
state["document"] = document
return state

View File

@ -11,7 +11,6 @@ from langchain_core.runnables import RunnableParallel
# Imports from the library
from .base_node import BaseNode
from langchain.text_splitter import RecursiveCharacterTextSplitter
class GenerateAnswerNode(BaseNode):
@ -71,7 +70,7 @@ class GenerateAnswerNode(BaseNode):
print("---GENERATING ANSWER---")
try:
user_input = state["user_input"]
document = state["document_chunks"]
document = state["document"]
except KeyError as e:
print(f"Error: {e} not found in state.")
raise
@ -111,34 +110,28 @@ class GenerateAnswerNode(BaseNode):
prompt = PromptTemplate(
template=template_chunks,
input_variables=["question"],
partial_variables={"context": chunk,
partial_variables={"context": chunk.page_content,
"chunk_id": i + 1, "format_instructions": format_instructions},
)
# Dynamically name the chains based on their index
chains_dict[f"chunk{i+1}"] = prompt | self.llm | output_parser
chain_name = f"chunk{i+1}"
chains_dict[chain_name] = prompt | self.llm | output_parser
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
chunk_size=4000,
chunk_overlap=0,
)
chunks = text_splitter.split_text(str(chains_dict))
# Use dictionary unpacking to pass the dynamically named chains to RunnableParallel
map_chain = RunnableParallel(**chains_dict)
# Chain
answer_map = map_chain.invoke({"question": user_input})
# Merge the answers from the chunks
merge_prompt = PromptTemplate(
template=template_merge,
input_variables=["context", "question"],
partial_variables={"format_instructions": format_instructions},
)
merge_chain = merge_prompt | self.llm | output_parser
answer = merge_chain.invoke(
{"context": answer_map, "question": user_input})
answer_lines = []
for chunk in chunks:
answer_temp = merge_chain.invoke(
{"context": chunk, "question": user_input})
answer_lines.append(answer_temp)
unique_answer_lines = list(set(answer_lines))
answer = '\n'.join(unique_answer_lines)
# Update the state with the generated answer
state.update({"answer": answer})
return state
return state

View File

@ -6,11 +6,11 @@ from langchain_community.document_transformers import Html2TextTransformer
from .base_node import BaseNode
class ParseHTMLNode(BaseNode):
class ParseNode(BaseNode):
"""
A node responsible for parsing HTML content from a document using specified tags.
A node responsible for parsing HTML content from a document.
It uses BeautifulSoupTransformer for parsing, providing flexibility in extracting
specific parts of an HTML document based on the tags provided in the state.
specific parts of an HTML document.
This node enhances the scraping workflow by allowing for targeted extraction of
content, thereby optimizing the processing of large HTML documents.
@ -28,14 +28,18 @@ class ParseHTMLNode(BaseNode):
the specified tags, if provided, and updates the state with the parsed content.
"""
def __init__(self, node_name: str):
def __init__(self, doc_type: str = "html", chunks_size: int = 4000, node_name: str = "ParseHTMLNode"):
"""
Initializes the ParseHTMLNode with a node name.
Args:
doc_type (str): type of the input document
chunks_size (int): size of the chunks to split the document
node_name (str): name of the node
node_type (str, optional): type of the node
"""
super().__init__(node_name, "node")
self.doc_type = doc_type
self.chunks_size = chunks_size
def execute(self, state):
"""
@ -57,23 +61,27 @@ class ParseHTMLNode(BaseNode):
information for parsing is missing.
"""
print("---PARSING HTML DOCUMENT---")
print("---PARSING DOCUMENT---")
try:
document = state["document"]
except KeyError as e:
print(f"Error: {e} not found in state.")
raise
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
chunk_size=4000,
chunk_size=self.chunks_size,
chunk_overlap=0,
)
docs_transformed = Html2TextTransformer(
).transform_documents(document)[0]
# Parse the document based on the specified doc_type
if self.doc_type == "html":
docs_transformed = Html2TextTransformer(
).transform_documents(document)[0]
elif self.doc_type == "text":
docs_transformed = document
chunks = text_splitter.split_text(docs_transformed.page_content)
state.update({"document_chunks": chunks})
state.update({"parsed_document": chunks})
return state

View File

@ -1,76 +0,0 @@
"""
Module for parsing the HTML node
"""
from langchain.text_splitter import RecursiveCharacterTextSplitter
from .base_node import BaseNode
class ParseTextNode(BaseNode):
"""
A node for extracting content from HTML documents based on provided tags.
This node leverages the BeautifulSoupTransformer to offer flexible parsing
capabilities. It allows you to isolate specific elements within an HTML
document, making it valuable for targeted content extraction in scraping workflows.
Attributes:
node_name (str): Unique name for the node (defaults to "ParseHTMLNode").
node_type (str): Indicates a standard operational node (set to "node").
Args:
node_name (str, optional): Custom name for the node (defaults to "ParseHTMLNode").
Methods:
execute(state):
* Extracts content from the 'document' field in the state based on tags (if provided in the state).
* Stores the result in the 'parsed_document' field of the state.
* Employs the RecursiveCharacterTextSplitter for handling larger documents.
"""
def __init__(self, node_name: str = "ParseHTMLNode"):
"""
Initializes the ParseHTMLNode.
Args:
node_name (str, optional): Custom name for the node (defaults to "ParseHTMLNode").
"""
super().__init__(node_name, "node")
def execute(self, state):
"""
Parses HTML content and updates the state.
Args:
state (dict): Expects the following keys:
'document': The HTML content to parse.
'tags' (optional): A list of HTML tags to target for extraction.
Returns:
dict: Updated state with the following:
'parsed_document': The extracted content
(or the original document if no tags were provided).
'document_chunks': The original document split into chunka
(using RecursiveCharacterTextSplitter)
for larger documents.
Raises:
KeyError: If the required 'document' key is missing from the state.
"""
print("---PARSING TEXT DOCUMENT---")
try:
document = state["document"]
except KeyError as e:
print(f"Error: {e} not found in state.")
raise
# ... (Add logic for parsing with BeautifulSoup based on 'tags' if present)
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
chunk_size=4000,
chunk_overlap=0,
)
state["document_chunks"] = text_splitter.split_text(document)
return state

View File

@ -2,11 +2,10 @@
Module for parsing the HTML node
"""
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.docstore.document import Document
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import EmbeddingsFilter, DocumentCompressorPipeline
from langchain_community.document_transformers import Html2TextTransformer, EmbeddingsRedundantFilter
from langchain_community.document_transformers import EmbeddingsRedundantFilter
from langchain_community.vectorstores import FAISS
from langchain_openai import OpenAIEmbeddings
@ -16,12 +15,10 @@ from .base_node import BaseNode
class RAGNode(BaseNode):
"""
A node responsible for parsing HTML content from a document using specified tags.
It uses BeautifulSoupTransformer for parsing, providing flexibility in extracting
specific parts of an HTML document based on the tags provided in the state.
A node responsible for compressing the input tokens and storing the document
in a vector database for retrieval.
This node enhances the scraping workflow by allowing for targeted extraction of
content, thereby optimizing the processing of large HTML documents.
It allows scraping of big documents without exceeding the token limit of the language model.
Attributes:
node_name (str): The unique identifier name for the node, defaulting to "ParseHTMLNode".
@ -36,7 +33,7 @@ class RAGNode(BaseNode):
the specified tags, if provided, and updates the state with the parsed content.
"""
def __init__(self, llm, node_name="TestRagNode"):
def __init__(self, llm, node_name="RagNode"):
"""
Initializes the ParseHTMLNode with a node name.
"""
@ -45,25 +42,21 @@ class RAGNode(BaseNode):
def execute(self, state):
"""
Executes the node's logic to parse the HTML document based on specified tags.
If tags are provided in the state, the document is parsed accordingly; otherwise,
the document remains unchanged. The method updates the state with either the original
or parsed document under the 'parsed_document' key.
Executes the node's logic to implement RAG (Retrieval-Augmented Generation)
The method updates the state with relevant chunks of the document.
Args:
state (dict): The current state of the graph, expected to contain
'document' within 'keys', and optionally 'tags' for targeted parsing.
state (dict): The state containing the 'document' key with the HTML content
Returns:
dict: The updated state with the 'parsed_document' key containing the parsed content,
if tags were provided, or the original document otherwise.
dict: The updated state containing the 'relevant_chunks' key with the relevant chunks.
Raises:
KeyError: If 'document' is not found in the state, indicating that the necessary
information for parsing is missing.
"""
print("---PARSING HTML DOCUMENT---")
print("---RAG STARTED---")
try:
user_input = state["user_input"]
document = state["document"]
@ -71,15 +64,14 @@ class RAGNode(BaseNode):
print(f"Error: {e} not found in state.")
raise
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
chunk_size=4000,
chunk_overlap=0,
)
parsed_document = state.get("parsed_document", None)
docs_transformed = Html2TextTransformer(
).transform_documents(document)[0]
if parsed_document:
chunks = parsed_document
else:
print("Parsed document not found. Using original document.")
chunks = document
chunks = text_splitter.split_text(docs_transformed.page_content)
chunked_docs = []
for i, chunk in enumerate(chunks):
@ -91,6 +83,8 @@ class RAGNode(BaseNode):
)
chunked_docs.append(doc)
print("---UPDATED CHUNKS METADATA---")
openai_key = self.llm.openai_api_key
retriever = FAISS.from_documents(chunked_docs,
OpenAIEmbeddings(api_key=openai_key)).as_retriever()
@ -102,13 +96,21 @@ class RAGNode(BaseNode):
pipeline_compressor = DocumentCompressorPipeline(
transformers=[redundant_filter, relevant_filter]
)
# redundant + relevant filter compressor
compression_retriever = ContextualCompressionRetriever(
base_compressor=pipeline_compressor, base_retriever=retriever
)
# relevant filter compressor only
# compression_retriever = ContextualCompressionRetriever(
# base_compressor=relevant_filter, base_retriever=retriever
# )
compressed_docs = compression_retriever.get_relevant_documents(
user_input)
print("Documents compressed and stored in a vector database.")
state.update({"document_chunks": compressed_docs})
print("---TOKENS COMPRESSED AND VECTOR STORED---")
state.update({"relevant_chunks": compressed_docs})
return state

View File

@ -29,4 +29,4 @@ def remover(html_content: str) -> str:
body_content = soup.find('body')
body = str(body_content) if body_content else ""
return title + body
return "Title: " + title + ", Body: " + body