implemented graph_config, fixed smart_scraper and speech graph

This commit is contained in:
Perinim 2024-03-17 20:35:04 +01:00
parent f27e0b4c5f
commit 52934bf007
13 changed files with 158 additions and 116 deletions

View File

@ -7,20 +7,26 @@ from dotenv import load_dotenv
from scrapegraphai.graphs import SmartScraperGraph
load_dotenv()
# Define the configuration for the language model
openai_key = os.getenv("OPENAI_APIKEY")
llm_config = {
"api_key": openai_key,
"model_name": "gpt-3.5-turbo",
# Define the configuration for the graph
graph_config = {
"llm": {
"api_key": openai_key,
"model": "gpt-3.5-turbo",
},
# "embedding_model": {
# "api_key": openai_key,
# "model": "gpt-3.5-turbo",
# },
}
# Define URL and PROMPT
URL = "https://www.ansa.it/veneto/"
PROMPT = "List me all the news with their description."
# Create the SmartScraperGraph instance
smart_scraper_graph = SmartScraperGraph(PROMPT, URL, llm_config)
smart_scraper_graph = SmartScraperGraph(
prompt = "List me all the news with their description.",
url = "https://www.ansa.it/veneto/",
config = graph_config
)
answer = smart_scraper_graph.run()
print(answer)

View File

@ -0,0 +1,38 @@
"""
Basic example of scraping pipeline using SpeechSummaryGraph
"""
import os
from dotenv import load_dotenv
from scrapegraphai.graphs import SpeechGraph
load_dotenv()
openai_key = os.getenv("OPENAI_APIKEY")
# Save the audio to a file
file_name = "website_summary.mp3"
curr_dir = os.path.dirname(os.path.realpath(__file__))
output_path = os.path.join(curr_dir, file_name)
# Define the configuration for the graph
graph_config = {
"llm": {
"api_key": openai_key,
"model": "gpt-3.5-turbo",
},
"tts_model": {
"api_key": openai_key,
"model": "tts-1",
"voice": "alloy"
},
"output_path": output_path,
}
speech_graph = SpeechGraph(
prompt = "List me all the projects and generate and audio for me to listen to.",
url = "https://perinim.github.io/projects/",
config = graph_config,
)
final_state = speech_graph.run()
print(final_state.get("answer", "No answer found."))

View File

@ -1,27 +0,0 @@
"""
Basic example of scraping pipeline using SpeechSummaryGraph
"""
import os
from dotenv import load_dotenv
from scrapegraphai.graphs import SpeechSummaryGraph
load_dotenv()
# Define the configuration for the language model
openai_key = os.getenv("OPENAI_APIKEY")
llm_config = {
"api_key": openai_key,
}
# Save the audio to a file
curr_dir = os.path.dirname(os.path.realpath(__file__))
output_file_path = os.path.join(curr_dir, "website_summary.mp3")
speech_summary_graph = SpeechSummaryGraph("""Make a summary of the news to be
converted to audio for blind people.""",
"https://www.wired.com/category/science/", llm_config,
output_file_path)
final_state = speech_summary_graph.run()
print(final_state.get("answer", "No answer found."))

View File

@ -3,4 +3,4 @@ __init__.py file for graphs folder
"""
from .base_graph import BaseGraph
from .smart_scraper_graph import SmartScraperGraph
from .speech_summary_graph import SpeechSummaryGraph
from .speech_graph import SpeechGraph

View File

@ -4,7 +4,7 @@ Module for creating the smart scraper
from ..models import OpenAI
from .base_graph import BaseGraph
from ..nodes import (
FetchHTMLNode,
FetchNode,
ParseNode,
RAGNode,
GenerateAnswerNode
@ -34,17 +34,17 @@ class SmartScraperGraph:
'temperature', and 'streaming'.
"""
def __init__(self, prompt: str, url: str, llm_config: dict):
def __init__(self, prompt: str, url: str, config: dict):
"""
Initializes the SmartScraper with a prompt, URL, and language model configuration.
"""
self.prompt = prompt
self.url = url
self.llm_config = llm_config
self.llm = self._create_llm()
self.config = config
self.llm_model = self._create_llm(config["llm"])
self.graph = self._create_graph()
def _create_llm(self):
def _create_llm(self, llm_config: dict):
"""
Creates an instance of the ChatOpenAI class with the provided language model configuration.
@ -55,12 +55,11 @@ class SmartScraperGraph:
ValueError: If 'api_key' is not provided in llm_config.
"""
llm_defaults = {
"model_name": "gpt-3.5-turbo",
"temperature": 0,
"streaming": True
}
# Update defaults with any LLM parameters that were provided
llm_params = {**llm_defaults, **self.llm_config}
llm_params = {**llm_defaults, **llm_config}
# Ensure the api_key is set, raise an error if it's not
if "api_key" not in llm_params:
raise ValueError("LLM configuration must include an 'api_key'.")
@ -75,24 +74,38 @@ class SmartScraperGraph:
BaseGraph: An instance of the BaseGraph class.
"""
# define the nodes for the graph
fetch_html_node = FetchHTMLNode("fetch_html")
parse_document_node = ParseNode(doc_type="html", chunks_size=4000, node_name="parse_document")
rag_node = RAGNode(self.llm, "rag")
generate_answer_node = GenerateAnswerNode(self.llm, "generate_answer")
fetch_node = FetchNode(
input="url | local_dir",
output=["doc"],
)
parse_node = ParseNode(
input="doc",
output=["parsed_doc"],
)
rag_node = RAGNode(
input="user_prompt & (parsed_doc | doc)",
output=["relevant_chunks"],
model_config={"llm_model": self.llm_model},
)
generate_answer_node = GenerateAnswerNode(
input="user_prompt & (relevant_chunks | parsed_doc | doc)",
output=["answer"],
model_config={"llm_model": self.llm_model},
)
return BaseGraph(
nodes={
fetch_html_node,
parse_document_node,
fetch_node,
parse_node,
rag_node,
generate_answer_node,
},
edges={
(fetch_html_node, parse_document_node),
(parse_document_node, rag_node),
(fetch_node, parse_node),
(parse_node, rag_node),
(rag_node, generate_answer_node)
},
entry_point=fetch_html_node
entry_point=fetch_node
)
def run(self) -> str:
@ -102,7 +115,7 @@ class SmartScraperGraph:
Returns:
str: The answer extracted from the web page, corresponding to the given prompt.
"""
inputs = {"user_input": self.prompt, "url": self.url}
inputs = {"user_prompt": self.prompt, "url": self.url}
final_state = self.graph.execute(inputs)
return final_state.get("answer", "No answer found.")

View File

@ -5,7 +5,7 @@ from scrapegraphai.utils.save_audio_from_bytes import save_audio_from_bytes
from ..models import OpenAI, OpenAITextToSpeech
from .base_graph import BaseGraph
from ..nodes import (
FetchHTMLNode,
FetchNode,
ParseNode,
RAGNode,
GenerateAnswerNode,
@ -13,7 +13,7 @@ from ..nodes import (
)
class SpeechSummaryGraph:
class SpeechGraph:
"""
SpeechSummaryGraph is a tool that automates the process of extracting and summarizing
information from web pages, then converting that summary into spoken word via an MP3 file.
@ -35,21 +35,18 @@ class SpeechSummaryGraph:
output_path (str): The file path where the generated MP3 should be saved.
"""
def __init__(self, prompt: str, url: str, llm_config: dict,
output_path: str = "website_summary.mp3"):
def __init__(self, prompt: str, url: str, config: dict):
"""
Initializes the SmartScraper with a prompt, URL, and language model configuration.
"""
self.prompt = f"{prompt} - Save the summary in a key called 'summary'."
self.prompt = prompt
self.url = url
self.llm_config = llm_config
self.llm = self._create_llm()
self.output_path = output_path
self.text_to_speech_model = OpenAITextToSpeech(
llm_config, model="tts-1", voice="alloy")
self.llm_model = self._create_llm(config["llm"])
self.output_path = config.get("output_path", "output.mp3")
self.text_to_speech_model = OpenAITextToSpeech(config["tts_model"])
self.graph = self._create_graph()
def _create_llm(self):
def _create_llm(self, llm_config: dict):
"""
Creates an instance of the ChatOpenAI class with the provided language model configuration.
@ -60,12 +57,11 @@ class SpeechSummaryGraph:
ValueError: If 'api_key' is not provided in llm_config.
"""
llm_defaults = {
"model_name": "gpt-3.5-turbo",
"temperature": 0,
"streaming": True
}
# Update defaults with any LLM parameters that were provided
llm_params = {**llm_defaults, **self.llm_config}
llm_params = {**llm_defaults, **llm_config}
# Ensure the api_key is set, raise an error if it's not
if "api_key" not in llm_params:
raise ValueError("LLM configuration must include an 'api_key'.")
@ -79,28 +75,46 @@ class SpeechSummaryGraph:
Returns:
BaseGraph: An instance of the BaseGraph class.
"""
fetch_html_node = FetchHTMLNode("fetch_html")
parse_document_node = ParseNode(doc_type="html", chunks_size=4000, node_name="parse_document")
rag_node = RAGNode(self.llm, "rag")
generate_answer_node = GenerateAnswerNode(self.llm, "generate_answer")
# define the nodes for the graph
fetch_node = FetchNode(
input="url | local_dir",
output=["doc"],
)
parse_node = ParseNode(
input="doc",
output=["parsed_doc"],
)
rag_node = RAGNode(
input="user_prompt & (parsed_doc | doc)",
output=["relevant_chunks"],
model_config={"llm_model": self.llm_model},
)
generate_answer_node = GenerateAnswerNode(
input="user_prompt & (relevant_chunks | parsed_doc | doc)",
output=["answer"],
model_config={"llm_model": self.llm_model},
)
text_to_speech_node = TextToSpeechNode(
self.text_to_speech_model, "text_to_speech")
input="answer",
output=["audio"],
model_config={"tts_model": self.text_to_speech_model},
)
return BaseGraph(
nodes={
fetch_html_node,
parse_document_node,
fetch_node,
parse_node,
rag_node,
generate_answer_node,
text_to_speech_node
},
edges={
(fetch_html_node, parse_document_node),
(parse_document_node, rag_node),
(fetch_node, parse_node),
(parse_node, rag_node),
(rag_node, generate_answer_node),
(generate_answer_node, text_to_speech_node)
},
entry_point=fetch_html_node
entry_point=fetch_node
)
def run(self) -> str:
@ -110,7 +124,7 @@ class SpeechSummaryGraph:
Returns:
str: The answer extracted from the web page, corresponding to the given prompt.
"""
inputs = {"user_input": self.prompt, "url": self.url}
inputs = {"user_prompt": self.prompt, "url": self.url}
final_state = self.graph.execute(inputs)
audio = final_state.get("audio", None)

View File

@ -22,7 +22,7 @@ class OpenAITextToSpeech:
bytes of the generated speech.
"""
def __init__(self, llm_config: dict, model: str = "tts-1", voice: str = "alloy"):
def __init__(self, tts_config: dict):
"""
Initializes an instance of the OpenAITextToSpeech class.
@ -35,9 +35,9 @@ class OpenAITextToSpeech:
"""
# convert model_name to model
self.client = OpenAI(api_key=llm_config.get("api_key"))
self.model = model
self.voice = voice
self.client = OpenAI(api_key=tts_config.get("api_key"))
self.model = tts_config.get("model", "tts-1")
self.voice = tts_config.get("voice", "alloy")
def run(self, text):
"""

View File

@ -35,7 +35,7 @@ class FetchNode(BaseNode):
to succeed.
"""
def __init__(self, input: str, output: List[str], node_name: str = "FetchNode"):
def __init__(self, input: str, output: List[str], node_name: str = "Fetch"):
"""
Initializes the FetchHTMLNode with a node name and node type.
Arguments:

View File

@ -38,7 +38,7 @@ class GenerateAnswerNode(BaseNode):
updating the state with the generated answer under the 'answer' key.
"""
def __init__(self, input: str, output: List[str], model_config: dict, node_name: str = "GenerateAnswerNode"):
def __init__(self, input: str, output: List[str], model_config: dict, node_name: str = "GenerateAnswer"):
"""
Initializes the GenerateAnswerNode with a language model client and a node name.
Args:

View File

@ -4,7 +4,7 @@ Module for proobable tags
from langchain.output_parsers import CommaSeparatedListOutputParser
from langchain.prompts import PromptTemplate
from .base_node import BaseNode
from typing import List
class GetProbableTagsNode(BaseNode):
"""
@ -29,17 +29,17 @@ class GetProbableTagsNode(BaseNode):
probable HTML tags, updating the state with these tags under the 'tags' key.
"""
def __init__(self, llm, node_name: str):
def __init__(self, input: str, output: List[str], model_config: dict, node_name: str = "GetProbableTags"):
"""
Initializes the GetProbableTagsNode with a language model client and a node name.
Args:
llm (OpenAIImageToText): An instance of the OpenAIImageToText class.
node_name (str): name of the node
"""
super().__init__(node_name, "node")
self.llm = llm
super().__init__(node_name, "node", input, output, 2, model_config)
self.llm_model = model_config["llm_model"]
def execute(self, state: dict):
def execute(self, state):
"""
Generates a list of probable HTML tags based on the user's input and updates the state
with this list. The method constructs a prompt for the language model, submits it, and
@ -57,13 +57,16 @@ class GetProbableTagsNode(BaseNode):
necessary information for generating tag predictions is missing.
"""
print("---GETTING PROBABLE TAGS---")
try:
user_input = state["user_input"]
url = state["url"]
except KeyError as e:
print(f"Error: {e} not found in state.")
raise
print(f"--- Executing {self.node_name} Node ---")
# Interpret input keys based on the provided input expression
input_keys = self.get_input_keys(state)
# Fetching data from the state based on the input keys
input_data = [state[key] for key in input_keys]
user_prompt = input_data[0]
url = input_data[1]
output_parser = CommaSeparatedListOutputParser()
format_instructions = output_parser.get_format_instructions()
@ -81,11 +84,9 @@ class GetProbableTagsNode(BaseNode):
)
# Execute the chain to get probable tags
tag_answer = tag_prompt | self.llm | output_parser
probable_tags = tag_answer.invoke({"question": user_input})
print("Possible tags: ", *probable_tags)
tag_answer = tag_prompt | self.llm_model | output_parser
probable_tags = tag_answer.invoke({"question": user_prompt})
# Update the dictionary with probable tags
state.update({"tags": probable_tags})
state.update({self.output[0]: probable_tags})
return state

View File

@ -29,7 +29,7 @@ class ParseNode(BaseNode):
the specified tags, if provided, and updates the state with the parsed content.
"""
def __init__(self, input: str, output: List[str], node_name: str = "ParseNode"):
def __init__(self, input: str, output: List[str], node_name: str = "Parse"):
"""
Initializes the ParseHTMLNode with a node name.
Args:

View File

@ -12,7 +12,6 @@ from typing import List
from .base_node import BaseNode
class RAGNode(BaseNode):
"""
A node responsible for compressing the input tokens and storing the document
@ -33,7 +32,7 @@ class RAGNode(BaseNode):
the specified tags, if provided, and updates the state with the parsed content.
"""
def __init__(self, input: str, output: List[str], model_config: dict, node_name: str = "RAGNode"):
def __init__(self, input: str, output: List[str], model_config: dict, node_name: str = "RAG"):
"""
Initializes the ParseHTMLNode with a node name.
"""
@ -78,7 +77,7 @@ class RAGNode(BaseNode):
)
chunked_docs.append(doc)
print("---UPDATED CHUNKS METADATA---")
print("--- (updated chunks metadata) ---")
openai_key = self.llm_model.openai_api_key
retriever = FAISS.from_documents(chunked_docs,
@ -105,7 +104,7 @@ class RAGNode(BaseNode):
compressed_docs = compression_retriever.get_relevant_documents(
user_prompt)
print("---TOKENS COMPRESSED AND VECTOR STORED---")
print("--- (tokens compressed and vector stored) ---")
state.update({self.output[0]: compressed_docs})
return state

View File

@ -17,12 +17,12 @@ class TextToSpeechNode(BaseNode):
execute(state, text): Execute the node's logic and return the updated state.
"""
def __init__(self, input: str, output: List[str], model_config: dict, node_name: str = "TextToSpeechNode"):
def __init__(self, input: str, output: List[str], model_config: dict, node_name: str = "TextToSpeech"):
"""
Initializes an instance of the TextToSpeechNode class.
"""
super().__init__(node_name, "node", input, output, 1, model_config)
self.text2speech_model = model_config["text2speech_model"]
self.tts_model = model_config["tts_model"]
def execute(self, state):
"""
@ -42,13 +42,11 @@ class TextToSpeechNode(BaseNode):
# Fetching data from the state based on the input keys
input_data = [state[key] for key in input_keys]
text2translate = input_data[0]
# if not a string, raise an error
if not isinstance(text2translate, str):
raise ValueError("No text to translate to speech.")
print("---TRANSLATING TEXT TO SPEECH---")
audio = self.text2speech_model.run(text2translate["summary"])
# get the text to translate
text2translate = str(next(iter(input_data[0].values())))
# text2translate = str(input_data[0])
audio = self.tts_model.run(text2translate)
state.update({self.output[0]: audio})
return state