From c807695720a85c74a0b4365afb397bbbcd7e2889 Mon Sep 17 00:00:00 2001 From: Federico Minutoli Date: Fri, 24 May 2024 01:09:03 +0200 Subject: [PATCH] feat(verbose): centralized graph logging on debug or warning depending on verbose --- scrapegraphai/graphs/abstract_graph.py | 121 ++++++++++++++----------- 1 file changed, 69 insertions(+), 52 deletions(-) diff --git a/scrapegraphai/graphs/abstract_graph.py b/scrapegraphai/graphs/abstract_graph.py index 33942956..839af910 100644 --- a/scrapegraphai/graphs/abstract_graph.py +++ b/scrapegraphai/graphs/abstract_graph.py @@ -1,16 +1,29 @@ """ AbstractGraph Module """ + from abc import ABC, abstractmethod from typing import Optional + from langchain_aws import BedrockEmbeddings -from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings from langchain_community.embeddings import HuggingFaceHubEmbeddings, OllamaEmbeddings from langchain_google_genai import GoogleGenerativeAIEmbeddings -from ..helpers import models_tokens -from ..utils.logging import set_verbosity -from ..models import AzureOpenAI, Bedrock, Gemini, Groq, HuggingFace, Ollama, OpenAI, Anthropic from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings +from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings + +from ..helpers import models_tokens +from ..models import ( + Anthropic, + AzureOpenAI, + Bedrock, + Gemini, + Groq, + HuggingFace, + Ollama, + OpenAI, +) +from ..utils.logging import set_verbosity_debug, set_verbosity_warning + class AbstractGraph(ABC): """ @@ -46,9 +59,11 @@ class AbstractGraph(ABC): self.source = source self.config = config self.llm_model = self._create_llm(config["llm"], chat=True) - self.embedder_model = self._create_default_embedder(llm_config=config["llm"] - ) if "embeddings" not in config else self._create_embedder( - config["embeddings"]) + self.embedder_model = ( + self._create_default_embedder(llm_config=config["llm"]) + if "embeddings" not in config + else self._create_embedder(config["embeddings"]) + ) # Create the graph self.graph = self._create_graph() @@ -56,19 +71,23 @@ class AbstractGraph(ABC): self.execution_info = None # Set common configuration parameters - - verbose = False if config is None else config.get( - "verbose", False) - set_verbosity(config.get("verbose", "info")) - self.headless = True if config is None else config.get( - "headless", True) + + verbose = bool(config and config.get("verbose")) + + if verbose: + set_verbosity_debug() + else: + set_verbosity_warning() + + self.headless = True if config is None else config.get("headless", True) self.loader_kwargs = config.get("loader_kwargs", {}) - common_params = {"headless": self.headless, - - "loader_kwargs": self.loader_kwargs, - "llm_model": self.llm_model, - "embedder_model": self.embedder_model} + common_params = { + "headless": self.headless, + "loader_kwargs": self.loader_kwargs, + "llm_model": self.llm_model, + "embedder_model": self.embedder_model, + } self.set_common_params(common_params, overwrite=False) def set_common_params(self, params: dict, overwrite=False): @@ -81,25 +100,25 @@ class AbstractGraph(ABC): for node in self.graph.nodes: node.update_config(params, overwrite) - + def _set_model_token(self, llm): - if 'Azure' in str(type(llm)): + if "Azure" in str(type(llm)): try: self.model_token = models_tokens["azure"][llm.model_name] except KeyError: raise KeyError("Model not supported") - elif 'HuggingFaceEndpoint' in str(type(llm)): - if 'mistral' in llm.repo_id: + elif "HuggingFaceEndpoint" in str(type(llm)): + if "mistral" in llm.repo_id: try: - self.model_token = models_tokens['mistral'][llm.repo_id] + self.model_token = models_tokens["mistral"][llm.repo_id] except KeyError: raise KeyError("Model not supported") - elif 'Google' in str(type(llm)): + elif "Google" in str(type(llm)): try: - if 'gemini' in llm.model: - self.model_token = models_tokens['gemini'][llm.model] + if "gemini" in llm.model: + self.model_token = models_tokens["gemini"][llm.model] except KeyError: raise KeyError("Model not supported") @@ -117,17 +136,14 @@ class AbstractGraph(ABC): KeyError: If the model is not supported. """ - llm_defaults = { - "temperature": 0, - "streaming": False - } + llm_defaults = {"temperature": 0, "streaming": False} llm_params = {**llm_defaults, **llm_config} # If model instance is passed directly instead of the model details - if 'model_instance' in llm_params: + if "model_instance" in llm_params: if chat: - self._set_model_token(llm_params['model_instance']) - return llm_params['model_instance'] + self._set_model_token(llm_params["model_instance"]) + return llm_params["model_instance"] # Instantiate the language model based on the model name if "gpt-" in llm_params["model"]: @@ -193,18 +209,20 @@ class AbstractGraph(ABC): elif "bedrock" in llm_params["model"]: llm_params["model"] = llm_params["model"].split("/")[-1] model_id = llm_params["model"] - client = llm_params.get('client', None) + client = llm_params.get("client", None) try: self.model_token = models_tokens["bedrock"][llm_params["model"]] except KeyError as exc: raise KeyError("Model not supported") from exc - return Bedrock({ - "client": client, - "model_id": model_id, - "model_kwargs": { - "temperature": llm_params["temperature"], + return Bedrock( + { + "client": client, + "model_id": model_id, + "model_kwargs": { + "temperature": llm_params["temperature"], + }, } - }) + ) elif "claude-3-" in llm_params["model"]: self.model_token = models_tokens["claude"]["claude3"] return Anthropic(llm_params) @@ -215,8 +233,7 @@ class AbstractGraph(ABC): raise KeyError("Model not supported") from exc return DeepSeek(llm_params) else: - raise ValueError( - "Model provided by the configuration not supported") + raise ValueError("Model provided by the configuration not supported") def _create_default_embedder(self, llm_config=None) -> object: """ @@ -229,8 +246,9 @@ class AbstractGraph(ABC): ValueError: If the model is not supported. """ if isinstance(self.llm_model, Gemini): - return GoogleGenerativeAIEmbeddings(google_api_key=llm_config['api_key'], - model="models/embedding-001") + return GoogleGenerativeAIEmbeddings( + google_api_key=llm_config["api_key"], model="models/embedding-001" + ) if isinstance(self.llm_model, OpenAI): return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key) elif isinstance(self.llm_model, AzureOpenAIEmbeddings): @@ -265,8 +283,8 @@ class AbstractGraph(ABC): Raises: KeyError: If the model is not supported. """ - if 'model_instance' in embedder_config: - return embedder_config['model_instance'] + if "model_instance" in embedder_config: + return embedder_config["model_instance"] # Instantiate the embedding model based on the model name if "openai" in embedder_config["model"]: return OpenAIEmbeddings(api_key=embedder_config["api_key"]) @@ -283,28 +301,27 @@ class AbstractGraph(ABC): try: models_tokens["hugging_face"][embedder_config["model"]] except KeyError as exc: - raise KeyError("Model not supported")from exc + raise KeyError("Model not supported") from exc return HuggingFaceHubEmbeddings(model=embedder_config["model"]) elif "gemini" in embedder_config["model"]: try: models_tokens["gemini"][embedder_config["model"]] except KeyError as exc: - raise KeyError("Model not supported")from exc + raise KeyError("Model not supported") from exc return GoogleGenerativeAIEmbeddings(model=embedder_config["model"]) elif "bedrock" in embedder_config["model"]: embedder_config["model"] = embedder_config["model"].split("/")[-1] - client = embedder_config.get('client', None) + client = embedder_config.get("client", None) try: models_tokens["bedrock"][embedder_config["model"]] except KeyError as exc: raise KeyError("Model not supported") from exc return BedrockEmbeddings(client=client, model_id=embedder_config["model"]) else: - raise ValueError( - "Model provided by the configuration not supported") + raise ValueError("Model provided by the configuration not supported") def get_state(self, key=None) -> dict: - """"" + """ "" Get the final state of the graph. Args: