From 5cf4e4f92f024041c44211aebd2e3bdf73351a00 Mon Sep 17 00:00:00 2001 From: VinciGit00 Date: Thu, 2 May 2024 09:20:46 +0200 Subject: [PATCH] fix: examples and graphs --- CONTRIBUTING.md | 1 + examples/single_node/fetch_node.py | 3 ++ examples/single_node/robot_node.py | 4 +- scrapegraphai/graphs/json_scraper_graph.py | 5 +- scrapegraphai/graphs/script_creator_graph.py | 19 ++++--- scrapegraphai/graphs/smart_scraper_graph.py | 8 +-- scrapegraphai/graphs/xml_scraper_graph.py | 7 +-- tests/graphs/scrape_json_ollama.py | 56 ++++++++++++++++++++ tests/graphs/scrape_xml_ollama_test.py | 4 +- tests/graphs/script_generator_test.py | 2 - tests/nodes/fetch_node_test.py | 3 ++ tests/nodes/robot_node_test.py | 4 +- 12 files changed, 95 insertions(+), 21 deletions(-) create mode 100644 tests/graphs/scrape_json_ollama.py diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 12c032f9..6f9f98f9 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -51,6 +51,7 @@ Please make sure to format your code accordingly before submitting a pull reques - [Style Guide for Python Code](https://www.python.org/dev/peps/pep-0008/) - [Google Python Style Guide](https://google.github.io/styleguide/pyguide.html) - [The Hitchhiker's Guide to Python](https://docs.python-guide.org/writing/style/) +- [Pylint style of code for the documentation](https://pylint.pycqa.org/en/1.6.0/tutorial.html) ## Submitting a Pull Request diff --git a/examples/single_node/fetch_node.py b/examples/single_node/fetch_node.py index 90660996..d03cb495 100644 --- a/examples/single_node/fetch_node.py +++ b/examples/single_node/fetch_node.py @@ -12,6 +12,9 @@ from scrapegraphai.nodes import FetchNode robots_node = FetchNode( input="url | local_dir", output=["doc"], + node_config={ + "headless": False + } ) # ************************************************ diff --git a/examples/single_node/robot_node.py b/examples/single_node/robot_node.py index 8aa26446..0e446262 100644 --- a/examples/single_node/robot_node.py +++ b/examples/single_node/robot_node.py @@ -26,7 +26,9 @@ llm_model = Ollama(graph_config["llm"]) robots_node = RobotsNode( input="url", output=["is_scrapable"], - node_config={"llm": llm_model} + node_config={"llm": llm_model, + "headless": False + } ) # ************************************************ diff --git a/scrapegraphai/graphs/json_scraper_graph.py b/scrapegraphai/graphs/json_scraper_graph.py index 0e231a5c..f7392212 100644 --- a/scrapegraphai/graphs/json_scraper_graph.py +++ b/scrapegraphai/graphs/json_scraper_graph.py @@ -21,7 +21,8 @@ class JSONScraperGraph(AbstractGraph): source (str): The source of the graph. config (dict): Configuration parameters for the graph. llm_model: An instance of a language model client, configured for generating answers. - embedder_model: An instance of an embedding model client, configured for generating embeddings. + embedder_model: An instance of an embedding model client, + configured for generating embeddings. verbose (bool): A flag indicating whether to show print statements during execution. headless (bool): A flag indicating whether to run the graph in headless mode. @@ -47,7 +48,7 @@ class JSONScraperGraph(AbstractGraph): def _create_graph(self) -> BaseGraph: """ Creates the graph of nodes representing the workflow for web scraping. - + Returns: BaseGraph: A graph instance representing the web scraping workflow. """ diff --git a/scrapegraphai/graphs/script_creator_graph.py b/scrapegraphai/graphs/script_creator_graph.py index 20fec447..573be06c 100644 --- a/scrapegraphai/graphs/script_creator_graph.py +++ b/scrapegraphai/graphs/script_creator_graph.py @@ -21,7 +21,8 @@ class ScriptCreatorGraph(AbstractGraph): source (str): The source of the graph. config (dict): Configuration parameters for the graph. llm_model: An instance of a language model client, configured for generating answers. - embedder_model: An instance of an embedding model client, configured for generating embeddings. + embedder_model: An instance of an embedding model client, + configured for generating embeddings. verbose (bool): A flag indicating whether to show print statements during execution. headless (bool): A flag indicating whether to run the graph in headless mode. model_token (int): The token limit for the language model. @@ -44,7 +45,7 @@ class ScriptCreatorGraph(AbstractGraph): def __init__(self, prompt: str, source: str, config: dict): self.library = config['library'] - + super().__init__(prompt, config, source) self.input_key = "url" if source.startswith("http") else "local_dir" @@ -61,25 +62,29 @@ class ScriptCreatorGraph(AbstractGraph): input="url | local_dir", output=["doc"], node_config={ - "headless": True if self.config is None else self.config.get("headless", True)} + "headless": True if self.config is None else self.config.get("headless", True), + "verbose": self.verbose} ) parse_node = ParseNode( input="doc", output=["parsed_doc"], - node_config={"chunk_size": self.model_token} + node_config={"chunk_size": self.model_token, + "verbose": self.verbose} ) rag_node = RAGNode( input="user_prompt & (parsed_doc | doc)", output=["relevant_chunks"], node_config={ "llm": self.llm_model, - "embedder_model": self.embedder_model + "embedder_model": self.embedder_model, + "verbose": self.verbose } ) generate_scraper_node = GenerateScraperNode( input="user_prompt & (relevant_chunks | parsed_doc | doc)", output=["answer"], - node_config={"llm": self.llm_model}, + node_config={"llm": self.llm_model, + "verbose": self.verbose}, library=self.library, website=self.source ) @@ -106,7 +111,7 @@ class ScriptCreatorGraph(AbstractGraph): Returns: str: The answer to the prompt. """ - + inputs = {"user_prompt": self.prompt, self.input_key: self.source} self.final_state, self.execution_info = self.graph.execute(inputs) diff --git a/scrapegraphai/graphs/smart_scraper_graph.py b/scrapegraphai/graphs/smart_scraper_graph.py index ad984c61..4d6b0e93 100644 --- a/scrapegraphai/graphs/smart_scraper_graph.py +++ b/scrapegraphai/graphs/smart_scraper_graph.py @@ -14,7 +14,8 @@ from .abstract_graph import AbstractGraph class SmartScraperGraph(AbstractGraph): """ - SmartScraper is a scraping pipeline that automates the process of extracting information from web pages + SmartScraper is a scraping pipeline that automates the process of + extracting information from web pages using a natural language model to interpret and answer prompts. Attributes: @@ -22,7 +23,8 @@ class SmartScraperGraph(AbstractGraph): source (str): The source of the graph. config (dict): Configuration parameters for the graph. llm_model: An instance of a language model client, configured for generating answers. - embedder_model: An instance of an embedding model client, configured for generating embeddings. + embedder_model: An instance of an embedding model client, + configured for generating embeddings. verbose (bool): A flag indicating whether to show print statements during execution. headless (bool): A flag indicating whether to run the graph in headless mode. @@ -45,7 +47,7 @@ class SmartScraperGraph(AbstractGraph): super().__init__(prompt, config, source) self.input_key = "url" if source.startswith("http") else "local_dir" - + def _create_graph(self) -> BaseGraph: """ Creates the graph of nodes representing the workflow for web scraping. diff --git a/scrapegraphai/graphs/xml_scraper_graph.py b/scrapegraphai/graphs/xml_scraper_graph.py index 83aba049..c84e1506 100644 --- a/scrapegraphai/graphs/xml_scraper_graph.py +++ b/scrapegraphai/graphs/xml_scraper_graph.py @@ -22,7 +22,8 @@ class XMLScraperGraph(AbstractGraph): source (str): The source of the graph. config (dict): Configuration parameters for the graph. llm_model: An instance of a language model client, configured for generating answers. - embedder_model: An instance of an embedding model client, configured for generating embeddings. + embedder_model: An instance of an embedding model client, + configured for generating embeddings. verbose (bool): A flag indicating whether to show print statements during execution. headless (bool): A flag indicating whether to run the graph in headless mode. model_token (int): The token limit for the language model. @@ -49,7 +50,7 @@ class XMLScraperGraph(AbstractGraph): def _create_graph(self) -> BaseGraph: """ Creates the graph of nodes representing the workflow for web scraping. - + Returns: BaseGraph: A graph instance representing the web scraping workflow. """ @@ -110,7 +111,7 @@ class XMLScraperGraph(AbstractGraph): Returns: str: The answer to the prompt. """ - + inputs = {"user_prompt": self.prompt, self.input_key: self.source} self.final_state, self.execution_info = self.graph.execute(inputs) diff --git a/tests/graphs/scrape_json_ollama.py b/tests/graphs/scrape_json_ollama.py new file mode 100644 index 00000000..a1ce8875 --- /dev/null +++ b/tests/graphs/scrape_json_ollama.py @@ -0,0 +1,56 @@ +""" +Module for scraping json documents +""" +import os +import pytest +from scrapegraphai.graphs import JSONScraperGraph + + +@pytest.fixture +def sample_json(): + """ + Example of text + """ + file_name = "inputs/example.json" + curr_dir = os.path.dirname(os.path.realpath(__file__)) + file_path = os.path.join(curr_dir, file_name) + + with open(file_path, 'r', encoding="utf-8") as file: + text = file.read() + + return text + + +@pytest.fixture +def graph_config(): + """ + Configuration of the graph + """ + return { + "llm": { + "model": "ollama/mistral", + "temperature": 0, + "format": "json", + "base_url": "http://localhost:11434", + }, + "embeddings": { + "model": "ollama/nomic-embed-text", + "temperature": 0, + "base_url": "http://localhost:11434", + } + } + + +def test_scraping_pipeline(sample_json: str, graph_config: dict): + """ + Start of the scraping pipeline + """ + smart_scraper_graph = JSONScraperGraph( + prompt="List me all the titles", + source=sample_json, + config=graph_config + ) + + result = smart_scraper_graph.run() + + assert result is not None diff --git a/tests/graphs/scrape_xml_ollama_test.py b/tests/graphs/scrape_xml_ollama_test.py index afa7527f..04494543 100644 --- a/tests/graphs/scrape_xml_ollama_test.py +++ b/tests/graphs/scrape_xml_ollama_test.py @@ -3,7 +3,7 @@ Module for scraping XML documents """ import os import pytest -from scrapegraphai.graphs import SmartScraperGraph +from scrapegraphai.graphs import XMLScraperGraph @pytest.fixture @@ -45,7 +45,7 @@ def test_scraping_pipeline(sample_xml: str, graph_config: dict): """ Start of the scraping pipeline """ - smart_scraper_graph = SmartScraperGraph( + smart_scraper_graph = XMLScraperGraph( prompt="List me all the authors, title and genres of the books", source=sample_xml, config=graph_config diff --git a/tests/graphs/script_generator_test.py b/tests/graphs/script_generator_test.py index 6114bac4..4982184e 100644 --- a/tests/graphs/script_generator_test.py +++ b/tests/graphs/script_generator_test.py @@ -46,6 +46,4 @@ def test_script_creator_graph(graph_config: dict): assert graph_exec_info is not None - assert isinstance(graph_exec_info, dict) - print(prettify_exec_info(graph_exec_info)) diff --git a/tests/nodes/fetch_node_test.py b/tests/nodes/fetch_node_test.py index e0552a05..811c2daf 100644 --- a/tests/nodes/fetch_node_test.py +++ b/tests/nodes/fetch_node_test.py @@ -17,6 +17,9 @@ def setup(): robots_node = FetchNode( input="url | local_dir", output=["doc"], + node_config={ + "headless": False + } ) return robots_node diff --git a/tests/nodes/robot_node_test.py b/tests/nodes/robot_node_test.py index 7808a976..cae3a895 100644 --- a/tests/nodes/robot_node_test.py +++ b/tests/nodes/robot_node_test.py @@ -32,7 +32,9 @@ def setup(): robots_node = RobotsNode( input="url", output=["is_scrapable"], - node_config={"llm": llm_model} + node_config={"llm": llm_model, + "headless": False + } ) return robots_node