Update abstract_graph.py

This commit is contained in:
Marco Vinciguerra 2024-08-27 18:08:53 +02:00
parent cf73883451
commit df70b4f75d

View File

@ -125,19 +125,19 @@ class AbstractGraph(ABC):
self.model_token = llm_params["model_tokens"]
except KeyError as exc:
raise KeyError("model_tokens not specified") from exc
return llm_params["model_instance"]
return llm_params["model_instance"]
known_providers = {"openai", "azure_openai", "google_genai", "google_vertexai",
"ollama", "oneapi", "nvidia", "groq", "anthropic" "bedrock", "mistralai",
"hugging_face", "deepseek", "ernie", "fireworks"}
split_model_provider = llm_params["model"].split("/")
llm_params["model_provider"] = split_model_provider[0]
llm_params["model"] = split_model_provider[1:]
if llm_params["model_provider"] not in known_providers:
raise ValueError(f"Provider {llm_params['model_provider']} is not supported. If possible, try to use a model instance instead.")
try:
self.model_token = models_tokens[llm_params["model_provider"]].get(llm_params["model"][0])
except KeyError: