mirror of
https://github.com/VinciGit00/Scrapegraph-ai.git
synced 2026-06-23 21:00:30 +08:00
implemented graph_config, fixed smart_scraper and speech graph
This commit is contained in:
parent
f27e0b4c5f
commit
52934bf007
@ -7,20 +7,26 @@ from dotenv import load_dotenv
|
||||
from scrapegraphai.graphs import SmartScraperGraph
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# Define the configuration for the language model
|
||||
openai_key = os.getenv("OPENAI_APIKEY")
|
||||
llm_config = {
|
||||
"api_key": openai_key,
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
|
||||
# Define the configuration for the graph
|
||||
graph_config = {
|
||||
"llm": {
|
||||
"api_key": openai_key,
|
||||
"model": "gpt-3.5-turbo",
|
||||
},
|
||||
# "embedding_model": {
|
||||
# "api_key": openai_key,
|
||||
# "model": "gpt-3.5-turbo",
|
||||
# },
|
||||
}
|
||||
|
||||
# Define URL and PROMPT
|
||||
URL = "https://www.ansa.it/veneto/"
|
||||
PROMPT = "List me all the news with their description."
|
||||
|
||||
# Create the SmartScraperGraph instance
|
||||
smart_scraper_graph = SmartScraperGraph(PROMPT, URL, llm_config)
|
||||
smart_scraper_graph = SmartScraperGraph(
|
||||
prompt = "List me all the news with their description.",
|
||||
url = "https://www.ansa.it/veneto/",
|
||||
config = graph_config
|
||||
)
|
||||
|
||||
answer = smart_scraper_graph.run()
|
||||
print(answer)
|
||||
|
||||
38
examples/graph_examples/speech_graph_example.py
Normal file
38
examples/graph_examples/speech_graph_example.py
Normal file
@ -0,0 +1,38 @@
|
||||
"""
|
||||
Basic example of scraping pipeline using SpeechSummaryGraph
|
||||
"""
|
||||
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
from scrapegraphai.graphs import SpeechGraph
|
||||
|
||||
load_dotenv()
|
||||
openai_key = os.getenv("OPENAI_APIKEY")
|
||||
|
||||
# Save the audio to a file
|
||||
file_name = "website_summary.mp3"
|
||||
curr_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
output_path = os.path.join(curr_dir, file_name)
|
||||
|
||||
# Define the configuration for the graph
|
||||
graph_config = {
|
||||
"llm": {
|
||||
"api_key": openai_key,
|
||||
"model": "gpt-3.5-turbo",
|
||||
},
|
||||
"tts_model": {
|
||||
"api_key": openai_key,
|
||||
"model": "tts-1",
|
||||
"voice": "alloy"
|
||||
},
|
||||
"output_path": output_path,
|
||||
}
|
||||
|
||||
speech_graph = SpeechGraph(
|
||||
prompt = "List me all the projects and generate and audio for me to listen to.",
|
||||
url = "https://perinim.github.io/projects/",
|
||||
config = graph_config,
|
||||
)
|
||||
|
||||
final_state = speech_graph.run()
|
||||
print(final_state.get("answer", "No answer found."))
|
||||
@ -1,27 +0,0 @@
|
||||
"""
|
||||
Basic example of scraping pipeline using SpeechSummaryGraph
|
||||
"""
|
||||
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
from scrapegraphai.graphs import SpeechSummaryGraph
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# Define the configuration for the language model
|
||||
openai_key = os.getenv("OPENAI_APIKEY")
|
||||
llm_config = {
|
||||
"api_key": openai_key,
|
||||
}
|
||||
|
||||
# Save the audio to a file
|
||||
curr_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
output_file_path = os.path.join(curr_dir, "website_summary.mp3")
|
||||
|
||||
speech_summary_graph = SpeechSummaryGraph("""Make a summary of the news to be
|
||||
converted to audio for blind people.""",
|
||||
"https://www.wired.com/category/science/", llm_config,
|
||||
output_file_path)
|
||||
|
||||
final_state = speech_summary_graph.run()
|
||||
print(final_state.get("answer", "No answer found."))
|
||||
@ -3,4 +3,4 @@ __init__.py file for graphs folder
|
||||
"""
|
||||
from .base_graph import BaseGraph
|
||||
from .smart_scraper_graph import SmartScraperGraph
|
||||
from .speech_summary_graph import SpeechSummaryGraph
|
||||
from .speech_graph import SpeechGraph
|
||||
|
||||
@ -4,7 +4,7 @@ Module for creating the smart scraper
|
||||
from ..models import OpenAI
|
||||
from .base_graph import BaseGraph
|
||||
from ..nodes import (
|
||||
FetchHTMLNode,
|
||||
FetchNode,
|
||||
ParseNode,
|
||||
RAGNode,
|
||||
GenerateAnswerNode
|
||||
@ -34,17 +34,17 @@ class SmartScraperGraph:
|
||||
'temperature', and 'streaming'.
|
||||
"""
|
||||
|
||||
def __init__(self, prompt: str, url: str, llm_config: dict):
|
||||
def __init__(self, prompt: str, url: str, config: dict):
|
||||
"""
|
||||
Initializes the SmartScraper with a prompt, URL, and language model configuration.
|
||||
"""
|
||||
self.prompt = prompt
|
||||
self.url = url
|
||||
self.llm_config = llm_config
|
||||
self.llm = self._create_llm()
|
||||
self.config = config
|
||||
self.llm_model = self._create_llm(config["llm"])
|
||||
self.graph = self._create_graph()
|
||||
|
||||
def _create_llm(self):
|
||||
def _create_llm(self, llm_config: dict):
|
||||
"""
|
||||
Creates an instance of the ChatOpenAI class with the provided language model configuration.
|
||||
|
||||
@ -55,12 +55,11 @@ class SmartScraperGraph:
|
||||
ValueError: If 'api_key' is not provided in llm_config.
|
||||
"""
|
||||
llm_defaults = {
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"temperature": 0,
|
||||
"streaming": True
|
||||
}
|
||||
# Update defaults with any LLM parameters that were provided
|
||||
llm_params = {**llm_defaults, **self.llm_config}
|
||||
llm_params = {**llm_defaults, **llm_config}
|
||||
# Ensure the api_key is set, raise an error if it's not
|
||||
if "api_key" not in llm_params:
|
||||
raise ValueError("LLM configuration must include an 'api_key'.")
|
||||
@ -75,24 +74,38 @@ class SmartScraperGraph:
|
||||
BaseGraph: An instance of the BaseGraph class.
|
||||
"""
|
||||
# define the nodes for the graph
|
||||
fetch_html_node = FetchHTMLNode("fetch_html")
|
||||
parse_document_node = ParseNode(doc_type="html", chunks_size=4000, node_name="parse_document")
|
||||
rag_node = RAGNode(self.llm, "rag")
|
||||
generate_answer_node = GenerateAnswerNode(self.llm, "generate_answer")
|
||||
fetch_node = FetchNode(
|
||||
input="url | local_dir",
|
||||
output=["doc"],
|
||||
)
|
||||
parse_node = ParseNode(
|
||||
input="doc",
|
||||
output=["parsed_doc"],
|
||||
)
|
||||
rag_node = RAGNode(
|
||||
input="user_prompt & (parsed_doc | doc)",
|
||||
output=["relevant_chunks"],
|
||||
model_config={"llm_model": self.llm_model},
|
||||
)
|
||||
generate_answer_node = GenerateAnswerNode(
|
||||
input="user_prompt & (relevant_chunks | parsed_doc | doc)",
|
||||
output=["answer"],
|
||||
model_config={"llm_model": self.llm_model},
|
||||
)
|
||||
|
||||
return BaseGraph(
|
||||
nodes={
|
||||
fetch_html_node,
|
||||
parse_document_node,
|
||||
fetch_node,
|
||||
parse_node,
|
||||
rag_node,
|
||||
generate_answer_node,
|
||||
},
|
||||
edges={
|
||||
(fetch_html_node, parse_document_node),
|
||||
(parse_document_node, rag_node),
|
||||
(fetch_node, parse_node),
|
||||
(parse_node, rag_node),
|
||||
(rag_node, generate_answer_node)
|
||||
},
|
||||
entry_point=fetch_html_node
|
||||
entry_point=fetch_node
|
||||
)
|
||||
|
||||
def run(self) -> str:
|
||||
@ -102,7 +115,7 @@ class SmartScraperGraph:
|
||||
Returns:
|
||||
str: The answer extracted from the web page, corresponding to the given prompt.
|
||||
"""
|
||||
inputs = {"user_input": self.prompt, "url": self.url}
|
||||
inputs = {"user_prompt": self.prompt, "url": self.url}
|
||||
final_state = self.graph.execute(inputs)
|
||||
|
||||
return final_state.get("answer", "No answer found.")
|
||||
|
||||
@ -5,7 +5,7 @@ from scrapegraphai.utils.save_audio_from_bytes import save_audio_from_bytes
|
||||
from ..models import OpenAI, OpenAITextToSpeech
|
||||
from .base_graph import BaseGraph
|
||||
from ..nodes import (
|
||||
FetchHTMLNode,
|
||||
FetchNode,
|
||||
ParseNode,
|
||||
RAGNode,
|
||||
GenerateAnswerNode,
|
||||
@ -13,7 +13,7 @@ from ..nodes import (
|
||||
)
|
||||
|
||||
|
||||
class SpeechSummaryGraph:
|
||||
class SpeechGraph:
|
||||
"""
|
||||
SpeechSummaryGraph is a tool that automates the process of extracting and summarizing
|
||||
information from web pages, then converting that summary into spoken word via an MP3 file.
|
||||
@ -35,21 +35,18 @@ class SpeechSummaryGraph:
|
||||
output_path (str): The file path where the generated MP3 should be saved.
|
||||
"""
|
||||
|
||||
def __init__(self, prompt: str, url: str, llm_config: dict,
|
||||
output_path: str = "website_summary.mp3"):
|
||||
def __init__(self, prompt: str, url: str, config: dict):
|
||||
"""
|
||||
Initializes the SmartScraper with a prompt, URL, and language model configuration.
|
||||
"""
|
||||
self.prompt = f"{prompt} - Save the summary in a key called 'summary'."
|
||||
self.prompt = prompt
|
||||
self.url = url
|
||||
self.llm_config = llm_config
|
||||
self.llm = self._create_llm()
|
||||
self.output_path = output_path
|
||||
self.text_to_speech_model = OpenAITextToSpeech(
|
||||
llm_config, model="tts-1", voice="alloy")
|
||||
self.llm_model = self._create_llm(config["llm"])
|
||||
self.output_path = config.get("output_path", "output.mp3")
|
||||
self.text_to_speech_model = OpenAITextToSpeech(config["tts_model"])
|
||||
self.graph = self._create_graph()
|
||||
|
||||
def _create_llm(self):
|
||||
def _create_llm(self, llm_config: dict):
|
||||
"""
|
||||
Creates an instance of the ChatOpenAI class with the provided language model configuration.
|
||||
|
||||
@ -60,12 +57,11 @@ class SpeechSummaryGraph:
|
||||
ValueError: If 'api_key' is not provided in llm_config.
|
||||
"""
|
||||
llm_defaults = {
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"temperature": 0,
|
||||
"streaming": True
|
||||
}
|
||||
# Update defaults with any LLM parameters that were provided
|
||||
llm_params = {**llm_defaults, **self.llm_config}
|
||||
llm_params = {**llm_defaults, **llm_config}
|
||||
# Ensure the api_key is set, raise an error if it's not
|
||||
if "api_key" not in llm_params:
|
||||
raise ValueError("LLM configuration must include an 'api_key'.")
|
||||
@ -79,28 +75,46 @@ class SpeechSummaryGraph:
|
||||
Returns:
|
||||
BaseGraph: An instance of the BaseGraph class.
|
||||
"""
|
||||
fetch_html_node = FetchHTMLNode("fetch_html")
|
||||
parse_document_node = ParseNode(doc_type="html", chunks_size=4000, node_name="parse_document")
|
||||
rag_node = RAGNode(self.llm, "rag")
|
||||
generate_answer_node = GenerateAnswerNode(self.llm, "generate_answer")
|
||||
# define the nodes for the graph
|
||||
fetch_node = FetchNode(
|
||||
input="url | local_dir",
|
||||
output=["doc"],
|
||||
)
|
||||
parse_node = ParseNode(
|
||||
input="doc",
|
||||
output=["parsed_doc"],
|
||||
)
|
||||
rag_node = RAGNode(
|
||||
input="user_prompt & (parsed_doc | doc)",
|
||||
output=["relevant_chunks"],
|
||||
model_config={"llm_model": self.llm_model},
|
||||
)
|
||||
generate_answer_node = GenerateAnswerNode(
|
||||
input="user_prompt & (relevant_chunks | parsed_doc | doc)",
|
||||
output=["answer"],
|
||||
model_config={"llm_model": self.llm_model},
|
||||
)
|
||||
text_to_speech_node = TextToSpeechNode(
|
||||
self.text_to_speech_model, "text_to_speech")
|
||||
input="answer",
|
||||
output=["audio"],
|
||||
model_config={"tts_model": self.text_to_speech_model},
|
||||
)
|
||||
|
||||
return BaseGraph(
|
||||
nodes={
|
||||
fetch_html_node,
|
||||
parse_document_node,
|
||||
fetch_node,
|
||||
parse_node,
|
||||
rag_node,
|
||||
generate_answer_node,
|
||||
text_to_speech_node
|
||||
},
|
||||
edges={
|
||||
(fetch_html_node, parse_document_node),
|
||||
(parse_document_node, rag_node),
|
||||
(fetch_node, parse_node),
|
||||
(parse_node, rag_node),
|
||||
(rag_node, generate_answer_node),
|
||||
(generate_answer_node, text_to_speech_node)
|
||||
},
|
||||
entry_point=fetch_html_node
|
||||
entry_point=fetch_node
|
||||
)
|
||||
|
||||
def run(self) -> str:
|
||||
@ -110,7 +124,7 @@ class SpeechSummaryGraph:
|
||||
Returns:
|
||||
str: The answer extracted from the web page, corresponding to the given prompt.
|
||||
"""
|
||||
inputs = {"user_input": self.prompt, "url": self.url}
|
||||
inputs = {"user_prompt": self.prompt, "url": self.url}
|
||||
final_state = self.graph.execute(inputs)
|
||||
|
||||
audio = final_state.get("audio", None)
|
||||
@ -22,7 +22,7 @@ class OpenAITextToSpeech:
|
||||
bytes of the generated speech.
|
||||
"""
|
||||
|
||||
def __init__(self, llm_config: dict, model: str = "tts-1", voice: str = "alloy"):
|
||||
def __init__(self, tts_config: dict):
|
||||
"""
|
||||
Initializes an instance of the OpenAITextToSpeech class.
|
||||
|
||||
@ -35,9 +35,9 @@ class OpenAITextToSpeech:
|
||||
"""
|
||||
|
||||
# convert model_name to model
|
||||
self.client = OpenAI(api_key=llm_config.get("api_key"))
|
||||
self.model = model
|
||||
self.voice = voice
|
||||
self.client = OpenAI(api_key=tts_config.get("api_key"))
|
||||
self.model = tts_config.get("model", "tts-1")
|
||||
self.voice = tts_config.get("voice", "alloy")
|
||||
|
||||
def run(self, text):
|
||||
"""
|
||||
|
||||
@ -35,7 +35,7 @@ class FetchNode(BaseNode):
|
||||
to succeed.
|
||||
"""
|
||||
|
||||
def __init__(self, input: str, output: List[str], node_name: str = "FetchNode"):
|
||||
def __init__(self, input: str, output: List[str], node_name: str = "Fetch"):
|
||||
"""
|
||||
Initializes the FetchHTMLNode with a node name and node type.
|
||||
Arguments:
|
||||
|
||||
@ -38,7 +38,7 @@ class GenerateAnswerNode(BaseNode):
|
||||
updating the state with the generated answer under the 'answer' key.
|
||||
"""
|
||||
|
||||
def __init__(self, input: str, output: List[str], model_config: dict, node_name: str = "GenerateAnswerNode"):
|
||||
def __init__(self, input: str, output: List[str], model_config: dict, node_name: str = "GenerateAnswer"):
|
||||
"""
|
||||
Initializes the GenerateAnswerNode with a language model client and a node name.
|
||||
Args:
|
||||
|
||||
@ -4,7 +4,7 @@ Module for proobable tags
|
||||
from langchain.output_parsers import CommaSeparatedListOutputParser
|
||||
from langchain.prompts import PromptTemplate
|
||||
from .base_node import BaseNode
|
||||
|
||||
from typing import List
|
||||
|
||||
class GetProbableTagsNode(BaseNode):
|
||||
"""
|
||||
@ -29,17 +29,17 @@ class GetProbableTagsNode(BaseNode):
|
||||
probable HTML tags, updating the state with these tags under the 'tags' key.
|
||||
"""
|
||||
|
||||
def __init__(self, llm, node_name: str):
|
||||
def __init__(self, input: str, output: List[str], model_config: dict, node_name: str = "GetProbableTags"):
|
||||
"""
|
||||
Initializes the GetProbableTagsNode with a language model client and a node name.
|
||||
Args:
|
||||
llm (OpenAIImageToText): An instance of the OpenAIImageToText class.
|
||||
node_name (str): name of the node
|
||||
"""
|
||||
super().__init__(node_name, "node")
|
||||
self.llm = llm
|
||||
super().__init__(node_name, "node", input, output, 2, model_config)
|
||||
self.llm_model = model_config["llm_model"]
|
||||
|
||||
def execute(self, state: dict):
|
||||
def execute(self, state):
|
||||
"""
|
||||
Generates a list of probable HTML tags based on the user's input and updates the state
|
||||
with this list. The method constructs a prompt for the language model, submits it, and
|
||||
@ -57,13 +57,16 @@ class GetProbableTagsNode(BaseNode):
|
||||
necessary information for generating tag predictions is missing.
|
||||
"""
|
||||
|
||||
print("---GETTING PROBABLE TAGS---")
|
||||
try:
|
||||
user_input = state["user_input"]
|
||||
url = state["url"]
|
||||
except KeyError as e:
|
||||
print(f"Error: {e} not found in state.")
|
||||
raise
|
||||
print(f"--- Executing {self.node_name} Node ---")
|
||||
|
||||
# Interpret input keys based on the provided input expression
|
||||
input_keys = self.get_input_keys(state)
|
||||
|
||||
# Fetching data from the state based on the input keys
|
||||
input_data = [state[key] for key in input_keys]
|
||||
|
||||
user_prompt = input_data[0]
|
||||
url = input_data[1]
|
||||
|
||||
output_parser = CommaSeparatedListOutputParser()
|
||||
format_instructions = output_parser.get_format_instructions()
|
||||
@ -81,11 +84,9 @@ class GetProbableTagsNode(BaseNode):
|
||||
)
|
||||
|
||||
# Execute the chain to get probable tags
|
||||
tag_answer = tag_prompt | self.llm | output_parser
|
||||
probable_tags = tag_answer.invoke({"question": user_input})
|
||||
|
||||
print("Possible tags: ", *probable_tags)
|
||||
tag_answer = tag_prompt | self.llm_model | output_parser
|
||||
probable_tags = tag_answer.invoke({"question": user_prompt})
|
||||
|
||||
# Update the dictionary with probable tags
|
||||
state.update({"tags": probable_tags})
|
||||
state.update({self.output[0]: probable_tags})
|
||||
return state
|
||||
|
||||
@ -29,7 +29,7 @@ class ParseNode(BaseNode):
|
||||
the specified tags, if provided, and updates the state with the parsed content.
|
||||
"""
|
||||
|
||||
def __init__(self, input: str, output: List[str], node_name: str = "ParseNode"):
|
||||
def __init__(self, input: str, output: List[str], node_name: str = "Parse"):
|
||||
"""
|
||||
Initializes the ParseHTMLNode with a node name.
|
||||
Args:
|
||||
|
||||
@ -12,7 +12,6 @@ from typing import List
|
||||
|
||||
from .base_node import BaseNode
|
||||
|
||||
|
||||
class RAGNode(BaseNode):
|
||||
"""
|
||||
A node responsible for compressing the input tokens and storing the document
|
||||
@ -33,7 +32,7 @@ class RAGNode(BaseNode):
|
||||
the specified tags, if provided, and updates the state with the parsed content.
|
||||
"""
|
||||
|
||||
def __init__(self, input: str, output: List[str], model_config: dict, node_name: str = "RAGNode"):
|
||||
def __init__(self, input: str, output: List[str], model_config: dict, node_name: str = "RAG"):
|
||||
"""
|
||||
Initializes the ParseHTMLNode with a node name.
|
||||
"""
|
||||
@ -78,7 +77,7 @@ class RAGNode(BaseNode):
|
||||
)
|
||||
chunked_docs.append(doc)
|
||||
|
||||
print("---UPDATED CHUNKS METADATA---")
|
||||
print("--- (updated chunks metadata) ---")
|
||||
|
||||
openai_key = self.llm_model.openai_api_key
|
||||
retriever = FAISS.from_documents(chunked_docs,
|
||||
@ -105,7 +104,7 @@ class RAGNode(BaseNode):
|
||||
compressed_docs = compression_retriever.get_relevant_documents(
|
||||
user_prompt)
|
||||
|
||||
print("---TOKENS COMPRESSED AND VECTOR STORED---")
|
||||
print("--- (tokens compressed and vector stored) ---")
|
||||
|
||||
state.update({self.output[0]: compressed_docs})
|
||||
return state
|
||||
|
||||
@ -17,12 +17,12 @@ class TextToSpeechNode(BaseNode):
|
||||
execute(state, text): Execute the node's logic and return the updated state.
|
||||
"""
|
||||
|
||||
def __init__(self, input: str, output: List[str], model_config: dict, node_name: str = "TextToSpeechNode"):
|
||||
def __init__(self, input: str, output: List[str], model_config: dict, node_name: str = "TextToSpeech"):
|
||||
"""
|
||||
Initializes an instance of the TextToSpeechNode class.
|
||||
"""
|
||||
super().__init__(node_name, "node", input, output, 1, model_config)
|
||||
self.text2speech_model = model_config["text2speech_model"]
|
||||
self.tts_model = model_config["tts_model"]
|
||||
|
||||
def execute(self, state):
|
||||
"""
|
||||
@ -42,13 +42,11 @@ class TextToSpeechNode(BaseNode):
|
||||
# Fetching data from the state based on the input keys
|
||||
input_data = [state[key] for key in input_keys]
|
||||
|
||||
text2translate = input_data[0]
|
||||
|
||||
# if not a string, raise an error
|
||||
if not isinstance(text2translate, str):
|
||||
raise ValueError("No text to translate to speech.")
|
||||
print("---TRANSLATING TEXT TO SPEECH---")
|
||||
audio = self.text2speech_model.run(text2translate["summary"])
|
||||
# get the text to translate
|
||||
text2translate = str(next(iter(input_data[0].values())))
|
||||
# text2translate = str(input_data[0])
|
||||
|
||||
audio = self.tts_model.run(text2translate)
|
||||
|
||||
state.update({self.output[0]: audio})
|
||||
return state
|
||||
|
||||
Loading…
Reference in New Issue
Block a user