fix(AbstractGraph): model selection bug

This commit is contained in:
Federico Aguzzi 2024-08-28 12:36:03 +02:00
parent 4eccc76442
commit 4f120e29c5
2 changed files with 78 additions and 24 deletions

View File

@ -131,15 +131,15 @@ class AbstractGraph(ABC):
"ollama", "oneapi", "nvidia", "groq", "anthropic" "bedrock", "mistralai", "ollama", "oneapi", "nvidia", "groq", "anthropic" "bedrock", "mistralai",
"hugging_face", "deepseek", "ernie", "fireworks"} "hugging_face", "deepseek", "ernie", "fireworks"}
split_model_provider = llm_params["model"].split("/") split_model_provider = llm_params["model"].split("/", 1)
llm_params["model_provider"] = split_model_provider[0] llm_params["model_provider"] = split_model_provider[0]
llm_params["model"] = split_model_provider[1:] llm_params["model"] = split_model_provider[1]
if llm_params["model_provider"] not in known_providers: if llm_params["model_provider"] not in known_providers:
raise ValueError(f"Provider {llm_params['model_provider']} is not supported. If possible, try to use a model instance instead.") raise ValueError(f"Provider {llm_params['model_provider']} is not supported. If possible, try to use a model instance instead.")
try: try:
self.model_token = models_tokens[llm_params["model_provider"]].get(llm_params["model"][0]) self.model_token = models_tokens[llm_params["model_provider"]][llm_params["model"]]
except KeyError: except KeyError:
print("Model not found, using default token size (8192)") print("Model not found, using default token size (8192)")
self.model_token = 8192 self.model_token = 8192
@ -150,18 +150,21 @@ class AbstractGraph(ABC):
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
return init_chat_model(**llm_params) return init_chat_model(**llm_params)
else: else:
if "deepseek" in llm_params["model"]: if llm_params["model_provider"] == "deepseek":
return DeepSeek(**llm_params) return DeepSeek(**llm_params)
if "ernie" in llm_params["model"]: if llm_params["model_provider"] == "ernie":
from langchain_community.chat_models import ErnieBotChat from langchain_community.chat_models import ErnieBotChat
return ErnieBotChat(**llm_params) return ErnieBotChat(**llm_params)
if "oneapi" in llm_params["model"]: if llm_params["model_provider"] == "oneapi":
return OneApi(**llm_params) return OneApi(**llm_params)
if "nvidia" in llm_params["model"]: if llm_params["model_provider"] == "nvidia":
from langchain_nvidia_ai_endpoints import ChatNVIDIA try:
from langchain_nvidia_ai_endpoints import ChatNVIDIA
except ImportError:
raise ImportError("The langchain_nvidia_ai_endpoints module is not installed. Please install it using `pip install langchain_nvidia_ai_endpoints`.")
return ChatNVIDIA(**llm_params) return ChatNVIDIA(**llm_params)
except Exception as e: except Exception as e:

View File

@ -3,29 +3,80 @@ Tests for the AbstractGraph.
""" """
import pytest import pytest
from unittest.mock import patch from unittest.mock import patch
from scrapegraphai.graphs import AbstractGraph from scrapegraphai.graphs import AbstractGraph, BaseGraph
from scrapegraphai.nodes import (
FetchNode,
ParseNode
)
from scrapegraphai.models import OneApi, DeepSeek
from langchain_openai import ChatOpenAI, AzureChatOpenAI
from langchain_community.chat_models import ChatOllama
from langchain_google_genai import ChatGoogleGenerativeAI
class TestGraph(AbstractGraph):
def __init__(self, prompt: str, config: dict):
super().__init__(prompt, config)
def _create_graph(self) -> BaseGraph:
fetch_node = FetchNode(
input="url| local_dir",
output=["doc", "link_urls", "img_urls"],
node_config={
"llm_model": self.llm_model,
"force": self.config.get("force", False),
"cut": self.config.get("cut", True),
"loader_kwargs": self.config.get("loader_kwargs", {}),
"browser_base": self.config.get("browser_base")
}
)
parse_node = ParseNode(
input="doc",
output=["parsed_doc"],
node_config={
"chunk_size": self.model_token
}
)
return BaseGraph(
nodes=[
fetch_node,
parse_node
],
edges=[
(fetch_node, parse_node),
],
entry_point=fetch_node,
graph_name=self.__class__.__name__
)
def run(self) -> str:
inputs = {"user_prompt": self.prompt, self.input_key: self.source}
self.final_state, self.execution_info = self.graph.execute(inputs)
return self.final_state.get("answer", "No answer found.")
class TestAbstractGraph: class TestAbstractGraph:
@pytest.mark.parametrize("llm_config, expected_model", [ @pytest.mark.parametrize("llm_config, expected_model", [
({"model": "openai/gpt-3.5-turbo"}, "ChatOpenAI"), ({"model": "openai/gpt-3.5-turbo", "openai_api_key": "sk-randomtest001"}, ChatOpenAI),
({"model": "azure_openai/gpt-3.5-turbo"}, "AzureChatOpenAI"), ({
({"model": "google_genai/gemini-pro"}, "ChatGoogleGenerativeAI"), "model": "azure_openai/gpt-3.5-turbo",
({"model": "google_vertexai/chat-bison"}, "ChatVertexAI"), "api_key": "random-api-key",
({"model": "ollama/llama2"}, "Ollama"), "api_version": "no version",
({"model": "oneapi/text-davinci-003"}, "OneApi"), "azure_endpoint": "https://www.example.com/"},
({"model": "nvidia/clara-instant-1-base"}, "ChatNVIDIA"), AzureChatOpenAI),
({"model": "deepseek/deepseek-coder-6.7b-instruct"}, "DeepSeek"), ({"model": "google_genai/gemini-pro", "google_api_key": "google-key-test"}, ChatGoogleGenerativeAI),
({"model": "ernie/ernie-bot"}, "ErnieBotChat"), ({"model": "ollama/llama2"}, ChatOllama),
({"model": "oneapi/qwen-turbo"}, OneApi),
({"model": "deepseek/deepseek-coder"}, DeepSeek),
]) ])
def test_create_llm(self, llm_config, expected_model): def test_create_llm(self, llm_config, expected_model):
graph = AbstractGraph("Test prompt", {"llm": llm_config}) graph = TestGraph("Test prompt", {"llm": llm_config})
assert isinstance(graph.llm_model, expected_model) assert isinstance(graph.llm_model, expected_model)
def test_create_llm_unknown_provider(self): def test_create_llm_unknown_provider(self):
with pytest.raises(ValueError): with pytest.raises(ValueError):
AbstractGraph("Test prompt", {"llm": {"model": "unknown_provider/model"}}) TestGraph("Test prompt", {"llm": {"model": "unknown_provider/model"}})
def test_create_llm_error(self):
with patch("your_module.init_chat_model", side_effect=Exception("Test error")):
with pytest.raises(Exception):
AbstractGraph("Test prompt", {"llm": {"model": "openai/gpt-3.5-turbo"}})