Fixed accidental reformatting.

This commit is contained in:
Cem Uzunoglu 2024-05-06 15:09:56 +03:00
parent e264e92e72
commit 2ac9e16dd9

View File

@ -47,8 +47,8 @@ class AbstractGraph(ABC):
self.source = source
self.config = config
self.llm_model = self._create_llm(config["llm"], chat=True)
self.embedder_model = self._create_default_embedder(
) if "embeddings" not in config else self._create_embedder(
self.embedder_model = self._create_default_embedder(
) if "embeddings" not in config else self._create_embedder(
config["embeddings"])
# Set common configuration parameters
@ -61,6 +61,7 @@ class AbstractGraph(ABC):
self.final_state = None
self.execution_info = None
def _set_model_token(self, llm):
if 'Azure' in str(type(llm)):
@ -68,7 +69,7 @@ class AbstractGraph(ABC):
self.model_token = models_tokens["azure"][llm.model_name]
except KeyError:
raise KeyError("Model not supported")
elif 'HuggingFaceEndpoint' in str(type(llm)):
if 'mistral' in llm.repo_id:
try:
@ -76,6 +77,7 @@ class AbstractGraph(ABC):
except KeyError:
raise KeyError("Model not supported")
def _create_llm(self, llm_config: dict, chat=False) -> object:
"""
Create a large language model instance based on the configuration provided.
@ -101,7 +103,7 @@ class AbstractGraph(ABC):
if chat:
self._set_model_token(llm_params['model_instance'])
return llm_params['model_instance']
# Instantiate the language model based on the model name
if "gpt-" in llm_params["model"]:
try:
@ -178,7 +180,7 @@ class AbstractGraph(ABC):
else:
raise ValueError(
"Model provided by the configuration not supported")
def _create_default_embedder(self) -> object:
"""
Create an embedding model instance based on the chosen llm model.
@ -209,7 +211,7 @@ class AbstractGraph(ABC):
return BedrockEmbeddings(client=None, model_id=self.llm_model.model_id)
else:
raise ValueError("Embedding Model missing or not supported")
def _create_embedder(self, embedder_config: dict) -> object:
"""
Create an embedding model instance based on the configuration provided.
@ -226,7 +228,7 @@ class AbstractGraph(ABC):
if 'model_instance' in embedder_config:
return embedder_config['model_instance']
# Instantiate the embedding model based on the model name
if "openai" in embedder_config["model"]:
return OpenAIEmbeddings(api_key=embedder_config["api_key"])
@ -241,14 +243,14 @@ class AbstractGraph(ABC):
except KeyError:
raise KeyError("Model not supported")
return OllamaEmbeddings(**embedder_config)
elif "hugging_face" in embedder_config["model"]:
try:
models_tokens["hugging_face"][embedder_config["model"]]
except KeyError:
raise KeyError("Model not supported")
return HuggingFaceHubEmbeddings(model=embedder_config["model"])
elif "bedrock" in embedder_config["model"]:
embedder_config["model"] = embedder_config["model"].split("/")[-1]
try:
@ -258,7 +260,7 @@ class AbstractGraph(ABC):
return BedrockEmbeddings(client=None, model_id=embedder_config["model"])
else:
raise ValueError(
"Model provided by the configuration not supported")
"Model provided by the configuration not supported")
def get_state(self, key=None) -> dict:
"""""
@ -282,7 +284,7 @@ class AbstractGraph(ABC):
Returns:
dict: The execution information of the graph.
"""
return self.execution_info
@abstractmethod
@ -298,3 +300,4 @@ class AbstractGraph(ABC):
Abstract method to execute the graph and return the result.
"""
pass