From 4f120e29c546373a2cc06c102cc9886cc5270c06 Mon Sep 17 00:00:00 2001 From: Federico Aguzzi <62149513+f-aguzzi@users.noreply.github.com> Date: Wed, 28 Aug 2024 12:36:03 +0200 Subject: [PATCH 1/2] fix(AbstractGraph): model selection bug --- scrapegraphai/graphs/abstract_graph.py | 19 +++--- tests/graphs/abstract_graph_test.py | 83 +++++++++++++++++++++----- 2 files changed, 78 insertions(+), 24 deletions(-) diff --git a/scrapegraphai/graphs/abstract_graph.py b/scrapegraphai/graphs/abstract_graph.py index 067af7d4..0d02b6d4 100644 --- a/scrapegraphai/graphs/abstract_graph.py +++ b/scrapegraphai/graphs/abstract_graph.py @@ -131,15 +131,15 @@ class AbstractGraph(ABC): "ollama", "oneapi", "nvidia", "groq", "anthropic" "bedrock", "mistralai", "hugging_face", "deepseek", "ernie", "fireworks"} - split_model_provider = llm_params["model"].split("/") + split_model_provider = llm_params["model"].split("/", 1) llm_params["model_provider"] = split_model_provider[0] - llm_params["model"] = split_model_provider[1:] + llm_params["model"] = split_model_provider[1] if llm_params["model_provider"] not in known_providers: raise ValueError(f"Provider {llm_params['model_provider']} is not supported. If possible, try to use a model instance instead.") try: - self.model_token = models_tokens[llm_params["model_provider"]].get(llm_params["model"][0]) + self.model_token = models_tokens[llm_params["model_provider"]][llm_params["model"]] except KeyError: print("Model not found, using default token size (8192)") self.model_token = 8192 @@ -150,18 +150,21 @@ class AbstractGraph(ABC): warnings.simplefilter("ignore") return init_chat_model(**llm_params) else: - if "deepseek" in llm_params["model"]: + if llm_params["model_provider"] == "deepseek": return DeepSeek(**llm_params) - if "ernie" in llm_params["model"]: + if llm_params["model_provider"] == "ernie": from langchain_community.chat_models import ErnieBotChat return ErnieBotChat(**llm_params) - if "oneapi" in llm_params["model"]: + if llm_params["model_provider"] == "oneapi": return OneApi(**llm_params) - if "nvidia" in llm_params["model"]: - from langchain_nvidia_ai_endpoints import ChatNVIDIA + if llm_params["model_provider"] == "nvidia": + try: + from langchain_nvidia_ai_endpoints import ChatNVIDIA + except ImportError: + raise ImportError("The langchain_nvidia_ai_endpoints module is not installed. Please install it using `pip install langchain_nvidia_ai_endpoints`.") return ChatNVIDIA(**llm_params) except Exception as e: diff --git a/tests/graphs/abstract_graph_test.py b/tests/graphs/abstract_graph_test.py index 805a1691..f52c9b32 100644 --- a/tests/graphs/abstract_graph_test.py +++ b/tests/graphs/abstract_graph_test.py @@ -3,29 +3,80 @@ Tests for the AbstractGraph. """ import pytest from unittest.mock import patch -from scrapegraphai.graphs import AbstractGraph +from scrapegraphai.graphs import AbstractGraph, BaseGraph +from scrapegraphai.nodes import ( + FetchNode, + ParseNode +) +from scrapegraphai.models import OneApi, DeepSeek +from langchain_openai import ChatOpenAI, AzureChatOpenAI +from langchain_community.chat_models import ChatOllama +from langchain_google_genai import ChatGoogleGenerativeAI + + + +class TestGraph(AbstractGraph): + def __init__(self, prompt: str, config: dict): + super().__init__(prompt, config) + + def _create_graph(self) -> BaseGraph: + fetch_node = FetchNode( + input="url| local_dir", + output=["doc", "link_urls", "img_urls"], + node_config={ + "llm_model": self.llm_model, + "force": self.config.get("force", False), + "cut": self.config.get("cut", True), + "loader_kwargs": self.config.get("loader_kwargs", {}), + "browser_base": self.config.get("browser_base") + } + ) + parse_node = ParseNode( + input="doc", + output=["parsed_doc"], + node_config={ + "chunk_size": self.model_token + } + ) + return BaseGraph( + nodes=[ + fetch_node, + parse_node + ], + edges=[ + (fetch_node, parse_node), + ], + entry_point=fetch_node, + graph_name=self.__class__.__name__ + ) + + def run(self) -> str: + inputs = {"user_prompt": self.prompt, self.input_key: self.source} + self.final_state, self.execution_info = self.graph.execute(inputs) + + return self.final_state.get("answer", "No answer found.") + class TestAbstractGraph: @pytest.mark.parametrize("llm_config, expected_model", [ - ({"model": "openai/gpt-3.5-turbo"}, "ChatOpenAI"), - ({"model": "azure_openai/gpt-3.5-turbo"}, "AzureChatOpenAI"), - ({"model": "google_genai/gemini-pro"}, "ChatGoogleGenerativeAI"), - ({"model": "google_vertexai/chat-bison"}, "ChatVertexAI"), - ({"model": "ollama/llama2"}, "Ollama"), - ({"model": "oneapi/text-davinci-003"}, "OneApi"), - ({"model": "nvidia/clara-instant-1-base"}, "ChatNVIDIA"), - ({"model": "deepseek/deepseek-coder-6.7b-instruct"}, "DeepSeek"), - ({"model": "ernie/ernie-bot"}, "ErnieBotChat"), + ({"model": "openai/gpt-3.5-turbo", "openai_api_key": "sk-randomtest001"}, ChatOpenAI), + ({ + "model": "azure_openai/gpt-3.5-turbo", + "api_key": "random-api-key", + "api_version": "no version", + "azure_endpoint": "https://www.example.com/"}, + AzureChatOpenAI), + ({"model": "google_genai/gemini-pro", "google_api_key": "google-key-test"}, ChatGoogleGenerativeAI), + ({"model": "ollama/llama2"}, ChatOllama), + ({"model": "oneapi/qwen-turbo"}, OneApi), + ({"model": "deepseek/deepseek-coder"}, DeepSeek), ]) + def test_create_llm(self, llm_config, expected_model): - graph = AbstractGraph("Test prompt", {"llm": llm_config}) + graph = TestGraph("Test prompt", {"llm": llm_config}) assert isinstance(graph.llm_model, expected_model) def test_create_llm_unknown_provider(self): with pytest.raises(ValueError): - AbstractGraph("Test prompt", {"llm": {"model": "unknown_provider/model"}}) + TestGraph("Test prompt", {"llm": {"model": "unknown_provider/model"}}) - def test_create_llm_error(self): - with patch("your_module.init_chat_model", side_effect=Exception("Test error")): - with pytest.raises(Exception): - AbstractGraph("Test prompt", {"llm": {"model": "openai/gpt-3.5-turbo"}}) From f7a85c266ae758cc16297ebc5d98f8919a80c523 Mon Sep 17 00:00:00 2001 From: Federico Aguzzi <62149513+f-aguzzi@users.noreply.github.com> Date: Wed, 28 Aug 2024 12:51:50 +0200 Subject: [PATCH 2/2] fix(models): better DeepSeek and OneApi integration --- examples/deepseek/csv_scraper_deepseek.py | 3 +-- examples/deepseek/csv_scraper_graph_multi_deepseek.py | 3 +-- examples/deepseek/json_scraper_deepseek.py | 3 +-- examples/deepseek/json_scraper_multi_deepseek.py | 3 +-- examples/deepseek/pdf_scraper_graph_deepseek.py | 3 +-- examples/deepseek/pdf_scraper_multi_deepseek.py | 3 +-- examples/deepseek/scrape_plain_text_deepseek.py | 3 +-- examples/deepseek/script_generator_deepseek.py | 3 +-- examples/deepseek/script_multi_generator_deepseek.py | 3 +-- examples/deepseek/search_graph_deepseek.py | 3 +-- examples/deepseek/search_graph_schema_deepseek.py | 3 +-- examples/deepseek/search_link_graph_deepseek.py | 3 +-- examples/deepseek/smart_scraper_deepseek.py | 3 +-- examples/deepseek/smart_scraper_multi_deepseek.py | 3 +-- examples/deepseek/smart_scraper_schema_deepseek.py | 3 +-- examples/deepseek/xml_scraper_deepseek.py | 3 +-- examples/deepseek/xml_scraper_graph_multi_deepseek.py | 3 +-- examples/oneapi/custom_graph_oneapi.py | 2 +- scrapegraphai/models/deepseek.py | 6 +++++- scrapegraphai/models/oneapi.py | 4 +++- tests/graphs/abstract_graph_test.py | 4 ++-- 21 files changed, 28 insertions(+), 39 deletions(-) diff --git a/examples/deepseek/csv_scraper_deepseek.py b/examples/deepseek/csv_scraper_deepseek.py index 60b1c394..26ff26ee 100644 --- a/examples/deepseek/csv_scraper_deepseek.py +++ b/examples/deepseek/csv_scraper_deepseek.py @@ -28,8 +28,7 @@ deepseek_key = os.getenv("DEEPSEEK_APIKEY") graph_config = { "llm": { "model": "deepseek/deepseek-chat", - "openai_api_key": deepseek_key, - "openai_api_base": 'https://api.deepseek.com/v1', + "api_key": deepseek_key, }, "verbose": True, } diff --git a/examples/deepseek/csv_scraper_graph_multi_deepseek.py b/examples/deepseek/csv_scraper_graph_multi_deepseek.py index 0a08f83f..88056648 100644 --- a/examples/deepseek/csv_scraper_graph_multi_deepseek.py +++ b/examples/deepseek/csv_scraper_graph_multi_deepseek.py @@ -28,8 +28,7 @@ deepseek_key = os.getenv("DEEPSEEK_APIKEY") graph_config = { "llm": { "model": "deepseek/deepseek-chat", - "openai_api_key": deepseek_key, - "openai_api_base": 'https://api.deepseek.com/v1', + "api_key": deepseek_key, }, "verbose": True, } diff --git a/examples/deepseek/json_scraper_deepseek.py b/examples/deepseek/json_scraper_deepseek.py index 02991c0d..5d8bf152 100644 --- a/examples/deepseek/json_scraper_deepseek.py +++ b/examples/deepseek/json_scraper_deepseek.py @@ -27,8 +27,7 @@ with open(file_path, 'r', encoding="utf-8") as file: graph_config = { "llm": { "model": "deepseek/deepseek-chat", - "openai_api_key": deepseek_key, - "openai_api_base": 'https://api.deepseek.com/v1', + "api_key": deepseek_key, }, "verbose": True, } diff --git a/examples/deepseek/json_scraper_multi_deepseek.py b/examples/deepseek/json_scraper_multi_deepseek.py index 4f9ca32d..893937cd 100644 --- a/examples/deepseek/json_scraper_multi_deepseek.py +++ b/examples/deepseek/json_scraper_multi_deepseek.py @@ -13,8 +13,7 @@ deepseek_key = os.getenv("DEEPSEEK_APIKEY") graph_config = { "llm": { "model": "deepseek/deepseek-chat", - "openai_api_key": deepseek_key, - "openai_api_base": 'https://api.deepseek.com/v1', + "api_key": deepseek_key, }, "verbose": True, } diff --git a/examples/deepseek/pdf_scraper_graph_deepseek.py b/examples/deepseek/pdf_scraper_graph_deepseek.py index c9c5e0b2..990e7369 100644 --- a/examples/deepseek/pdf_scraper_graph_deepseek.py +++ b/examples/deepseek/pdf_scraper_graph_deepseek.py @@ -18,8 +18,7 @@ deepseek_key = os.getenv("DEEPSEEK_APIKEY") graph_config = { "llm": { "model": "deepseek/deepseek-chat", - "openai_api_key": deepseek_key, - "openai_api_base": 'https://api.deepseek.com/v1', + "api_key": deepseek_key, }, "verbose": True, } diff --git a/examples/deepseek/pdf_scraper_multi_deepseek.py b/examples/deepseek/pdf_scraper_multi_deepseek.py index e43dd10a..59727a62 100644 --- a/examples/deepseek/pdf_scraper_multi_deepseek.py +++ b/examples/deepseek/pdf_scraper_multi_deepseek.py @@ -13,8 +13,7 @@ deepseek_key = os.getenv("DEEPSEEK_APIKEY") graph_config = { "llm": { "model": "deepseek/deepseek-chat", - "openai_api_key": deepseek_key, - "openai_api_base": 'https://api.deepseek.com/v1', + "api_key": deepseek_key, }, "verbose": True, } diff --git a/examples/deepseek/scrape_plain_text_deepseek.py b/examples/deepseek/scrape_plain_text_deepseek.py index a7834a8f..52128737 100644 --- a/examples/deepseek/scrape_plain_text_deepseek.py +++ b/examples/deepseek/scrape_plain_text_deepseek.py @@ -29,8 +29,7 @@ deepseek_key = os.getenv("DEEPSEEK_APIKEY") graph_config = { "llm": { "model": "deepseek/deepseek-chat", - "openai_api_key": deepseek_key, - "openai_api_base": 'https://api.deepseek.com/v1', + "api_key": deepseek_key, }, "verbose": True, } diff --git a/examples/deepseek/script_generator_deepseek.py b/examples/deepseek/script_generator_deepseek.py index 3de06f25..eaec5232 100644 --- a/examples/deepseek/script_generator_deepseek.py +++ b/examples/deepseek/script_generator_deepseek.py @@ -18,8 +18,7 @@ deepseek_key = os.getenv("DEEPSEEK_APIKEY") graph_config = { "llm": { "model": "deepseek/deepseek-chat", - "openai_api_key": deepseek_key, - "openai_api_base": 'https://api.deepseek.com/v1', + "api_key": deepseek_key, }, "library": "beautifulsoup" } diff --git a/examples/deepseek/script_multi_generator_deepseek.py b/examples/deepseek/script_multi_generator_deepseek.py index cc577ecd..150298ed 100644 --- a/examples/deepseek/script_multi_generator_deepseek.py +++ b/examples/deepseek/script_multi_generator_deepseek.py @@ -18,8 +18,7 @@ deepseek_key = os.getenv("DEEPSEEK_APIKEY") graph_config = { "llm": { "model": "deepseek/deepseek-chat", - "openai_api_key": deepseek_key, - "openai_api_base": 'https://api.deepseek.com/v1', + "api_key": deepseek_key, }, "library": "beautifulsoup" } diff --git a/examples/deepseek/search_graph_deepseek.py b/examples/deepseek/search_graph_deepseek.py index 54d2e9fa..e7c2483c 100644 --- a/examples/deepseek/search_graph_deepseek.py +++ b/examples/deepseek/search_graph_deepseek.py @@ -16,8 +16,7 @@ deepseek_key = os.getenv("DEEPSEEK_APIKEY") graph_config = { "llm": { "model": "deepseek/deepseek-chat", - "openai_api_key": deepseek_key, - "openai_api_base": 'https://api.deepseek.com/v1', + "api_key": deepseek_key, }, "max_results": 2, "verbose": True, diff --git a/examples/deepseek/search_graph_schema_deepseek.py b/examples/deepseek/search_graph_schema_deepseek.py index bcebe76d..1471ede1 100644 --- a/examples/deepseek/search_graph_schema_deepseek.py +++ b/examples/deepseek/search_graph_schema_deepseek.py @@ -32,8 +32,7 @@ deepseek_key = os.getenv("DEEPSEEK_APIKEY") graph_config = { "llm": { "model": "deepseek/deepseek-chat", - "openai_api_key": deepseek_key, - "openai_api_base": 'https://api.deepseek.com/v1', + "api_key": deepseek_key, }, "verbose": True, } diff --git a/examples/deepseek/search_link_graph_deepseek.py b/examples/deepseek/search_link_graph_deepseek.py index 96f886a9..dac13737 100644 --- a/examples/deepseek/search_link_graph_deepseek.py +++ b/examples/deepseek/search_link_graph_deepseek.py @@ -17,8 +17,7 @@ deepseek_key = os.getenv("DEEPSEEK_APIKEY") graph_config = { "llm": { "model": "deepseek/deepseek-chat", - "openai_api_key": deepseek_key, - "openai_api_base": 'https://api.deepseek.com/v1', + "api_key": deepseek_key, }, "verbose": True, } diff --git a/examples/deepseek/smart_scraper_deepseek.py b/examples/deepseek/smart_scraper_deepseek.py index 50314819..4c49b160 100644 --- a/examples/deepseek/smart_scraper_deepseek.py +++ b/examples/deepseek/smart_scraper_deepseek.py @@ -19,8 +19,7 @@ deepseek_key = os.getenv("DEEPSEEK_APIKEY") graph_config = { "llm": { "model": "deepseek/deepseek-chat", - "openai_api_key": deepseek_key, - "openai_api_base": 'https://api.deepseek.com/v1', + "api_key": deepseek_key, }, "verbose": True, } diff --git a/examples/deepseek/smart_scraper_multi_deepseek.py b/examples/deepseek/smart_scraper_multi_deepseek.py index 374cc6e2..2ef062de 100644 --- a/examples/deepseek/smart_scraper_multi_deepseek.py +++ b/examples/deepseek/smart_scraper_multi_deepseek.py @@ -17,8 +17,7 @@ deepseek_key = os.getenv("DEEPSEEK_APIKEY") graph_config = { "llm": { "model": "deepseek/deepseek-chat", - "openai_api_key": deepseek_key, - "openai_api_base": 'https://api.deepseek.com/v1', + "api_key": deepseek_key, }, "verbose": True, } diff --git a/examples/deepseek/smart_scraper_schema_deepseek.py b/examples/deepseek/smart_scraper_schema_deepseek.py index 6d164eb1..722e02bf 100644 --- a/examples/deepseek/smart_scraper_schema_deepseek.py +++ b/examples/deepseek/smart_scraper_schema_deepseek.py @@ -31,8 +31,7 @@ deepseek_key = os.getenv("DEEPSEEK_APIKEY") graph_config = { "llm": { "model": "deepseek/deepseek-chat", - "openai_api_key": deepseek_key, - "openai_api_base": 'https://api.deepseek.com/v1', + "api_key": deepseek_key, }, "verbose": True, } diff --git a/examples/deepseek/xml_scraper_deepseek.py b/examples/deepseek/xml_scraper_deepseek.py index d69665f4..02178c4b 100644 --- a/examples/deepseek/xml_scraper_deepseek.py +++ b/examples/deepseek/xml_scraper_deepseek.py @@ -29,8 +29,7 @@ deepseek_key = os.getenv("DEEPSEEK_APIKEY") graph_config = { "llm": { "model": "deepseek/deepseek-chat", - "openai_api_key": deepseek_key, - "openai_api_base": 'https://api.deepseek.com/v1', + "api_key": deepseek_key, }, "verbose": True, } diff --git a/examples/deepseek/xml_scraper_graph_multi_deepseek.py b/examples/deepseek/xml_scraper_graph_multi_deepseek.py index 5098c9fd..ae74ba21 100644 --- a/examples/deepseek/xml_scraper_graph_multi_deepseek.py +++ b/examples/deepseek/xml_scraper_graph_multi_deepseek.py @@ -28,8 +28,7 @@ deepseek_key = os.getenv("DEEPSEEK_APIKEY") graph_config = { "llm": { "model": "deepseek/deepseek-chat", - "openai_api_key": deepseek_key, - "openai_api_base": 'https://api.deepseek.com/v1', + "api_key": deepseek_key, }, "verbose": True, } diff --git a/examples/oneapi/custom_graph_oneapi.py b/examples/oneapi/custom_graph_oneapi.py index 5777ab33..be58d1d1 100644 --- a/examples/oneapi/custom_graph_oneapi.py +++ b/examples/oneapi/custom_graph_oneapi.py @@ -22,7 +22,7 @@ graph_config = { # Define the graph nodes # ************************************************ -llm_model = OpenAI(graph_config["llm"]) +llm_model = ChatOpenAI(graph_config["llm"]) embedder = OpenAIEmbeddings(api_key=llm_model.openai_api_key) # define the nodes for the graph diff --git a/scrapegraphai/models/deepseek.py b/scrapegraphai/models/deepseek.py index 523fe667..31b2bd5d 100644 --- a/scrapegraphai/models/deepseek.py +++ b/scrapegraphai/models/deepseek.py @@ -14,5 +14,9 @@ class DeepSeek(ChatOpenAI): llm_config (dict): Configuration parameters for the language model. """ - def __init__(self, llm_config: dict): + def __init__(self, **llm_config): + if 'api_key' in llm_config: + llm_config['openai_api_key'] = llm_config.pop('api_key') + llm_config['openai_api_base'] = 'https://api.deepseek.com/v1' + super().__init__(**llm_config) diff --git a/scrapegraphai/models/oneapi.py b/scrapegraphai/models/oneapi.py index 54e846d9..9b20621b 100644 --- a/scrapegraphai/models/oneapi.py +++ b/scrapegraphai/models/oneapi.py @@ -13,5 +13,7 @@ class OneApi(ChatOpenAI): llm_config (dict): Configuration parameters for the language model. """ - def __init__(self, llm_config: dict): + def __init__(self, **llm_config): + if 'api_key' in llm_config: + llm_config['openai_api_key'] = llm_config.pop('api_key') super().__init__(**llm_config) diff --git a/tests/graphs/abstract_graph_test.py b/tests/graphs/abstract_graph_test.py index f52c9b32..60c8ab4c 100644 --- a/tests/graphs/abstract_graph_test.py +++ b/tests/graphs/abstract_graph_test.py @@ -68,8 +68,8 @@ class TestAbstractGraph: AzureChatOpenAI), ({"model": "google_genai/gemini-pro", "google_api_key": "google-key-test"}, ChatGoogleGenerativeAI), ({"model": "ollama/llama2"}, ChatOllama), - ({"model": "oneapi/qwen-turbo"}, OneApi), - ({"model": "deepseek/deepseek-coder"}, DeepSeek), + ({"model": "oneapi/qwen-turbo", "api_key": "oneapi-api-key"}, OneApi), + ({"model": "deepseek/deepseek-coder", "api_key": "deepseek-api-key"}, DeepSeek), ]) def test_create_llm(self, llm_config, expected_model):