diff --git a/examples/openai/multiple_search_openai.py b/examples/openai/multiple_search_openai.py index 498010dc..dbeecf77 100644 --- a/examples/openai/multiple_search_openai.py +++ b/examples/openai/multiple_search_openai.py @@ -10,6 +10,36 @@ from scrapegraphai.utils import prettify_exec_info load_dotenv() +schema= """{ + "Job Postings": { + "Company x": [ + { + "title": "...", + "description": "...", + "location": "...", + "date_posted": "..", + "requirements": ["...", "...", "..."] + }, + { + "title": "...", + "description": "...", + "location": "...", + "date_posted": "..", + "requirements": ["...", "...", "..."] + } + ], + "Company y": [ + { + "title": "...", + "description": "...", + "location": "...", + "date_posted": "..", + "requirements": ["...", "...", "..."] + } + ] + } +}""" + # ************************************************ # Define the configuration for the graph # ************************************************ @@ -19,47 +49,23 @@ openai_key = os.getenv("OPENAI_APIKEY") graph_config = { "llm": { "api_key": openai_key, - "model": "gpt-4o", + "model": "gpt-3.5-turbo", }, "verbose": True, "headless": False, + "schema": schema, } -schema= """{ - "Job Postings": { - "Company A": [ - { - "title": "Software Engineer", - "description": "Develop and maintain software applications.", - "location": "New York, NY", - "date_posted": "2024-05-01", - "requirements": ["Python", "Django", "REST APIs"] - }, - { - "title": "Data Scientist", - "description": "Analyze and interpret complex data.", - "location": "San Francisco, CA", - "date_posted": "2024-05-05", - "requirements": ["Python", "Machine Learning", "SQL"] - } - ], - "Company B": [ - { - "title": "Project Manager", - "description": "Manage software development projects.", - "location": "Boston, MA", - "date_posted": "2024-04-20", - "requirements": ["Project Management", "Agile", "Scrum"] - } - ] - } -}""" + multiple_search_graph = MultipleSearchGraph( prompt="List me all the projects with their description", - source= ["https://perinim.github.io/projects/", "https://perinim.github.io/projects/"], + source= [ + "https://www.linkedin.com/jobs/machine-learning-engineer-offerte-di-lavoro/?currentJobId=3889037104&originalSubdomain=it", + "https://www.glassdoor.com/Job/italy-machine-learning-engineer-jobs-SRCH_IL.0,5_IN120_KO6,31.html", + "https://it.indeed.com/jobs?q=ML+engineer&vjk=3c2e6d27601ffaaa" + ], config=graph_config, - schema = schema ) result = multiple_search_graph.run() diff --git a/scrapegraphai/graphs/abstract_graph.py b/scrapegraphai/graphs/abstract_graph.py index 08362f5c..e1cf77f7 100644 --- a/scrapegraphai/graphs/abstract_graph.py +++ b/scrapegraphai/graphs/abstract_graph.py @@ -40,12 +40,11 @@ class AbstractGraph(ABC): >>> result = my_graph.run() """ - def __init__(self, prompt: str, config: dict, source: Optional[str] = None, schema: Optional[dict]=None): + def __init__(self, prompt: str, config: dict, source: Optional[str] = None): self.prompt = prompt self.source = source self.config = config - self.schema = schema self.llm_model = self._create_llm(config["llm"], chat=True) self.embedder_model = self._create_default_embedder(llm_config=config["llm"] ) if "embeddings" not in config else self._create_embedder( @@ -62,6 +61,7 @@ class AbstractGraph(ABC): self.headless = True if config is None else config.get( "headless", True) self.loader_kwargs = config.get("loader_kwargs", {}) + self.schema = config.get("schema", None) common_params = {"headless": self.headless, "verbose": self.verbose, @@ -69,6 +69,7 @@ class AbstractGraph(ABC): "llm_model": self.llm_model, "embedder_model": self.embedder_model, "schema": self.schema} + self.set_common_params(common_params, overwrite=False) def set_common_params(self, params: dict, overwrite=False): diff --git a/scrapegraphai/graphs/multiple_search_graph.py b/scrapegraphai/graphs/multiple_search_graph.py index 0f3ddf7a..95cc1dda 100644 --- a/scrapegraphai/graphs/multiple_search_graph.py +++ b/scrapegraphai/graphs/multiple_search_graph.py @@ -14,6 +14,8 @@ from .abstract_graph import AbstractGraph from .smart_scraper_graph import SmartScraperGraph from typing import List, Optional + + class MultipleSearchGraph(AbstractGraph): """ MultipleSearchGraph is a scraping pipeline that searches the internet for answers to a given prompt. @@ -39,7 +41,7 @@ class MultipleSearchGraph(AbstractGraph): >>> result = search_graph.run() """ - def __init__(self, prompt: str, source: List[str], config: dict, schema:Optional[dict]= None): + def __init__(self, prompt: str, source: List[str], config: dict): self.max_results = config.get("max_results", 3) @@ -48,7 +50,7 @@ class MultipleSearchGraph(AbstractGraph): else: self.copy_config = deepcopy(config) - super().__init__(prompt, config) + super().__init__(prompt, config, source) def _create_graph(self) -> BaseGraph: """ @@ -65,7 +67,7 @@ class MultipleSearchGraph(AbstractGraph): smart_scraper_instance = SmartScraperGraph( prompt="", source="", - config=self.copy_config + config=self.copy_config, ) # ************************************************ @@ -85,6 +87,7 @@ class MultipleSearchGraph(AbstractGraph): output=["answer"], node_config={ "llm_model": self.llm_model, + "schema": self.config.get("schema", None), } ) diff --git a/scrapegraphai/graphs/smart_scraper_graph.py b/scrapegraphai/graphs/smart_scraper_graph.py index 4093e49f..8a6d03e2 100644 --- a/scrapegraphai/graphs/smart_scraper_graph.py +++ b/scrapegraphai/graphs/smart_scraper_graph.py @@ -81,7 +81,8 @@ class SmartScraperGraph(AbstractGraph): input="user_prompt & (relevant_chunks | parsed_doc | doc)", output=["answer"], node_config={ - "llm_model": self.llm_model + "llm_model": self.llm_model, + "schema": self.config.get("schema", None), } ) diff --git a/scrapegraphai/nodes/generate_answer_node.py b/scrapegraphai/nodes/generate_answer_node.py index 9d3a9798..701e23d4 100644 --- a/scrapegraphai/nodes/generate_answer_node.py +++ b/scrapegraphai/nodes/generate_answer_node.py @@ -35,7 +35,7 @@ class GenerateAnswerNode(BaseNode): def __init__(self, input: str, output: List[str], node_config: Optional[dict] = None, node_name: str = "GenerateAnswer"): - print(node_config) + super().__init__(node_name, "node", input, output, 2, node_config) self.llm_model = node_config["llm_model"] diff --git a/scrapegraphai/nodes/merge_answers_node.py b/scrapegraphai/nodes/merge_answers_node.py index 63ed6afa..c2564554 100644 --- a/scrapegraphai/nodes/merge_answers_node.py +++ b/scrapegraphai/nodes/merge_answers_node.py @@ -79,6 +79,8 @@ class MergeAnswersNode(BaseNode): You need to merge the content from the different websites into a single answer without repetitions (if there are any). \n The scraped contents are in a JSON format and you need to merge them based on the context and providing a correct JSON structure.\n OUTPUT INSTRUCTIONS: {format_instructions}\n + You must format the output with the following schema, if not None:\n + SCHEMA: {schema}\n USER PROMPT: {user_prompt}\n WEBSITE CONTENT: {website_content} """ @@ -89,6 +91,7 @@ class MergeAnswersNode(BaseNode): partial_variables={ "format_instructions": format_instructions, "website_content": answers_str, + "schema": self.node_config.get("schema", None), }, )