From c77231c983bd6e154eefd26422cd156da4c8b7bb Mon Sep 17 00:00:00 2001 From: Marco Vinciguerra Date: Sun, 11 Aug 2024 19:18:24 +0200 Subject: [PATCH] feat: update abstract graph Co-Authored-By: Matteo Vedovati <68272450+vedovati-matteo@users.noreply.github.com> --- scrapegraphai/graphs/abstract_graph.py | 51 ++++++++++---------------- 1 file changed, 20 insertions(+), 31 deletions(-) diff --git a/scrapegraphai/graphs/abstract_graph.py b/scrapegraphai/graphs/abstract_graph.py index 83b532bc..41e9c9b9 100644 --- a/scrapegraphai/graphs/abstract_graph.py +++ b/scrapegraphai/graphs/abstract_graph.py @@ -136,7 +136,6 @@ class AbstractGraph(ABC): raise KeyError("model_tokens not specified") from exc return llm_params["model_instance"] - # Instantiate the language model based on the model name (models that use the common interface) def handle_model(model_name, provider, token_key, default_token=8192): try: self.model_token = models_tokens[provider][token_key] @@ -153,51 +152,39 @@ class AbstractGraph(ABC): model_name = llm_params["model"].split("/")[-1] return handle_model(model_name, "azure_openai", model_name) - if "gpt-" in llm_params["model"]: - return handle_model(llm_params["model"], "openai", llm_params["model"]) - - if "fireworks" in llm_params["model"]: + elif "fireworks" in llm_params["model"]: model_name = "/".join(llm_params["model"].split("/")[1:]) token_key = llm_params["model"].split("/")[-1] return handle_model(model_name, "fireworks", token_key) - if "gemini" in llm_params["model"]: + elif "gemini" in llm_params["model"]: model_name = llm_params["model"].split("/")[-1] return handle_model(model_name, "google_genai", model_name) - if llm_params["model"].startswith("claude"): + elif llm_params["model"].startswith("claude"): model_name = llm_params["model"].split("/")[-1] return handle_model(model_name, "anthropic", model_name) - if llm_params["model"].startswith("vertexai"): + elif llm_params["model"].startswith("vertexai"): return handle_model(llm_params["model"], "google_vertexai", llm_params["model"]) - if "ollama" in llm_params["model"]: + elif "gpt-" in llm_params["model"]: + return handle_model(llm_params["model"], "openai", llm_params["model"]) + + elif "ollama" in llm_params["model"]: model_name = llm_params["model"].split("ollama/")[-1] token_key = model_name if "model_tokens" not in llm_params else llm_params["model_tokens"] return handle_model(model_name, "ollama", token_key) - if "hugging_face" in llm_params["model"]: - model_name = llm_params["model"].split("/")[-1] - return handle_model(model_name, "hugging_face", model_name) - - if "groq" in llm_params["model"]: - model_name = llm_params["model"].split("/")[-1] - return handle_model(model_name, "groq", model_name) - - if "bedrock" in llm_params["model"]: - model_name = llm_params["model"].split("/")[-1] - return handle_model(model_name, "bedrock", model_name) - - if "claude-3-" in llm_params["model"]: + elif "claude-3-" in llm_params["model"]: return handle_model(llm_params["model"], "anthropic", "claude3") - - if llm_params["model"].startswith("mistral"): + + elif llm_params["model"].startswith("mistral"): model_name = llm_params["model"].split("/")[-1] return handle_model(model_name, "mistralai", model_name) # Instantiate the language model based on the model name (models that do not use the common interface) - if "deepseek" in llm_params["model"]: + elif "deepseek" in llm_params["model"]: try: self.model_token = models_tokens["deepseek"][llm_params["model"]] except KeyError: @@ -205,15 +192,15 @@ class AbstractGraph(ABC): self.model_token = 8192 return DeepSeek(llm_params) - if "ernie" in llm_params["model"]: + elif "ernie" in llm_params["model"]: try: self.model_token = models_tokens["ernie"][llm_params["model"]] except KeyError: print("model not found, using default token size (8192)") self.model_token = 8192 return ErnieBotChat(llm_params) - - if "oneapi" in llm_params["model"]: + + elif "oneapi" in llm_params["model"]: # take the model after the last dash llm_params["model"] = llm_params["model"].split("/")[-1] try: @@ -221,16 +208,18 @@ class AbstractGraph(ABC): except KeyError as exc: raise KeyError("Model not supported") from exc return OneApi(llm_params) - - if "nvidia" in llm_params["model"]: + + elif "nvidia" in llm_params["model"]: try: self.model_token = models_tokens["nvidia"][llm_params["model"].split("/")[-1]] llm_params["model"] = "/".join(llm_params["model"].split("/")[1:]) except KeyError as exc: raise KeyError("Model not supported") from exc return ChatNVIDIA(llm_params) + else: + model_name = llm_params["model"].split("/")[-1] + return handle_model(model_name, llm_params["model"], model_name) - # Raise an error if the model did not match any of the previous cases raise ValueError("Model provided by the configuration not supported")