refactor: reuse code for common interface models

This commit is contained in:
Federico Aguzzi 2024-07-31 13:41:09 +02:00
parent b17756d934
commit bb73d916a1

View File

@ -146,138 +146,61 @@ class AbstractGraph(ABC):
raise KeyError("model_tokens not specified") from exc
return llm_params["model_instance"]
# Instantiate the language model based on the model name
if "gpt-" in llm_params["model"]:
# Instantiate the language model based on the model name (models that use the common interface)
def handle_model(model_name, provider, token_key, default_token=8192):
try:
self.model_token = models_tokens["openai"][llm_params["model"]]
llm_params["model_provider"] = "openai"
except KeyError as exc:
raise KeyError("Model not supported") from exc
self.model_token = models_tokens[provider][token_key]
except KeyError:
print(f"Model not found, using default token size ({default_token})")
self.model_token = default_token
llm_params["model_provider"] = provider
llm_params["model"] = model_name
return init_chat_model(**llm_params)
if "oneapi" in llm_params["model"]:
# take the model after the last dash
llm_params["model"] = llm_params["model"].split("/")[-1]
try:
self.model_token = models_tokens["oneapi"][llm_params["model"]]
except KeyError as exc:
raise KeyError("Model not supported") from exc
return OneApi(llm_params)
if "gpt-" in llm_params["model"]:
return handle_model(llm_params["model"], "openai", llm_params["model"])
if "fireworks" in llm_params["model"]:
try:
self.model_token = models_tokens["fireworks"][llm_params["model"].split("/")[-1]]
llm_params["model"] = "/".join(llm_params["model"].split("/")[1:])
except KeyError as exc:
raise KeyError("Model not supported") from exc
llm_params["model_provider"] = "fireworks"
return init_chat_model(**llm_params)
model_name = "/".join(llm_params["model"].split("/")[1:])
token_key = llm_params["model"].split("/")[-1]
return handle_model(model_name, "fireworks", token_key)
if "azure" in llm_params["model"]:
# take the model after the last dash
llm_params["model"] = llm_params["model"].split("/")[-1]
try:
self.model_token = models_tokens["azure"][llm_params["model"]]
except KeyError as exc:
raise KeyError("Model not supported") from exc
llm_params["model_provider"] = "azure_openai"
return init_chat_model(**llm_params)
if "nvidia" in llm_params["model"]:
try:
self.model_token = models_tokens["nvidia"][llm_params["model"].split("/")[-1]]
llm_params["model"] = "/".join(llm_params["model"].split("/")[1:])
except KeyError as exc:
raise KeyError("Model not supported") from exc
return ChatNVIDIA(llm_params)
model_name = llm_params["model"].split("/")[-1]
return handle_model(model_name, "azure_openai", model_name)
if "gemini" in llm_params["model"]:
llm_params["model"] = llm_params["model"].split("/")[-1]
try:
self.model_token = models_tokens["gemini"][llm_params["model"]]
except KeyError as exc:
raise KeyError("Model not supported") from exc
llm_params["model_provider"] = "google_genai "
return init_chat_model(**llm_params)
model_name = llm_params["model"].split("/")[-1]
return handle_model(model_name, "google_genai", model_name)
if llm_params["model"].startswith("claude"):
llm_params["model"] = llm_params["model"].split("/")[-1]
try:
self.model_token = models_tokens["claude"][llm_params["model"]]
except KeyError as exc:
raise KeyError("Model not supported") from exc
llm_params["model_provider"] = "anthropic"
return init_chat_model(**llm_params)
model_name = llm_params["model"].split("/")[-1]
return handle_model(model_name, "anthropic", model_name)
if llm_params["model"].startswith("vertexai"):
try:
self.model_token = models_tokens["vertexai"][llm_params["model"]]
except KeyError as exc:
raise KeyError("Model not supported") from exc
llm_params["model_provider"] = "google_vertexai"
return init_chat_model(**llm_params)
return handle_model(llm_params["model"], "google_vertexai", llm_params["model"])
if "ollama" in llm_params["model"]:
llm_params["model"] = llm_params["model"].split("ollama/")[-1]
llm_params["model_provider"] = "ollama"
# allow user to set model_tokens in config
try:
if "model_tokens" in llm_params:
self.model_token = llm_params["model_tokens"]
elif llm_params["model"] in models_tokens["ollama"]:
try:
self.model_token = models_tokens["ollama"][llm_params["model"]]
except KeyError as exc:
print("model not found, using default token size (8192)")
self.model_token = 8192
else:
self.model_token = 8192
except AttributeError:
self.model_token = 8192
return init_chat_model(**llm_params)
model_name = llm_params["model"].split("ollama/")[-1]
token_key = model_name if "model_tokens" not in llm_params else llm_params["model_tokens"]
return handle_model(model_name, "ollama", token_key)
if "hugging_face" in llm_params["model"]:
llm_params["model"] = llm_params["model"].split("/")[-1]
try:
self.model_token = models_tokens["hugging_face"][llm_params["model"]]
except KeyError:
print("model not found, using default token size (8192)")
self.model_token = 8192
llm_params["model_provider"] = "hugging_face"
return init_chat_model(**llm_params)
model_name = llm_params["model"].split("/")[-1]
return handle_model(model_name, "hugging_face", model_name)
if "groq" in llm_params["model"]:
llm_params["model"] = llm_params["model"].split("/")[-1]
try:
self.model_token = models_tokens["groq"][llm_params["model"]]
except KeyError:
print("model not found, using default token size (8192)")
self.model_token = 8192
llm_params["model_provider"] = "groq"
return init_chat_model(**llm_params)
model_name = llm_params["model"].split("/")[-1]
return handle_model(model_name, "groq", model_name)
if "bedrock" in llm_params["model"]:
llm_params["model"] = llm_params["model"].split("/")[-1]
try:
self.model_token = models_tokens["bedrock"][llm_params["model"]]
except KeyError:
print("model not found, using default token size (8192)")
self.model_token = 8192
llm_params["model_provider"] = "bedrock"
return init_chat_model(**llm_params)
model_name = llm_params["model"].split("/")[-1]
return handle_model(model_name, "bedrock", model_name)
if "claude-3-" in llm_params["model"]:
try:
self.model_token = models_tokens["claude"]["claude3"]
except KeyError:
print("model not found, using default token size (8192)")
self.model_token = 8192
llm_params["model_provider"] = "anthropic"
return init_chat_model(**llm_params)
return handle_model(llm_params["model"], "anthropic", "claude3")
# Instantiate the language model based on the model name (models that do not use the common interface)
if "deepseek" in llm_params["model"]:
try:
self.model_token = models_tokens["deepseek"][llm_params["model"]]
@ -293,7 +216,25 @@ class AbstractGraph(ABC):
print("model not found, using default token size (8192)")
self.model_token = 8192
return ErnieBotChat(llm_params)
if "oneapi" in llm_params["model"]:
# take the model after the last dash
llm_params["model"] = llm_params["model"].split("/")[-1]
try:
self.model_token = models_tokens["oneapi"][llm_params["model"]]
except KeyError as exc:
raise KeyError("Model not supported") from exc
return OneApi(llm_params)
if "nvidia" in llm_params["model"]:
try:
self.model_token = models_tokens["nvidia"][llm_params["model"].split("/")[-1]]
llm_params["model"] = "/".join(llm_params["model"].split("/")[1:])
except KeyError as exc:
raise KeyError("Model not supported") from exc
return ChatNVIDIA(llm_params)
# Raise an error if the model did not match any of the previous cases
raise ValueError("Model provided by the configuration not supported")
def _create_default_embedder(self, llm_config=None) -> object: