mirror of
https://github.com/VinciGit00/Scrapegraph-ai.git
synced 2026-07-01 21:00:48 +08:00
Merge pull request #602 from ScrapeGraphAI/593-abstract-graph-fix-round-4
Abstract graph fix
This commit is contained in:
commit
08fa257407
@ -28,8 +28,7 @@ deepseek_key = os.getenv("DEEPSEEK_APIKEY")
|
||||
graph_config = {
|
||||
"llm": {
|
||||
"model": "deepseek/deepseek-chat",
|
||||
"openai_api_key": deepseek_key,
|
||||
"openai_api_base": 'https://api.deepseek.com/v1',
|
||||
"api_key": deepseek_key,
|
||||
},
|
||||
"verbose": True,
|
||||
}
|
||||
|
||||
@ -28,8 +28,7 @@ deepseek_key = os.getenv("DEEPSEEK_APIKEY")
|
||||
graph_config = {
|
||||
"llm": {
|
||||
"model": "deepseek/deepseek-chat",
|
||||
"openai_api_key": deepseek_key,
|
||||
"openai_api_base": 'https://api.deepseek.com/v1',
|
||||
"api_key": deepseek_key,
|
||||
},
|
||||
"verbose": True,
|
||||
}
|
||||
|
||||
@ -27,8 +27,7 @@ with open(file_path, 'r', encoding="utf-8") as file:
|
||||
graph_config = {
|
||||
"llm": {
|
||||
"model": "deepseek/deepseek-chat",
|
||||
"openai_api_key": deepseek_key,
|
||||
"openai_api_base": 'https://api.deepseek.com/v1',
|
||||
"api_key": deepseek_key,
|
||||
},
|
||||
"verbose": True,
|
||||
}
|
||||
|
||||
@ -13,8 +13,7 @@ deepseek_key = os.getenv("DEEPSEEK_APIKEY")
|
||||
graph_config = {
|
||||
"llm": {
|
||||
"model": "deepseek/deepseek-chat",
|
||||
"openai_api_key": deepseek_key,
|
||||
"openai_api_base": 'https://api.deepseek.com/v1',
|
||||
"api_key": deepseek_key,
|
||||
},
|
||||
"verbose": True,
|
||||
}
|
||||
|
||||
@ -18,8 +18,7 @@ deepseek_key = os.getenv("DEEPSEEK_APIKEY")
|
||||
graph_config = {
|
||||
"llm": {
|
||||
"model": "deepseek/deepseek-chat",
|
||||
"openai_api_key": deepseek_key,
|
||||
"openai_api_base": 'https://api.deepseek.com/v1',
|
||||
"api_key": deepseek_key,
|
||||
},
|
||||
"verbose": True,
|
||||
}
|
||||
|
||||
@ -13,8 +13,7 @@ deepseek_key = os.getenv("DEEPSEEK_APIKEY")
|
||||
graph_config = {
|
||||
"llm": {
|
||||
"model": "deepseek/deepseek-chat",
|
||||
"openai_api_key": deepseek_key,
|
||||
"openai_api_base": 'https://api.deepseek.com/v1',
|
||||
"api_key": deepseek_key,
|
||||
},
|
||||
"verbose": True,
|
||||
}
|
||||
|
||||
@ -29,8 +29,7 @@ deepseek_key = os.getenv("DEEPSEEK_APIKEY")
|
||||
graph_config = {
|
||||
"llm": {
|
||||
"model": "deepseek/deepseek-chat",
|
||||
"openai_api_key": deepseek_key,
|
||||
"openai_api_base": 'https://api.deepseek.com/v1',
|
||||
"api_key": deepseek_key,
|
||||
},
|
||||
"verbose": True,
|
||||
}
|
||||
|
||||
@ -18,8 +18,7 @@ deepseek_key = os.getenv("DEEPSEEK_APIKEY")
|
||||
graph_config = {
|
||||
"llm": {
|
||||
"model": "deepseek/deepseek-chat",
|
||||
"openai_api_key": deepseek_key,
|
||||
"openai_api_base": 'https://api.deepseek.com/v1',
|
||||
"api_key": deepseek_key,
|
||||
},
|
||||
"library": "beautifulsoup"
|
||||
}
|
||||
|
||||
@ -18,8 +18,7 @@ deepseek_key = os.getenv("DEEPSEEK_APIKEY")
|
||||
graph_config = {
|
||||
"llm": {
|
||||
"model": "deepseek/deepseek-chat",
|
||||
"openai_api_key": deepseek_key,
|
||||
"openai_api_base": 'https://api.deepseek.com/v1',
|
||||
"api_key": deepseek_key,
|
||||
},
|
||||
"library": "beautifulsoup"
|
||||
}
|
||||
|
||||
@ -16,8 +16,7 @@ deepseek_key = os.getenv("DEEPSEEK_APIKEY")
|
||||
graph_config = {
|
||||
"llm": {
|
||||
"model": "deepseek/deepseek-chat",
|
||||
"openai_api_key": deepseek_key,
|
||||
"openai_api_base": 'https://api.deepseek.com/v1',
|
||||
"api_key": deepseek_key,
|
||||
},
|
||||
"max_results": 2,
|
||||
"verbose": True,
|
||||
|
||||
@ -32,8 +32,7 @@ deepseek_key = os.getenv("DEEPSEEK_APIKEY")
|
||||
graph_config = {
|
||||
"llm": {
|
||||
"model": "deepseek/deepseek-chat",
|
||||
"openai_api_key": deepseek_key,
|
||||
"openai_api_base": 'https://api.deepseek.com/v1',
|
||||
"api_key": deepseek_key,
|
||||
},
|
||||
"verbose": True,
|
||||
}
|
||||
|
||||
@ -17,8 +17,7 @@ deepseek_key = os.getenv("DEEPSEEK_APIKEY")
|
||||
graph_config = {
|
||||
"llm": {
|
||||
"model": "deepseek/deepseek-chat",
|
||||
"openai_api_key": deepseek_key,
|
||||
"openai_api_base": 'https://api.deepseek.com/v1',
|
||||
"api_key": deepseek_key,
|
||||
},
|
||||
"verbose": True,
|
||||
}
|
||||
|
||||
@ -19,8 +19,7 @@ deepseek_key = os.getenv("DEEPSEEK_APIKEY")
|
||||
graph_config = {
|
||||
"llm": {
|
||||
"model": "deepseek/deepseek-chat",
|
||||
"openai_api_key": deepseek_key,
|
||||
"openai_api_base": 'https://api.deepseek.com/v1',
|
||||
"api_key": deepseek_key,
|
||||
},
|
||||
"verbose": True,
|
||||
}
|
||||
|
||||
@ -17,8 +17,7 @@ deepseek_key = os.getenv("DEEPSEEK_APIKEY")
|
||||
graph_config = {
|
||||
"llm": {
|
||||
"model": "deepseek/deepseek-chat",
|
||||
"openai_api_key": deepseek_key,
|
||||
"openai_api_base": 'https://api.deepseek.com/v1',
|
||||
"api_key": deepseek_key,
|
||||
},
|
||||
"verbose": True,
|
||||
}
|
||||
|
||||
@ -31,8 +31,7 @@ deepseek_key = os.getenv("DEEPSEEK_APIKEY")
|
||||
graph_config = {
|
||||
"llm": {
|
||||
"model": "deepseek/deepseek-chat",
|
||||
"openai_api_key": deepseek_key,
|
||||
"openai_api_base": 'https://api.deepseek.com/v1',
|
||||
"api_key": deepseek_key,
|
||||
},
|
||||
"verbose": True,
|
||||
}
|
||||
|
||||
@ -29,8 +29,7 @@ deepseek_key = os.getenv("DEEPSEEK_APIKEY")
|
||||
graph_config = {
|
||||
"llm": {
|
||||
"model": "deepseek/deepseek-chat",
|
||||
"openai_api_key": deepseek_key,
|
||||
"openai_api_base": 'https://api.deepseek.com/v1',
|
||||
"api_key": deepseek_key,
|
||||
},
|
||||
"verbose": True,
|
||||
}
|
||||
|
||||
@ -28,8 +28,7 @@ deepseek_key = os.getenv("DEEPSEEK_APIKEY")
|
||||
graph_config = {
|
||||
"llm": {
|
||||
"model": "deepseek/deepseek-chat",
|
||||
"openai_api_key": deepseek_key,
|
||||
"openai_api_base": 'https://api.deepseek.com/v1',
|
||||
"api_key": deepseek_key,
|
||||
},
|
||||
"verbose": True,
|
||||
}
|
||||
|
||||
@ -22,7 +22,7 @@ graph_config = {
|
||||
# Define the graph nodes
|
||||
# ************************************************
|
||||
|
||||
llm_model = OpenAI(graph_config["llm"])
|
||||
llm_model = ChatOpenAI(graph_config["llm"])
|
||||
embedder = OpenAIEmbeddings(api_key=llm_model.openai_api_key)
|
||||
|
||||
# define the nodes for the graph
|
||||
|
||||
@ -131,15 +131,15 @@ class AbstractGraph(ABC):
|
||||
"ollama", "oneapi", "nvidia", "groq", "anthropic" "bedrock", "mistralai",
|
||||
"hugging_face", "deepseek", "ernie", "fireworks"}
|
||||
|
||||
split_model_provider = llm_params["model"].split("/")
|
||||
split_model_provider = llm_params["model"].split("/", 1)
|
||||
llm_params["model_provider"] = split_model_provider[0]
|
||||
llm_params["model"] = split_model_provider[1:]
|
||||
llm_params["model"] = split_model_provider[1]
|
||||
|
||||
if llm_params["model_provider"] not in known_providers:
|
||||
raise ValueError(f"Provider {llm_params['model_provider']} is not supported. If possible, try to use a model instance instead.")
|
||||
|
||||
try:
|
||||
self.model_token = models_tokens[llm_params["model_provider"]].get(llm_params["model"][0])
|
||||
self.model_token = models_tokens[llm_params["model_provider"]][llm_params["model"]]
|
||||
except KeyError:
|
||||
print("Model not found, using default token size (8192)")
|
||||
self.model_token = 8192
|
||||
@ -150,18 +150,21 @@ class AbstractGraph(ABC):
|
||||
warnings.simplefilter("ignore")
|
||||
return init_chat_model(**llm_params)
|
||||
else:
|
||||
if "deepseek" in llm_params["model"]:
|
||||
if llm_params["model_provider"] == "deepseek":
|
||||
return DeepSeek(**llm_params)
|
||||
|
||||
if "ernie" in llm_params["model"]:
|
||||
if llm_params["model_provider"] == "ernie":
|
||||
from langchain_community.chat_models import ErnieBotChat
|
||||
return ErnieBotChat(**llm_params)
|
||||
|
||||
if "oneapi" in llm_params["model"]:
|
||||
if llm_params["model_provider"] == "oneapi":
|
||||
return OneApi(**llm_params)
|
||||
|
||||
if "nvidia" in llm_params["model"]:
|
||||
from langchain_nvidia_ai_endpoints import ChatNVIDIA
|
||||
if llm_params["model_provider"] == "nvidia":
|
||||
try:
|
||||
from langchain_nvidia_ai_endpoints import ChatNVIDIA
|
||||
except ImportError:
|
||||
raise ImportError("The langchain_nvidia_ai_endpoints module is not installed. Please install it using `pip install langchain_nvidia_ai_endpoints`.")
|
||||
return ChatNVIDIA(**llm_params)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@ -14,5 +14,9 @@ class DeepSeek(ChatOpenAI):
|
||||
llm_config (dict): Configuration parameters for the language model.
|
||||
"""
|
||||
|
||||
def __init__(self, llm_config: dict):
|
||||
def __init__(self, **llm_config):
|
||||
if 'api_key' in llm_config:
|
||||
llm_config['openai_api_key'] = llm_config.pop('api_key')
|
||||
llm_config['openai_api_base'] = 'https://api.deepseek.com/v1'
|
||||
|
||||
super().__init__(**llm_config)
|
||||
|
||||
@ -13,5 +13,7 @@ class OneApi(ChatOpenAI):
|
||||
llm_config (dict): Configuration parameters for the language model.
|
||||
"""
|
||||
|
||||
def __init__(self, llm_config: dict):
|
||||
def __init__(self, **llm_config):
|
||||
if 'api_key' in llm_config:
|
||||
llm_config['openai_api_key'] = llm_config.pop('api_key')
|
||||
super().__init__(**llm_config)
|
||||
|
||||
@ -3,29 +3,80 @@ Tests for the AbstractGraph.
|
||||
"""
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
from scrapegraphai.graphs import AbstractGraph
|
||||
from scrapegraphai.graphs import AbstractGraph, BaseGraph
|
||||
from scrapegraphai.nodes import (
|
||||
FetchNode,
|
||||
ParseNode
|
||||
)
|
||||
from scrapegraphai.models import OneApi, DeepSeek
|
||||
from langchain_openai import ChatOpenAI, AzureChatOpenAI
|
||||
from langchain_community.chat_models import ChatOllama
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
|
||||
|
||||
|
||||
class TestGraph(AbstractGraph):
|
||||
def __init__(self, prompt: str, config: dict):
|
||||
super().__init__(prompt, config)
|
||||
|
||||
def _create_graph(self) -> BaseGraph:
|
||||
fetch_node = FetchNode(
|
||||
input="url| local_dir",
|
||||
output=["doc", "link_urls", "img_urls"],
|
||||
node_config={
|
||||
"llm_model": self.llm_model,
|
||||
"force": self.config.get("force", False),
|
||||
"cut": self.config.get("cut", True),
|
||||
"loader_kwargs": self.config.get("loader_kwargs", {}),
|
||||
"browser_base": self.config.get("browser_base")
|
||||
}
|
||||
)
|
||||
parse_node = ParseNode(
|
||||
input="doc",
|
||||
output=["parsed_doc"],
|
||||
node_config={
|
||||
"chunk_size": self.model_token
|
||||
}
|
||||
)
|
||||
return BaseGraph(
|
||||
nodes=[
|
||||
fetch_node,
|
||||
parse_node
|
||||
],
|
||||
edges=[
|
||||
(fetch_node, parse_node),
|
||||
],
|
||||
entry_point=fetch_node,
|
||||
graph_name=self.__class__.__name__
|
||||
)
|
||||
|
||||
def run(self) -> str:
|
||||
inputs = {"user_prompt": self.prompt, self.input_key: self.source}
|
||||
self.final_state, self.execution_info = self.graph.execute(inputs)
|
||||
|
||||
return self.final_state.get("answer", "No answer found.")
|
||||
|
||||
|
||||
class TestAbstractGraph:
|
||||
@pytest.mark.parametrize("llm_config, expected_model", [
|
||||
({"model": "openai/gpt-3.5-turbo"}, "ChatOpenAI"),
|
||||
({"model": "azure_openai/gpt-3.5-turbo"}, "AzureChatOpenAI"),
|
||||
({"model": "google_genai/gemini-pro"}, "ChatGoogleGenerativeAI"),
|
||||
({"model": "google_vertexai/chat-bison"}, "ChatVertexAI"),
|
||||
({"model": "ollama/llama2"}, "Ollama"),
|
||||
({"model": "oneapi/text-davinci-003"}, "OneApi"),
|
||||
({"model": "nvidia/clara-instant-1-base"}, "ChatNVIDIA"),
|
||||
({"model": "deepseek/deepseek-coder-6.7b-instruct"}, "DeepSeek"),
|
||||
({"model": "ernie/ernie-bot"}, "ErnieBotChat"),
|
||||
({"model": "openai/gpt-3.5-turbo", "openai_api_key": "sk-randomtest001"}, ChatOpenAI),
|
||||
({
|
||||
"model": "azure_openai/gpt-3.5-turbo",
|
||||
"api_key": "random-api-key",
|
||||
"api_version": "no version",
|
||||
"azure_endpoint": "https://www.example.com/"},
|
||||
AzureChatOpenAI),
|
||||
({"model": "google_genai/gemini-pro", "google_api_key": "google-key-test"}, ChatGoogleGenerativeAI),
|
||||
({"model": "ollama/llama2"}, ChatOllama),
|
||||
({"model": "oneapi/qwen-turbo", "api_key": "oneapi-api-key"}, OneApi),
|
||||
({"model": "deepseek/deepseek-coder", "api_key": "deepseek-api-key"}, DeepSeek),
|
||||
])
|
||||
|
||||
def test_create_llm(self, llm_config, expected_model):
|
||||
graph = AbstractGraph("Test prompt", {"llm": llm_config})
|
||||
graph = TestGraph("Test prompt", {"llm": llm_config})
|
||||
assert isinstance(graph.llm_model, expected_model)
|
||||
|
||||
def test_create_llm_unknown_provider(self):
|
||||
with pytest.raises(ValueError):
|
||||
AbstractGraph("Test prompt", {"llm": {"model": "unknown_provider/model"}})
|
||||
TestGraph("Test prompt", {"llm": {"model": "unknown_provider/model"}})
|
||||
|
||||
def test_create_llm_error(self):
|
||||
with patch("your_module.init_chat_model", side_effect=Exception("Test error")):
|
||||
with pytest.raises(Exception):
|
||||
AbstractGraph("Test prompt", {"llm": {"model": "openai/gpt-3.5-turbo"}})
|
||||
|
||||
Loading…
Reference in New Issue
Block a user