refactor: reuse code for common interface models

This commit is contained in:
Federico Aguzzi 2024-07-31 13:41:09 +02:00
parent b17756d934
commit bb73d916a1

View File

@ -146,138 +146,61 @@ class AbstractGraph(ABC):
raise KeyError("model_tokens not specified") from exc raise KeyError("model_tokens not specified") from exc
return llm_params["model_instance"] return llm_params["model_instance"]
# Instantiate the language model based on the model name # Instantiate the language model based on the model name (models that use the common interface)
if "gpt-" in llm_params["model"]: def handle_model(model_name, provider, token_key, default_token=8192):
try: try:
self.model_token = models_tokens["openai"][llm_params["model"]] self.model_token = models_tokens[provider][token_key]
llm_params["model_provider"] = "openai" except KeyError:
except KeyError as exc: print(f"Model not found, using default token size ({default_token})")
raise KeyError("Model not supported") from exc self.model_token = default_token
llm_params["model_provider"] = provider
llm_params["model"] = model_name
return init_chat_model(**llm_params) return init_chat_model(**llm_params)
if "oneapi" in llm_params["model"]: if "gpt-" in llm_params["model"]:
# take the model after the last dash return handle_model(llm_params["model"], "openai", llm_params["model"])
llm_params["model"] = llm_params["model"].split("/")[-1]
try:
self.model_token = models_tokens["oneapi"][llm_params["model"]]
except KeyError as exc:
raise KeyError("Model not supported") from exc
return OneApi(llm_params)
if "fireworks" in llm_params["model"]: if "fireworks" in llm_params["model"]:
try: model_name = "/".join(llm_params["model"].split("/")[1:])
self.model_token = models_tokens["fireworks"][llm_params["model"].split("/")[-1]] token_key = llm_params["model"].split("/")[-1]
llm_params["model"] = "/".join(llm_params["model"].split("/")[1:]) return handle_model(model_name, "fireworks", token_key)
except KeyError as exc:
raise KeyError("Model not supported") from exc
llm_params["model_provider"] = "fireworks"
return init_chat_model(**llm_params)
if "azure" in llm_params["model"]: if "azure" in llm_params["model"]:
# take the model after the last dash model_name = llm_params["model"].split("/")[-1]
llm_params["model"] = llm_params["model"].split("/")[-1] return handle_model(model_name, "azure_openai", model_name)
try:
self.model_token = models_tokens["azure"][llm_params["model"]]
except KeyError as exc:
raise KeyError("Model not supported") from exc
llm_params["model_provider"] = "azure_openai"
return init_chat_model(**llm_params)
if "nvidia" in llm_params["model"]:
try:
self.model_token = models_tokens["nvidia"][llm_params["model"].split("/")[-1]]
llm_params["model"] = "/".join(llm_params["model"].split("/")[1:])
except KeyError as exc:
raise KeyError("Model not supported") from exc
return ChatNVIDIA(llm_params)
if "gemini" in llm_params["model"]: if "gemini" in llm_params["model"]:
llm_params["model"] = llm_params["model"].split("/")[-1] model_name = llm_params["model"].split("/")[-1]
try: return handle_model(model_name, "google_genai", model_name)
self.model_token = models_tokens["gemini"][llm_params["model"]]
except KeyError as exc:
raise KeyError("Model not supported") from exc
llm_params["model_provider"] = "google_genai "
return init_chat_model(**llm_params)
if llm_params["model"].startswith("claude"): if llm_params["model"].startswith("claude"):
llm_params["model"] = llm_params["model"].split("/")[-1] model_name = llm_params["model"].split("/")[-1]
try: return handle_model(model_name, "anthropic", model_name)
self.model_token = models_tokens["claude"][llm_params["model"]]
except KeyError as exc:
raise KeyError("Model not supported") from exc
llm_params["model_provider"] = "anthropic"
return init_chat_model(**llm_params)
if llm_params["model"].startswith("vertexai"): if llm_params["model"].startswith("vertexai"):
try: return handle_model(llm_params["model"], "google_vertexai", llm_params["model"])
self.model_token = models_tokens["vertexai"][llm_params["model"]]
except KeyError as exc:
raise KeyError("Model not supported") from exc
llm_params["model_provider"] = "google_vertexai"
return init_chat_model(**llm_params)
if "ollama" in llm_params["model"]: if "ollama" in llm_params["model"]:
llm_params["model"] = llm_params["model"].split("ollama/")[-1] model_name = llm_params["model"].split("ollama/")[-1]
llm_params["model_provider"] = "ollama" token_key = model_name if "model_tokens" not in llm_params else llm_params["model_tokens"]
return handle_model(model_name, "ollama", token_key)
# allow user to set model_tokens in config
try:
if "model_tokens" in llm_params:
self.model_token = llm_params["model_tokens"]
elif llm_params["model"] in models_tokens["ollama"]:
try:
self.model_token = models_tokens["ollama"][llm_params["model"]]
except KeyError as exc:
print("model not found, using default token size (8192)")
self.model_token = 8192
else:
self.model_token = 8192
except AttributeError:
self.model_token = 8192
return init_chat_model(**llm_params)
if "hugging_face" in llm_params["model"]: if "hugging_face" in llm_params["model"]:
llm_params["model"] = llm_params["model"].split("/")[-1] model_name = llm_params["model"].split("/")[-1]
try: return handle_model(model_name, "hugging_face", model_name)
self.model_token = models_tokens["hugging_face"][llm_params["model"]]
except KeyError:
print("model not found, using default token size (8192)")
self.model_token = 8192
llm_params["model_provider"] = "hugging_face"
return init_chat_model(**llm_params)
if "groq" in llm_params["model"]: if "groq" in llm_params["model"]:
llm_params["model"] = llm_params["model"].split("/")[-1] model_name = llm_params["model"].split("/")[-1]
return handle_model(model_name, "groq", model_name)
try:
self.model_token = models_tokens["groq"][llm_params["model"]]
except KeyError:
print("model not found, using default token size (8192)")
self.model_token = 8192
llm_params["model_provider"] = "groq"
return init_chat_model(**llm_params)
if "bedrock" in llm_params["model"]: if "bedrock" in llm_params["model"]:
llm_params["model"] = llm_params["model"].split("/")[-1] model_name = llm_params["model"].split("/")[-1]
try: return handle_model(model_name, "bedrock", model_name)
self.model_token = models_tokens["bedrock"][llm_params["model"]]
except KeyError:
print("model not found, using default token size (8192)")
self.model_token = 8192
llm_params["model_provider"] = "bedrock"
return init_chat_model(**llm_params)
if "claude-3-" in llm_params["model"]: if "claude-3-" in llm_params["model"]:
try: return handle_model(llm_params["model"], "anthropic", "claude3")
self.model_token = models_tokens["claude"]["claude3"]
except KeyError:
print("model not found, using default token size (8192)")
self.model_token = 8192
llm_params["model_provider"] = "anthropic"
return init_chat_model(**llm_params)
# Instantiate the language model based on the model name (models that do not use the common interface)
if "deepseek" in llm_params["model"]: if "deepseek" in llm_params["model"]:
try: try:
self.model_token = models_tokens["deepseek"][llm_params["model"]] self.model_token = models_tokens["deepseek"][llm_params["model"]]
@ -293,7 +216,25 @@ class AbstractGraph(ABC):
print("model not found, using default token size (8192)") print("model not found, using default token size (8192)")
self.model_token = 8192 self.model_token = 8192
return ErnieBotChat(llm_params) return ErnieBotChat(llm_params)
if "oneapi" in llm_params["model"]:
# take the model after the last dash
llm_params["model"] = llm_params["model"].split("/")[-1]
try:
self.model_token = models_tokens["oneapi"][llm_params["model"]]
except KeyError as exc:
raise KeyError("Model not supported") from exc
return OneApi(llm_params)
if "nvidia" in llm_params["model"]:
try:
self.model_token = models_tokens["nvidia"][llm_params["model"].split("/")[-1]]
llm_params["model"] = "/".join(llm_params["model"].split("/")[1:])
except KeyError as exc:
raise KeyError("Model not supported") from exc
return ChatNVIDIA(llm_params)
# Raise an error if the model did not match any of the previous cases
raise ValueError("Model provided by the configuration not supported") raise ValueError("Model provided by the configuration not supported")
def _create_default_embedder(self, llm_config=None) -> object: def _create_default_embedder(self, llm_config=None) -> object: