mirror of
https://github.com/VinciGit00/Scrapegraph-ai.git
synced 2026-07-04 21:00:36 +08:00
refactor: reuse code for common interface models
This commit is contained in:
parent
b17756d934
commit
bb73d916a1
@ -146,138 +146,61 @@ class AbstractGraph(ABC):
|
|||||||
raise KeyError("model_tokens not specified") from exc
|
raise KeyError("model_tokens not specified") from exc
|
||||||
return llm_params["model_instance"]
|
return llm_params["model_instance"]
|
||||||
|
|
||||||
# Instantiate the language model based on the model name
|
# Instantiate the language model based on the model name (models that use the common interface)
|
||||||
if "gpt-" in llm_params["model"]:
|
def handle_model(model_name, provider, token_key, default_token=8192):
|
||||||
try:
|
try:
|
||||||
self.model_token = models_tokens["openai"][llm_params["model"]]
|
self.model_token = models_tokens[provider][token_key]
|
||||||
llm_params["model_provider"] = "openai"
|
except KeyError:
|
||||||
except KeyError as exc:
|
print(f"Model not found, using default token size ({default_token})")
|
||||||
raise KeyError("Model not supported") from exc
|
self.model_token = default_token
|
||||||
|
llm_params["model_provider"] = provider
|
||||||
|
llm_params["model"] = model_name
|
||||||
return init_chat_model(**llm_params)
|
return init_chat_model(**llm_params)
|
||||||
|
|
||||||
if "oneapi" in llm_params["model"]:
|
if "gpt-" in llm_params["model"]:
|
||||||
# take the model after the last dash
|
return handle_model(llm_params["model"], "openai", llm_params["model"])
|
||||||
llm_params["model"] = llm_params["model"].split("/")[-1]
|
|
||||||
try:
|
|
||||||
self.model_token = models_tokens["oneapi"][llm_params["model"]]
|
|
||||||
except KeyError as exc:
|
|
||||||
raise KeyError("Model not supported") from exc
|
|
||||||
return OneApi(llm_params)
|
|
||||||
|
|
||||||
if "fireworks" in llm_params["model"]:
|
if "fireworks" in llm_params["model"]:
|
||||||
try:
|
model_name = "/".join(llm_params["model"].split("/")[1:])
|
||||||
self.model_token = models_tokens["fireworks"][llm_params["model"].split("/")[-1]]
|
token_key = llm_params["model"].split("/")[-1]
|
||||||
llm_params["model"] = "/".join(llm_params["model"].split("/")[1:])
|
return handle_model(model_name, "fireworks", token_key)
|
||||||
except KeyError as exc:
|
|
||||||
raise KeyError("Model not supported") from exc
|
|
||||||
llm_params["model_provider"] = "fireworks"
|
|
||||||
return init_chat_model(**llm_params)
|
|
||||||
|
|
||||||
if "azure" in llm_params["model"]:
|
if "azure" in llm_params["model"]:
|
||||||
# take the model after the last dash
|
model_name = llm_params["model"].split("/")[-1]
|
||||||
llm_params["model"] = llm_params["model"].split("/")[-1]
|
return handle_model(model_name, "azure_openai", model_name)
|
||||||
try:
|
|
||||||
self.model_token = models_tokens["azure"][llm_params["model"]]
|
|
||||||
except KeyError as exc:
|
|
||||||
raise KeyError("Model not supported") from exc
|
|
||||||
llm_params["model_provider"] = "azure_openai"
|
|
||||||
return init_chat_model(**llm_params)
|
|
||||||
|
|
||||||
if "nvidia" in llm_params["model"]:
|
|
||||||
try:
|
|
||||||
self.model_token = models_tokens["nvidia"][llm_params["model"].split("/")[-1]]
|
|
||||||
llm_params["model"] = "/".join(llm_params["model"].split("/")[1:])
|
|
||||||
except KeyError as exc:
|
|
||||||
raise KeyError("Model not supported") from exc
|
|
||||||
return ChatNVIDIA(llm_params)
|
|
||||||
|
|
||||||
if "gemini" in llm_params["model"]:
|
if "gemini" in llm_params["model"]:
|
||||||
llm_params["model"] = llm_params["model"].split("/")[-1]
|
model_name = llm_params["model"].split("/")[-1]
|
||||||
try:
|
return handle_model(model_name, "google_genai", model_name)
|
||||||
self.model_token = models_tokens["gemini"][llm_params["model"]]
|
|
||||||
except KeyError as exc:
|
|
||||||
raise KeyError("Model not supported") from exc
|
|
||||||
llm_params["model_provider"] = "google_genai "
|
|
||||||
return init_chat_model(**llm_params)
|
|
||||||
|
|
||||||
if llm_params["model"].startswith("claude"):
|
if llm_params["model"].startswith("claude"):
|
||||||
llm_params["model"] = llm_params["model"].split("/")[-1]
|
model_name = llm_params["model"].split("/")[-1]
|
||||||
try:
|
return handle_model(model_name, "anthropic", model_name)
|
||||||
self.model_token = models_tokens["claude"][llm_params["model"]]
|
|
||||||
except KeyError as exc:
|
|
||||||
raise KeyError("Model not supported") from exc
|
|
||||||
llm_params["model_provider"] = "anthropic"
|
|
||||||
return init_chat_model(**llm_params)
|
|
||||||
|
|
||||||
if llm_params["model"].startswith("vertexai"):
|
if llm_params["model"].startswith("vertexai"):
|
||||||
try:
|
return handle_model(llm_params["model"], "google_vertexai", llm_params["model"])
|
||||||
self.model_token = models_tokens["vertexai"][llm_params["model"]]
|
|
||||||
except KeyError as exc:
|
|
||||||
raise KeyError("Model not supported") from exc
|
|
||||||
llm_params["model_provider"] = "google_vertexai"
|
|
||||||
return init_chat_model(**llm_params)
|
|
||||||
|
|
||||||
if "ollama" in llm_params["model"]:
|
if "ollama" in llm_params["model"]:
|
||||||
llm_params["model"] = llm_params["model"].split("ollama/")[-1]
|
model_name = llm_params["model"].split("ollama/")[-1]
|
||||||
llm_params["model_provider"] = "ollama"
|
token_key = model_name if "model_tokens" not in llm_params else llm_params["model_tokens"]
|
||||||
|
return handle_model(model_name, "ollama", token_key)
|
||||||
# allow user to set model_tokens in config
|
|
||||||
try:
|
|
||||||
if "model_tokens" in llm_params:
|
|
||||||
self.model_token = llm_params["model_tokens"]
|
|
||||||
elif llm_params["model"] in models_tokens["ollama"]:
|
|
||||||
try:
|
|
||||||
self.model_token = models_tokens["ollama"][llm_params["model"]]
|
|
||||||
except KeyError as exc:
|
|
||||||
print("model not found, using default token size (8192)")
|
|
||||||
self.model_token = 8192
|
|
||||||
else:
|
|
||||||
self.model_token = 8192
|
|
||||||
except AttributeError:
|
|
||||||
self.model_token = 8192
|
|
||||||
|
|
||||||
return init_chat_model(**llm_params)
|
|
||||||
|
|
||||||
if "hugging_face" in llm_params["model"]:
|
if "hugging_face" in llm_params["model"]:
|
||||||
llm_params["model"] = llm_params["model"].split("/")[-1]
|
model_name = llm_params["model"].split("/")[-1]
|
||||||
try:
|
return handle_model(model_name, "hugging_face", model_name)
|
||||||
self.model_token = models_tokens["hugging_face"][llm_params["model"]]
|
|
||||||
except KeyError:
|
|
||||||
print("model not found, using default token size (8192)")
|
|
||||||
self.model_token = 8192
|
|
||||||
llm_params["model_provider"] = "hugging_face"
|
|
||||||
return init_chat_model(**llm_params)
|
|
||||||
|
|
||||||
if "groq" in llm_params["model"]:
|
if "groq" in llm_params["model"]:
|
||||||
llm_params["model"] = llm_params["model"].split("/")[-1]
|
model_name = llm_params["model"].split("/")[-1]
|
||||||
|
return handle_model(model_name, "groq", model_name)
|
||||||
try:
|
|
||||||
self.model_token = models_tokens["groq"][llm_params["model"]]
|
|
||||||
except KeyError:
|
|
||||||
print("model not found, using default token size (8192)")
|
|
||||||
self.model_token = 8192
|
|
||||||
llm_params["model_provider"] = "groq"
|
|
||||||
return init_chat_model(**llm_params)
|
|
||||||
|
|
||||||
if "bedrock" in llm_params["model"]:
|
if "bedrock" in llm_params["model"]:
|
||||||
llm_params["model"] = llm_params["model"].split("/")[-1]
|
model_name = llm_params["model"].split("/")[-1]
|
||||||
try:
|
return handle_model(model_name, "bedrock", model_name)
|
||||||
self.model_token = models_tokens["bedrock"][llm_params["model"]]
|
|
||||||
except KeyError:
|
|
||||||
print("model not found, using default token size (8192)")
|
|
||||||
self.model_token = 8192
|
|
||||||
llm_params["model_provider"] = "bedrock"
|
|
||||||
return init_chat_model(**llm_params)
|
|
||||||
|
|
||||||
if "claude-3-" in llm_params["model"]:
|
if "claude-3-" in llm_params["model"]:
|
||||||
try:
|
return handle_model(llm_params["model"], "anthropic", "claude3")
|
||||||
self.model_token = models_tokens["claude"]["claude3"]
|
|
||||||
except KeyError:
|
|
||||||
print("model not found, using default token size (8192)")
|
|
||||||
self.model_token = 8192
|
|
||||||
llm_params["model_provider"] = "anthropic"
|
|
||||||
return init_chat_model(**llm_params)
|
|
||||||
|
|
||||||
|
# Instantiate the language model based on the model name (models that do not use the common interface)
|
||||||
if "deepseek" in llm_params["model"]:
|
if "deepseek" in llm_params["model"]:
|
||||||
try:
|
try:
|
||||||
self.model_token = models_tokens["deepseek"][llm_params["model"]]
|
self.model_token = models_tokens["deepseek"][llm_params["model"]]
|
||||||
@ -294,6 +217,24 @@ class AbstractGraph(ABC):
|
|||||||
self.model_token = 8192
|
self.model_token = 8192
|
||||||
return ErnieBotChat(llm_params)
|
return ErnieBotChat(llm_params)
|
||||||
|
|
||||||
|
if "oneapi" in llm_params["model"]:
|
||||||
|
# take the model after the last dash
|
||||||
|
llm_params["model"] = llm_params["model"].split("/")[-1]
|
||||||
|
try:
|
||||||
|
self.model_token = models_tokens["oneapi"][llm_params["model"]]
|
||||||
|
except KeyError as exc:
|
||||||
|
raise KeyError("Model not supported") from exc
|
||||||
|
return OneApi(llm_params)
|
||||||
|
|
||||||
|
if "nvidia" in llm_params["model"]:
|
||||||
|
try:
|
||||||
|
self.model_token = models_tokens["nvidia"][llm_params["model"].split("/")[-1]]
|
||||||
|
llm_params["model"] = "/".join(llm_params["model"].split("/")[1:])
|
||||||
|
except KeyError as exc:
|
||||||
|
raise KeyError("Model not supported") from exc
|
||||||
|
return ChatNVIDIA(llm_params)
|
||||||
|
|
||||||
|
# Raise an error if the model did not match any of the previous cases
|
||||||
raise ValueError("Model provided by the configuration not supported")
|
raise ValueError("Model provided by the configuration not supported")
|
||||||
|
|
||||||
def _create_default_embedder(self, llm_config=None) -> object:
|
def _create_default_embedder(self, llm_config=None) -> object:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user