mirror of
https://github.com/VinciGit00/Scrapegraph-ai.git
synced 2026-06-25 21:11:11 +08:00
feat(multiple_search): working multiple example
This commit is contained in:
parent
05e511e36f
commit
bed3eed50c
@ -10,6 +10,36 @@ from scrapegraphai.utils import prettify_exec_info
|
||||
load_dotenv()
|
||||
|
||||
|
||||
schema= """{
|
||||
"Job Postings": {
|
||||
"Company x": [
|
||||
{
|
||||
"title": "...",
|
||||
"description": "...",
|
||||
"location": "...",
|
||||
"date_posted": "..",
|
||||
"requirements": ["...", "...", "..."]
|
||||
},
|
||||
{
|
||||
"title": "...",
|
||||
"description": "...",
|
||||
"location": "...",
|
||||
"date_posted": "..",
|
||||
"requirements": ["...", "...", "..."]
|
||||
}
|
||||
],
|
||||
"Company y": [
|
||||
{
|
||||
"title": "...",
|
||||
"description": "...",
|
||||
"location": "...",
|
||||
"date_posted": "..",
|
||||
"requirements": ["...", "...", "..."]
|
||||
}
|
||||
]
|
||||
}
|
||||
}"""
|
||||
|
||||
# ************************************************
|
||||
# Define the configuration for the graph
|
||||
# ************************************************
|
||||
@ -19,47 +49,23 @@ openai_key = os.getenv("OPENAI_APIKEY")
|
||||
graph_config = {
|
||||
"llm": {
|
||||
"api_key": openai_key,
|
||||
"model": "gpt-4o",
|
||||
"model": "gpt-3.5-turbo",
|
||||
},
|
||||
"verbose": True,
|
||||
"headless": False,
|
||||
"schema": schema,
|
||||
}
|
||||
|
||||
schema= """{
|
||||
"Job Postings": {
|
||||
"Company A": [
|
||||
{
|
||||
"title": "Software Engineer",
|
||||
"description": "Develop and maintain software applications.",
|
||||
"location": "New York, NY",
|
||||
"date_posted": "2024-05-01",
|
||||
"requirements": ["Python", "Django", "REST APIs"]
|
||||
},
|
||||
{
|
||||
"title": "Data Scientist",
|
||||
"description": "Analyze and interpret complex data.",
|
||||
"location": "San Francisco, CA",
|
||||
"date_posted": "2024-05-05",
|
||||
"requirements": ["Python", "Machine Learning", "SQL"]
|
||||
}
|
||||
],
|
||||
"Company B": [
|
||||
{
|
||||
"title": "Project Manager",
|
||||
"description": "Manage software development projects.",
|
||||
"location": "Boston, MA",
|
||||
"date_posted": "2024-04-20",
|
||||
"requirements": ["Project Management", "Agile", "Scrum"]
|
||||
}
|
||||
]
|
||||
}
|
||||
}"""
|
||||
|
||||
|
||||
multiple_search_graph = MultipleSearchGraph(
|
||||
prompt="List me all the projects with their description",
|
||||
source= ["https://perinim.github.io/projects/", "https://perinim.github.io/projects/"],
|
||||
source= [
|
||||
"https://www.linkedin.com/jobs/machine-learning-engineer-offerte-di-lavoro/?currentJobId=3889037104&originalSubdomain=it",
|
||||
"https://www.glassdoor.com/Job/italy-machine-learning-engineer-jobs-SRCH_IL.0,5_IN120_KO6,31.html",
|
||||
"https://it.indeed.com/jobs?q=ML+engineer&vjk=3c2e6d27601ffaaa"
|
||||
],
|
||||
config=graph_config,
|
||||
schema = schema
|
||||
)
|
||||
|
||||
result = multiple_search_graph.run()
|
||||
|
||||
@ -40,12 +40,11 @@ class AbstractGraph(ABC):
|
||||
>>> result = my_graph.run()
|
||||
"""
|
||||
|
||||
def __init__(self, prompt: str, config: dict, source: Optional[str] = None, schema: Optional[dict]=None):
|
||||
def __init__(self, prompt: str, config: dict, source: Optional[str] = None):
|
||||
|
||||
self.prompt = prompt
|
||||
self.source = source
|
||||
self.config = config
|
||||
self.schema = schema
|
||||
self.llm_model = self._create_llm(config["llm"], chat=True)
|
||||
self.embedder_model = self._create_default_embedder(llm_config=config["llm"]
|
||||
) if "embeddings" not in config else self._create_embedder(
|
||||
@ -62,6 +61,7 @@ class AbstractGraph(ABC):
|
||||
self.headless = True if config is None else config.get(
|
||||
"headless", True)
|
||||
self.loader_kwargs = config.get("loader_kwargs", {})
|
||||
self.schema = config.get("schema", None)
|
||||
|
||||
common_params = {"headless": self.headless,
|
||||
"verbose": self.verbose,
|
||||
@ -69,6 +69,7 @@ class AbstractGraph(ABC):
|
||||
"llm_model": self.llm_model,
|
||||
"embedder_model": self.embedder_model,
|
||||
"schema": self.schema}
|
||||
|
||||
self.set_common_params(common_params, overwrite=False)
|
||||
|
||||
def set_common_params(self, params: dict, overwrite=False):
|
||||
|
||||
@ -14,6 +14,8 @@ from .abstract_graph import AbstractGraph
|
||||
from .smart_scraper_graph import SmartScraperGraph
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
class MultipleSearchGraph(AbstractGraph):
|
||||
"""
|
||||
MultipleSearchGraph is a scraping pipeline that searches the internet for answers to a given prompt.
|
||||
@ -39,7 +41,7 @@ class MultipleSearchGraph(AbstractGraph):
|
||||
>>> result = search_graph.run()
|
||||
"""
|
||||
|
||||
def __init__(self, prompt: str, source: List[str], config: dict, schema:Optional[dict]= None):
|
||||
def __init__(self, prompt: str, source: List[str], config: dict):
|
||||
|
||||
self.max_results = config.get("max_results", 3)
|
||||
|
||||
@ -48,7 +50,7 @@ class MultipleSearchGraph(AbstractGraph):
|
||||
else:
|
||||
self.copy_config = deepcopy(config)
|
||||
|
||||
super().__init__(prompt, config)
|
||||
super().__init__(prompt, config, source)
|
||||
|
||||
def _create_graph(self) -> BaseGraph:
|
||||
"""
|
||||
@ -65,7 +67,7 @@ class MultipleSearchGraph(AbstractGraph):
|
||||
smart_scraper_instance = SmartScraperGraph(
|
||||
prompt="",
|
||||
source="",
|
||||
config=self.copy_config
|
||||
config=self.copy_config,
|
||||
)
|
||||
|
||||
# ************************************************
|
||||
@ -85,6 +87,7 @@ class MultipleSearchGraph(AbstractGraph):
|
||||
output=["answer"],
|
||||
node_config={
|
||||
"llm_model": self.llm_model,
|
||||
"schema": self.config.get("schema", None),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@ -81,7 +81,8 @@ class SmartScraperGraph(AbstractGraph):
|
||||
input="user_prompt & (relevant_chunks | parsed_doc | doc)",
|
||||
output=["answer"],
|
||||
node_config={
|
||||
"llm_model": self.llm_model
|
||||
"llm_model": self.llm_model,
|
||||
"schema": self.config.get("schema", None),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@ -35,7 +35,7 @@ class GenerateAnswerNode(BaseNode):
|
||||
|
||||
def __init__(self, input: str, output: List[str], node_config: Optional[dict] = None,
|
||||
node_name: str = "GenerateAnswer"):
|
||||
print(node_config)
|
||||
|
||||
super().__init__(node_name, "node", input, output, 2, node_config)
|
||||
|
||||
self.llm_model = node_config["llm_model"]
|
||||
|
||||
@ -79,6 +79,8 @@ class MergeAnswersNode(BaseNode):
|
||||
You need to merge the content from the different websites into a single answer without repetitions (if there are any). \n
|
||||
The scraped contents are in a JSON format and you need to merge them based on the context and providing a correct JSON structure.\n
|
||||
OUTPUT INSTRUCTIONS: {format_instructions}\n
|
||||
You must format the output with the following schema, if not None:\n
|
||||
SCHEMA: {schema}\n
|
||||
USER PROMPT: {user_prompt}\n
|
||||
WEBSITE CONTENT: {website_content}
|
||||
"""
|
||||
@ -89,6 +91,7 @@ class MergeAnswersNode(BaseNode):
|
||||
partial_variables={
|
||||
"format_instructions": format_instructions,
|
||||
"website_content": answers_str,
|
||||
"schema": self.node_config.get("schema", None),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user