mirror of
https://github.com/VinciGit00/Scrapegraph-ai.git
synced 2026-07-01 21:00:48 +08:00
fix(AbstractGraph): model selection bug
This commit is contained in:
parent
4eccc76442
commit
4f120e29c5
@ -131,15 +131,15 @@ class AbstractGraph(ABC):
|
|||||||
"ollama", "oneapi", "nvidia", "groq", "anthropic" "bedrock", "mistralai",
|
"ollama", "oneapi", "nvidia", "groq", "anthropic" "bedrock", "mistralai",
|
||||||
"hugging_face", "deepseek", "ernie", "fireworks"}
|
"hugging_face", "deepseek", "ernie", "fireworks"}
|
||||||
|
|
||||||
split_model_provider = llm_params["model"].split("/")
|
split_model_provider = llm_params["model"].split("/", 1)
|
||||||
llm_params["model_provider"] = split_model_provider[0]
|
llm_params["model_provider"] = split_model_provider[0]
|
||||||
llm_params["model"] = split_model_provider[1:]
|
llm_params["model"] = split_model_provider[1]
|
||||||
|
|
||||||
if llm_params["model_provider"] not in known_providers:
|
if llm_params["model_provider"] not in known_providers:
|
||||||
raise ValueError(f"Provider {llm_params['model_provider']} is not supported. If possible, try to use a model instance instead.")
|
raise ValueError(f"Provider {llm_params['model_provider']} is not supported. If possible, try to use a model instance instead.")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.model_token = models_tokens[llm_params["model_provider"]].get(llm_params["model"][0])
|
self.model_token = models_tokens[llm_params["model_provider"]][llm_params["model"]]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
print("Model not found, using default token size (8192)")
|
print("Model not found, using default token size (8192)")
|
||||||
self.model_token = 8192
|
self.model_token = 8192
|
||||||
@ -150,18 +150,21 @@ class AbstractGraph(ABC):
|
|||||||
warnings.simplefilter("ignore")
|
warnings.simplefilter("ignore")
|
||||||
return init_chat_model(**llm_params)
|
return init_chat_model(**llm_params)
|
||||||
else:
|
else:
|
||||||
if "deepseek" in llm_params["model"]:
|
if llm_params["model_provider"] == "deepseek":
|
||||||
return DeepSeek(**llm_params)
|
return DeepSeek(**llm_params)
|
||||||
|
|
||||||
if "ernie" in llm_params["model"]:
|
if llm_params["model_provider"] == "ernie":
|
||||||
from langchain_community.chat_models import ErnieBotChat
|
from langchain_community.chat_models import ErnieBotChat
|
||||||
return ErnieBotChat(**llm_params)
|
return ErnieBotChat(**llm_params)
|
||||||
|
|
||||||
if "oneapi" in llm_params["model"]:
|
if llm_params["model_provider"] == "oneapi":
|
||||||
return OneApi(**llm_params)
|
return OneApi(**llm_params)
|
||||||
|
|
||||||
if "nvidia" in llm_params["model"]:
|
if llm_params["model_provider"] == "nvidia":
|
||||||
from langchain_nvidia_ai_endpoints import ChatNVIDIA
|
try:
|
||||||
|
from langchain_nvidia_ai_endpoints import ChatNVIDIA
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("The langchain_nvidia_ai_endpoints module is not installed. Please install it using `pip install langchain_nvidia_ai_endpoints`.")
|
||||||
return ChatNVIDIA(**llm_params)
|
return ChatNVIDIA(**llm_params)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@ -3,29 +3,80 @@ Tests for the AbstractGraph.
|
|||||||
"""
|
"""
|
||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
from scrapegraphai.graphs import AbstractGraph
|
from scrapegraphai.graphs import AbstractGraph, BaseGraph
|
||||||
|
from scrapegraphai.nodes import (
|
||||||
|
FetchNode,
|
||||||
|
ParseNode
|
||||||
|
)
|
||||||
|
from scrapegraphai.models import OneApi, DeepSeek
|
||||||
|
from langchain_openai import ChatOpenAI, AzureChatOpenAI
|
||||||
|
from langchain_community.chat_models import ChatOllama
|
||||||
|
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class TestGraph(AbstractGraph):
|
||||||
|
def __init__(self, prompt: str, config: dict):
|
||||||
|
super().__init__(prompt, config)
|
||||||
|
|
||||||
|
def _create_graph(self) -> BaseGraph:
|
||||||
|
fetch_node = FetchNode(
|
||||||
|
input="url| local_dir",
|
||||||
|
output=["doc", "link_urls", "img_urls"],
|
||||||
|
node_config={
|
||||||
|
"llm_model": self.llm_model,
|
||||||
|
"force": self.config.get("force", False),
|
||||||
|
"cut": self.config.get("cut", True),
|
||||||
|
"loader_kwargs": self.config.get("loader_kwargs", {}),
|
||||||
|
"browser_base": self.config.get("browser_base")
|
||||||
|
}
|
||||||
|
)
|
||||||
|
parse_node = ParseNode(
|
||||||
|
input="doc",
|
||||||
|
output=["parsed_doc"],
|
||||||
|
node_config={
|
||||||
|
"chunk_size": self.model_token
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return BaseGraph(
|
||||||
|
nodes=[
|
||||||
|
fetch_node,
|
||||||
|
parse_node
|
||||||
|
],
|
||||||
|
edges=[
|
||||||
|
(fetch_node, parse_node),
|
||||||
|
],
|
||||||
|
entry_point=fetch_node,
|
||||||
|
graph_name=self.__class__.__name__
|
||||||
|
)
|
||||||
|
|
||||||
|
def run(self) -> str:
|
||||||
|
inputs = {"user_prompt": self.prompt, self.input_key: self.source}
|
||||||
|
self.final_state, self.execution_info = self.graph.execute(inputs)
|
||||||
|
|
||||||
|
return self.final_state.get("answer", "No answer found.")
|
||||||
|
|
||||||
|
|
||||||
class TestAbstractGraph:
|
class TestAbstractGraph:
|
||||||
@pytest.mark.parametrize("llm_config, expected_model", [
|
@pytest.mark.parametrize("llm_config, expected_model", [
|
||||||
({"model": "openai/gpt-3.5-turbo"}, "ChatOpenAI"),
|
({"model": "openai/gpt-3.5-turbo", "openai_api_key": "sk-randomtest001"}, ChatOpenAI),
|
||||||
({"model": "azure_openai/gpt-3.5-turbo"}, "AzureChatOpenAI"),
|
({
|
||||||
({"model": "google_genai/gemini-pro"}, "ChatGoogleGenerativeAI"),
|
"model": "azure_openai/gpt-3.5-turbo",
|
||||||
({"model": "google_vertexai/chat-bison"}, "ChatVertexAI"),
|
"api_key": "random-api-key",
|
||||||
({"model": "ollama/llama2"}, "Ollama"),
|
"api_version": "no version",
|
||||||
({"model": "oneapi/text-davinci-003"}, "OneApi"),
|
"azure_endpoint": "https://www.example.com/"},
|
||||||
({"model": "nvidia/clara-instant-1-base"}, "ChatNVIDIA"),
|
AzureChatOpenAI),
|
||||||
({"model": "deepseek/deepseek-coder-6.7b-instruct"}, "DeepSeek"),
|
({"model": "google_genai/gemini-pro", "google_api_key": "google-key-test"}, ChatGoogleGenerativeAI),
|
||||||
({"model": "ernie/ernie-bot"}, "ErnieBotChat"),
|
({"model": "ollama/llama2"}, ChatOllama),
|
||||||
|
({"model": "oneapi/qwen-turbo"}, OneApi),
|
||||||
|
({"model": "deepseek/deepseek-coder"}, DeepSeek),
|
||||||
])
|
])
|
||||||
|
|
||||||
def test_create_llm(self, llm_config, expected_model):
|
def test_create_llm(self, llm_config, expected_model):
|
||||||
graph = AbstractGraph("Test prompt", {"llm": llm_config})
|
graph = TestGraph("Test prompt", {"llm": llm_config})
|
||||||
assert isinstance(graph.llm_model, expected_model)
|
assert isinstance(graph.llm_model, expected_model)
|
||||||
|
|
||||||
def test_create_llm_unknown_provider(self):
|
def test_create_llm_unknown_provider(self):
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
AbstractGraph("Test prompt", {"llm": {"model": "unknown_provider/model"}})
|
TestGraph("Test prompt", {"llm": {"model": "unknown_provider/model"}})
|
||||||
|
|
||||||
def test_create_llm_error(self):
|
|
||||||
with patch("your_module.init_chat_model", side_effect=Exception("Test error")):
|
|
||||||
with pytest.raises(Exception):
|
|
||||||
AbstractGraph("Test prompt", {"llm": {"model": "openai/gpt-3.5-turbo"}})
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user