mirror of
https://github.com/VinciGit00/Scrapegraph-ai.git
synced 2026-06-23 21:00:30 +08:00
fix(fetch_node): bug in handling local files
This commit is contained in:
parent
fcb3abb01d
commit
a6e1813ddd
@ -1,113 +0,0 @@
|
||||
"""
|
||||
Example of custom graph using existing nodes
|
||||
"""
|
||||
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
from scrapegraphai.models import OpenAI, OpenAIImageToText
|
||||
from scrapegraphai.graphs import BaseGraph
|
||||
from scrapegraphai.nodes import FetchNode, ParseNode, ImageToTextNode, RAGNode, GenerateAnswerOmniNode
|
||||
load_dotenv()
|
||||
|
||||
# ************************************************
|
||||
# Define the configuration for the graph
|
||||
# ************************************************
|
||||
|
||||
openai_key = os.getenv("OPENAI_APIKEY")
|
||||
|
||||
graph_config = {
|
||||
"llm": {
|
||||
"api_key": openai_key,
|
||||
"model": "gpt-4o",
|
||||
"temperature": 0,
|
||||
"streaming": False
|
||||
},
|
||||
}
|
||||
|
||||
# ************************************************
|
||||
# Define the graph nodes
|
||||
# ************************************************
|
||||
|
||||
llm_model = OpenAI(graph_config["llm"])
|
||||
iit_model = OpenAIImageToText(graph_config["llm"])
|
||||
embedder = OpenAIEmbeddings(api_key=llm_model.openai_api_key)
|
||||
|
||||
# define the nodes for the graph
|
||||
|
||||
fetch_node = FetchNode(
|
||||
input="url | local_dir",
|
||||
output=["doc", "link_urls", "img_urls"],
|
||||
node_config={
|
||||
"verbose": True,
|
||||
"headless": True,
|
||||
}
|
||||
)
|
||||
parse_node = ParseNode(
|
||||
input="doc",
|
||||
output=["parsed_doc"],
|
||||
node_config={
|
||||
"chunk_size": 4096,
|
||||
"verbose": True,
|
||||
}
|
||||
)
|
||||
image_to_text_node = ImageToTextNode(
|
||||
input="img_urls",
|
||||
output=["img_desc"],
|
||||
node_config={
|
||||
"llm_model": iit_model,
|
||||
"max_images": 4,
|
||||
}
|
||||
)
|
||||
rag_node = RAGNode(
|
||||
input="user_prompt & (parsed_doc | doc)",
|
||||
output=["relevant_chunks"],
|
||||
node_config={
|
||||
"llm_model": llm_model,
|
||||
"embedder_model": embedder,
|
||||
"verbose": True,
|
||||
}
|
||||
)
|
||||
generate_answer_omni_node = GenerateAnswerOmniNode(
|
||||
input="user_prompt & (relevant_chunks | parsed_doc | doc) & img_desc",
|
||||
output=["answer"],
|
||||
node_config={
|
||||
"llm_model": llm_model,
|
||||
"verbose": True,
|
||||
}
|
||||
)
|
||||
|
||||
# ************************************************
|
||||
# Create the graph by defining the connections
|
||||
# ************************************************
|
||||
|
||||
graph = BaseGraph(
|
||||
nodes=[
|
||||
fetch_node,
|
||||
parse_node,
|
||||
image_to_text_node,
|
||||
rag_node,
|
||||
generate_answer_omni_node,
|
||||
],
|
||||
edges=[
|
||||
(fetch_node, parse_node),
|
||||
(parse_node, image_to_text_node),
|
||||
(image_to_text_node, rag_node),
|
||||
(rag_node, generate_answer_omni_node)
|
||||
],
|
||||
entry_point=fetch_node
|
||||
)
|
||||
|
||||
# ************************************************
|
||||
# Execute the graph
|
||||
# ************************************************
|
||||
|
||||
result, execution_info = graph.execute({
|
||||
"user_prompt": "List me all the projects with their titles and image links and descriptions.",
|
||||
"url": "https://perinim.github.io/projects/"
|
||||
})
|
||||
|
||||
# get the answer from the result
|
||||
result = result.get("answer", "No answer found.")
|
||||
print(result)
|
||||
@ -19,7 +19,7 @@ openai_key = os.getenv("OPENAI_APIKEY")
|
||||
graph_config = {
|
||||
"llm": {
|
||||
"api_key": openai_key,
|
||||
"model": "gpt-4o",
|
||||
"model": "gpt-4-turbo",
|
||||
},
|
||||
"verbose": True,
|
||||
"headless": True,
|
||||
|
||||
@ -30,8 +30,8 @@ class CSVScraperGraph(AbstractGraph):
|
||||
Creates the graph of nodes representing the workflow for web scraping.
|
||||
"""
|
||||
fetch_node = FetchNode(
|
||||
input="csv",
|
||||
output=["doc"],
|
||||
input="csv | csv_dir",
|
||||
output=["doc", "link_urls", "img_urls"],
|
||||
)
|
||||
parse_node = ParseNode(
|
||||
input="doc",
|
||||
|
||||
@ -61,7 +61,7 @@ class DeepScraperGraph(AbstractGraph):
|
||||
"""
|
||||
fetch_node = FetchNode(
|
||||
input="url | local_dir",
|
||||
output=["doc"]
|
||||
output=["doc", "link_urls", "img_urls"]
|
||||
)
|
||||
parse_node = ParseNode(
|
||||
input="doc",
|
||||
|
||||
@ -54,8 +54,8 @@ class JSONScraperGraph(AbstractGraph):
|
||||
"""
|
||||
|
||||
fetch_node = FetchNode(
|
||||
input="json",
|
||||
output=["doc"],
|
||||
input="json | json_dir",
|
||||
output=["doc", "link_urls", "img_urls"],
|
||||
)
|
||||
parse_node = ParseNode(
|
||||
input="doc",
|
||||
|
||||
@ -56,8 +56,8 @@ class PDFScraperGraph(AbstractGraph):
|
||||
"""
|
||||
|
||||
fetch_node = FetchNode(
|
||||
input='pdf',
|
||||
output=["doc"],
|
||||
input='pdf | pdf_dir',
|
||||
output=["doc", "link_urls", "img_urls"],
|
||||
)
|
||||
parse_node = ParseNode(
|
||||
input="doc",
|
||||
|
||||
@ -59,7 +59,7 @@ class ScriptCreatorGraph(AbstractGraph):
|
||||
|
||||
fetch_node = FetchNode(
|
||||
input="url | local_dir",
|
||||
output=["doc"],
|
||||
output=["doc", "link_urls", "img_urls"],
|
||||
)
|
||||
parse_node = ParseNode(
|
||||
input="doc",
|
||||
|
||||
@ -57,7 +57,7 @@ class SmartScraperGraph(AbstractGraph):
|
||||
"""
|
||||
fetch_node = FetchNode(
|
||||
input="url | local_dir",
|
||||
output=["doc"],
|
||||
output=["doc", "link_urls", "img_urls"],
|
||||
node_config={
|
||||
"loader_kwargs": self.config.get("loader_kwargs", {}),
|
||||
}
|
||||
|
||||
@ -56,7 +56,7 @@ class SpeechGraph(AbstractGraph):
|
||||
|
||||
fetch_node = FetchNode(
|
||||
input="url | local_dir",
|
||||
output=["doc"]
|
||||
output=["doc", "link_urls", "img_urls"]
|
||||
)
|
||||
parse_node = ParseNode(
|
||||
input="doc",
|
||||
|
||||
@ -56,8 +56,8 @@ class XMLScraperGraph(AbstractGraph):
|
||||
"""
|
||||
|
||||
fetch_node = FetchNode(
|
||||
input="xml",
|
||||
output=["doc"]
|
||||
input="xml | xml_dir",
|
||||
output=["doc", "link_urls", "img_urls"]
|
||||
)
|
||||
parse_node = ParseNode(
|
||||
input="doc",
|
||||
|
||||
@ -83,37 +83,49 @@ class FetchNode(BaseNode):
|
||||
|
||||
source = input_data[0]
|
||||
if (
|
||||
self.input == "json_dir"
|
||||
or self.input == "xml_dir"
|
||||
or self.input == "csv_dir"
|
||||
input_keys[0] == "json_dir"
|
||||
or input_keys[0] == "xml_dir"
|
||||
or input_keys[0] == "csv_dir"
|
||||
):
|
||||
compressed_document = [
|
||||
Document(page_content=source, metadata={"source": "local_dir"})
|
||||
]
|
||||
# if it is a local directory
|
||||
|
||||
state.update({self.output[0]: compressed_document})
|
||||
return state
|
||||
|
||||
# handling for pdf
|
||||
elif self.input == "pdf":
|
||||
elif input_keys[0] == "pdf":
|
||||
loader = PyPDFLoader(source)
|
||||
compressed_document = loader.load()
|
||||
state.update({self.output[0]: compressed_document})
|
||||
return state
|
||||
|
||||
elif self.input == "csv":
|
||||
elif input_keys[0] == "csv":
|
||||
compressed_document = [
|
||||
Document(
|
||||
page_content=str(pd.read_csv(source)), metadata={"source": "csv"}
|
||||
)
|
||||
]
|
||||
elif self.input == "json":
|
||||
state.update({self.output[0]: compressed_document})
|
||||
return state
|
||||
|
||||
elif input_keys[0] == "json":
|
||||
f = open(source)
|
||||
compressed_document = [
|
||||
Document(page_content=str(json.load(f)), metadata={"source": "json"})
|
||||
]
|
||||
elif self.input == "xml":
|
||||
state.update({self.output[0]: compressed_document})
|
||||
return state
|
||||
|
||||
elif input_keys[0] == "xml":
|
||||
with open(source, "r", encoding="utf-8") as f:
|
||||
data = f.read()
|
||||
compressed_document = [
|
||||
Document(page_content=data, metadata={"source": "xml"})
|
||||
]
|
||||
state.update({self.output[0]: compressed_document})
|
||||
return state
|
||||
|
||||
elif self.input == "pdf_dir":
|
||||
pass
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user