fix: bug for abstract graph

This commit is contained in:
Marco Vinciguerra 2024-08-27 18:05:34 +02:00
parent 050fa3faa0
commit cf73883451
5 changed files with 2 additions and 54 deletions

View File

@ -9,7 +9,7 @@ from scrapegraphai.utils import prettify_exec_info
graph_config = { graph_config = {
"llm": { "llm": {
"model": "ollama/mistral", "model": "ollama/llama3.1",
"temperature": 0, "temperature": 0,
"format": "json", # Ollama needs the format to be specified explicitly "format": "json", # Ollama needs the format to be specified explicitly
# "base_url": "http://localhost:11434", # set ollama URL arbitrarily # "base_url": "http://localhost:11434", # set ollama URL arbitrarily

View File

@ -130,7 +130,6 @@ graphviz==0.20.3
# via burr # via burr
greenlet==3.0.3 greenlet==3.0.3
# via playwright # via playwright
# via sqlalchemy
grpcio==1.65.4 grpcio==1.65.4
# via google-api-core # via google-api-core
# via grpcio-status # via grpcio-status

View File

@ -83,7 +83,6 @@ googleapis-common-protos==1.63.2
# via grpcio-status # via grpcio-status
greenlet==3.0.3 greenlet==3.0.3
# via playwright # via playwright
# via sqlalchemy
grpcio==1.65.1 grpcio==1.65.1
# via google-api-core # via google-api-core
# via grpcio-status # via grpcio-status

View File

@ -139,7 +139,7 @@ class AbstractGraph(ABC):
raise ValueError(f"Provider {llm_params['model_provider']} is not supported. If possible, try to use a model instance instead.") raise ValueError(f"Provider {llm_params['model_provider']} is not supported. If possible, try to use a model instance instead.")
try: try:
self.model_token = models_tokens[llm_params["model"]][llm_params["model"]] self.model_token = models_tokens[llm_params["model_provider"]].get(llm_params["model"][0])
except KeyError: except KeyError:
print("Model not found, using default token size (8192)") print("Model not found, using default token size (8192)")
self.model_token = 8192 self.model_token = 8192

View File

@ -1,50 +0,0 @@
"""
Module for scraping JSON documents
"""
import os
import json
import pytest
from scrapegraphai.graphs import JSONScraperGraph
# Load configuration from a JSON file
CONFIG_FILE = "config.json"
with open(CONFIG_FILE, "r") as f:
CONFIG = json.load(f)
# Fixture to read the sample JSON file
@pytest.fixture
def sample_json():
"""
Read the sample JSON file
"""
file_path = os.path.join(os.path.dirname(__file__), "inputs", "example.json")
with open(file_path, "r", encoding="utf-8") as file:
text = file.read()
return text
# Parametrized fixture to load graph configurations
@pytest.fixture(params=CONFIG["graph_configs"])
def graph_config(request):
"""
Load graph configuration
"""
return request.param
# Test function for the scraping pipeline
def test_scraping_pipeline(sample_json, graph_config):
"""
Test the scraping pipeline
"""
expected_titles = ["Title 1", "Title 2", "Title 3"] # Replace with expected titles
smart_scraper_graph = JSONScraperGraph(
prompt="List me all the titles",
source=sample_json,
config=graph_config
)
result = smart_scraper_graph.run()
assert result is not None
assert isinstance(result, list)
assert sorted(result) == sorted(expected_titles)