From 52934bf007e82a6d2ebf6ca54fac24698aad033e Mon Sep 17 00:00:00 2001 From: Perinim Date: Sun, 17 Mar 2024 20:35:04 +0100 Subject: [PATCH] implemented graph_config, fixed smart_scraper and speech graph --- .../graph_examples/smart_scraper_example.py | 26 +++++--- .../graph_examples/speech_graph_example.py | 38 ++++++++++++ .../speech_summary_graph_example.py | 27 -------- scrapegraphai/graphs/__init__.py | 2 +- scrapegraphai/graphs/smart_scraper_graph.py | 47 +++++++++----- ...peech_summary_graph.py => speech_graph.py} | 62 ++++++++++++------- scrapegraphai/models/openai_tts.py | 8 +-- scrapegraphai/nodes/fetch_node.py | 2 +- scrapegraphai/nodes/generate_answer_node.py | 2 +- scrapegraphai/nodes/get_probable_tags_node.py | 35 ++++++----- scrapegraphai/nodes/parse_node.py | 2 +- scrapegraphai/nodes/rag_node.py | 7 +-- scrapegraphai/nodes/text_to_speech_node.py | 16 +++-- 13 files changed, 158 insertions(+), 116 deletions(-) create mode 100644 examples/graph_examples/speech_graph_example.py delete mode 100644 examples/graph_examples/speech_summary_graph_example.py rename scrapegraphai/graphs/{speech_summary_graph.py => speech_graph.py} (69%) diff --git a/examples/graph_examples/smart_scraper_example.py b/examples/graph_examples/smart_scraper_example.py index 4ab0fa70..15a55a63 100644 --- a/examples/graph_examples/smart_scraper_example.py +++ b/examples/graph_examples/smart_scraper_example.py @@ -7,20 +7,26 @@ from dotenv import load_dotenv from scrapegraphai.graphs import SmartScraperGraph load_dotenv() - -# Define the configuration for the language model openai_key = os.getenv("OPENAI_APIKEY") -llm_config = { - "api_key": openai_key, - "model_name": "gpt-3.5-turbo", + +# Define the configuration for the graph +graph_config = { + "llm": { + "api_key": openai_key, + "model": "gpt-3.5-turbo", + }, + # "embedding_model": { + # "api_key": openai_key, + # "model": "gpt-3.5-turbo", + # }, } -# Define URL and PROMPT -URL = "https://www.ansa.it/veneto/" -PROMPT = "List me all the news with their description." - # Create the SmartScraperGraph instance -smart_scraper_graph = SmartScraperGraph(PROMPT, URL, llm_config) +smart_scraper_graph = SmartScraperGraph( + prompt = "List me all the news with their description.", + url = "https://www.ansa.it/veneto/", + config = graph_config +) answer = smart_scraper_graph.run() print(answer) diff --git a/examples/graph_examples/speech_graph_example.py b/examples/graph_examples/speech_graph_example.py new file mode 100644 index 00000000..41f2d6c8 --- /dev/null +++ b/examples/graph_examples/speech_graph_example.py @@ -0,0 +1,38 @@ +""" +Basic example of scraping pipeline using SpeechSummaryGraph +""" + +import os +from dotenv import load_dotenv +from scrapegraphai.graphs import SpeechGraph + +load_dotenv() +openai_key = os.getenv("OPENAI_APIKEY") + +# Save the audio to a file +file_name = "website_summary.mp3" +curr_dir = os.path.dirname(os.path.realpath(__file__)) +output_path = os.path.join(curr_dir, file_name) + +# Define the configuration for the graph +graph_config = { + "llm": { + "api_key": openai_key, + "model": "gpt-3.5-turbo", + }, + "tts_model": { + "api_key": openai_key, + "model": "tts-1", + "voice": "alloy" + }, + "output_path": output_path, +} + +speech_graph = SpeechGraph( + prompt = "List me all the projects and generate and audio for me to listen to.", + url = "https://perinim.github.io/projects/", + config = graph_config, +) + +final_state = speech_graph.run() +print(final_state.get("answer", "No answer found.")) diff --git a/examples/graph_examples/speech_summary_graph_example.py b/examples/graph_examples/speech_summary_graph_example.py deleted file mode 100644 index eff068ea..00000000 --- a/examples/graph_examples/speech_summary_graph_example.py +++ /dev/null @@ -1,27 +0,0 @@ -""" -Basic example of scraping pipeline using SpeechSummaryGraph -""" - -import os -from dotenv import load_dotenv -from scrapegraphai.graphs import SpeechSummaryGraph - -load_dotenv() - -# Define the configuration for the language model -openai_key = os.getenv("OPENAI_APIKEY") -llm_config = { - "api_key": openai_key, -} - -# Save the audio to a file -curr_dir = os.path.dirname(os.path.realpath(__file__)) -output_file_path = os.path.join(curr_dir, "website_summary.mp3") - -speech_summary_graph = SpeechSummaryGraph("""Make a summary of the news to be -converted to audio for blind people.""", - "https://www.wired.com/category/science/", llm_config, - output_file_path) - -final_state = speech_summary_graph.run() -print(final_state.get("answer", "No answer found.")) diff --git a/scrapegraphai/graphs/__init__.py b/scrapegraphai/graphs/__init__.py index fde11ca8..67792b28 100644 --- a/scrapegraphai/graphs/__init__.py +++ b/scrapegraphai/graphs/__init__.py @@ -3,4 +3,4 @@ __init__.py file for graphs folder """ from .base_graph import BaseGraph from .smart_scraper_graph import SmartScraperGraph -from .speech_summary_graph import SpeechSummaryGraph +from .speech_graph import SpeechGraph diff --git a/scrapegraphai/graphs/smart_scraper_graph.py b/scrapegraphai/graphs/smart_scraper_graph.py index fc25ad5a..a847b28b 100644 --- a/scrapegraphai/graphs/smart_scraper_graph.py +++ b/scrapegraphai/graphs/smart_scraper_graph.py @@ -4,7 +4,7 @@ Module for creating the smart scraper from ..models import OpenAI from .base_graph import BaseGraph from ..nodes import ( - FetchHTMLNode, + FetchNode, ParseNode, RAGNode, GenerateAnswerNode @@ -34,17 +34,17 @@ class SmartScraperGraph: 'temperature', and 'streaming'. """ - def __init__(self, prompt: str, url: str, llm_config: dict): + def __init__(self, prompt: str, url: str, config: dict): """ Initializes the SmartScraper with a prompt, URL, and language model configuration. """ self.prompt = prompt self.url = url - self.llm_config = llm_config - self.llm = self._create_llm() + self.config = config + self.llm_model = self._create_llm(config["llm"]) self.graph = self._create_graph() - def _create_llm(self): + def _create_llm(self, llm_config: dict): """ Creates an instance of the ChatOpenAI class with the provided language model configuration. @@ -55,12 +55,11 @@ class SmartScraperGraph: ValueError: If 'api_key' is not provided in llm_config. """ llm_defaults = { - "model_name": "gpt-3.5-turbo", "temperature": 0, "streaming": True } # Update defaults with any LLM parameters that were provided - llm_params = {**llm_defaults, **self.llm_config} + llm_params = {**llm_defaults, **llm_config} # Ensure the api_key is set, raise an error if it's not if "api_key" not in llm_params: raise ValueError("LLM configuration must include an 'api_key'.") @@ -75,24 +74,38 @@ class SmartScraperGraph: BaseGraph: An instance of the BaseGraph class. """ # define the nodes for the graph - fetch_html_node = FetchHTMLNode("fetch_html") - parse_document_node = ParseNode(doc_type="html", chunks_size=4000, node_name="parse_document") - rag_node = RAGNode(self.llm, "rag") - generate_answer_node = GenerateAnswerNode(self.llm, "generate_answer") + fetch_node = FetchNode( + input="url | local_dir", + output=["doc"], + ) + parse_node = ParseNode( + input="doc", + output=["parsed_doc"], + ) + rag_node = RAGNode( + input="user_prompt & (parsed_doc | doc)", + output=["relevant_chunks"], + model_config={"llm_model": self.llm_model}, + ) + generate_answer_node = GenerateAnswerNode( + input="user_prompt & (relevant_chunks | parsed_doc | doc)", + output=["answer"], + model_config={"llm_model": self.llm_model}, + ) return BaseGraph( nodes={ - fetch_html_node, - parse_document_node, + fetch_node, + parse_node, rag_node, generate_answer_node, }, edges={ - (fetch_html_node, parse_document_node), - (parse_document_node, rag_node), + (fetch_node, parse_node), + (parse_node, rag_node), (rag_node, generate_answer_node) }, - entry_point=fetch_html_node + entry_point=fetch_node ) def run(self) -> str: @@ -102,7 +115,7 @@ class SmartScraperGraph: Returns: str: The answer extracted from the web page, corresponding to the given prompt. """ - inputs = {"user_input": self.prompt, "url": self.url} + inputs = {"user_prompt": self.prompt, "url": self.url} final_state = self.graph.execute(inputs) return final_state.get("answer", "No answer found.") diff --git a/scrapegraphai/graphs/speech_summary_graph.py b/scrapegraphai/graphs/speech_graph.py similarity index 69% rename from scrapegraphai/graphs/speech_summary_graph.py rename to scrapegraphai/graphs/speech_graph.py index a23af88f..b97d18f5 100644 --- a/scrapegraphai/graphs/speech_summary_graph.py +++ b/scrapegraphai/graphs/speech_graph.py @@ -5,7 +5,7 @@ from scrapegraphai.utils.save_audio_from_bytes import save_audio_from_bytes from ..models import OpenAI, OpenAITextToSpeech from .base_graph import BaseGraph from ..nodes import ( - FetchHTMLNode, + FetchNode, ParseNode, RAGNode, GenerateAnswerNode, @@ -13,7 +13,7 @@ from ..nodes import ( ) -class SpeechSummaryGraph: +class SpeechGraph: """ SpeechSummaryGraph is a tool that automates the process of extracting and summarizing information from web pages, then converting that summary into spoken word via an MP3 file. @@ -35,21 +35,18 @@ class SpeechSummaryGraph: output_path (str): The file path where the generated MP3 should be saved. """ - def __init__(self, prompt: str, url: str, llm_config: dict, - output_path: str = "website_summary.mp3"): + def __init__(self, prompt: str, url: str, config: dict): """ Initializes the SmartScraper with a prompt, URL, and language model configuration. """ - self.prompt = f"{prompt} - Save the summary in a key called 'summary'." + self.prompt = prompt self.url = url - self.llm_config = llm_config - self.llm = self._create_llm() - self.output_path = output_path - self.text_to_speech_model = OpenAITextToSpeech( - llm_config, model="tts-1", voice="alloy") + self.llm_model = self._create_llm(config["llm"]) + self.output_path = config.get("output_path", "output.mp3") + self.text_to_speech_model = OpenAITextToSpeech(config["tts_model"]) self.graph = self._create_graph() - def _create_llm(self): + def _create_llm(self, llm_config: dict): """ Creates an instance of the ChatOpenAI class with the provided language model configuration. @@ -60,12 +57,11 @@ class SpeechSummaryGraph: ValueError: If 'api_key' is not provided in llm_config. """ llm_defaults = { - "model_name": "gpt-3.5-turbo", "temperature": 0, "streaming": True } # Update defaults with any LLM parameters that were provided - llm_params = {**llm_defaults, **self.llm_config} + llm_params = {**llm_defaults, **llm_config} # Ensure the api_key is set, raise an error if it's not if "api_key" not in llm_params: raise ValueError("LLM configuration must include an 'api_key'.") @@ -79,28 +75,46 @@ class SpeechSummaryGraph: Returns: BaseGraph: An instance of the BaseGraph class. """ - fetch_html_node = FetchHTMLNode("fetch_html") - parse_document_node = ParseNode(doc_type="html", chunks_size=4000, node_name="parse_document") - rag_node = RAGNode(self.llm, "rag") - generate_answer_node = GenerateAnswerNode(self.llm, "generate_answer") + # define the nodes for the graph + fetch_node = FetchNode( + input="url | local_dir", + output=["doc"], + ) + parse_node = ParseNode( + input="doc", + output=["parsed_doc"], + ) + rag_node = RAGNode( + input="user_prompt & (parsed_doc | doc)", + output=["relevant_chunks"], + model_config={"llm_model": self.llm_model}, + ) + generate_answer_node = GenerateAnswerNode( + input="user_prompt & (relevant_chunks | parsed_doc | doc)", + output=["answer"], + model_config={"llm_model": self.llm_model}, + ) text_to_speech_node = TextToSpeechNode( - self.text_to_speech_model, "text_to_speech") + input="answer", + output=["audio"], + model_config={"tts_model": self.text_to_speech_model}, + ) return BaseGraph( nodes={ - fetch_html_node, - parse_document_node, + fetch_node, + parse_node, rag_node, generate_answer_node, text_to_speech_node }, edges={ - (fetch_html_node, parse_document_node), - (parse_document_node, rag_node), + (fetch_node, parse_node), + (parse_node, rag_node), (rag_node, generate_answer_node), (generate_answer_node, text_to_speech_node) }, - entry_point=fetch_html_node + entry_point=fetch_node ) def run(self) -> str: @@ -110,7 +124,7 @@ class SpeechSummaryGraph: Returns: str: The answer extracted from the web page, corresponding to the given prompt. """ - inputs = {"user_input": self.prompt, "url": self.url} + inputs = {"user_prompt": self.prompt, "url": self.url} final_state = self.graph.execute(inputs) audio = final_state.get("audio", None) diff --git a/scrapegraphai/models/openai_tts.py b/scrapegraphai/models/openai_tts.py index 7286a4ca..f2227f8c 100644 --- a/scrapegraphai/models/openai_tts.py +++ b/scrapegraphai/models/openai_tts.py @@ -22,7 +22,7 @@ class OpenAITextToSpeech: bytes of the generated speech. """ - def __init__(self, llm_config: dict, model: str = "tts-1", voice: str = "alloy"): + def __init__(self, tts_config: dict): """ Initializes an instance of the OpenAITextToSpeech class. @@ -35,9 +35,9 @@ class OpenAITextToSpeech: """ # convert model_name to model - self.client = OpenAI(api_key=llm_config.get("api_key")) - self.model = model - self.voice = voice + self.client = OpenAI(api_key=tts_config.get("api_key")) + self.model = tts_config.get("model", "tts-1") + self.voice = tts_config.get("voice", "alloy") def run(self, text): """ diff --git a/scrapegraphai/nodes/fetch_node.py b/scrapegraphai/nodes/fetch_node.py index e28e239d..8a539dd7 100644 --- a/scrapegraphai/nodes/fetch_node.py +++ b/scrapegraphai/nodes/fetch_node.py @@ -35,7 +35,7 @@ class FetchNode(BaseNode): to succeed. """ - def __init__(self, input: str, output: List[str], node_name: str = "FetchNode"): + def __init__(self, input: str, output: List[str], node_name: str = "Fetch"): """ Initializes the FetchHTMLNode with a node name and node type. Arguments: diff --git a/scrapegraphai/nodes/generate_answer_node.py b/scrapegraphai/nodes/generate_answer_node.py index c1f8d291..2caabe82 100644 --- a/scrapegraphai/nodes/generate_answer_node.py +++ b/scrapegraphai/nodes/generate_answer_node.py @@ -38,7 +38,7 @@ class GenerateAnswerNode(BaseNode): updating the state with the generated answer under the 'answer' key. """ - def __init__(self, input: str, output: List[str], model_config: dict, node_name: str = "GenerateAnswerNode"): + def __init__(self, input: str, output: List[str], model_config: dict, node_name: str = "GenerateAnswer"): """ Initializes the GenerateAnswerNode with a language model client and a node name. Args: diff --git a/scrapegraphai/nodes/get_probable_tags_node.py b/scrapegraphai/nodes/get_probable_tags_node.py index fd5eff67..66b3d37b 100644 --- a/scrapegraphai/nodes/get_probable_tags_node.py +++ b/scrapegraphai/nodes/get_probable_tags_node.py @@ -4,7 +4,7 @@ Module for proobable tags from langchain.output_parsers import CommaSeparatedListOutputParser from langchain.prompts import PromptTemplate from .base_node import BaseNode - +from typing import List class GetProbableTagsNode(BaseNode): """ @@ -29,17 +29,17 @@ class GetProbableTagsNode(BaseNode): probable HTML tags, updating the state with these tags under the 'tags' key. """ - def __init__(self, llm, node_name: str): + def __init__(self, input: str, output: List[str], model_config: dict, node_name: str = "GetProbableTags"): """ Initializes the GetProbableTagsNode with a language model client and a node name. Args: llm (OpenAIImageToText): An instance of the OpenAIImageToText class. node_name (str): name of the node """ - super().__init__(node_name, "node") - self.llm = llm + super().__init__(node_name, "node", input, output, 2, model_config) + self.llm_model = model_config["llm_model"] - def execute(self, state: dict): + def execute(self, state): """ Generates a list of probable HTML tags based on the user's input and updates the state with this list. The method constructs a prompt for the language model, submits it, and @@ -57,13 +57,16 @@ class GetProbableTagsNode(BaseNode): necessary information for generating tag predictions is missing. """ - print("---GETTING PROBABLE TAGS---") - try: - user_input = state["user_input"] - url = state["url"] - except KeyError as e: - print(f"Error: {e} not found in state.") - raise + print(f"--- Executing {self.node_name} Node ---") + + # Interpret input keys based on the provided input expression + input_keys = self.get_input_keys(state) + + # Fetching data from the state based on the input keys + input_data = [state[key] for key in input_keys] + + user_prompt = input_data[0] + url = input_data[1] output_parser = CommaSeparatedListOutputParser() format_instructions = output_parser.get_format_instructions() @@ -81,11 +84,9 @@ class GetProbableTagsNode(BaseNode): ) # Execute the chain to get probable tags - tag_answer = tag_prompt | self.llm | output_parser - probable_tags = tag_answer.invoke({"question": user_input}) - - print("Possible tags: ", *probable_tags) + tag_answer = tag_prompt | self.llm_model | output_parser + probable_tags = tag_answer.invoke({"question": user_prompt}) # Update the dictionary with probable tags - state.update({"tags": probable_tags}) + state.update({self.output[0]: probable_tags}) return state diff --git a/scrapegraphai/nodes/parse_node.py b/scrapegraphai/nodes/parse_node.py index 62e0f0f9..e2b417dc 100644 --- a/scrapegraphai/nodes/parse_node.py +++ b/scrapegraphai/nodes/parse_node.py @@ -29,7 +29,7 @@ class ParseNode(BaseNode): the specified tags, if provided, and updates the state with the parsed content. """ - def __init__(self, input: str, output: List[str], node_name: str = "ParseNode"): + def __init__(self, input: str, output: List[str], node_name: str = "Parse"): """ Initializes the ParseHTMLNode with a node name. Args: diff --git a/scrapegraphai/nodes/rag_node.py b/scrapegraphai/nodes/rag_node.py index b37684b9..9432d0b5 100644 --- a/scrapegraphai/nodes/rag_node.py +++ b/scrapegraphai/nodes/rag_node.py @@ -12,7 +12,6 @@ from typing import List from .base_node import BaseNode - class RAGNode(BaseNode): """ A node responsible for compressing the input tokens and storing the document @@ -33,7 +32,7 @@ class RAGNode(BaseNode): the specified tags, if provided, and updates the state with the parsed content. """ - def __init__(self, input: str, output: List[str], model_config: dict, node_name: str = "RAGNode"): + def __init__(self, input: str, output: List[str], model_config: dict, node_name: str = "RAG"): """ Initializes the ParseHTMLNode with a node name. """ @@ -78,7 +77,7 @@ class RAGNode(BaseNode): ) chunked_docs.append(doc) - print("---UPDATED CHUNKS METADATA---") + print("--- (updated chunks metadata) ---") openai_key = self.llm_model.openai_api_key retriever = FAISS.from_documents(chunked_docs, @@ -105,7 +104,7 @@ class RAGNode(BaseNode): compressed_docs = compression_retriever.get_relevant_documents( user_prompt) - print("---TOKENS COMPRESSED AND VECTOR STORED---") + print("--- (tokens compressed and vector stored) ---") state.update({self.output[0]: compressed_docs}) return state diff --git a/scrapegraphai/nodes/text_to_speech_node.py b/scrapegraphai/nodes/text_to_speech_node.py index 9ab79e9e..978548f4 100644 --- a/scrapegraphai/nodes/text_to_speech_node.py +++ b/scrapegraphai/nodes/text_to_speech_node.py @@ -17,12 +17,12 @@ class TextToSpeechNode(BaseNode): execute(state, text): Execute the node's logic and return the updated state. """ - def __init__(self, input: str, output: List[str], model_config: dict, node_name: str = "TextToSpeechNode"): + def __init__(self, input: str, output: List[str], model_config: dict, node_name: str = "TextToSpeech"): """ Initializes an instance of the TextToSpeechNode class. """ super().__init__(node_name, "node", input, output, 1, model_config) - self.text2speech_model = model_config["text2speech_model"] + self.tts_model = model_config["tts_model"] def execute(self, state): """ @@ -42,13 +42,11 @@ class TextToSpeechNode(BaseNode): # Fetching data from the state based on the input keys input_data = [state[key] for key in input_keys] - text2translate = input_data[0] - - # if not a string, raise an error - if not isinstance(text2translate, str): - raise ValueError("No text to translate to speech.") - print("---TRANSLATING TEXT TO SPEECH---") - audio = self.text2speech_model.run(text2translate["summary"]) + # get the text to translate + text2translate = str(next(iter(input_data[0].values()))) + # text2translate = str(input_data[0]) + + audio = self.tts_model.run(text2translate) state.update({self.output[0]: audio}) return state