mirror of
https://github.com/VinciGit00/Scrapegraph-ai.git
synced 2026-07-01 21:00:48 +08:00
fix(AbstractGraph): correct and simplify instancing logic
This commit is contained in:
parent
22ab45f6bd
commit
f73343f193
@ -125,103 +125,47 @@ class AbstractGraph(ABC):
|
|||||||
self.model_token = llm_params["model_tokens"]
|
self.model_token = llm_params["model_tokens"]
|
||||||
except KeyError as exc:
|
except KeyError as exc:
|
||||||
raise KeyError("model_tokens not specified") from exc
|
raise KeyError("model_tokens not specified") from exc
|
||||||
return llm_params["model_instance"]
|
return llm_params["model_instance"]
|
||||||
|
|
||||||
def handle_model(model_name, provider, token_key, default_token=8192):
|
known_providers = {"openai", "azure_openai", "google_genai", "google_vertexai",
|
||||||
try:
|
"ollama", "oneapi", "nvidia", "groq", "anthropic" "bedrock", "mistralai",
|
||||||
self.model_token = models_tokens[provider][token_key]
|
"hugging_face", "deepseek", "ernie", "fireworks"}
|
||||||
except KeyError:
|
|
||||||
print(f"Model not found, using default token size ({default_token})")
|
split_model_provider = llm_params["model"].split("/")
|
||||||
self.model_token = default_token
|
llm_params["model_provider"] = split_model_provider[0]
|
||||||
llm_params["model_provider"] = provider
|
llm_params["model"] = split_model_provider[1:]
|
||||||
llm_params["model"] = model_name
|
|
||||||
with warnings.catch_warnings():
|
|
||||||
warnings.simplefilter("ignore")
|
|
||||||
return init_chat_model(**llm_params)
|
|
||||||
|
|
||||||
known_models = {"chatgpt","gpt","openai", "azure_openai", "google_genai",
|
if llm_params["model_provider"] not in known_providers:
|
||||||
"ollama", "oneapi", "nvidia", "groq", "google_vertexai",
|
raise ValueError(f"Provider {llm_params['model_provider']} is not supported. If possible, try to use a model instance instead.")
|
||||||
"bedrock", "mistralai", "hugging_face", "deepseek", "ernie",
|
|
||||||
"fireworks", "claude-3-"}
|
try:
|
||||||
|
self.model_token = models_tokens[llm_params["model"]][llm_params["model"]]
|
||||||
if llm_params["model"].split("/")[0] not in known_models and llm_params["model"].split("-")[0] not in known_models:
|
except KeyError:
|
||||||
raise ValueError(f"Model '{llm_params['model']}' is not supported")
|
print("Model not found, using default token size (8192)")
|
||||||
|
self.model_token = 8192
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if "fireworks" in llm_params["model"]:
|
if llm_params["model_provider"] not in {"oneapi", "nvidia", "ernie", "deepseek"}:
|
||||||
model_name = "/".join(llm_params["model"].split("/")[1:])
|
with warnings.catch_warnings():
|
||||||
token_key = llm_params["model"].split("/")[-1]
|
warnings.simplefilter("ignore")
|
||||||
return handle_model(model_name, "fireworks", token_key)
|
return init_chat_model(**llm_params)
|
||||||
|
|
||||||
elif "gemini" in llm_params["model"]:
|
|
||||||
model_name = llm_params["model"].split("/")[-1]
|
|
||||||
return handle_model(model_name, "google_genai", model_name)
|
|
||||||
|
|
||||||
elif llm_params["model"].startswith("claude"):
|
|
||||||
model_name = llm_params["model"].split("/")[-1]
|
|
||||||
return handle_model(model_name, "anthropic", model_name)
|
|
||||||
|
|
||||||
elif llm_params["model"].startswith("vertexai"):
|
|
||||||
return handle_model(llm_params["model"], "google_vertexai", llm_params["model"])
|
|
||||||
|
|
||||||
elif "gpt-" in llm_params["model"]:
|
|
||||||
return handle_model(llm_params["model"], "openai", llm_params["model"])
|
|
||||||
|
|
||||||
elif "ollama" in llm_params["model"]:
|
|
||||||
model_name = llm_params["model"].split("ollama/")[-1]
|
|
||||||
token_key = model_name if "model_tokens" not in llm_params else None
|
|
||||||
model_tokens = 8192 if "model_tokens" not in llm_params else llm_params["model_tokens"]
|
|
||||||
return handle_model(model_name, "ollama", token_key, model_tokens)
|
|
||||||
|
|
||||||
elif "claude-3-" in llm_params["model"]:
|
|
||||||
return handle_model(llm_params["model"], "anthropic", "claude3")
|
|
||||||
|
|
||||||
elif llm_params["model"].startswith("mistral"):
|
|
||||||
model_name = llm_params["model"].split("/")[-1]
|
|
||||||
return handle_model(model_name, "mistralai", model_name)
|
|
||||||
|
|
||||||
elif "deepseek" in llm_params["model"]:
|
|
||||||
try:
|
|
||||||
self.model_token = models_tokens["deepseek"][llm_params["model"]]
|
|
||||||
except KeyError:
|
|
||||||
print("model not found, using default token size (8192)")
|
|
||||||
self.model_token = 8192
|
|
||||||
return DeepSeek(llm_params)
|
|
||||||
|
|
||||||
elif "ernie" in llm_params["model"]:
|
|
||||||
from langchain_community.chat_models import ErnieBotChat
|
|
||||||
|
|
||||||
try:
|
|
||||||
self.model_token = models_tokens["ernie"][llm_params["model"]]
|
|
||||||
except KeyError:
|
|
||||||
print("model not found, using default token size (8192)")
|
|
||||||
self.model_token = 8192
|
|
||||||
return ErnieBotChat(llm_params)
|
|
||||||
|
|
||||||
elif "oneapi" in llm_params["model"]:
|
|
||||||
llm_params["model"] = llm_params["model"].split("/")[-1]
|
|
||||||
try:
|
|
||||||
self.model_token = models_tokens["oneapi"][llm_params["model"]]
|
|
||||||
except KeyError:
|
|
||||||
raise KeyError("Model not supported")
|
|
||||||
return OneApi(llm_params)
|
|
||||||
|
|
||||||
elif "nvidia" in llm_params["model"]:
|
|
||||||
from langchain_nvidia_ai_endpoints import ChatNVIDIA
|
|
||||||
|
|
||||||
try:
|
|
||||||
self.model_token = models_tokens["nvidia"][llm_params["model"].split("/")[-1]]
|
|
||||||
llm_params["model"] = "/".join(llm_params["model"].split("/")[1:])
|
|
||||||
except KeyError:
|
|
||||||
raise KeyError("Model not supported")
|
|
||||||
return ChatNVIDIA(llm_params)
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
model_name = llm_params["model"].split("/")[-1]
|
if "deepseek" in llm_params["model"]:
|
||||||
return handle_model(model_name, llm_params["model"], model_name)
|
return DeepSeek(**llm_params)
|
||||||
|
|
||||||
except KeyError as e:
|
if "ernie" in llm_params["model"]:
|
||||||
print(f"Model not supported: {e}")
|
from langchain_community.chat_models import ErnieBotChat
|
||||||
|
return ErnieBotChat(**llm_params)
|
||||||
|
|
||||||
|
if "oneapi" in llm_params["model"]:
|
||||||
|
return OneApi(**llm_params)
|
||||||
|
|
||||||
|
if "nvidia" in llm_params["model"]:
|
||||||
|
from langchain_nvidia_ai_endpoints import ChatNVIDIA
|
||||||
|
return ChatNVIDIA(**llm_params)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error instancing model: {e}")
|
||||||
|
|
||||||
|
|
||||||
def get_state(self, key=None) -> dict:
|
def get_state(self, key=None) -> dict:
|
||||||
|
|||||||
@ -102,7 +102,7 @@ models_tokens = {
|
|||||||
"oneapi": {
|
"oneapi": {
|
||||||
"qwen-turbo": 6000,
|
"qwen-turbo": 6000,
|
||||||
},
|
},
|
||||||
"nvdia": {
|
"nvidia": {
|
||||||
"meta/llama3-70b-instruct": 419,
|
"meta/llama3-70b-instruct": 419,
|
||||||
"meta/llama3-8b-instruct": 419,
|
"meta/llama3-8b-instruct": 419,
|
||||||
"nemotron-4-340b-instruct": 1024,
|
"nemotron-4-340b-instruct": 1024,
|
||||||
@ -127,7 +127,7 @@ models_tokens = {
|
|||||||
"gemma-7b-it": 8192,
|
"gemma-7b-it": 8192,
|
||||||
"claude-3-haiku-20240307'": 8192,
|
"claude-3-haiku-20240307'": 8192,
|
||||||
},
|
},
|
||||||
"claude": {
|
"anthropic": {
|
||||||
"claude_instant": 100000,
|
"claude_instant": 100000,
|
||||||
"claude2": 9000,
|
"claude2": 9000,
|
||||||
"claude2.1": 200000,
|
"claude2.1": 200000,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user