mirror of
https://github.com/VinciGit00/Scrapegraph-ai.git
synced 2026-07-01 21:00:48 +08:00
fix(AbstractGraph): model selection bug
This commit is contained in:
parent
4eccc76442
commit
4f120e29c5
@ -131,15 +131,15 @@ class AbstractGraph(ABC):
|
||||
"ollama", "oneapi", "nvidia", "groq", "anthropic" "bedrock", "mistralai",
|
||||
"hugging_face", "deepseek", "ernie", "fireworks"}
|
||||
|
||||
split_model_provider = llm_params["model"].split("/")
|
||||
split_model_provider = llm_params["model"].split("/", 1)
|
||||
llm_params["model_provider"] = split_model_provider[0]
|
||||
llm_params["model"] = split_model_provider[1:]
|
||||
llm_params["model"] = split_model_provider[1]
|
||||
|
||||
if llm_params["model_provider"] not in known_providers:
|
||||
raise ValueError(f"Provider {llm_params['model_provider']} is not supported. If possible, try to use a model instance instead.")
|
||||
|
||||
try:
|
||||
self.model_token = models_tokens[llm_params["model_provider"]].get(llm_params["model"][0])
|
||||
self.model_token = models_tokens[llm_params["model_provider"]][llm_params["model"]]
|
||||
except KeyError:
|
||||
print("Model not found, using default token size (8192)")
|
||||
self.model_token = 8192
|
||||
@ -150,18 +150,21 @@ class AbstractGraph(ABC):
|
||||
warnings.simplefilter("ignore")
|
||||
return init_chat_model(**llm_params)
|
||||
else:
|
||||
if "deepseek" in llm_params["model"]:
|
||||
if llm_params["model_provider"] == "deepseek":
|
||||
return DeepSeek(**llm_params)
|
||||
|
||||
if "ernie" in llm_params["model"]:
|
||||
if llm_params["model_provider"] == "ernie":
|
||||
from langchain_community.chat_models import ErnieBotChat
|
||||
return ErnieBotChat(**llm_params)
|
||||
|
||||
if "oneapi" in llm_params["model"]:
|
||||
if llm_params["model_provider"] == "oneapi":
|
||||
return OneApi(**llm_params)
|
||||
|
||||
if "nvidia" in llm_params["model"]:
|
||||
from langchain_nvidia_ai_endpoints import ChatNVIDIA
|
||||
if llm_params["model_provider"] == "nvidia":
|
||||
try:
|
||||
from langchain_nvidia_ai_endpoints import ChatNVIDIA
|
||||
except ImportError:
|
||||
raise ImportError("The langchain_nvidia_ai_endpoints module is not installed. Please install it using `pip install langchain_nvidia_ai_endpoints`.")
|
||||
return ChatNVIDIA(**llm_params)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@ -3,29 +3,80 @@ Tests for the AbstractGraph.
|
||||
"""
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
from scrapegraphai.graphs import AbstractGraph
|
||||
from scrapegraphai.graphs import AbstractGraph, BaseGraph
|
||||
from scrapegraphai.nodes import (
|
||||
FetchNode,
|
||||
ParseNode
|
||||
)
|
||||
from scrapegraphai.models import OneApi, DeepSeek
|
||||
from langchain_openai import ChatOpenAI, AzureChatOpenAI
|
||||
from langchain_community.chat_models import ChatOllama
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
|
||||
|
||||
|
||||
class TestGraph(AbstractGraph):
|
||||
def __init__(self, prompt: str, config: dict):
|
||||
super().__init__(prompt, config)
|
||||
|
||||
def _create_graph(self) -> BaseGraph:
|
||||
fetch_node = FetchNode(
|
||||
input="url| local_dir",
|
||||
output=["doc", "link_urls", "img_urls"],
|
||||
node_config={
|
||||
"llm_model": self.llm_model,
|
||||
"force": self.config.get("force", False),
|
||||
"cut": self.config.get("cut", True),
|
||||
"loader_kwargs": self.config.get("loader_kwargs", {}),
|
||||
"browser_base": self.config.get("browser_base")
|
||||
}
|
||||
)
|
||||
parse_node = ParseNode(
|
||||
input="doc",
|
||||
output=["parsed_doc"],
|
||||
node_config={
|
||||
"chunk_size": self.model_token
|
||||
}
|
||||
)
|
||||
return BaseGraph(
|
||||
nodes=[
|
||||
fetch_node,
|
||||
parse_node
|
||||
],
|
||||
edges=[
|
||||
(fetch_node, parse_node),
|
||||
],
|
||||
entry_point=fetch_node,
|
||||
graph_name=self.__class__.__name__
|
||||
)
|
||||
|
||||
def run(self) -> str:
|
||||
inputs = {"user_prompt": self.prompt, self.input_key: self.source}
|
||||
self.final_state, self.execution_info = self.graph.execute(inputs)
|
||||
|
||||
return self.final_state.get("answer", "No answer found.")
|
||||
|
||||
|
||||
class TestAbstractGraph:
|
||||
@pytest.mark.parametrize("llm_config, expected_model", [
|
||||
({"model": "openai/gpt-3.5-turbo"}, "ChatOpenAI"),
|
||||
({"model": "azure_openai/gpt-3.5-turbo"}, "AzureChatOpenAI"),
|
||||
({"model": "google_genai/gemini-pro"}, "ChatGoogleGenerativeAI"),
|
||||
({"model": "google_vertexai/chat-bison"}, "ChatVertexAI"),
|
||||
({"model": "ollama/llama2"}, "Ollama"),
|
||||
({"model": "oneapi/text-davinci-003"}, "OneApi"),
|
||||
({"model": "nvidia/clara-instant-1-base"}, "ChatNVIDIA"),
|
||||
({"model": "deepseek/deepseek-coder-6.7b-instruct"}, "DeepSeek"),
|
||||
({"model": "ernie/ernie-bot"}, "ErnieBotChat"),
|
||||
({"model": "openai/gpt-3.5-turbo", "openai_api_key": "sk-randomtest001"}, ChatOpenAI),
|
||||
({
|
||||
"model": "azure_openai/gpt-3.5-turbo",
|
||||
"api_key": "random-api-key",
|
||||
"api_version": "no version",
|
||||
"azure_endpoint": "https://www.example.com/"},
|
||||
AzureChatOpenAI),
|
||||
({"model": "google_genai/gemini-pro", "google_api_key": "google-key-test"}, ChatGoogleGenerativeAI),
|
||||
({"model": "ollama/llama2"}, ChatOllama),
|
||||
({"model": "oneapi/qwen-turbo"}, OneApi),
|
||||
({"model": "deepseek/deepseek-coder"}, DeepSeek),
|
||||
])
|
||||
|
||||
def test_create_llm(self, llm_config, expected_model):
|
||||
graph = AbstractGraph("Test prompt", {"llm": llm_config})
|
||||
graph = TestGraph("Test prompt", {"llm": llm_config})
|
||||
assert isinstance(graph.llm_model, expected_model)
|
||||
|
||||
def test_create_llm_unknown_provider(self):
|
||||
with pytest.raises(ValueError):
|
||||
AbstractGraph("Test prompt", {"llm": {"model": "unknown_provider/model"}})
|
||||
TestGraph("Test prompt", {"llm": {"model": "unknown_provider/model"}})
|
||||
|
||||
def test_create_llm_error(self):
|
||||
with patch("your_module.init_chat_model", side_effect=Exception("Test error")):
|
||||
with pytest.raises(Exception):
|
||||
AbstractGraph("Test prompt", {"llm": {"model": "openai/gpt-3.5-turbo"}})
|
||||
|
||||
Loading…
Reference in New Issue
Block a user