mirror of
https://github.com/VinciGit00/Scrapegraph-ai.git
synced 2026-07-01 21:00:48 +08:00
feat: update abstract graph
Co-Authored-By: Matteo Vedovati <68272450+vedovati-matteo@users.noreply.github.com>
This commit is contained in:
parent
2333b513aa
commit
c77231c983
@ -136,7 +136,6 @@ class AbstractGraph(ABC):
|
||||
raise KeyError("model_tokens not specified") from exc
|
||||
return llm_params["model_instance"]
|
||||
|
||||
# Instantiate the language model based on the model name (models that use the common interface)
|
||||
def handle_model(model_name, provider, token_key, default_token=8192):
|
||||
try:
|
||||
self.model_token = models_tokens[provider][token_key]
|
||||
@ -153,51 +152,39 @@ class AbstractGraph(ABC):
|
||||
model_name = llm_params["model"].split("/")[-1]
|
||||
return handle_model(model_name, "azure_openai", model_name)
|
||||
|
||||
if "gpt-" in llm_params["model"]:
|
||||
return handle_model(llm_params["model"], "openai", llm_params["model"])
|
||||
|
||||
if "fireworks" in llm_params["model"]:
|
||||
elif "fireworks" in llm_params["model"]:
|
||||
model_name = "/".join(llm_params["model"].split("/")[1:])
|
||||
token_key = llm_params["model"].split("/")[-1]
|
||||
return handle_model(model_name, "fireworks", token_key)
|
||||
|
||||
if "gemini" in llm_params["model"]:
|
||||
elif "gemini" in llm_params["model"]:
|
||||
model_name = llm_params["model"].split("/")[-1]
|
||||
return handle_model(model_name, "google_genai", model_name)
|
||||
|
||||
if llm_params["model"].startswith("claude"):
|
||||
elif llm_params["model"].startswith("claude"):
|
||||
model_name = llm_params["model"].split("/")[-1]
|
||||
return handle_model(model_name, "anthropic", model_name)
|
||||
|
||||
if llm_params["model"].startswith("vertexai"):
|
||||
elif llm_params["model"].startswith("vertexai"):
|
||||
return handle_model(llm_params["model"], "google_vertexai", llm_params["model"])
|
||||
|
||||
if "ollama" in llm_params["model"]:
|
||||
elif "gpt-" in llm_params["model"]:
|
||||
return handle_model(llm_params["model"], "openai", llm_params["model"])
|
||||
|
||||
elif "ollama" in llm_params["model"]:
|
||||
model_name = llm_params["model"].split("ollama/")[-1]
|
||||
token_key = model_name if "model_tokens" not in llm_params else llm_params["model_tokens"]
|
||||
return handle_model(model_name, "ollama", token_key)
|
||||
|
||||
if "hugging_face" in llm_params["model"]:
|
||||
model_name = llm_params["model"].split("/")[-1]
|
||||
return handle_model(model_name, "hugging_face", model_name)
|
||||
|
||||
if "groq" in llm_params["model"]:
|
||||
model_name = llm_params["model"].split("/")[-1]
|
||||
return handle_model(model_name, "groq", model_name)
|
||||
|
||||
if "bedrock" in llm_params["model"]:
|
||||
model_name = llm_params["model"].split("/")[-1]
|
||||
return handle_model(model_name, "bedrock", model_name)
|
||||
|
||||
if "claude-3-" in llm_params["model"]:
|
||||
elif "claude-3-" in llm_params["model"]:
|
||||
return handle_model(llm_params["model"], "anthropic", "claude3")
|
||||
|
||||
if llm_params["model"].startswith("mistral"):
|
||||
|
||||
elif llm_params["model"].startswith("mistral"):
|
||||
model_name = llm_params["model"].split("/")[-1]
|
||||
return handle_model(model_name, "mistralai", model_name)
|
||||
|
||||
# Instantiate the language model based on the model name (models that do not use the common interface)
|
||||
if "deepseek" in llm_params["model"]:
|
||||
elif "deepseek" in llm_params["model"]:
|
||||
try:
|
||||
self.model_token = models_tokens["deepseek"][llm_params["model"]]
|
||||
except KeyError:
|
||||
@ -205,15 +192,15 @@ class AbstractGraph(ABC):
|
||||
self.model_token = 8192
|
||||
return DeepSeek(llm_params)
|
||||
|
||||
if "ernie" in llm_params["model"]:
|
||||
elif "ernie" in llm_params["model"]:
|
||||
try:
|
||||
self.model_token = models_tokens["ernie"][llm_params["model"]]
|
||||
except KeyError:
|
||||
print("model not found, using default token size (8192)")
|
||||
self.model_token = 8192
|
||||
return ErnieBotChat(llm_params)
|
||||
|
||||
if "oneapi" in llm_params["model"]:
|
||||
|
||||
elif "oneapi" in llm_params["model"]:
|
||||
# take the model after the last dash
|
||||
llm_params["model"] = llm_params["model"].split("/")[-1]
|
||||
try:
|
||||
@ -221,16 +208,18 @@ class AbstractGraph(ABC):
|
||||
except KeyError as exc:
|
||||
raise KeyError("Model not supported") from exc
|
||||
return OneApi(llm_params)
|
||||
|
||||
if "nvidia" in llm_params["model"]:
|
||||
|
||||
elif "nvidia" in llm_params["model"]:
|
||||
try:
|
||||
self.model_token = models_tokens["nvidia"][llm_params["model"].split("/")[-1]]
|
||||
llm_params["model"] = "/".join(llm_params["model"].split("/")[1:])
|
||||
except KeyError as exc:
|
||||
raise KeyError("Model not supported") from exc
|
||||
return ChatNVIDIA(llm_params)
|
||||
else:
|
||||
model_name = llm_params["model"].split("/")[-1]
|
||||
return handle_model(model_name, llm_params["model"], model_name)
|
||||
|
||||
# Raise an error if the model did not match any of the previous cases
|
||||
raise ValueError("Model provided by the configuration not supported")
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user