mirror of
https://github.com/VinciGit00/Scrapegraph-ai.git
synced 2026-06-25 21:11:11 +08:00
commit
e17ba55057
1
.gitignore
vendored
1
.gitignore
vendored
@ -26,5 +26,6 @@ venv/
|
||||
*.pdf
|
||||
*.mp3
|
||||
*.sqlite
|
||||
*.google-cookie
|
||||
examples/graph_examples/ScrapeGraphAI_generated_graph
|
||||
main.py
|
||||
|
||||
@ -7,7 +7,6 @@ from dotenv import load_dotenv
|
||||
from scrapegraphai.models import OpenAI
|
||||
from scrapegraphai.graphs import BaseGraph
|
||||
from scrapegraphai.nodes import FetchNode, ParseNode, RAGNode, GenerateAnswerNode
|
||||
from scrapegraphai.utils import convert_to_csv, convert_to_json
|
||||
|
||||
load_dotenv()
|
||||
openai_key = os.getenv("OPENAI_APIKEY")
|
||||
@ -68,8 +67,4 @@ result = graph.execute({
|
||||
|
||||
# get the answer from the result
|
||||
result = result.get("answer", "No answer found.")
|
||||
print(result)
|
||||
|
||||
# Save to json and csv
|
||||
convert_to_csv(result, "result")
|
||||
convert_to_json(result, "result")
|
||||
print(result)
|
||||
33
examples/search_graph_example.py
Normal file
33
examples/search_graph_example.py
Normal file
@ -0,0 +1,33 @@
|
||||
"""
|
||||
Example of Search Graph
|
||||
"""
|
||||
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
from scrapegraphai.graphs import SearchGraph
|
||||
from scrapegraphai.utils import convert_to_csv, convert_to_json
|
||||
|
||||
load_dotenv()
|
||||
openai_key = os.getenv("OPENAI_APIKEY")
|
||||
|
||||
# Define the configuration for the graph
|
||||
graph_config = {
|
||||
"llm": {
|
||||
"api_key": openai_key,
|
||||
"model": "gpt-3.5-turbo",
|
||||
"temperature": 0,
|
||||
},
|
||||
}
|
||||
|
||||
# Create the SmartScraperGraph instance
|
||||
smart_scraper_graph = SearchGraph(
|
||||
prompt="List me all the regions of Italy.",
|
||||
config=graph_config
|
||||
)
|
||||
|
||||
result = smart_scraper_graph.run()
|
||||
print(result)
|
||||
|
||||
# Save to json and csv
|
||||
convert_to_csv(result, "result")
|
||||
convert_to_json(result, "result")
|
||||
@ -5,7 +5,6 @@ Basic example of scraping pipeline using SmartScraper
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
from scrapegraphai.graphs import SmartScraperGraph
|
||||
from scrapegraphai.utils import convert_to_csv, convert_to_json
|
||||
|
||||
load_dotenv()
|
||||
openai_key = os.getenv("OPENAI_APIKEY")
|
||||
@ -28,7 +27,3 @@ smart_scraper_graph = SmartScraperGraph(
|
||||
|
||||
result = smart_scraper_graph.run()
|
||||
print(result)
|
||||
|
||||
# Save to json and csv
|
||||
convert_to_csv(result, "result")
|
||||
convert_to_json(result, "result")
|
||||
|
||||
@ -5,7 +5,6 @@ Basic example of scraping pipeline using SpeechSummaryGraph
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
from scrapegraphai.graphs import SpeechGraph
|
||||
from scrapegraphai.utils import convert_to_csv, convert_to_json
|
||||
|
||||
load_dotenv()
|
||||
openai_key = os.getenv("OPENAI_APIKEY")
|
||||
@ -37,7 +36,3 @@ speech_graph = SpeechGraph(
|
||||
|
||||
result = speech_graph.run()
|
||||
print(result.get("answer", "No answer found"))
|
||||
|
||||
# Save to json and csv
|
||||
convert_to_csv(result, "result")
|
||||
convert_to_json(result, "result")
|
||||
|
||||
@ -36,6 +36,7 @@ python-dotenv = "1.0.1"
|
||||
tiktoken = {version = ">=0.5.2,<0.6.0"}
|
||||
tqdm = "4.66.1"
|
||||
graphviz = "0.20.1"
|
||||
google = "3.0.0"
|
||||
|
||||
[tool.poetry.dev-dependencies]
|
||||
pytest = "8.0.0"
|
||||
|
||||
@ -4,3 +4,4 @@ __init__.py file for graphs folder
|
||||
from .base_graph import BaseGraph
|
||||
from .smart_scraper_graph import SmartScraperGraph
|
||||
from .speech_graph import SpeechGraph
|
||||
from .search_graph import SearchGraph
|
||||
|
||||
43
scrapegraphai/graphs/abstract_graph.py
Normal file
43
scrapegraphai/graphs/abstract_graph.py
Normal file
@ -0,0 +1,43 @@
|
||||
"""
|
||||
Module having abstract class for creating all the graphs
|
||||
"""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
class AbstractGraph(ABC):
|
||||
"""
|
||||
Abstract class representing a generic graph-based tool.
|
||||
"""
|
||||
|
||||
def __init__(self, prompt: str, config: dict, file_source: Optional[str] = "url"):
|
||||
"""
|
||||
Initializes the AbstractGraph with a prompt, file source, and configuration.
|
||||
"""
|
||||
self.prompt = prompt
|
||||
self.file_source = file_source
|
||||
self.input_key = "url" if file_source.startswith(
|
||||
"http") else "local_dir"
|
||||
self.config = config
|
||||
self.llm_model = self._create_llm(config["llm"])
|
||||
self.graph = self._create_graph()
|
||||
|
||||
@abstractmethod
|
||||
def _create_llm(self, llm_config: dict):
|
||||
"""
|
||||
Abstract method to create a language model instance.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _create_graph(self):
|
||||
"""
|
||||
Abstract method to create a graph representation.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def run(self) -> str:
|
||||
"""
|
||||
Abstract method to execute the graph and return the result.
|
||||
"""
|
||||
pass
|
||||
91
scrapegraphai/graphs/search_graph.py
Normal file
91
scrapegraphai/graphs/search_graph.py
Normal file
@ -0,0 +1,91 @@
|
||||
"""
|
||||
Module for making the search on the intenet
|
||||
"""
|
||||
from ..models import OpenAI, Gemini
|
||||
from .base_graph import BaseGraph
|
||||
from ..nodes import (
|
||||
SearchInternetNode,
|
||||
FetchNode,
|
||||
ParseNode,
|
||||
RAGNode,
|
||||
GenerateAnswerNode
|
||||
)
|
||||
from .abstract_graph import AbstractGraph
|
||||
|
||||
|
||||
class SearchGraph(AbstractGraph):
|
||||
"""
|
||||
Module for searching info on the internet
|
||||
"""
|
||||
|
||||
def _create_llm(self, llm_config: dict):
|
||||
"""
|
||||
Creates an instance of the language model (OpenAI or Gemini) based on configuration.
|
||||
"""
|
||||
llm_defaults = {
|
||||
"temperature": 0,
|
||||
"streaming": True
|
||||
}
|
||||
llm_params = {**llm_defaults, **llm_config}
|
||||
if "api_key" not in llm_params:
|
||||
raise ValueError("LLM configuration must include an 'api_key'.")
|
||||
if "gpt-" in llm_params["model"]:
|
||||
return OpenAI(llm_params)
|
||||
elif "gemini" in llm_params["model"]:
|
||||
return Gemini(llm_params)
|
||||
else:
|
||||
raise ValueError("Model not supported")
|
||||
|
||||
def _create_graph(self):
|
||||
"""
|
||||
Creates the graph of nodes representing the workflow for web scraping and searching.
|
||||
"""
|
||||
search_internet_node = SearchInternetNode(
|
||||
input="user_prompt",
|
||||
output=["url"],
|
||||
model_config={"llm_model": self.llm_model}
|
||||
)
|
||||
fetch_node = FetchNode(
|
||||
input="url | local_dir",
|
||||
output=["doc"],
|
||||
)
|
||||
parse_node = ParseNode(
|
||||
input="doc",
|
||||
output=["parsed_doc"],
|
||||
)
|
||||
rag_node = RAGNode(
|
||||
input="user_prompt & (parsed_doc | doc)",
|
||||
output=["relevant_chunks"],
|
||||
model_config={"llm_model": self.llm_model},
|
||||
)
|
||||
generate_answer_node = GenerateAnswerNode(
|
||||
input="user_prompt & (relevant_chunks | parsed_doc | doc)",
|
||||
output=["answer"],
|
||||
model_config={"llm_model": self.llm_model},
|
||||
)
|
||||
|
||||
return BaseGraph(
|
||||
nodes={
|
||||
search_internet_node,
|
||||
fetch_node,
|
||||
parse_node,
|
||||
rag_node,
|
||||
generate_answer_node,
|
||||
},
|
||||
edges={
|
||||
(search_internet_node, fetch_node),
|
||||
(fetch_node, parse_node),
|
||||
(parse_node, rag_node),
|
||||
(rag_node, generate_answer_node)
|
||||
},
|
||||
entry_point=search_internet_node
|
||||
)
|
||||
|
||||
def run(self) -> str:
|
||||
"""
|
||||
Executes the web scraping and searching process.
|
||||
"""
|
||||
inputs = {"user_prompt": self.prompt}
|
||||
final_state = self.graph.execute(inputs)
|
||||
|
||||
return final_state.get("answer", "No answer found.")
|
||||
@ -9,64 +9,26 @@ from ..nodes import (
|
||||
RAGNode,
|
||||
GenerateAnswerNode
|
||||
)
|
||||
from .abstract_graph import AbstractGraph
|
||||
|
||||
|
||||
class SmartScraperGraph:
|
||||
class SmartScraperGraph(AbstractGraph):
|
||||
"""
|
||||
SmartScraper is a comprehensive web scraping tool that automates the process of extracting
|
||||
information from web pages using a natural language model to interpret and answer prompts.
|
||||
|
||||
Attributes:
|
||||
prompt (str): The user's natural language prompt for the information to be extracted.
|
||||
url (str): The URL of the web page to scrape.
|
||||
llm_config (dict): Configuration parameters for the language model, with
|
||||
'api_key' being mandatory.
|
||||
llm (ChatOpenAI): An instance of the ChatOpenAI class configured with llm_config.
|
||||
graph (BaseGraph): An instance of the BaseGraph class representing the scraping workflow.
|
||||
|
||||
Methods:
|
||||
run(): Executes the web scraping process and returns the answer to the prompt.
|
||||
|
||||
Args:
|
||||
prompt (str): The user's natural language prompt for the information to be extracted.
|
||||
url (str): The URL of the web page to scrape.
|
||||
llm_config (dict): A dictionary containing configuration options for the language model.
|
||||
Must include 'api_key', may also specify 'model_name',
|
||||
'temperature', and 'streaming'.
|
||||
"""
|
||||
|
||||
def __init__(self, prompt: str, file_source: str, config: dict):
|
||||
"""
|
||||
Initializes the SmartScraper with a prompt, URL, and language model configuration.
|
||||
"""
|
||||
self.prompt = prompt
|
||||
self.file_source = file_source
|
||||
self.input_key = "url" if file_source.startswith(
|
||||
"http") else "local_dir"
|
||||
self.config = config
|
||||
self.llm_model = self._create_llm(config["llm"])
|
||||
self.graph = self._create_graph()
|
||||
|
||||
def _create_llm(self, llm_config: dict):
|
||||
"""
|
||||
Creates an instance of the ChatOpenAI class with the provided language model configuration.
|
||||
|
||||
Returns:
|
||||
ChatOpenAI: An instance of the ChatOpenAI class.
|
||||
|
||||
Raises:
|
||||
ValueError: If 'api_key' is not provided in llm_config.
|
||||
Creates an instance of the language model (OpenAI or Gemini) based on configuration.
|
||||
"""
|
||||
llm_defaults = {
|
||||
"temperature": 0,
|
||||
"streaming": True
|
||||
}
|
||||
# Update defaults with any LLM parameters that were provided
|
||||
llm_params = {**llm_defaults, **llm_config}
|
||||
# Ensure the api_key is set, raise an error if it's not
|
||||
if "api_key" not in llm_params:
|
||||
raise ValueError("LLM configuration must include an 'api_key'.")
|
||||
# select the model based on the model name
|
||||
if "gpt-" in llm_params["model"]:
|
||||
return OpenAI(llm_params)
|
||||
elif "gemini" in llm_params["model"]:
|
||||
@ -76,11 +38,7 @@ class SmartScraperGraph:
|
||||
def _create_graph(self):
|
||||
"""
|
||||
Creates the graph of nodes representing the workflow for web scraping.
|
||||
|
||||
Returns:
|
||||
BaseGraph: An instance of the BaseGraph class.
|
||||
"""
|
||||
# define the nodes for the graph
|
||||
fetch_node = FetchNode(
|
||||
input="url | local_dir",
|
||||
output=["doc"],
|
||||
@ -117,12 +75,8 @@ class SmartScraperGraph:
|
||||
|
||||
def run(self) -> str:
|
||||
"""
|
||||
Executes the scraping process by running the graph and returns the extracted information.
|
||||
|
||||
Returns:
|
||||
str: The answer extracted from the web page, corresponding to the given prompt.
|
||||
Executes the web scraping process and returns the answer to the prompt.
|
||||
"""
|
||||
|
||||
inputs = {"user_prompt": self.prompt, self.input_key: self.file_source}
|
||||
final_state = self.graph.execute(inputs)
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
"""
|
||||
Module for extracting the summary from the speach
|
||||
Module for converting text to speach
|
||||
"""
|
||||
from scrapegraphai.utils.save_audio_from_bytes import save_audio_from_bytes
|
||||
from ..models import OpenAI, Gemini, OpenAITextToSpeech
|
||||
@ -11,62 +11,26 @@ from ..nodes import (
|
||||
GenerateAnswerNode,
|
||||
TextToSpeechNode,
|
||||
)
|
||||
from .abstract_graph import AbstractGraph
|
||||
|
||||
|
||||
class SpeechGraph:
|
||||
class SpeechGraph(AbstractGraph):
|
||||
"""
|
||||
SpeechSummaryGraph is a tool that automates the process of extracting and summarizing
|
||||
information from web pages, then converting that summary into spoken word via an MP3 file.
|
||||
|
||||
Attributes:
|
||||
url (str): The URL of the web page to scrape and summarize.
|
||||
llm_config (dict): Configuration parameters for the language model,
|
||||
with 'api_key' mandatory.
|
||||
summary_prompt (str): The prompt used to guide the summarization process.
|
||||
output_path (Path): The path where the generated MP3 file will be saved.
|
||||
|
||||
Methods:
|
||||
run(): Executes the web scraping, summarization, and text-to-speech process.
|
||||
|
||||
Args:
|
||||
url (str): The URL of the web page to scrape and summarize.
|
||||
llm_config (dict): A dictionary containing configuration options for the language model.
|
||||
summary_prompt (str): The prompt used to guide the summarization process.
|
||||
output_path (str): The file path where the generated MP3 should be saved.
|
||||
"""
|
||||
|
||||
def __init__(self, prompt: str, file_source: str, config: dict):
|
||||
"""
|
||||
Initializes the SmartScraper with a prompt, URL, and language model configuration.
|
||||
"""
|
||||
self.prompt = prompt
|
||||
self.file_source = file_source
|
||||
self.input_key = "url" if "http" in file_source else "local_dir"
|
||||
self.llm_model = self._create_llm(config["llm"])
|
||||
self.output_path = config.get("output_path", "output.mp3")
|
||||
self.text_to_speech_model = OpenAITextToSpeech(config["tts_model"])
|
||||
self.graph = self._create_graph()
|
||||
|
||||
def _create_llm(self, llm_config: dict):
|
||||
"""
|
||||
Creates an instance of the ChatOpenAI class with the provided language model configuration.
|
||||
|
||||
Returns:
|
||||
ChatOpenAI: An instance of the ChatOpenAI class.
|
||||
|
||||
Raises:
|
||||
ValueError: If 'api_key' is not provided in llm_config.
|
||||
Creates an instance of the language model (OpenAI or Gemini) based on configuration.
|
||||
"""
|
||||
llm_defaults = {
|
||||
"temperature": 0,
|
||||
"streaming": True
|
||||
}
|
||||
# Update defaults with any LLM parameters that were provided
|
||||
llm_params = {**llm_defaults, **llm_config}
|
||||
# Ensure the api_key is set, raise an error if it's not
|
||||
if "api_key" not in llm_params:
|
||||
raise ValueError("LLM configuration must include an 'api_key'.")
|
||||
# select the model based on the model name
|
||||
if "gpt-" in llm_params["model"]:
|
||||
return OpenAI(llm_params)
|
||||
elif "gemini" in llm_params["model"]:
|
||||
@ -76,12 +40,8 @@ class SpeechGraph:
|
||||
|
||||
def _create_graph(self):
|
||||
"""
|
||||
Creates the graph of nodes representing the workflow for web scraping.
|
||||
|
||||
Returns:
|
||||
BaseGraph: An instance of the BaseGraph class.
|
||||
Creates the graph of nodes representing the workflow for web scraping and summarization.
|
||||
"""
|
||||
# define the nodes for the graph
|
||||
fetch_node = FetchNode(
|
||||
input="url | local_dir",
|
||||
output=["doc"],
|
||||
@ -103,7 +63,8 @@ class SpeechGraph:
|
||||
text_to_speech_node = TextToSpeechNode(
|
||||
input="answer",
|
||||
output=["audio"],
|
||||
model_config={"tts_model": self.text_to_speech_model},
|
||||
model_config={"tts_model": OpenAITextToSpeech(
|
||||
self.config["tts_model"])},
|
||||
)
|
||||
|
||||
return BaseGraph(
|
||||
@ -125,10 +86,7 @@ class SpeechGraph:
|
||||
|
||||
def run(self) -> str:
|
||||
"""
|
||||
Executes the scraping process by running the graph and returns the extracted information.
|
||||
|
||||
Returns:
|
||||
str: The answer extracted from the web page, corresponding to the given prompt.
|
||||
Executes the web scraping, summarization, and text-to-speech process.
|
||||
"""
|
||||
inputs = {"user_prompt": self.prompt, self.input_key: self.file_source}
|
||||
final_state = self.graph.execute(inputs)
|
||||
@ -136,7 +94,8 @@ class SpeechGraph:
|
||||
audio = final_state.get("audio", None)
|
||||
if not audio:
|
||||
raise ValueError("No audio generated from the text.")
|
||||
save_audio_from_bytes(audio, self.output_path)
|
||||
print(f"Audio saved to {self.output_path}")
|
||||
save_audio_from_bytes(audio, self.config.get(
|
||||
"output_path", "output.mp3"))
|
||||
print(f"Audio saved to {self.config.get('output_path', 'output.mp3')}")
|
||||
|
||||
return final_state
|
||||
|
||||
@ -10,3 +10,4 @@ from .parse_node import ParseNode
|
||||
from .rag_node import RAGNode
|
||||
from .text_to_speech_node import TextToSpeechNode
|
||||
from .image_to_text_node import ImageToTextNode
|
||||
from .search_internet_node import SearchInternetNode
|
||||
|
||||
@ -88,7 +88,8 @@ class BaseNode(ABC):
|
||||
def _validate_input_keys(self, input_keys):
|
||||
if len(input_keys) < self.min_input_len:
|
||||
raise ValueError(
|
||||
f"{self.node_name} requires at least {self.min_input_len} input keys, got {len(input_keys)}.")
|
||||
f"""{self.node_name} requires at least {self.min_input_len} input keys,
|
||||
got {len(input_keys)}.""")
|
||||
|
||||
def _parse_input_keys(self, state: dict, expression: str) -> List[str]:
|
||||
"""
|
||||
|
||||
@ -67,10 +67,11 @@ class FetchNode(BaseNode):
|
||||
# Fetching data from the state based on the input keys
|
||||
input_data = [state[key] for key in input_keys]
|
||||
|
||||
source = input_data[0]
|
||||
|
||||
if not source.startswith(
|
||||
"http"):
|
||||
source = input_data[0]
|
||||
|
||||
print(f"Fetching content from: {source}")
|
||||
# if it is a local directory
|
||||
if not source.startswith("http"):
|
||||
document = [Document(page_content=source, metadata={
|
||||
"source": "local_dir"
|
||||
})]
|
||||
|
||||
103
scrapegraphai/nodes/search_internet_node.py
Normal file
103
scrapegraphai/nodes/search_internet_node.py
Normal file
@ -0,0 +1,103 @@
|
||||
"""
|
||||
Module for generating the answer node
|
||||
"""
|
||||
from typing import List
|
||||
from langchain.output_parsers import CommaSeparatedListOutputParser
|
||||
from langchain.prompts import PromptTemplate
|
||||
from ..utils.research_web import search_on_web
|
||||
from .base_node import BaseNode
|
||||
|
||||
|
||||
class SearchInternetNode(BaseNode):
|
||||
"""
|
||||
A node that generates an answer by querying a language model (LLM) based on the user's input
|
||||
and the content extracted from a webpage. It constructs a prompt from the user's input
|
||||
and the scraped content, feeds it to the LLM, and parses the LLM's response to produce
|
||||
an answer.
|
||||
|
||||
Attributes:
|
||||
node_name (str): The unique identifier name for the node.
|
||||
node_type (str): The type of the node, set to "node" indicating a standard operational node.
|
||||
input (str): The user input used to construct the prompt.
|
||||
output (List[str]): The keys in the state dictionary
|
||||
where the generated answer will be stored.
|
||||
model_config (dict): Configuration parameters for the language model client.
|
||||
|
||||
Args:
|
||||
input (str): The user input used to construct the prompt.
|
||||
output (List[str]): The keys in the state dictionary where the
|
||||
generated answer will be stored.
|
||||
model_config (dict): Configuration parameters for the language model client.
|
||||
node_name (str, optional): The unique identifier name for the node.
|
||||
Defaults to "GenerateAnswer".
|
||||
|
||||
Methods:
|
||||
execute(state): Processes the input and document from the state to generate an answer,
|
||||
updating the state with the generated answer under the 'answer' key.
|
||||
"""
|
||||
|
||||
def __init__(self, input: str, output: List[str], model_config: dict,
|
||||
node_name: str = "SearchInternet"):
|
||||
"""
|
||||
Initializes the SearchInternetNode with input, output, model configuration, and a node name.
|
||||
Args:
|
||||
input (str): The user input used to construct the prompt.
|
||||
output (List[str]): The keys in the state dictionary where the
|
||||
generated answer will be stored.
|
||||
model_config (dict): Configuration parameters for the language model client.
|
||||
node_name (str): The unique identifier name for the node.
|
||||
"""
|
||||
super().__init__(node_name, "node", input, output, 1, model_config)
|
||||
self.llm_model = model_config["llm_model"]
|
||||
|
||||
def execute(self, state):
|
||||
"""
|
||||
Generates an answer by constructing a prompt from the user's input and the scraped
|
||||
content, querying the language model, and parsing its response.
|
||||
|
||||
The method updates the state with the generated answer under the 'answer' key.
|
||||
|
||||
Args:
|
||||
state (dict): The current state of the graph, expected to contain 'user_input',
|
||||
and optionally 'parsed_document' or 'relevant_chunks' within 'keys'.
|
||||
|
||||
Returns:
|
||||
dict: The updated state with the 'answer' key containing the generated answer.
|
||||
|
||||
Raises:
|
||||
KeyError: If 'user_input' or 'document' is not found in the state, indicating
|
||||
that the necessary information for generating an answer is missing.
|
||||
"""
|
||||
|
||||
print(f"--- Executing {self.node_name} Node ---")
|
||||
|
||||
input_keys = self.get_input_keys(state)
|
||||
|
||||
# Fetching data from the state based on the input keys
|
||||
input_data = [state[key] for key in input_keys]
|
||||
|
||||
user_prompt = input_data[0]
|
||||
|
||||
output_parser = CommaSeparatedListOutputParser()
|
||||
|
||||
search_template = """Given the following user prompt, return a query that can be
|
||||
used to search the internet for relevant information. \n
|
||||
You should return only the query string. \n
|
||||
User Prompt: {user_prompt}"""
|
||||
|
||||
search_prompt = PromptTemplate(
|
||||
template=search_template,
|
||||
input_variables=["user_prompt"],
|
||||
)
|
||||
|
||||
# Execute the chain to get the search query
|
||||
search_answer = search_prompt | self.llm_model | output_parser
|
||||
search_query = search_answer.invoke({"user_prompt": user_prompt})[0]
|
||||
|
||||
print(f"Search Query: {search_query}")
|
||||
# TODO: handle multiple URLs
|
||||
answer = search_on_web(query=search_query, max_results=1)[0]
|
||||
|
||||
# Update the state with the generated answer
|
||||
state.update({self.output[0]: answer})
|
||||
return state
|
||||
37
scrapegraphai/utils/research_web.py
Normal file
37
scrapegraphai/utils/research_web.py
Normal file
@ -0,0 +1,37 @@
|
||||
"""
|
||||
Module for making the request on the web
|
||||
"""
|
||||
import re
|
||||
from typing import List
|
||||
from langchain_community.tools import DuckDuckGoSearchResults
|
||||
from googlesearch import search
|
||||
|
||||
|
||||
def search_on_web(query: str, search_engine: str = "Google", max_results: int = 10) -> List[str]:
|
||||
"""
|
||||
Function that given a query it finds it on the intenet
|
||||
Args:
|
||||
query (str): query to search on internet
|
||||
search_engine (str, optional): type of browser, it could be DuckDuckGo or Google,
|
||||
default: Google
|
||||
max_results (int, optional): maximum number of results
|
||||
|
||||
Returns:
|
||||
List[str]: List of strings of web link
|
||||
"""
|
||||
|
||||
if search_engine == "Google":
|
||||
res = []
|
||||
|
||||
for url in search(query, stop=max_results):
|
||||
res.append(url)
|
||||
return res
|
||||
elif search_engine == "DuckDuckGo":
|
||||
research = DuckDuckGoSearchResults(max_results=max_results)
|
||||
res = research.run(query)
|
||||
|
||||
links = re.findall(r'https?://[^\s,\]]+', res)
|
||||
|
||||
return links
|
||||
raise ValueError(
|
||||
"The only search engines avaiable are DuckDuckGo or Google")
|
||||
Loading…
Reference in New Issue
Block a user