mirror of
https://github.com/VinciGit00/Scrapegraph-ai.git
synced 2026-06-28 21:01:55 +08:00
fix: try to infer possible provider from the model name, resolves #805
This commit is contained in:
parent
777a68554e
commit
d2d0312dc6
@ -144,9 +144,18 @@ class AbstractGraph(ABC):
|
||||
"ollama", "oneapi", "nvidia", "groq", "anthropic", "bedrock", "mistralai",
|
||||
"hugging_face", "deepseek", "ernie", "fireworks", "togetherai"}
|
||||
|
||||
split_model_provider = llm_params["model"].split("/", 1)
|
||||
llm_params["model_provider"] = split_model_provider[0]
|
||||
llm_params["model"] = split_model_provider[1]
|
||||
if '/' in llm_params["model"]:
|
||||
split_model_provider = llm_params["model"].split("/", 1)
|
||||
llm_params["model_provider"] = split_model_provider[0]
|
||||
llm_params["model"] = split_model_provider[1]
|
||||
else:
|
||||
possible_providers = [provider for provider, models_d in models_tokens.items() if llm_params["model"] in models_d]
|
||||
if len(possible_providers) <= 0:
|
||||
raise ValueError(f"""Provider {llm_params['model_provider']} is not supported.
|
||||
If possible, try to use a model instance instead.""")
|
||||
llm_params["model_provider"] = possible_providers[0]
|
||||
print((f"Found providers {possible_providers} for model {llm_params['model']}, using {llm_params['model_provider']}.\n"
|
||||
"If it was not intended please specify the model provider in the graph configuration"))
|
||||
|
||||
if llm_params["model_provider"] not in known_providers:
|
||||
raise ValueError(f"""Provider {llm_params['model_provider']} is not supported.
|
||||
|
||||
Loading…
Reference in New Issue
Block a user