diff --git a/scrapegraphai/graphs/abstract_graph.py b/scrapegraphai/graphs/abstract_graph.py index 29d0532f..bfc1848f 100644 --- a/scrapegraphai/graphs/abstract_graph.py +++ b/scrapegraphai/graphs/abstract_graph.py @@ -63,13 +63,10 @@ class AbstractGraph(ABC): self.cache_path = self.config.get("cache_path", False) self.browser_base = self.config.get("browser_base") - # Create the graph self.graph = self._create_graph() self.final_state = None self.execution_info = None - # Set common configuration parameters - verbose = bool(config and config.get("verbose")) if verbose: @@ -87,12 +84,10 @@ class AbstractGraph(ABC): self.set_common_params(common_params, overwrite=True) - # set burr config self.burr_kwargs = config.get("burr_kwargs", None) if self.burr_kwargs is not None: self.graph.use_burr = True if "app_instance_id" not in self.burr_kwargs: - # set a random uuid for the app_instance_id to avoid conflicts self.burr_kwargs["app_instance_id"] = str(uuid.uuid4()) self.graph.burr_config = self.burr_kwargs @@ -125,7 +120,6 @@ class AbstractGraph(ABC): llm_defaults = {"temperature": 0, "streaming": False} llm_params = {**llm_defaults, **llm_config} - # If model instance is passed directly instead of the model details if "model_instance" in llm_params: try: self.model_token = llm_params["model_tokens"] @@ -145,18 +139,14 @@ class AbstractGraph(ABC): warnings.simplefilter("ignore") return init_chat_model(**llm_params) - known_models = ["chatgpt","gpt","openai", "azure_openai", "google_genai", - "ollama", "oneapi", "nvidia", "groq", "google_vertexai", - "bedrock", "mistralai", "hugging_face", "deepseek", "ernie", "fireworks"] - + known_models = {"chatgpt","gpt","openai", "azure_openai", "google_genai", + "ollama", "oneapi", "nvidia", "groq", "google_vertexai", + "bedrock", "mistralai", "hugging_face", "deepseek", "ernie", "fireworks"} if llm_params["model"].split("/")[0] not in known_models and llm_params["model"].split("-")[0] not in known_models: raise ValueError(f"Model '{llm_params['model']}' is not supported") try: - if "azure" in llm_params["model"]: - model_name = llm_params["model"].split("/")[-1] - return handle_model(model_name, "azure_openai", model_name) if "fireworks" in llm_params["model"]: model_name = "/".join(llm_params["model"].split("/")[1:]) token_key = llm_params["model"].split("/")[-1] @@ -207,7 +197,6 @@ class AbstractGraph(ABC): return ErnieBotChat(llm_params) elif "oneapi" in llm_params["model"]: - # take the model after the last dash llm_params["model"] = llm_params["model"].split("/")[-1] try: self.model_token = models_tokens["oneapi"][llm_params["model"]] diff --git a/scrapegraphai/helpers/models_tokens.py b/scrapegraphai/helpers/models_tokens.py index 852a9ce2..62d46e7b 100644 --- a/scrapegraphai/helpers/models_tokens.py +++ b/scrapegraphai/helpers/models_tokens.py @@ -50,6 +50,11 @@ models_tokens = { "gemini-1.5-flash-latest": 128000, "gemini-1.5-pro-latest": 128000, "models/embedding-001": 2048 + }, + "google_vertexai": { + "gemini-1.5-flash": 128000, + "gemini-1.5-pro": 128000, + "gemini-1.0-pro": 128000, }, "ollama": { "command-r": 12800, @@ -96,7 +101,7 @@ models_tokens = { "oneapi": { "qwen-turbo": 6000, }, - "nv dia": { + "nvdia": { "meta/llama3-70b-instruct": 419, "meta/llama3-8b-instruct": 419, "nemotron-4-340b-instruct": 1024, @@ -132,11 +137,6 @@ models_tokens = { "claude-3-haiku-20240307": 200000, "claude-3-5-sonnet-20240620": 200000, }, - "google_vertexai": { - "gemini-1.5-flash": 128000, - "gemini-1.5-pro": 128000, - "gemini-1.0-pro": 128000, - }, "bedrock": { "anthropic.claude-3-haiku-20240307-v1:0": 200000, "anthropic.claude-3-sonnet-20240229-v1:0": 200000,