Merge pull request #46 from VinciGit00/research_branch

Research branch
This commit is contained in:
Marco Vinciguerra 2024-04-06 14:45:29 +02:00 committed by GitHub
commit e17ba55057
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 334 additions and 123 deletions

1
.gitignore vendored
View File

@ -26,5 +26,6 @@ venv/
*.pdf
*.mp3
*.sqlite
*.google-cookie
examples/graph_examples/ScrapeGraphAI_generated_graph
main.py

View File

@ -7,7 +7,6 @@ from dotenv import load_dotenv
from scrapegraphai.models import OpenAI
from scrapegraphai.graphs import BaseGraph
from scrapegraphai.nodes import FetchNode, ParseNode, RAGNode, GenerateAnswerNode
from scrapegraphai.utils import convert_to_csv, convert_to_json
load_dotenv()
openai_key = os.getenv("OPENAI_APIKEY")
@ -68,8 +67,4 @@ result = graph.execute({
# get the answer from the result
result = result.get("answer", "No answer found.")
print(result)
# Save to json and csv
convert_to_csv(result, "result")
convert_to_json(result, "result")
print(result)

View File

@ -0,0 +1,33 @@
"""
Example of Search Graph
"""
import os
from dotenv import load_dotenv
from scrapegraphai.graphs import SearchGraph
from scrapegraphai.utils import convert_to_csv, convert_to_json
load_dotenv()
openai_key = os.getenv("OPENAI_APIKEY")
# Define the configuration for the graph
graph_config = {
"llm": {
"api_key": openai_key,
"model": "gpt-3.5-turbo",
"temperature": 0,
},
}
# Create the SmartScraperGraph instance
smart_scraper_graph = SearchGraph(
prompt="List me all the regions of Italy.",
config=graph_config
)
result = smart_scraper_graph.run()
print(result)
# Save to json and csv
convert_to_csv(result, "result")
convert_to_json(result, "result")

View File

@ -5,7 +5,6 @@ Basic example of scraping pipeline using SmartScraper
import os
from dotenv import load_dotenv
from scrapegraphai.graphs import SmartScraperGraph
from scrapegraphai.utils import convert_to_csv, convert_to_json
load_dotenv()
openai_key = os.getenv("OPENAI_APIKEY")
@ -28,7 +27,3 @@ smart_scraper_graph = SmartScraperGraph(
result = smart_scraper_graph.run()
print(result)
# Save to json and csv
convert_to_csv(result, "result")
convert_to_json(result, "result")

View File

@ -5,7 +5,6 @@ Basic example of scraping pipeline using SpeechSummaryGraph
import os
from dotenv import load_dotenv
from scrapegraphai.graphs import SpeechGraph
from scrapegraphai.utils import convert_to_csv, convert_to_json
load_dotenv()
openai_key = os.getenv("OPENAI_APIKEY")
@ -37,7 +36,3 @@ speech_graph = SpeechGraph(
result = speech_graph.run()
print(result.get("answer", "No answer found"))
# Save to json and csv
convert_to_csv(result, "result")
convert_to_json(result, "result")

View File

@ -36,6 +36,7 @@ python-dotenv = "1.0.1"
tiktoken = {version = ">=0.5.2,<0.6.0"}
tqdm = "4.66.1"
graphviz = "0.20.1"
google = "3.0.0"
[tool.poetry.dev-dependencies]
pytest = "8.0.0"

View File

@ -4,3 +4,4 @@ __init__.py file for graphs folder
from .base_graph import BaseGraph
from .smart_scraper_graph import SmartScraperGraph
from .speech_graph import SpeechGraph
from .search_graph import SearchGraph

View File

@ -0,0 +1,43 @@
"""
Module having abstract class for creating all the graphs
"""
from abc import ABC, abstractmethod
from typing import Optional
class AbstractGraph(ABC):
"""
Abstract class representing a generic graph-based tool.
"""
def __init__(self, prompt: str, config: dict, file_source: Optional[str] = "url"):
"""
Initializes the AbstractGraph with a prompt, file source, and configuration.
"""
self.prompt = prompt
self.file_source = file_source
self.input_key = "url" if file_source.startswith(
"http") else "local_dir"
self.config = config
self.llm_model = self._create_llm(config["llm"])
self.graph = self._create_graph()
@abstractmethod
def _create_llm(self, llm_config: dict):
"""
Abstract method to create a language model instance.
"""
pass
@abstractmethod
def _create_graph(self):
"""
Abstract method to create a graph representation.
"""
pass
@abstractmethod
def run(self) -> str:
"""
Abstract method to execute the graph and return the result.
"""
pass

View File

@ -0,0 +1,91 @@
"""
Module for making the search on the intenet
"""
from ..models import OpenAI, Gemini
from .base_graph import BaseGraph
from ..nodes import (
SearchInternetNode,
FetchNode,
ParseNode,
RAGNode,
GenerateAnswerNode
)
from .abstract_graph import AbstractGraph
class SearchGraph(AbstractGraph):
"""
Module for searching info on the internet
"""
def _create_llm(self, llm_config: dict):
"""
Creates an instance of the language model (OpenAI or Gemini) based on configuration.
"""
llm_defaults = {
"temperature": 0,
"streaming": True
}
llm_params = {**llm_defaults, **llm_config}
if "api_key" not in llm_params:
raise ValueError("LLM configuration must include an 'api_key'.")
if "gpt-" in llm_params["model"]:
return OpenAI(llm_params)
elif "gemini" in llm_params["model"]:
return Gemini(llm_params)
else:
raise ValueError("Model not supported")
def _create_graph(self):
"""
Creates the graph of nodes representing the workflow for web scraping and searching.
"""
search_internet_node = SearchInternetNode(
input="user_prompt",
output=["url"],
model_config={"llm_model": self.llm_model}
)
fetch_node = FetchNode(
input="url | local_dir",
output=["doc"],
)
parse_node = ParseNode(
input="doc",
output=["parsed_doc"],
)
rag_node = RAGNode(
input="user_prompt & (parsed_doc | doc)",
output=["relevant_chunks"],
model_config={"llm_model": self.llm_model},
)
generate_answer_node = GenerateAnswerNode(
input="user_prompt & (relevant_chunks | parsed_doc | doc)",
output=["answer"],
model_config={"llm_model": self.llm_model},
)
return BaseGraph(
nodes={
search_internet_node,
fetch_node,
parse_node,
rag_node,
generate_answer_node,
},
edges={
(search_internet_node, fetch_node),
(fetch_node, parse_node),
(parse_node, rag_node),
(rag_node, generate_answer_node)
},
entry_point=search_internet_node
)
def run(self) -> str:
"""
Executes the web scraping and searching process.
"""
inputs = {"user_prompt": self.prompt}
final_state = self.graph.execute(inputs)
return final_state.get("answer", "No answer found.")

View File

@ -9,64 +9,26 @@ from ..nodes import (
RAGNode,
GenerateAnswerNode
)
from .abstract_graph import AbstractGraph
class SmartScraperGraph:
class SmartScraperGraph(AbstractGraph):
"""
SmartScraper is a comprehensive web scraping tool that automates the process of extracting
information from web pages using a natural language model to interpret and answer prompts.
Attributes:
prompt (str): The user's natural language prompt for the information to be extracted.
url (str): The URL of the web page to scrape.
llm_config (dict): Configuration parameters for the language model, with
'api_key' being mandatory.
llm (ChatOpenAI): An instance of the ChatOpenAI class configured with llm_config.
graph (BaseGraph): An instance of the BaseGraph class representing the scraping workflow.
Methods:
run(): Executes the web scraping process and returns the answer to the prompt.
Args:
prompt (str): The user's natural language prompt for the information to be extracted.
url (str): The URL of the web page to scrape.
llm_config (dict): A dictionary containing configuration options for the language model.
Must include 'api_key', may also specify 'model_name',
'temperature', and 'streaming'.
"""
def __init__(self, prompt: str, file_source: str, config: dict):
"""
Initializes the SmartScraper with a prompt, URL, and language model configuration.
"""
self.prompt = prompt
self.file_source = file_source
self.input_key = "url" if file_source.startswith(
"http") else "local_dir"
self.config = config
self.llm_model = self._create_llm(config["llm"])
self.graph = self._create_graph()
def _create_llm(self, llm_config: dict):
"""
Creates an instance of the ChatOpenAI class with the provided language model configuration.
Returns:
ChatOpenAI: An instance of the ChatOpenAI class.
Raises:
ValueError: If 'api_key' is not provided in llm_config.
Creates an instance of the language model (OpenAI or Gemini) based on configuration.
"""
llm_defaults = {
"temperature": 0,
"streaming": True
}
# Update defaults with any LLM parameters that were provided
llm_params = {**llm_defaults, **llm_config}
# Ensure the api_key is set, raise an error if it's not
if "api_key" not in llm_params:
raise ValueError("LLM configuration must include an 'api_key'.")
# select the model based on the model name
if "gpt-" in llm_params["model"]:
return OpenAI(llm_params)
elif "gemini" in llm_params["model"]:
@ -76,11 +38,7 @@ class SmartScraperGraph:
def _create_graph(self):
"""
Creates the graph of nodes representing the workflow for web scraping.
Returns:
BaseGraph: An instance of the BaseGraph class.
"""
# define the nodes for the graph
fetch_node = FetchNode(
input="url | local_dir",
output=["doc"],
@ -117,12 +75,8 @@ class SmartScraperGraph:
def run(self) -> str:
"""
Executes the scraping process by running the graph and returns the extracted information.
Returns:
str: The answer extracted from the web page, corresponding to the given prompt.
Executes the web scraping process and returns the answer to the prompt.
"""
inputs = {"user_prompt": self.prompt, self.input_key: self.file_source}
final_state = self.graph.execute(inputs)

View File

@ -1,5 +1,5 @@
"""
Module for extracting the summary from the speach
Module for converting text to speach
"""
from scrapegraphai.utils.save_audio_from_bytes import save_audio_from_bytes
from ..models import OpenAI, Gemini, OpenAITextToSpeech
@ -11,62 +11,26 @@ from ..nodes import (
GenerateAnswerNode,
TextToSpeechNode,
)
from .abstract_graph import AbstractGraph
class SpeechGraph:
class SpeechGraph(AbstractGraph):
"""
SpeechSummaryGraph is a tool that automates the process of extracting and summarizing
information from web pages, then converting that summary into spoken word via an MP3 file.
Attributes:
url (str): The URL of the web page to scrape and summarize.
llm_config (dict): Configuration parameters for the language model,
with 'api_key' mandatory.
summary_prompt (str): The prompt used to guide the summarization process.
output_path (Path): The path where the generated MP3 file will be saved.
Methods:
run(): Executes the web scraping, summarization, and text-to-speech process.
Args:
url (str): The URL of the web page to scrape and summarize.
llm_config (dict): A dictionary containing configuration options for the language model.
summary_prompt (str): The prompt used to guide the summarization process.
output_path (str): The file path where the generated MP3 should be saved.
"""
def __init__(self, prompt: str, file_source: str, config: dict):
"""
Initializes the SmartScraper with a prompt, URL, and language model configuration.
"""
self.prompt = prompt
self.file_source = file_source
self.input_key = "url" if "http" in file_source else "local_dir"
self.llm_model = self._create_llm(config["llm"])
self.output_path = config.get("output_path", "output.mp3")
self.text_to_speech_model = OpenAITextToSpeech(config["tts_model"])
self.graph = self._create_graph()
def _create_llm(self, llm_config: dict):
"""
Creates an instance of the ChatOpenAI class with the provided language model configuration.
Returns:
ChatOpenAI: An instance of the ChatOpenAI class.
Raises:
ValueError: If 'api_key' is not provided in llm_config.
Creates an instance of the language model (OpenAI or Gemini) based on configuration.
"""
llm_defaults = {
"temperature": 0,
"streaming": True
}
# Update defaults with any LLM parameters that were provided
llm_params = {**llm_defaults, **llm_config}
# Ensure the api_key is set, raise an error if it's not
if "api_key" not in llm_params:
raise ValueError("LLM configuration must include an 'api_key'.")
# select the model based on the model name
if "gpt-" in llm_params["model"]:
return OpenAI(llm_params)
elif "gemini" in llm_params["model"]:
@ -76,12 +40,8 @@ class SpeechGraph:
def _create_graph(self):
"""
Creates the graph of nodes representing the workflow for web scraping.
Returns:
BaseGraph: An instance of the BaseGraph class.
Creates the graph of nodes representing the workflow for web scraping and summarization.
"""
# define the nodes for the graph
fetch_node = FetchNode(
input="url | local_dir",
output=["doc"],
@ -103,7 +63,8 @@ class SpeechGraph:
text_to_speech_node = TextToSpeechNode(
input="answer",
output=["audio"],
model_config={"tts_model": self.text_to_speech_model},
model_config={"tts_model": OpenAITextToSpeech(
self.config["tts_model"])},
)
return BaseGraph(
@ -125,10 +86,7 @@ class SpeechGraph:
def run(self) -> str:
"""
Executes the scraping process by running the graph and returns the extracted information.
Returns:
str: The answer extracted from the web page, corresponding to the given prompt.
Executes the web scraping, summarization, and text-to-speech process.
"""
inputs = {"user_prompt": self.prompt, self.input_key: self.file_source}
final_state = self.graph.execute(inputs)
@ -136,7 +94,8 @@ class SpeechGraph:
audio = final_state.get("audio", None)
if not audio:
raise ValueError("No audio generated from the text.")
save_audio_from_bytes(audio, self.output_path)
print(f"Audio saved to {self.output_path}")
save_audio_from_bytes(audio, self.config.get(
"output_path", "output.mp3"))
print(f"Audio saved to {self.config.get('output_path', 'output.mp3')}")
return final_state

View File

@ -10,3 +10,4 @@ from .parse_node import ParseNode
from .rag_node import RAGNode
from .text_to_speech_node import TextToSpeechNode
from .image_to_text_node import ImageToTextNode
from .search_internet_node import SearchInternetNode

View File

@ -88,7 +88,8 @@ class BaseNode(ABC):
def _validate_input_keys(self, input_keys):
if len(input_keys) < self.min_input_len:
raise ValueError(
f"{self.node_name} requires at least {self.min_input_len} input keys, got {len(input_keys)}.")
f"""{self.node_name} requires at least {self.min_input_len} input keys,
got {len(input_keys)}.""")
def _parse_input_keys(self, state: dict, expression: str) -> List[str]:
"""

View File

@ -67,10 +67,11 @@ class FetchNode(BaseNode):
# Fetching data from the state based on the input keys
input_data = [state[key] for key in input_keys]
source = input_data[0]
if not source.startswith(
"http"):
source = input_data[0]
print(f"Fetching content from: {source}")
# if it is a local directory
if not source.startswith("http"):
document = [Document(page_content=source, metadata={
"source": "local_dir"
})]

View File

@ -0,0 +1,103 @@
"""
Module for generating the answer node
"""
from typing import List
from langchain.output_parsers import CommaSeparatedListOutputParser
from langchain.prompts import PromptTemplate
from ..utils.research_web import search_on_web
from .base_node import BaseNode
class SearchInternetNode(BaseNode):
"""
A node that generates an answer by querying a language model (LLM) based on the user's input
and the content extracted from a webpage. It constructs a prompt from the user's input
and the scraped content, feeds it to the LLM, and parses the LLM's response to produce
an answer.
Attributes:
node_name (str): The unique identifier name for the node.
node_type (str): The type of the node, set to "node" indicating a standard operational node.
input (str): The user input used to construct the prompt.
output (List[str]): The keys in the state dictionary
where the generated answer will be stored.
model_config (dict): Configuration parameters for the language model client.
Args:
input (str): The user input used to construct the prompt.
output (List[str]): The keys in the state dictionary where the
generated answer will be stored.
model_config (dict): Configuration parameters for the language model client.
node_name (str, optional): The unique identifier name for the node.
Defaults to "GenerateAnswer".
Methods:
execute(state): Processes the input and document from the state to generate an answer,
updating the state with the generated answer under the 'answer' key.
"""
def __init__(self, input: str, output: List[str], model_config: dict,
node_name: str = "SearchInternet"):
"""
Initializes the SearchInternetNode with input, output, model configuration, and a node name.
Args:
input (str): The user input used to construct the prompt.
output (List[str]): The keys in the state dictionary where the
generated answer will be stored.
model_config (dict): Configuration parameters for the language model client.
node_name (str): The unique identifier name for the node.
"""
super().__init__(node_name, "node", input, output, 1, model_config)
self.llm_model = model_config["llm_model"]
def execute(self, state):
"""
Generates an answer by constructing a prompt from the user's input and the scraped
content, querying the language model, and parsing its response.
The method updates the state with the generated answer under the 'answer' key.
Args:
state (dict): The current state of the graph, expected to contain 'user_input',
and optionally 'parsed_document' or 'relevant_chunks' within 'keys'.
Returns:
dict: The updated state with the 'answer' key containing the generated answer.
Raises:
KeyError: If 'user_input' or 'document' is not found in the state, indicating
that the necessary information for generating an answer is missing.
"""
print(f"--- Executing {self.node_name} Node ---")
input_keys = self.get_input_keys(state)
# Fetching data from the state based on the input keys
input_data = [state[key] for key in input_keys]
user_prompt = input_data[0]
output_parser = CommaSeparatedListOutputParser()
search_template = """Given the following user prompt, return a query that can be
used to search the internet for relevant information. \n
You should return only the query string. \n
User Prompt: {user_prompt}"""
search_prompt = PromptTemplate(
template=search_template,
input_variables=["user_prompt"],
)
# Execute the chain to get the search query
search_answer = search_prompt | self.llm_model | output_parser
search_query = search_answer.invoke({"user_prompt": user_prompt})[0]
print(f"Search Query: {search_query}")
# TODO: handle multiple URLs
answer = search_on_web(query=search_query, max_results=1)[0]
# Update the state with the generated answer
state.update({self.output[0]: answer})
return state

View File

@ -0,0 +1,37 @@
"""
Module for making the request on the web
"""
import re
from typing import List
from langchain_community.tools import DuckDuckGoSearchResults
from googlesearch import search
def search_on_web(query: str, search_engine: str = "Google", max_results: int = 10) -> List[str]:
"""
Function that given a query it finds it on the intenet
Args:
query (str): query to search on internet
search_engine (str, optional): type of browser, it could be DuckDuckGo or Google,
default: Google
max_results (int, optional): maximum number of results
Returns:
List[str]: List of strings of web link
"""
if search_engine == "Google":
res = []
for url in search(query, stop=max_results):
res.append(url)
return res
elif search_engine == "DuckDuckGo":
research = DuckDuckGoSearchResults(max_results=max_results)
res = research.run(query)
links = re.findall(r'https?://[^\s,\]]+', res)
return links
raise ValueError(
"The only search engines avaiable are DuckDuckGo or Google")