diff --git a/scrapegraphai/graphs/abstract_graph.py b/scrapegraphai/graphs/abstract_graph.py index b022607c..306901e8 100644 --- a/scrapegraphai/graphs/abstract_graph.py +++ b/scrapegraphai/graphs/abstract_graph.py @@ -146,138 +146,61 @@ class AbstractGraph(ABC): raise KeyError("model_tokens not specified") from exc return llm_params["model_instance"] - # Instantiate the language model based on the model name - if "gpt-" in llm_params["model"]: + # Instantiate the language model based on the model name (models that use the common interface) + def handle_model(model_name, provider, token_key, default_token=8192): try: - self.model_token = models_tokens["openai"][llm_params["model"]] - llm_params["model_provider"] = "openai" - except KeyError as exc: - raise KeyError("Model not supported") from exc + self.model_token = models_tokens[provider][token_key] + except KeyError: + print(f"Model not found, using default token size ({default_token})") + self.model_token = default_token + llm_params["model_provider"] = provider + llm_params["model"] = model_name return init_chat_model(**llm_params) - if "oneapi" in llm_params["model"]: - # take the model after the last dash - llm_params["model"] = llm_params["model"].split("/")[-1] - try: - self.model_token = models_tokens["oneapi"][llm_params["model"]] - except KeyError as exc: - raise KeyError("Model not supported") from exc - return OneApi(llm_params) + if "gpt-" in llm_params["model"]: + return handle_model(llm_params["model"], "openai", llm_params["model"]) if "fireworks" in llm_params["model"]: - try: - self.model_token = models_tokens["fireworks"][llm_params["model"].split("/")[-1]] - llm_params["model"] = "/".join(llm_params["model"].split("/")[1:]) - except KeyError as exc: - raise KeyError("Model not supported") from exc - llm_params["model_provider"] = "fireworks" - return init_chat_model(**llm_params) + model_name = "/".join(llm_params["model"].split("/")[1:]) + token_key = llm_params["model"].split("/")[-1] + return handle_model(model_name, "fireworks", token_key) if "azure" in llm_params["model"]: - # take the model after the last dash - llm_params["model"] = llm_params["model"].split("/")[-1] - try: - self.model_token = models_tokens["azure"][llm_params["model"]] - except KeyError as exc: - raise KeyError("Model not supported") from exc - llm_params["model_provider"] = "azure_openai" - return init_chat_model(**llm_params) - - if "nvidia" in llm_params["model"]: - try: - self.model_token = models_tokens["nvidia"][llm_params["model"].split("/")[-1]] - llm_params["model"] = "/".join(llm_params["model"].split("/")[1:]) - except KeyError as exc: - raise KeyError("Model not supported") from exc - return ChatNVIDIA(llm_params) + model_name = llm_params["model"].split("/")[-1] + return handle_model(model_name, "azure_openai", model_name) if "gemini" in llm_params["model"]: - llm_params["model"] = llm_params["model"].split("/")[-1] - try: - self.model_token = models_tokens["gemini"][llm_params["model"]] - except KeyError as exc: - raise KeyError("Model not supported") from exc - llm_params["model_provider"] = "google_genai " - return init_chat_model(**llm_params) + model_name = llm_params["model"].split("/")[-1] + return handle_model(model_name, "google_genai", model_name) if llm_params["model"].startswith("claude"): - llm_params["model"] = llm_params["model"].split("/")[-1] - try: - self.model_token = models_tokens["claude"][llm_params["model"]] - except KeyError as exc: - raise KeyError("Model not supported") from exc - llm_params["model_provider"] = "anthropic" - return init_chat_model(**llm_params) + model_name = llm_params["model"].split("/")[-1] + return handle_model(model_name, "anthropic", model_name) if llm_params["model"].startswith("vertexai"): - try: - self.model_token = models_tokens["vertexai"][llm_params["model"]] - except KeyError as exc: - raise KeyError("Model not supported") from exc - llm_params["model_provider"] = "google_vertexai" - return init_chat_model(**llm_params) + return handle_model(llm_params["model"], "google_vertexai", llm_params["model"]) if "ollama" in llm_params["model"]: - llm_params["model"] = llm_params["model"].split("ollama/")[-1] - llm_params["model_provider"] = "ollama" - - # allow user to set model_tokens in config - try: - if "model_tokens" in llm_params: - self.model_token = llm_params["model_tokens"] - elif llm_params["model"] in models_tokens["ollama"]: - try: - self.model_token = models_tokens["ollama"][llm_params["model"]] - except KeyError as exc: - print("model not found, using default token size (8192)") - self.model_token = 8192 - else: - self.model_token = 8192 - except AttributeError: - self.model_token = 8192 - - return init_chat_model(**llm_params) + model_name = llm_params["model"].split("ollama/")[-1] + token_key = model_name if "model_tokens" not in llm_params else llm_params["model_tokens"] + return handle_model(model_name, "ollama", token_key) if "hugging_face" in llm_params["model"]: - llm_params["model"] = llm_params["model"].split("/")[-1] - try: - self.model_token = models_tokens["hugging_face"][llm_params["model"]] - except KeyError: - print("model not found, using default token size (8192)") - self.model_token = 8192 - llm_params["model_provider"] = "hugging_face" - return init_chat_model(**llm_params) + model_name = llm_params["model"].split("/")[-1] + return handle_model(model_name, "hugging_face", model_name) if "groq" in llm_params["model"]: - llm_params["model"] = llm_params["model"].split("/")[-1] - - try: - self.model_token = models_tokens["groq"][llm_params["model"]] - except KeyError: - print("model not found, using default token size (8192)") - self.model_token = 8192 - llm_params["model_provider"] = "groq" - return init_chat_model(**llm_params) + model_name = llm_params["model"].split("/")[-1] + return handle_model(model_name, "groq", model_name) if "bedrock" in llm_params["model"]: - llm_params["model"] = llm_params["model"].split("/")[-1] - try: - self.model_token = models_tokens["bedrock"][llm_params["model"]] - except KeyError: - print("model not found, using default token size (8192)") - self.model_token = 8192 - llm_params["model_provider"] = "bedrock" - return init_chat_model(**llm_params) + model_name = llm_params["model"].split("/")[-1] + return handle_model(model_name, "bedrock", model_name) if "claude-3-" in llm_params["model"]: - try: - self.model_token = models_tokens["claude"]["claude3"] - except KeyError: - print("model not found, using default token size (8192)") - self.model_token = 8192 - llm_params["model_provider"] = "anthropic" - return init_chat_model(**llm_params) + return handle_model(llm_params["model"], "anthropic", "claude3") + # Instantiate the language model based on the model name (models that do not use the common interface) if "deepseek" in llm_params["model"]: try: self.model_token = models_tokens["deepseek"][llm_params["model"]] @@ -293,7 +216,25 @@ class AbstractGraph(ABC): print("model not found, using default token size (8192)") self.model_token = 8192 return ErnieBotChat(llm_params) + + if "oneapi" in llm_params["model"]: + # take the model after the last dash + llm_params["model"] = llm_params["model"].split("/")[-1] + try: + self.model_token = models_tokens["oneapi"][llm_params["model"]] + except KeyError as exc: + raise KeyError("Model not supported") from exc + return OneApi(llm_params) + + if "nvidia" in llm_params["model"]: + try: + self.model_token = models_tokens["nvidia"][llm_params["model"].split("/")[-1]] + llm_params["model"] = "/".join(llm_params["model"].split("/")[1:]) + except KeyError as exc: + raise KeyError("Model not supported") from exc + return ChatNVIDIA(llm_params) + # Raise an error if the model did not match any of the previous cases raise ValueError("Model provided by the configuration not supported") def _create_default_embedder(self, llm_config=None) -> object: