From cf73883451729b19034005ee7ebe618c1e256a11 Mon Sep 17 00:00:00 2001 From: Marco Vinciguerra Date: Tue, 27 Aug 2024 18:05:34 +0200 Subject: [PATCH] fix: bug for abstract graph --- examples/local_models/smart_scraper_ollama.py | 2 +- requirements-dev.lock | 1 - requirements.lock | 1 - scrapegraphai/graphs/abstract_graph.py | 2 +- tests/graphs/scrape_json_ollama_test.py | 50 ------------------- 5 files changed, 2 insertions(+), 54 deletions(-) delete mode 100644 tests/graphs/scrape_json_ollama_test.py diff --git a/examples/local_models/smart_scraper_ollama.py b/examples/local_models/smart_scraper_ollama.py index 3f6c0967..d5585ff7 100644 --- a/examples/local_models/smart_scraper_ollama.py +++ b/examples/local_models/smart_scraper_ollama.py @@ -9,7 +9,7 @@ from scrapegraphai.utils import prettify_exec_info graph_config = { "llm": { - "model": "ollama/mistral", + "model": "ollama/llama3.1", "temperature": 0, "format": "json", # Ollama needs the format to be specified explicitly # "base_url": "http://localhost:11434", # set ollama URL arbitrarily diff --git a/requirements-dev.lock b/requirements-dev.lock index 04ca69d9..b816db3d 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -130,7 +130,6 @@ graphviz==0.20.3 # via burr greenlet==3.0.3 # via playwright - # via sqlalchemy grpcio==1.65.4 # via google-api-core # via grpcio-status diff --git a/requirements.lock b/requirements.lock index f3cb5626..30d89366 100644 --- a/requirements.lock +++ b/requirements.lock @@ -83,7 +83,6 @@ googleapis-common-protos==1.63.2 # via grpcio-status greenlet==3.0.3 # via playwright - # via sqlalchemy grpcio==1.65.1 # via google-api-core # via grpcio-status diff --git a/scrapegraphai/graphs/abstract_graph.py b/scrapegraphai/graphs/abstract_graph.py index 03fd30e2..58eb30f4 100644 --- a/scrapegraphai/graphs/abstract_graph.py +++ b/scrapegraphai/graphs/abstract_graph.py @@ -139,7 +139,7 @@ class AbstractGraph(ABC): raise ValueError(f"Provider {llm_params['model_provider']} is not supported. If possible, try to use a model instance instead.") try: - self.model_token = models_tokens[llm_params["model"]][llm_params["model"]] + self.model_token = models_tokens[llm_params["model_provider"]].get(llm_params["model"][0]) except KeyError: print("Model not found, using default token size (8192)") self.model_token = 8192 diff --git a/tests/graphs/scrape_json_ollama_test.py b/tests/graphs/scrape_json_ollama_test.py deleted file mode 100644 index 17ef80b1..00000000 --- a/tests/graphs/scrape_json_ollama_test.py +++ /dev/null @@ -1,50 +0,0 @@ -""" -Module for scraping JSON documents -""" -import os -import json -import pytest - -from scrapegraphai.graphs import JSONScraperGraph - -# Load configuration from a JSON file -CONFIG_FILE = "config.json" -with open(CONFIG_FILE, "r") as f: - CONFIG = json.load(f) - -# Fixture to read the sample JSON file -@pytest.fixture -def sample_json(): - """ - Read the sample JSON file - """ - file_path = os.path.join(os.path.dirname(__file__), "inputs", "example.json") - with open(file_path, "r", encoding="utf-8") as file: - text = file.read() - return text - -# Parametrized fixture to load graph configurations -@pytest.fixture(params=CONFIG["graph_configs"]) -def graph_config(request): - """ - Load graph configuration - """ - return request.param - -# Test function for the scraping pipeline -def test_scraping_pipeline(sample_json, graph_config): - """ - Test the scraping pipeline - """ - expected_titles = ["Title 1", "Title 2", "Title 3"] # Replace with expected titles - - smart_scraper_graph = JSONScraperGraph( - prompt="List me all the titles", - source=sample_json, - config=graph_config - ) - result = smart_scraper_graph.run() - - assert result is not None - assert isinstance(result, list) - assert sorted(result) == sorted(expected_titles)