feat: update abstract graph

Co-Authored-By: Matteo Vedovati <68272450+vedovati-matteo@users.noreply.github.com>
This commit is contained in:
Marco Vinciguerra 2024-08-11 19:18:24 +02:00
parent 2333b513aa
commit c77231c983

View File

@ -136,7 +136,6 @@ class AbstractGraph(ABC):
raise KeyError("model_tokens not specified") from exc
return llm_params["model_instance"]
# Instantiate the language model based on the model name (models that use the common interface)
def handle_model(model_name, provider, token_key, default_token=8192):
try:
self.model_token = models_tokens[provider][token_key]
@ -153,51 +152,39 @@ class AbstractGraph(ABC):
model_name = llm_params["model"].split("/")[-1]
return handle_model(model_name, "azure_openai", model_name)
if "gpt-" in llm_params["model"]:
return handle_model(llm_params["model"], "openai", llm_params["model"])
if "fireworks" in llm_params["model"]:
elif "fireworks" in llm_params["model"]:
model_name = "/".join(llm_params["model"].split("/")[1:])
token_key = llm_params["model"].split("/")[-1]
return handle_model(model_name, "fireworks", token_key)
if "gemini" in llm_params["model"]:
elif "gemini" in llm_params["model"]:
model_name = llm_params["model"].split("/")[-1]
return handle_model(model_name, "google_genai", model_name)
if llm_params["model"].startswith("claude"):
elif llm_params["model"].startswith("claude"):
model_name = llm_params["model"].split("/")[-1]
return handle_model(model_name, "anthropic", model_name)
if llm_params["model"].startswith("vertexai"):
elif llm_params["model"].startswith("vertexai"):
return handle_model(llm_params["model"], "google_vertexai", llm_params["model"])
if "ollama" in llm_params["model"]:
elif "gpt-" in llm_params["model"]:
return handle_model(llm_params["model"], "openai", llm_params["model"])
elif "ollama" in llm_params["model"]:
model_name = llm_params["model"].split("ollama/")[-1]
token_key = model_name if "model_tokens" not in llm_params else llm_params["model_tokens"]
return handle_model(model_name, "ollama", token_key)
if "hugging_face" in llm_params["model"]:
model_name = llm_params["model"].split("/")[-1]
return handle_model(model_name, "hugging_face", model_name)
if "groq" in llm_params["model"]:
model_name = llm_params["model"].split("/")[-1]
return handle_model(model_name, "groq", model_name)
if "bedrock" in llm_params["model"]:
model_name = llm_params["model"].split("/")[-1]
return handle_model(model_name, "bedrock", model_name)
if "claude-3-" in llm_params["model"]:
elif "claude-3-" in llm_params["model"]:
return handle_model(llm_params["model"], "anthropic", "claude3")
if llm_params["model"].startswith("mistral"):
elif llm_params["model"].startswith("mistral"):
model_name = llm_params["model"].split("/")[-1]
return handle_model(model_name, "mistralai", model_name)
# Instantiate the language model based on the model name (models that do not use the common interface)
if "deepseek" in llm_params["model"]:
elif "deepseek" in llm_params["model"]:
try:
self.model_token = models_tokens["deepseek"][llm_params["model"]]
except KeyError:
@ -205,15 +192,15 @@ class AbstractGraph(ABC):
self.model_token = 8192
return DeepSeek(llm_params)
if "ernie" in llm_params["model"]:
elif "ernie" in llm_params["model"]:
try:
self.model_token = models_tokens["ernie"][llm_params["model"]]
except KeyError:
print("model not found, using default token size (8192)")
self.model_token = 8192
return ErnieBotChat(llm_params)
if "oneapi" in llm_params["model"]:
elif "oneapi" in llm_params["model"]:
# take the model after the last dash
llm_params["model"] = llm_params["model"].split("/")[-1]
try:
@ -221,16 +208,18 @@ class AbstractGraph(ABC):
except KeyError as exc:
raise KeyError("Model not supported") from exc
return OneApi(llm_params)
if "nvidia" in llm_params["model"]:
elif "nvidia" in llm_params["model"]:
try:
self.model_token = models_tokens["nvidia"][llm_params["model"].split("/")[-1]]
llm_params["model"] = "/".join(llm_params["model"].split("/")[1:])
except KeyError as exc:
raise KeyError("Model not supported") from exc
return ChatNVIDIA(llm_params)
else:
model_name = llm_params["model"].split("/")[-1]
return handle_model(model_name, llm_params["model"], model_name)
# Raise an error if the model did not match any of the previous cases
raise ValueError("Model provided by the configuration not supported")