diff --git a/scrapegraphai/graphs/abstract_graph.py b/scrapegraphai/graphs/abstract_graph.py index 067af7d4..0d02b6d4 100644 --- a/scrapegraphai/graphs/abstract_graph.py +++ b/scrapegraphai/graphs/abstract_graph.py @@ -131,15 +131,15 @@ class AbstractGraph(ABC): "ollama", "oneapi", "nvidia", "groq", "anthropic" "bedrock", "mistralai", "hugging_face", "deepseek", "ernie", "fireworks"} - split_model_provider = llm_params["model"].split("/") + split_model_provider = llm_params["model"].split("/", 1) llm_params["model_provider"] = split_model_provider[0] - llm_params["model"] = split_model_provider[1:] + llm_params["model"] = split_model_provider[1] if llm_params["model_provider"] not in known_providers: raise ValueError(f"Provider {llm_params['model_provider']} is not supported. If possible, try to use a model instance instead.") try: - self.model_token = models_tokens[llm_params["model_provider"]].get(llm_params["model"][0]) + self.model_token = models_tokens[llm_params["model_provider"]][llm_params["model"]] except KeyError: print("Model not found, using default token size (8192)") self.model_token = 8192 @@ -150,18 +150,21 @@ class AbstractGraph(ABC): warnings.simplefilter("ignore") return init_chat_model(**llm_params) else: - if "deepseek" in llm_params["model"]: + if llm_params["model_provider"] == "deepseek": return DeepSeek(**llm_params) - if "ernie" in llm_params["model"]: + if llm_params["model_provider"] == "ernie": from langchain_community.chat_models import ErnieBotChat return ErnieBotChat(**llm_params) - if "oneapi" in llm_params["model"]: + if llm_params["model_provider"] == "oneapi": return OneApi(**llm_params) - if "nvidia" in llm_params["model"]: - from langchain_nvidia_ai_endpoints import ChatNVIDIA + if llm_params["model_provider"] == "nvidia": + try: + from langchain_nvidia_ai_endpoints import ChatNVIDIA + except ImportError: + raise ImportError("The langchain_nvidia_ai_endpoints module is not installed. Please install it using `pip install langchain_nvidia_ai_endpoints`.") return ChatNVIDIA(**llm_params) except Exception as e: diff --git a/tests/graphs/abstract_graph_test.py b/tests/graphs/abstract_graph_test.py index 805a1691..f52c9b32 100644 --- a/tests/graphs/abstract_graph_test.py +++ b/tests/graphs/abstract_graph_test.py @@ -3,29 +3,80 @@ Tests for the AbstractGraph. """ import pytest from unittest.mock import patch -from scrapegraphai.graphs import AbstractGraph +from scrapegraphai.graphs import AbstractGraph, BaseGraph +from scrapegraphai.nodes import ( + FetchNode, + ParseNode +) +from scrapegraphai.models import OneApi, DeepSeek +from langchain_openai import ChatOpenAI, AzureChatOpenAI +from langchain_community.chat_models import ChatOllama +from langchain_google_genai import ChatGoogleGenerativeAI + + + +class TestGraph(AbstractGraph): + def __init__(self, prompt: str, config: dict): + super().__init__(prompt, config) + + def _create_graph(self) -> BaseGraph: + fetch_node = FetchNode( + input="url| local_dir", + output=["doc", "link_urls", "img_urls"], + node_config={ + "llm_model": self.llm_model, + "force": self.config.get("force", False), + "cut": self.config.get("cut", True), + "loader_kwargs": self.config.get("loader_kwargs", {}), + "browser_base": self.config.get("browser_base") + } + ) + parse_node = ParseNode( + input="doc", + output=["parsed_doc"], + node_config={ + "chunk_size": self.model_token + } + ) + return BaseGraph( + nodes=[ + fetch_node, + parse_node + ], + edges=[ + (fetch_node, parse_node), + ], + entry_point=fetch_node, + graph_name=self.__class__.__name__ + ) + + def run(self) -> str: + inputs = {"user_prompt": self.prompt, self.input_key: self.source} + self.final_state, self.execution_info = self.graph.execute(inputs) + + return self.final_state.get("answer", "No answer found.") + class TestAbstractGraph: @pytest.mark.parametrize("llm_config, expected_model", [ - ({"model": "openai/gpt-3.5-turbo"}, "ChatOpenAI"), - ({"model": "azure_openai/gpt-3.5-turbo"}, "AzureChatOpenAI"), - ({"model": "google_genai/gemini-pro"}, "ChatGoogleGenerativeAI"), - ({"model": "google_vertexai/chat-bison"}, "ChatVertexAI"), - ({"model": "ollama/llama2"}, "Ollama"), - ({"model": "oneapi/text-davinci-003"}, "OneApi"), - ({"model": "nvidia/clara-instant-1-base"}, "ChatNVIDIA"), - ({"model": "deepseek/deepseek-coder-6.7b-instruct"}, "DeepSeek"), - ({"model": "ernie/ernie-bot"}, "ErnieBotChat"), + ({"model": "openai/gpt-3.5-turbo", "openai_api_key": "sk-randomtest001"}, ChatOpenAI), + ({ + "model": "azure_openai/gpt-3.5-turbo", + "api_key": "random-api-key", + "api_version": "no version", + "azure_endpoint": "https://www.example.com/"}, + AzureChatOpenAI), + ({"model": "google_genai/gemini-pro", "google_api_key": "google-key-test"}, ChatGoogleGenerativeAI), + ({"model": "ollama/llama2"}, ChatOllama), + ({"model": "oneapi/qwen-turbo"}, OneApi), + ({"model": "deepseek/deepseek-coder"}, DeepSeek), ]) + def test_create_llm(self, llm_config, expected_model): - graph = AbstractGraph("Test prompt", {"llm": llm_config}) + graph = TestGraph("Test prompt", {"llm": llm_config}) assert isinstance(graph.llm_model, expected_model) def test_create_llm_unknown_provider(self): with pytest.raises(ValueError): - AbstractGraph("Test prompt", {"llm": {"model": "unknown_provider/model"}}) + TestGraph("Test prompt", {"llm": {"model": "unknown_provider/model"}}) - def test_create_llm_error(self): - with patch("your_module.init_chat_model", side_effect=Exception("Test error")): - with pytest.raises(Exception): - AbstractGraph("Test prompt", {"llm": {"model": "openai/gpt-3.5-turbo"}})