feat(verbose): centralized graph logging on debug or warning depending on verbose

This commit is contained in:
Federico Minutoli 2024-05-24 01:09:03 +02:00
parent 0790ecd208
commit c807695720

View File

@ -1,16 +1,29 @@
"""
AbstractGraph Module
"""
from abc import ABC, abstractmethod
from typing import Optional
from langchain_aws import BedrockEmbeddings
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
from langchain_community.embeddings import HuggingFaceHubEmbeddings, OllamaEmbeddings
from langchain_google_genai import GoogleGenerativeAIEmbeddings
from ..helpers import models_tokens
from ..utils.logging import set_verbosity
from ..models import AzureOpenAI, Bedrock, Gemini, Groq, HuggingFace, Ollama, OpenAI, Anthropic
from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
from ..helpers import models_tokens
from ..models import (
Anthropic,
AzureOpenAI,
Bedrock,
Gemini,
Groq,
HuggingFace,
Ollama,
OpenAI,
)
from ..utils.logging import set_verbosity_debug, set_verbosity_warning
class AbstractGraph(ABC):
"""
@ -46,9 +59,11 @@ class AbstractGraph(ABC):
self.source = source
self.config = config
self.llm_model = self._create_llm(config["llm"], chat=True)
self.embedder_model = self._create_default_embedder(llm_config=config["llm"]
) if "embeddings" not in config else self._create_embedder(
config["embeddings"])
self.embedder_model = (
self._create_default_embedder(llm_config=config["llm"])
if "embeddings" not in config
else self._create_embedder(config["embeddings"])
)
# Create the graph
self.graph = self._create_graph()
@ -56,19 +71,23 @@ class AbstractGraph(ABC):
self.execution_info = None
# Set common configuration parameters
verbose = False if config is None else config.get(
"verbose", False)
set_verbosity(config.get("verbose", "info"))
self.headless = True if config is None else config.get(
"headless", True)
verbose = bool(config and config.get("verbose"))
if verbose:
set_verbosity_debug()
else:
set_verbosity_warning()
self.headless = True if config is None else config.get("headless", True)
self.loader_kwargs = config.get("loader_kwargs", {})
common_params = {"headless": self.headless,
"loader_kwargs": self.loader_kwargs,
"llm_model": self.llm_model,
"embedder_model": self.embedder_model}
common_params = {
"headless": self.headless,
"loader_kwargs": self.loader_kwargs,
"llm_model": self.llm_model,
"embedder_model": self.embedder_model,
}
self.set_common_params(common_params, overwrite=False)
def set_common_params(self, params: dict, overwrite=False):
@ -81,25 +100,25 @@ class AbstractGraph(ABC):
for node in self.graph.nodes:
node.update_config(params, overwrite)
def _set_model_token(self, llm):
if 'Azure' in str(type(llm)):
if "Azure" in str(type(llm)):
try:
self.model_token = models_tokens["azure"][llm.model_name]
except KeyError:
raise KeyError("Model not supported")
elif 'HuggingFaceEndpoint' in str(type(llm)):
if 'mistral' in llm.repo_id:
elif "HuggingFaceEndpoint" in str(type(llm)):
if "mistral" in llm.repo_id:
try:
self.model_token = models_tokens['mistral'][llm.repo_id]
self.model_token = models_tokens["mistral"][llm.repo_id]
except KeyError:
raise KeyError("Model not supported")
elif 'Google' in str(type(llm)):
elif "Google" in str(type(llm)):
try:
if 'gemini' in llm.model:
self.model_token = models_tokens['gemini'][llm.model]
if "gemini" in llm.model:
self.model_token = models_tokens["gemini"][llm.model]
except KeyError:
raise KeyError("Model not supported")
@ -117,17 +136,14 @@ class AbstractGraph(ABC):
KeyError: If the model is not supported.
"""
llm_defaults = {
"temperature": 0,
"streaming": False
}
llm_defaults = {"temperature": 0, "streaming": False}
llm_params = {**llm_defaults, **llm_config}
# If model instance is passed directly instead of the model details
if 'model_instance' in llm_params:
if "model_instance" in llm_params:
if chat:
self._set_model_token(llm_params['model_instance'])
return llm_params['model_instance']
self._set_model_token(llm_params["model_instance"])
return llm_params["model_instance"]
# Instantiate the language model based on the model name
if "gpt-" in llm_params["model"]:
@ -193,18 +209,20 @@ class AbstractGraph(ABC):
elif "bedrock" in llm_params["model"]:
llm_params["model"] = llm_params["model"].split("/")[-1]
model_id = llm_params["model"]
client = llm_params.get('client', None)
client = llm_params.get("client", None)
try:
self.model_token = models_tokens["bedrock"][llm_params["model"]]
except KeyError as exc:
raise KeyError("Model not supported") from exc
return Bedrock({
"client": client,
"model_id": model_id,
"model_kwargs": {
"temperature": llm_params["temperature"],
return Bedrock(
{
"client": client,
"model_id": model_id,
"model_kwargs": {
"temperature": llm_params["temperature"],
},
}
})
)
elif "claude-3-" in llm_params["model"]:
self.model_token = models_tokens["claude"]["claude3"]
return Anthropic(llm_params)
@ -215,8 +233,7 @@ class AbstractGraph(ABC):
raise KeyError("Model not supported") from exc
return DeepSeek(llm_params)
else:
raise ValueError(
"Model provided by the configuration not supported")
raise ValueError("Model provided by the configuration not supported")
def _create_default_embedder(self, llm_config=None) -> object:
"""
@ -229,8 +246,9 @@ class AbstractGraph(ABC):
ValueError: If the model is not supported.
"""
if isinstance(self.llm_model, Gemini):
return GoogleGenerativeAIEmbeddings(google_api_key=llm_config['api_key'],
model="models/embedding-001")
return GoogleGenerativeAIEmbeddings(
google_api_key=llm_config["api_key"], model="models/embedding-001"
)
if isinstance(self.llm_model, OpenAI):
return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key)
elif isinstance(self.llm_model, AzureOpenAIEmbeddings):
@ -265,8 +283,8 @@ class AbstractGraph(ABC):
Raises:
KeyError: If the model is not supported.
"""
if 'model_instance' in embedder_config:
return embedder_config['model_instance']
if "model_instance" in embedder_config:
return embedder_config["model_instance"]
# Instantiate the embedding model based on the model name
if "openai" in embedder_config["model"]:
return OpenAIEmbeddings(api_key=embedder_config["api_key"])
@ -283,28 +301,27 @@ class AbstractGraph(ABC):
try:
models_tokens["hugging_face"][embedder_config["model"]]
except KeyError as exc:
raise KeyError("Model not supported")from exc
raise KeyError("Model not supported") from exc
return HuggingFaceHubEmbeddings(model=embedder_config["model"])
elif "gemini" in embedder_config["model"]:
try:
models_tokens["gemini"][embedder_config["model"]]
except KeyError as exc:
raise KeyError("Model not supported")from exc
raise KeyError("Model not supported") from exc
return GoogleGenerativeAIEmbeddings(model=embedder_config["model"])
elif "bedrock" in embedder_config["model"]:
embedder_config["model"] = embedder_config["model"].split("/")[-1]
client = embedder_config.get('client', None)
client = embedder_config.get("client", None)
try:
models_tokens["bedrock"][embedder_config["model"]]
except KeyError as exc:
raise KeyError("Model not supported") from exc
return BedrockEmbeddings(client=client, model_id=embedder_config["model"])
else:
raise ValueError(
"Model provided by the configuration not supported")
raise ValueError("Model provided by the configuration not supported")
def get_state(self, key=None) -> dict:
"""""
""" ""
Get the final state of the graph.
Args: