diff --git a/.gitignore b/.gitignore index 26f73e8c..b09506ab 100644 --- a/.gitignore +++ b/.gitignore @@ -26,5 +26,6 @@ venv/ *.pdf *.mp3 *.sqlite +*.google-cookie examples/graph_examples/ScrapeGraphAI_generated_graph main.py diff --git a/examples/custom_graph_openai.py b/examples/custom_graph_openai.py index 4f982ace..48fbc9d6 100644 --- a/examples/custom_graph_openai.py +++ b/examples/custom_graph_openai.py @@ -7,7 +7,6 @@ from dotenv import load_dotenv from scrapegraphai.models import OpenAI from scrapegraphai.graphs import BaseGraph from scrapegraphai.nodes import FetchNode, ParseNode, RAGNode, GenerateAnswerNode -from scrapegraphai.utils import convert_to_csv, convert_to_json load_dotenv() openai_key = os.getenv("OPENAI_APIKEY") @@ -68,8 +67,4 @@ result = graph.execute({ # get the answer from the result result = result.get("answer", "No answer found.") -print(result) - -# Save to json and csv -convert_to_csv(result, "result") -convert_to_json(result, "result") +print(result) \ No newline at end of file diff --git a/examples/search_graph_example.py b/examples/search_graph_example.py new file mode 100644 index 00000000..5d3a2270 --- /dev/null +++ b/examples/search_graph_example.py @@ -0,0 +1,33 @@ +""" +Example of Search Graph +""" + +import os +from dotenv import load_dotenv +from scrapegraphai.graphs import SearchGraph +from scrapegraphai.utils import convert_to_csv, convert_to_json + +load_dotenv() +openai_key = os.getenv("OPENAI_APIKEY") + +# Define the configuration for the graph +graph_config = { + "llm": { + "api_key": openai_key, + "model": "gpt-3.5-turbo", + "temperature": 0, + }, +} + +# Create the SmartScraperGraph instance +smart_scraper_graph = SearchGraph( + prompt="List me all the regions of Italy.", + config=graph_config +) + +result = smart_scraper_graph.run() +print(result) + +# Save to json and csv +convert_to_csv(result, "result") +convert_to_json(result, "result") diff --git a/examples/smart_scraper_example.py b/examples/smart_scraper_example.py index cc9ba530..a8bc2d92 100644 --- a/examples/smart_scraper_example.py +++ b/examples/smart_scraper_example.py @@ -5,7 +5,6 @@ Basic example of scraping pipeline using SmartScraper import os from dotenv import load_dotenv from scrapegraphai.graphs import SmartScraperGraph -from scrapegraphai.utils import convert_to_csv, convert_to_json load_dotenv() openai_key = os.getenv("OPENAI_APIKEY") @@ -28,7 +27,3 @@ smart_scraper_graph = SmartScraperGraph( result = smart_scraper_graph.run() print(result) - -# Save to json and csv -convert_to_csv(result, "result") -convert_to_json(result, "result") diff --git a/examples/speech_graph_example.py b/examples/speech_graph_example.py index 6c2d9785..7afac269 100644 --- a/examples/speech_graph_example.py +++ b/examples/speech_graph_example.py @@ -5,7 +5,6 @@ Basic example of scraping pipeline using SpeechSummaryGraph import os from dotenv import load_dotenv from scrapegraphai.graphs import SpeechGraph -from scrapegraphai.utils import convert_to_csv, convert_to_json load_dotenv() openai_key = os.getenv("OPENAI_APIKEY") @@ -37,7 +36,3 @@ speech_graph = SpeechGraph( result = speech_graph.run() print(result.get("answer", "No answer found")) - -# Save to json and csv -convert_to_csv(result, "result") -convert_to_json(result, "result") diff --git a/pyproject.toml b/pyproject.toml index 401cdff3..a211f58c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ python-dotenv = "1.0.1" tiktoken = {version = ">=0.5.2,<0.6.0"} tqdm = "4.66.1" graphviz = "0.20.1" +google = "3.0.0" [tool.poetry.dev-dependencies] pytest = "8.0.0" diff --git a/scrapegraphai/graphs/__init__.py b/scrapegraphai/graphs/__init__.py index 67792b28..a7d2897b 100644 --- a/scrapegraphai/graphs/__init__.py +++ b/scrapegraphai/graphs/__init__.py @@ -4,3 +4,4 @@ __init__.py file for graphs folder from .base_graph import BaseGraph from .smart_scraper_graph import SmartScraperGraph from .speech_graph import SpeechGraph +from .search_graph import SearchGraph diff --git a/scrapegraphai/graphs/abstract_graph.py b/scrapegraphai/graphs/abstract_graph.py new file mode 100644 index 00000000..31c0bfce --- /dev/null +++ b/scrapegraphai/graphs/abstract_graph.py @@ -0,0 +1,43 @@ +""" +Module having abstract class for creating all the graphs +""" +from abc import ABC, abstractmethod +from typing import Optional + +class AbstractGraph(ABC): + """ + Abstract class representing a generic graph-based tool. + """ + + def __init__(self, prompt: str, config: dict, file_source: Optional[str] = "url"): + """ + Initializes the AbstractGraph with a prompt, file source, and configuration. + """ + self.prompt = prompt + self.file_source = file_source + self.input_key = "url" if file_source.startswith( + "http") else "local_dir" + self.config = config + self.llm_model = self._create_llm(config["llm"]) + self.graph = self._create_graph() + + @abstractmethod + def _create_llm(self, llm_config: dict): + """ + Abstract method to create a language model instance. + """ + pass + + @abstractmethod + def _create_graph(self): + """ + Abstract method to create a graph representation. + """ + pass + + @abstractmethod + def run(self) -> str: + """ + Abstract method to execute the graph and return the result. + """ + pass diff --git a/scrapegraphai/graphs/search_graph.py b/scrapegraphai/graphs/search_graph.py new file mode 100644 index 00000000..c5f3a6a5 --- /dev/null +++ b/scrapegraphai/graphs/search_graph.py @@ -0,0 +1,91 @@ +""" +Module for making the search on the intenet +""" +from ..models import OpenAI, Gemini +from .base_graph import BaseGraph +from ..nodes import ( + SearchInternetNode, + FetchNode, + ParseNode, + RAGNode, + GenerateAnswerNode +) +from .abstract_graph import AbstractGraph + + +class SearchGraph(AbstractGraph): + """ + Module for searching info on the internet + """ + + def _create_llm(self, llm_config: dict): + """ + Creates an instance of the language model (OpenAI or Gemini) based on configuration. + """ + llm_defaults = { + "temperature": 0, + "streaming": True + } + llm_params = {**llm_defaults, **llm_config} + if "api_key" not in llm_params: + raise ValueError("LLM configuration must include an 'api_key'.") + if "gpt-" in llm_params["model"]: + return OpenAI(llm_params) + elif "gemini" in llm_params["model"]: + return Gemini(llm_params) + else: + raise ValueError("Model not supported") + + def _create_graph(self): + """ + Creates the graph of nodes representing the workflow for web scraping and searching. + """ + search_internet_node = SearchInternetNode( + input="user_prompt", + output=["url"], + model_config={"llm_model": self.llm_model} + ) + fetch_node = FetchNode( + input="url | local_dir", + output=["doc"], + ) + parse_node = ParseNode( + input="doc", + output=["parsed_doc"], + ) + rag_node = RAGNode( + input="user_prompt & (parsed_doc | doc)", + output=["relevant_chunks"], + model_config={"llm_model": self.llm_model}, + ) + generate_answer_node = GenerateAnswerNode( + input="user_prompt & (relevant_chunks | parsed_doc | doc)", + output=["answer"], + model_config={"llm_model": self.llm_model}, + ) + + return BaseGraph( + nodes={ + search_internet_node, + fetch_node, + parse_node, + rag_node, + generate_answer_node, + }, + edges={ + (search_internet_node, fetch_node), + (fetch_node, parse_node), + (parse_node, rag_node), + (rag_node, generate_answer_node) + }, + entry_point=search_internet_node + ) + + def run(self) -> str: + """ + Executes the web scraping and searching process. + """ + inputs = {"user_prompt": self.prompt} + final_state = self.graph.execute(inputs) + + return final_state.get("answer", "No answer found.") diff --git a/scrapegraphai/graphs/smart_scraper_graph.py b/scrapegraphai/graphs/smart_scraper_graph.py index 16e42b81..4f946435 100644 --- a/scrapegraphai/graphs/smart_scraper_graph.py +++ b/scrapegraphai/graphs/smart_scraper_graph.py @@ -9,64 +9,26 @@ from ..nodes import ( RAGNode, GenerateAnswerNode ) +from .abstract_graph import AbstractGraph -class SmartScraperGraph: +class SmartScraperGraph(AbstractGraph): """ SmartScraper is a comprehensive web scraping tool that automates the process of extracting information from web pages using a natural language model to interpret and answer prompts. - - Attributes: - prompt (str): The user's natural language prompt for the information to be extracted. - url (str): The URL of the web page to scrape. - llm_config (dict): Configuration parameters for the language model, with - 'api_key' being mandatory. - llm (ChatOpenAI): An instance of the ChatOpenAI class configured with llm_config. - graph (BaseGraph): An instance of the BaseGraph class representing the scraping workflow. - - Methods: - run(): Executes the web scraping process and returns the answer to the prompt. - - Args: - prompt (str): The user's natural language prompt for the information to be extracted. - url (str): The URL of the web page to scrape. - llm_config (dict): A dictionary containing configuration options for the language model. - Must include 'api_key', may also specify 'model_name', - 'temperature', and 'streaming'. """ - def __init__(self, prompt: str, file_source: str, config: dict): - """ - Initializes the SmartScraper with a prompt, URL, and language model configuration. - """ - self.prompt = prompt - self.file_source = file_source - self.input_key = "url" if file_source.startswith( - "http") else "local_dir" - self.config = config - self.llm_model = self._create_llm(config["llm"]) - self.graph = self._create_graph() - def _create_llm(self, llm_config: dict): """ - Creates an instance of the ChatOpenAI class with the provided language model configuration. - - Returns: - ChatOpenAI: An instance of the ChatOpenAI class. - - Raises: - ValueError: If 'api_key' is not provided in llm_config. + Creates an instance of the language model (OpenAI or Gemini) based on configuration. """ llm_defaults = { "temperature": 0, "streaming": True } - # Update defaults with any LLM parameters that were provided llm_params = {**llm_defaults, **llm_config} - # Ensure the api_key is set, raise an error if it's not if "api_key" not in llm_params: raise ValueError("LLM configuration must include an 'api_key'.") - # select the model based on the model name if "gpt-" in llm_params["model"]: return OpenAI(llm_params) elif "gemini" in llm_params["model"]: @@ -76,11 +38,7 @@ class SmartScraperGraph: def _create_graph(self): """ Creates the graph of nodes representing the workflow for web scraping. - - Returns: - BaseGraph: An instance of the BaseGraph class. """ - # define the nodes for the graph fetch_node = FetchNode( input="url | local_dir", output=["doc"], @@ -117,12 +75,8 @@ class SmartScraperGraph: def run(self) -> str: """ - Executes the scraping process by running the graph and returns the extracted information. - - Returns: - str: The answer extracted from the web page, corresponding to the given prompt. + Executes the web scraping process and returns the answer to the prompt. """ - inputs = {"user_prompt": self.prompt, self.input_key: self.file_source} final_state = self.graph.execute(inputs) diff --git a/scrapegraphai/graphs/speech_graph.py b/scrapegraphai/graphs/speech_graph.py index 1e70e01d..da019698 100644 --- a/scrapegraphai/graphs/speech_graph.py +++ b/scrapegraphai/graphs/speech_graph.py @@ -1,5 +1,5 @@ """ -Module for extracting the summary from the speach +Module for converting text to speach """ from scrapegraphai.utils.save_audio_from_bytes import save_audio_from_bytes from ..models import OpenAI, Gemini, OpenAITextToSpeech @@ -11,62 +11,26 @@ from ..nodes import ( GenerateAnswerNode, TextToSpeechNode, ) +from .abstract_graph import AbstractGraph -class SpeechGraph: +class SpeechGraph(AbstractGraph): """ SpeechSummaryGraph is a tool that automates the process of extracting and summarizing information from web pages, then converting that summary into spoken word via an MP3 file. - - Attributes: - url (str): The URL of the web page to scrape and summarize. - llm_config (dict): Configuration parameters for the language model, - with 'api_key' mandatory. - summary_prompt (str): The prompt used to guide the summarization process. - output_path (Path): The path where the generated MP3 file will be saved. - - Methods: - run(): Executes the web scraping, summarization, and text-to-speech process. - - Args: - url (str): The URL of the web page to scrape and summarize. - llm_config (dict): A dictionary containing configuration options for the language model. - summary_prompt (str): The prompt used to guide the summarization process. - output_path (str): The file path where the generated MP3 should be saved. """ - def __init__(self, prompt: str, file_source: str, config: dict): - """ - Initializes the SmartScraper with a prompt, URL, and language model configuration. - """ - self.prompt = prompt - self.file_source = file_source - self.input_key = "url" if "http" in file_source else "local_dir" - self.llm_model = self._create_llm(config["llm"]) - self.output_path = config.get("output_path", "output.mp3") - self.text_to_speech_model = OpenAITextToSpeech(config["tts_model"]) - self.graph = self._create_graph() - def _create_llm(self, llm_config: dict): """ - Creates an instance of the ChatOpenAI class with the provided language model configuration. - - Returns: - ChatOpenAI: An instance of the ChatOpenAI class. - - Raises: - ValueError: If 'api_key' is not provided in llm_config. + Creates an instance of the language model (OpenAI or Gemini) based on configuration. """ llm_defaults = { "temperature": 0, "streaming": True } - # Update defaults with any LLM parameters that were provided llm_params = {**llm_defaults, **llm_config} - # Ensure the api_key is set, raise an error if it's not if "api_key" not in llm_params: raise ValueError("LLM configuration must include an 'api_key'.") - # select the model based on the model name if "gpt-" in llm_params["model"]: return OpenAI(llm_params) elif "gemini" in llm_params["model"]: @@ -76,12 +40,8 @@ class SpeechGraph: def _create_graph(self): """ - Creates the graph of nodes representing the workflow for web scraping. - - Returns: - BaseGraph: An instance of the BaseGraph class. + Creates the graph of nodes representing the workflow for web scraping and summarization. """ - # define the nodes for the graph fetch_node = FetchNode( input="url | local_dir", output=["doc"], @@ -103,7 +63,8 @@ class SpeechGraph: text_to_speech_node = TextToSpeechNode( input="answer", output=["audio"], - model_config={"tts_model": self.text_to_speech_model}, + model_config={"tts_model": OpenAITextToSpeech( + self.config["tts_model"])}, ) return BaseGraph( @@ -125,10 +86,7 @@ class SpeechGraph: def run(self) -> str: """ - Executes the scraping process by running the graph and returns the extracted information. - - Returns: - str: The answer extracted from the web page, corresponding to the given prompt. + Executes the web scraping, summarization, and text-to-speech process. """ inputs = {"user_prompt": self.prompt, self.input_key: self.file_source} final_state = self.graph.execute(inputs) @@ -136,7 +94,8 @@ class SpeechGraph: audio = final_state.get("audio", None) if not audio: raise ValueError("No audio generated from the text.") - save_audio_from_bytes(audio, self.output_path) - print(f"Audio saved to {self.output_path}") + save_audio_from_bytes(audio, self.config.get( + "output_path", "output.mp3")) + print(f"Audio saved to {self.config.get('output_path', 'output.mp3')}") return final_state diff --git a/scrapegraphai/nodes/__init__.py b/scrapegraphai/nodes/__init__.py index 71223add..e66aec9d 100644 --- a/scrapegraphai/nodes/__init__.py +++ b/scrapegraphai/nodes/__init__.py @@ -10,3 +10,4 @@ from .parse_node import ParseNode from .rag_node import RAGNode from .text_to_speech_node import TextToSpeechNode from .image_to_text_node import ImageToTextNode +from .search_internet_node import SearchInternetNode diff --git a/scrapegraphai/nodes/base_node.py b/scrapegraphai/nodes/base_node.py index 6a85f2d3..781e35d7 100644 --- a/scrapegraphai/nodes/base_node.py +++ b/scrapegraphai/nodes/base_node.py @@ -88,7 +88,8 @@ class BaseNode(ABC): def _validate_input_keys(self, input_keys): if len(input_keys) < self.min_input_len: raise ValueError( - f"{self.node_name} requires at least {self.min_input_len} input keys, got {len(input_keys)}.") + f"""{self.node_name} requires at least {self.min_input_len} input keys, + got {len(input_keys)}.""") def _parse_input_keys(self, state: dict, expression: str) -> List[str]: """ diff --git a/scrapegraphai/nodes/fetch_node.py b/scrapegraphai/nodes/fetch_node.py index bb638caa..892fb551 100644 --- a/scrapegraphai/nodes/fetch_node.py +++ b/scrapegraphai/nodes/fetch_node.py @@ -67,10 +67,11 @@ class FetchNode(BaseNode): # Fetching data from the state based on the input keys input_data = [state[key] for key in input_keys] - source = input_data[0] - - if not source.startswith( - "http"): + source = input_data[0] + + print(f"Fetching content from: {source}") + # if it is a local directory + if not source.startswith("http"): document = [Document(page_content=source, metadata={ "source": "local_dir" })] diff --git a/scrapegraphai/nodes/search_internet_node.py b/scrapegraphai/nodes/search_internet_node.py new file mode 100644 index 00000000..88297c1d --- /dev/null +++ b/scrapegraphai/nodes/search_internet_node.py @@ -0,0 +1,103 @@ +""" +Module for generating the answer node +""" +from typing import List +from langchain.output_parsers import CommaSeparatedListOutputParser +from langchain.prompts import PromptTemplate +from ..utils.research_web import search_on_web +from .base_node import BaseNode + + +class SearchInternetNode(BaseNode): + """ + A node that generates an answer by querying a language model (LLM) based on the user's input + and the content extracted from a webpage. It constructs a prompt from the user's input + and the scraped content, feeds it to the LLM, and parses the LLM's response to produce + an answer. + + Attributes: + node_name (str): The unique identifier name for the node. + node_type (str): The type of the node, set to "node" indicating a standard operational node. + input (str): The user input used to construct the prompt. + output (List[str]): The keys in the state dictionary + where the generated answer will be stored. + model_config (dict): Configuration parameters for the language model client. + + Args: + input (str): The user input used to construct the prompt. + output (List[str]): The keys in the state dictionary where the + generated answer will be stored. + model_config (dict): Configuration parameters for the language model client. + node_name (str, optional): The unique identifier name for the node. + Defaults to "GenerateAnswer". + + Methods: + execute(state): Processes the input and document from the state to generate an answer, + updating the state with the generated answer under the 'answer' key. + """ + + def __init__(self, input: str, output: List[str], model_config: dict, + node_name: str = "SearchInternet"): + """ + Initializes the SearchInternetNode with input, output, model configuration, and a node name. + Args: + input (str): The user input used to construct the prompt. + output (List[str]): The keys in the state dictionary where the + generated answer will be stored. + model_config (dict): Configuration parameters for the language model client. + node_name (str): The unique identifier name for the node. + """ + super().__init__(node_name, "node", input, output, 1, model_config) + self.llm_model = model_config["llm_model"] + + def execute(self, state): + """ + Generates an answer by constructing a prompt from the user's input and the scraped + content, querying the language model, and parsing its response. + + The method updates the state with the generated answer under the 'answer' key. + + Args: + state (dict): The current state of the graph, expected to contain 'user_input', + and optionally 'parsed_document' or 'relevant_chunks' within 'keys'. + + Returns: + dict: The updated state with the 'answer' key containing the generated answer. + + Raises: + KeyError: If 'user_input' or 'document' is not found in the state, indicating + that the necessary information for generating an answer is missing. + """ + + print(f"--- Executing {self.node_name} Node ---") + + input_keys = self.get_input_keys(state) + + # Fetching data from the state based on the input keys + input_data = [state[key] for key in input_keys] + + user_prompt = input_data[0] + + output_parser = CommaSeparatedListOutputParser() + + search_template = """Given the following user prompt, return a query that can be + used to search the internet for relevant information. \n + You should return only the query string. \n + User Prompt: {user_prompt}""" + + search_prompt = PromptTemplate( + template=search_template, + input_variables=["user_prompt"], + ) + + # Execute the chain to get the search query + search_answer = search_prompt | self.llm_model | output_parser + search_query = search_answer.invoke({"user_prompt": user_prompt})[0] + + print(f"Search Query: {search_query}") + # TODO: handle multiple URLs + answer = search_on_web(query=search_query, max_results=1)[0] + + # Update the state with the generated answer + state.update({self.output[0]: answer}) + return state diff --git a/scrapegraphai/utils/research_web.py b/scrapegraphai/utils/research_web.py new file mode 100644 index 00000000..8f48adcd --- /dev/null +++ b/scrapegraphai/utils/research_web.py @@ -0,0 +1,37 @@ +""" +Module for making the request on the web +""" +import re +from typing import List +from langchain_community.tools import DuckDuckGoSearchResults +from googlesearch import search + + +def search_on_web(query: str, search_engine: str = "Google", max_results: int = 10) -> List[str]: + """ + Function that given a query it finds it on the intenet + Args: + query (str): query to search on internet + search_engine (str, optional): type of browser, it could be DuckDuckGo or Google, + default: Google + max_results (int, optional): maximum number of results + + Returns: + List[str]: List of strings of web link + """ + + if search_engine == "Google": + res = [] + + for url in search(query, stop=max_results): + res.append(url) + return res + elif search_engine == "DuckDuckGo": + research = DuckDuckGoSearchResults(max_results=max_results) + res = research.run(query) + + links = re.findall(r'https?://[^\s,\]]+', res) + + return links + raise ValueError( + "The only search engines avaiable are DuckDuckGo or Google")