mirror of
https://github.com/VinciGit00/Scrapegraph-ai.git
synced 2026-07-04 21:00:36 +08:00
refactor: reuse code for common interface models
This commit is contained in:
parent
b17756d934
commit
bb73d916a1
@ -146,138 +146,61 @@ class AbstractGraph(ABC):
|
||||
raise KeyError("model_tokens not specified") from exc
|
||||
return llm_params["model_instance"]
|
||||
|
||||
# Instantiate the language model based on the model name
|
||||
if "gpt-" in llm_params["model"]:
|
||||
# Instantiate the language model based on the model name (models that use the common interface)
|
||||
def handle_model(model_name, provider, token_key, default_token=8192):
|
||||
try:
|
||||
self.model_token = models_tokens["openai"][llm_params["model"]]
|
||||
llm_params["model_provider"] = "openai"
|
||||
except KeyError as exc:
|
||||
raise KeyError("Model not supported") from exc
|
||||
self.model_token = models_tokens[provider][token_key]
|
||||
except KeyError:
|
||||
print(f"Model not found, using default token size ({default_token})")
|
||||
self.model_token = default_token
|
||||
llm_params["model_provider"] = provider
|
||||
llm_params["model"] = model_name
|
||||
return init_chat_model(**llm_params)
|
||||
|
||||
if "oneapi" in llm_params["model"]:
|
||||
# take the model after the last dash
|
||||
llm_params["model"] = llm_params["model"].split("/")[-1]
|
||||
try:
|
||||
self.model_token = models_tokens["oneapi"][llm_params["model"]]
|
||||
except KeyError as exc:
|
||||
raise KeyError("Model not supported") from exc
|
||||
return OneApi(llm_params)
|
||||
if "gpt-" in llm_params["model"]:
|
||||
return handle_model(llm_params["model"], "openai", llm_params["model"])
|
||||
|
||||
if "fireworks" in llm_params["model"]:
|
||||
try:
|
||||
self.model_token = models_tokens["fireworks"][llm_params["model"].split("/")[-1]]
|
||||
llm_params["model"] = "/".join(llm_params["model"].split("/")[1:])
|
||||
except KeyError as exc:
|
||||
raise KeyError("Model not supported") from exc
|
||||
llm_params["model_provider"] = "fireworks"
|
||||
return init_chat_model(**llm_params)
|
||||
model_name = "/".join(llm_params["model"].split("/")[1:])
|
||||
token_key = llm_params["model"].split("/")[-1]
|
||||
return handle_model(model_name, "fireworks", token_key)
|
||||
|
||||
if "azure" in llm_params["model"]:
|
||||
# take the model after the last dash
|
||||
llm_params["model"] = llm_params["model"].split("/")[-1]
|
||||
try:
|
||||
self.model_token = models_tokens["azure"][llm_params["model"]]
|
||||
except KeyError as exc:
|
||||
raise KeyError("Model not supported") from exc
|
||||
llm_params["model_provider"] = "azure_openai"
|
||||
return init_chat_model(**llm_params)
|
||||
|
||||
if "nvidia" in llm_params["model"]:
|
||||
try:
|
||||
self.model_token = models_tokens["nvidia"][llm_params["model"].split("/")[-1]]
|
||||
llm_params["model"] = "/".join(llm_params["model"].split("/")[1:])
|
||||
except KeyError as exc:
|
||||
raise KeyError("Model not supported") from exc
|
||||
return ChatNVIDIA(llm_params)
|
||||
model_name = llm_params["model"].split("/")[-1]
|
||||
return handle_model(model_name, "azure_openai", model_name)
|
||||
|
||||
if "gemini" in llm_params["model"]:
|
||||
llm_params["model"] = llm_params["model"].split("/")[-1]
|
||||
try:
|
||||
self.model_token = models_tokens["gemini"][llm_params["model"]]
|
||||
except KeyError as exc:
|
||||
raise KeyError("Model not supported") from exc
|
||||
llm_params["model_provider"] = "google_genai "
|
||||
return init_chat_model(**llm_params)
|
||||
model_name = llm_params["model"].split("/")[-1]
|
||||
return handle_model(model_name, "google_genai", model_name)
|
||||
|
||||
if llm_params["model"].startswith("claude"):
|
||||
llm_params["model"] = llm_params["model"].split("/")[-1]
|
||||
try:
|
||||
self.model_token = models_tokens["claude"][llm_params["model"]]
|
||||
except KeyError as exc:
|
||||
raise KeyError("Model not supported") from exc
|
||||
llm_params["model_provider"] = "anthropic"
|
||||
return init_chat_model(**llm_params)
|
||||
model_name = llm_params["model"].split("/")[-1]
|
||||
return handle_model(model_name, "anthropic", model_name)
|
||||
|
||||
if llm_params["model"].startswith("vertexai"):
|
||||
try:
|
||||
self.model_token = models_tokens["vertexai"][llm_params["model"]]
|
||||
except KeyError as exc:
|
||||
raise KeyError("Model not supported") from exc
|
||||
llm_params["model_provider"] = "google_vertexai"
|
||||
return init_chat_model(**llm_params)
|
||||
return handle_model(llm_params["model"], "google_vertexai", llm_params["model"])
|
||||
|
||||
if "ollama" in llm_params["model"]:
|
||||
llm_params["model"] = llm_params["model"].split("ollama/")[-1]
|
||||
llm_params["model_provider"] = "ollama"
|
||||
|
||||
# allow user to set model_tokens in config
|
||||
try:
|
||||
if "model_tokens" in llm_params:
|
||||
self.model_token = llm_params["model_tokens"]
|
||||
elif llm_params["model"] in models_tokens["ollama"]:
|
||||
try:
|
||||
self.model_token = models_tokens["ollama"][llm_params["model"]]
|
||||
except KeyError as exc:
|
||||
print("model not found, using default token size (8192)")
|
||||
self.model_token = 8192
|
||||
else:
|
||||
self.model_token = 8192
|
||||
except AttributeError:
|
||||
self.model_token = 8192
|
||||
|
||||
return init_chat_model(**llm_params)
|
||||
model_name = llm_params["model"].split("ollama/")[-1]
|
||||
token_key = model_name if "model_tokens" not in llm_params else llm_params["model_tokens"]
|
||||
return handle_model(model_name, "ollama", token_key)
|
||||
|
||||
if "hugging_face" in llm_params["model"]:
|
||||
llm_params["model"] = llm_params["model"].split("/")[-1]
|
||||
try:
|
||||
self.model_token = models_tokens["hugging_face"][llm_params["model"]]
|
||||
except KeyError:
|
||||
print("model not found, using default token size (8192)")
|
||||
self.model_token = 8192
|
||||
llm_params["model_provider"] = "hugging_face"
|
||||
return init_chat_model(**llm_params)
|
||||
model_name = llm_params["model"].split("/")[-1]
|
||||
return handle_model(model_name, "hugging_face", model_name)
|
||||
|
||||
if "groq" in llm_params["model"]:
|
||||
llm_params["model"] = llm_params["model"].split("/")[-1]
|
||||
|
||||
try:
|
||||
self.model_token = models_tokens["groq"][llm_params["model"]]
|
||||
except KeyError:
|
||||
print("model not found, using default token size (8192)")
|
||||
self.model_token = 8192
|
||||
llm_params["model_provider"] = "groq"
|
||||
return init_chat_model(**llm_params)
|
||||
model_name = llm_params["model"].split("/")[-1]
|
||||
return handle_model(model_name, "groq", model_name)
|
||||
|
||||
if "bedrock" in llm_params["model"]:
|
||||
llm_params["model"] = llm_params["model"].split("/")[-1]
|
||||
try:
|
||||
self.model_token = models_tokens["bedrock"][llm_params["model"]]
|
||||
except KeyError:
|
||||
print("model not found, using default token size (8192)")
|
||||
self.model_token = 8192
|
||||
llm_params["model_provider"] = "bedrock"
|
||||
return init_chat_model(**llm_params)
|
||||
model_name = llm_params["model"].split("/")[-1]
|
||||
return handle_model(model_name, "bedrock", model_name)
|
||||
|
||||
if "claude-3-" in llm_params["model"]:
|
||||
try:
|
||||
self.model_token = models_tokens["claude"]["claude3"]
|
||||
except KeyError:
|
||||
print("model not found, using default token size (8192)")
|
||||
self.model_token = 8192
|
||||
llm_params["model_provider"] = "anthropic"
|
||||
return init_chat_model(**llm_params)
|
||||
return handle_model(llm_params["model"], "anthropic", "claude3")
|
||||
|
||||
# Instantiate the language model based on the model name (models that do not use the common interface)
|
||||
if "deepseek" in llm_params["model"]:
|
||||
try:
|
||||
self.model_token = models_tokens["deepseek"][llm_params["model"]]
|
||||
@ -293,7 +216,25 @@ class AbstractGraph(ABC):
|
||||
print("model not found, using default token size (8192)")
|
||||
self.model_token = 8192
|
||||
return ErnieBotChat(llm_params)
|
||||
|
||||
if "oneapi" in llm_params["model"]:
|
||||
# take the model after the last dash
|
||||
llm_params["model"] = llm_params["model"].split("/")[-1]
|
||||
try:
|
||||
self.model_token = models_tokens["oneapi"][llm_params["model"]]
|
||||
except KeyError as exc:
|
||||
raise KeyError("Model not supported") from exc
|
||||
return OneApi(llm_params)
|
||||
|
||||
if "nvidia" in llm_params["model"]:
|
||||
try:
|
||||
self.model_token = models_tokens["nvidia"][llm_params["model"].split("/")[-1]]
|
||||
llm_params["model"] = "/".join(llm_params["model"].split("/")[1:])
|
||||
except KeyError as exc:
|
||||
raise KeyError("Model not supported") from exc
|
||||
return ChatNVIDIA(llm_params)
|
||||
|
||||
# Raise an error if the model did not match any of the previous cases
|
||||
raise ValueError("Model provided by the configuration not supported")
|
||||
|
||||
def _create_default_embedder(self, llm_config=None) -> object:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user