feat(multiple_search): working multiple example

This commit is contained in:
Marco Perini 2024-05-18 01:51:12 +02:00
parent 05e511e36f
commit bed3eed50c
6 changed files with 53 additions and 39 deletions

View File

@ -10,6 +10,36 @@ from scrapegraphai.utils import prettify_exec_info
load_dotenv() load_dotenv()
schema= """{
"Job Postings": {
"Company x": [
{
"title": "...",
"description": "...",
"location": "...",
"date_posted": "..",
"requirements": ["...", "...", "..."]
},
{
"title": "...",
"description": "...",
"location": "...",
"date_posted": "..",
"requirements": ["...", "...", "..."]
}
],
"Company y": [
{
"title": "...",
"description": "...",
"location": "...",
"date_posted": "..",
"requirements": ["...", "...", "..."]
}
]
}
}"""
# ************************************************ # ************************************************
# Define the configuration for the graph # Define the configuration for the graph
# ************************************************ # ************************************************
@ -19,47 +49,23 @@ openai_key = os.getenv("OPENAI_APIKEY")
graph_config = { graph_config = {
"llm": { "llm": {
"api_key": openai_key, "api_key": openai_key,
"model": "gpt-4o", "model": "gpt-3.5-turbo",
}, },
"verbose": True, "verbose": True,
"headless": False, "headless": False,
"schema": schema,
} }
schema= """{
"Job Postings": {
"Company A": [
{
"title": "Software Engineer",
"description": "Develop and maintain software applications.",
"location": "New York, NY",
"date_posted": "2024-05-01",
"requirements": ["Python", "Django", "REST APIs"]
},
{
"title": "Data Scientist",
"description": "Analyze and interpret complex data.",
"location": "San Francisco, CA",
"date_posted": "2024-05-05",
"requirements": ["Python", "Machine Learning", "SQL"]
}
],
"Company B": [
{
"title": "Project Manager",
"description": "Manage software development projects.",
"location": "Boston, MA",
"date_posted": "2024-04-20",
"requirements": ["Project Management", "Agile", "Scrum"]
}
]
}
}"""
multiple_search_graph = MultipleSearchGraph( multiple_search_graph = MultipleSearchGraph(
prompt="List me all the projects with their description", prompt="List me all the projects with their description",
source= ["https://perinim.github.io/projects/", "https://perinim.github.io/projects/"], source= [
"https://www.linkedin.com/jobs/machine-learning-engineer-offerte-di-lavoro/?currentJobId=3889037104&originalSubdomain=it",
"https://www.glassdoor.com/Job/italy-machine-learning-engineer-jobs-SRCH_IL.0,5_IN120_KO6,31.html",
"https://it.indeed.com/jobs?q=ML+engineer&vjk=3c2e6d27601ffaaa"
],
config=graph_config, config=graph_config,
schema = schema
) )
result = multiple_search_graph.run() result = multiple_search_graph.run()

View File

@ -40,12 +40,11 @@ class AbstractGraph(ABC):
>>> result = my_graph.run() >>> result = my_graph.run()
""" """
def __init__(self, prompt: str, config: dict, source: Optional[str] = None, schema: Optional[dict]=None): def __init__(self, prompt: str, config: dict, source: Optional[str] = None):
self.prompt = prompt self.prompt = prompt
self.source = source self.source = source
self.config = config self.config = config
self.schema = schema
self.llm_model = self._create_llm(config["llm"], chat=True) self.llm_model = self._create_llm(config["llm"], chat=True)
self.embedder_model = self._create_default_embedder(llm_config=config["llm"] self.embedder_model = self._create_default_embedder(llm_config=config["llm"]
) if "embeddings" not in config else self._create_embedder( ) if "embeddings" not in config else self._create_embedder(
@ -62,6 +61,7 @@ class AbstractGraph(ABC):
self.headless = True if config is None else config.get( self.headless = True if config is None else config.get(
"headless", True) "headless", True)
self.loader_kwargs = config.get("loader_kwargs", {}) self.loader_kwargs = config.get("loader_kwargs", {})
self.schema = config.get("schema", None)
common_params = {"headless": self.headless, common_params = {"headless": self.headless,
"verbose": self.verbose, "verbose": self.verbose,
@ -69,6 +69,7 @@ class AbstractGraph(ABC):
"llm_model": self.llm_model, "llm_model": self.llm_model,
"embedder_model": self.embedder_model, "embedder_model": self.embedder_model,
"schema": self.schema} "schema": self.schema}
self.set_common_params(common_params, overwrite=False) self.set_common_params(common_params, overwrite=False)
def set_common_params(self, params: dict, overwrite=False): def set_common_params(self, params: dict, overwrite=False):

View File

@ -14,6 +14,8 @@ from .abstract_graph import AbstractGraph
from .smart_scraper_graph import SmartScraperGraph from .smart_scraper_graph import SmartScraperGraph
from typing import List, Optional from typing import List, Optional
class MultipleSearchGraph(AbstractGraph): class MultipleSearchGraph(AbstractGraph):
""" """
MultipleSearchGraph is a scraping pipeline that searches the internet for answers to a given prompt. MultipleSearchGraph is a scraping pipeline that searches the internet for answers to a given prompt.
@ -39,7 +41,7 @@ class MultipleSearchGraph(AbstractGraph):
>>> result = search_graph.run() >>> result = search_graph.run()
""" """
def __init__(self, prompt: str, source: List[str], config: dict, schema:Optional[dict]= None): def __init__(self, prompt: str, source: List[str], config: dict):
self.max_results = config.get("max_results", 3) self.max_results = config.get("max_results", 3)
@ -48,7 +50,7 @@ class MultipleSearchGraph(AbstractGraph):
else: else:
self.copy_config = deepcopy(config) self.copy_config = deepcopy(config)
super().__init__(prompt, config) super().__init__(prompt, config, source)
def _create_graph(self) -> BaseGraph: def _create_graph(self) -> BaseGraph:
""" """
@ -65,7 +67,7 @@ class MultipleSearchGraph(AbstractGraph):
smart_scraper_instance = SmartScraperGraph( smart_scraper_instance = SmartScraperGraph(
prompt="", prompt="",
source="", source="",
config=self.copy_config config=self.copy_config,
) )
# ************************************************ # ************************************************
@ -85,6 +87,7 @@ class MultipleSearchGraph(AbstractGraph):
output=["answer"], output=["answer"],
node_config={ node_config={
"llm_model": self.llm_model, "llm_model": self.llm_model,
"schema": self.config.get("schema", None),
} }
) )

View File

@ -81,7 +81,8 @@ class SmartScraperGraph(AbstractGraph):
input="user_prompt & (relevant_chunks | parsed_doc | doc)", input="user_prompt & (relevant_chunks | parsed_doc | doc)",
output=["answer"], output=["answer"],
node_config={ node_config={
"llm_model": self.llm_model "llm_model": self.llm_model,
"schema": self.config.get("schema", None),
} }
) )

View File

@ -35,7 +35,7 @@ class GenerateAnswerNode(BaseNode):
def __init__(self, input: str, output: List[str], node_config: Optional[dict] = None, def __init__(self, input: str, output: List[str], node_config: Optional[dict] = None,
node_name: str = "GenerateAnswer"): node_name: str = "GenerateAnswer"):
print(node_config)
super().__init__(node_name, "node", input, output, 2, node_config) super().__init__(node_name, "node", input, output, 2, node_config)
self.llm_model = node_config["llm_model"] self.llm_model = node_config["llm_model"]

View File

@ -79,6 +79,8 @@ class MergeAnswersNode(BaseNode):
You need to merge the content from the different websites into a single answer without repetitions (if there are any). \n You need to merge the content from the different websites into a single answer without repetitions (if there are any). \n
The scraped contents are in a JSON format and you need to merge them based on the context and providing a correct JSON structure.\n The scraped contents are in a JSON format and you need to merge them based on the context and providing a correct JSON structure.\n
OUTPUT INSTRUCTIONS: {format_instructions}\n OUTPUT INSTRUCTIONS: {format_instructions}\n
You must format the output with the following schema, if not None:\n
SCHEMA: {schema}\n
USER PROMPT: {user_prompt}\n USER PROMPT: {user_prompt}\n
WEBSITE CONTENT: {website_content} WEBSITE CONTENT: {website_content}
""" """
@ -89,6 +91,7 @@ class MergeAnswersNode(BaseNode):
partial_variables={ partial_variables={
"format_instructions": format_instructions, "format_instructions": format_instructions,
"website_content": answers_str, "website_content": answers_str,
"schema": self.node_config.get("schema", None),
}, },
) )