fix(AbstractGraph): correct and simplify instancing logic

This commit is contained in:
Federico Aguzzi 2024-08-27 12:09:58 +02:00
parent 22ab45f6bd
commit f73343f193
2 changed files with 37 additions and 93 deletions

View File

@ -125,103 +125,47 @@ class AbstractGraph(ABC):
self.model_token = llm_params["model_tokens"]
except KeyError as exc:
raise KeyError("model_tokens not specified") from exc
return llm_params["model_instance"]
return llm_params["model_instance"]
def handle_model(model_name, provider, token_key, default_token=8192):
try:
self.model_token = models_tokens[provider][token_key]
except KeyError:
print(f"Model not found, using default token size ({default_token})")
self.model_token = default_token
llm_params["model_provider"] = provider
llm_params["model"] = model_name
with warnings.catch_warnings():
warnings.simplefilter("ignore")
return init_chat_model(**llm_params)
known_providers = {"openai", "azure_openai", "google_genai", "google_vertexai",
"ollama", "oneapi", "nvidia", "groq", "anthropic" "bedrock", "mistralai",
"hugging_face", "deepseek", "ernie", "fireworks"}
split_model_provider = llm_params["model"].split("/")
llm_params["model_provider"] = split_model_provider[0]
llm_params["model"] = split_model_provider[1:]
known_models = {"chatgpt","gpt","openai", "azure_openai", "google_genai",
"ollama", "oneapi", "nvidia", "groq", "google_vertexai",
"bedrock", "mistralai", "hugging_face", "deepseek", "ernie",
"fireworks", "claude-3-"}
if llm_params["model"].split("/")[0] not in known_models and llm_params["model"].split("-")[0] not in known_models:
raise ValueError(f"Model '{llm_params['model']}' is not supported")
if llm_params["model_provider"] not in known_providers:
raise ValueError(f"Provider {llm_params['model_provider']} is not supported. If possible, try to use a model instance instead.")
try:
self.model_token = models_tokens[llm_params["model"]][llm_params["model"]]
except KeyError:
print("Model not found, using default token size (8192)")
self.model_token = 8192
try:
if "fireworks" in llm_params["model"]:
model_name = "/".join(llm_params["model"].split("/")[1:])
token_key = llm_params["model"].split("/")[-1]
return handle_model(model_name, "fireworks", token_key)
elif "gemini" in llm_params["model"]:
model_name = llm_params["model"].split("/")[-1]
return handle_model(model_name, "google_genai", model_name)
elif llm_params["model"].startswith("claude"):
model_name = llm_params["model"].split("/")[-1]
return handle_model(model_name, "anthropic", model_name)
elif llm_params["model"].startswith("vertexai"):
return handle_model(llm_params["model"], "google_vertexai", llm_params["model"])
elif "gpt-" in llm_params["model"]:
return handle_model(llm_params["model"], "openai", llm_params["model"])
elif "ollama" in llm_params["model"]:
model_name = llm_params["model"].split("ollama/")[-1]
token_key = model_name if "model_tokens" not in llm_params else None
model_tokens = 8192 if "model_tokens" not in llm_params else llm_params["model_tokens"]
return handle_model(model_name, "ollama", token_key, model_tokens)
elif "claude-3-" in llm_params["model"]:
return handle_model(llm_params["model"], "anthropic", "claude3")
elif llm_params["model"].startswith("mistral"):
model_name = llm_params["model"].split("/")[-1]
return handle_model(model_name, "mistralai", model_name)
elif "deepseek" in llm_params["model"]:
try:
self.model_token = models_tokens["deepseek"][llm_params["model"]]
except KeyError:
print("model not found, using default token size (8192)")
self.model_token = 8192
return DeepSeek(llm_params)
elif "ernie" in llm_params["model"]:
from langchain_community.chat_models import ErnieBotChat
try:
self.model_token = models_tokens["ernie"][llm_params["model"]]
except KeyError:
print("model not found, using default token size (8192)")
self.model_token = 8192
return ErnieBotChat(llm_params)
elif "oneapi" in llm_params["model"]:
llm_params["model"] = llm_params["model"].split("/")[-1]
try:
self.model_token = models_tokens["oneapi"][llm_params["model"]]
except KeyError:
raise KeyError("Model not supported")
return OneApi(llm_params)
elif "nvidia" in llm_params["model"]:
from langchain_nvidia_ai_endpoints import ChatNVIDIA
try:
self.model_token = models_tokens["nvidia"][llm_params["model"].split("/")[-1]]
llm_params["model"] = "/".join(llm_params["model"].split("/")[1:])
except KeyError:
raise KeyError("Model not supported")
return ChatNVIDIA(llm_params)
if llm_params["model_provider"] not in {"oneapi", "nvidia", "ernie", "deepseek"}:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
return init_chat_model(**llm_params)
else:
model_name = llm_params["model"].split("/")[-1]
return handle_model(model_name, llm_params["model"], model_name)
if "deepseek" in llm_params["model"]:
return DeepSeek(**llm_params)
except KeyError as e:
print(f"Model not supported: {e}")
if "ernie" in llm_params["model"]:
from langchain_community.chat_models import ErnieBotChat
return ErnieBotChat(**llm_params)
if "oneapi" in llm_params["model"]:
return OneApi(**llm_params)
if "nvidia" in llm_params["model"]:
from langchain_nvidia_ai_endpoints import ChatNVIDIA
return ChatNVIDIA(**llm_params)
except Exception as e:
print(f"Error instancing model: {e}")
def get_state(self, key=None) -> dict:

View File

@ -102,7 +102,7 @@ models_tokens = {
"oneapi": {
"qwen-turbo": 6000,
},
"nvdia": {
"nvidia": {
"meta/llama3-70b-instruct": 419,
"meta/llama3-8b-instruct": 419,
"nemotron-4-340b-instruct": 1024,
@ -127,7 +127,7 @@ models_tokens = {
"gemma-7b-it": 8192,
"claude-3-haiku-20240307'": 8192,
},
"claude": {
"anthropic": {
"claude_instant": 100000,
"claude2": 9000,
"claude2.1": 200000,