fix: try to infer possible provider from the model name, resolves #805

This commit is contained in:
Michele_Zenoni 2024-11-17 20:04:30 +01:00
parent 777a68554e
commit d2d0312dc6

View File

@ -144,9 +144,18 @@ class AbstractGraph(ABC):
"ollama", "oneapi", "nvidia", "groq", "anthropic", "bedrock", "mistralai",
"hugging_face", "deepseek", "ernie", "fireworks", "togetherai"}
split_model_provider = llm_params["model"].split("/", 1)
llm_params["model_provider"] = split_model_provider[0]
llm_params["model"] = split_model_provider[1]
if '/' in llm_params["model"]:
split_model_provider = llm_params["model"].split("/", 1)
llm_params["model_provider"] = split_model_provider[0]
llm_params["model"] = split_model_provider[1]
else:
possible_providers = [provider for provider, models_d in models_tokens.items() if llm_params["model"] in models_d]
if len(possible_providers) <= 0:
raise ValueError(f"""Provider {llm_params['model_provider']} is not supported.
If possible, try to use a model instance instead.""")
llm_params["model_provider"] = possible_providers[0]
print((f"Found providers {possible_providers} for model {llm_params['model']}, using {llm_params['model_provider']}.\n"
"If it was not intended please specify the model provider in the graph configuration"))
if llm_params["model_provider"] not in known_providers:
raise ValueError(f"""Provider {llm_params['model_provider']} is not supported.