diff --git a/scrapegraphai/graphs/abstract_graph.py b/scrapegraphai/graphs/abstract_graph.py index eb4fee7f..2ae47c41 100644 --- a/scrapegraphai/graphs/abstract_graph.py +++ b/scrapegraphai/graphs/abstract_graph.py @@ -47,8 +47,8 @@ class AbstractGraph(ABC): self.source = source self.config = config self.llm_model = self._create_llm(config["llm"], chat=True) - self.embedder_model = self._create_default_embedder( - ) if "embeddings" not in config else self._create_embedder( + self.embedder_model = self._create_default_embedder( + ) if "embeddings" not in config else self._create_embedder( config["embeddings"]) # Set common configuration parameters @@ -61,6 +61,7 @@ class AbstractGraph(ABC): self.final_state = None self.execution_info = None + def _set_model_token(self, llm): if 'Azure' in str(type(llm)): @@ -68,7 +69,7 @@ class AbstractGraph(ABC): self.model_token = models_tokens["azure"][llm.model_name] except KeyError: raise KeyError("Model not supported") - + elif 'HuggingFaceEndpoint' in str(type(llm)): if 'mistral' in llm.repo_id: try: @@ -76,6 +77,7 @@ class AbstractGraph(ABC): except KeyError: raise KeyError("Model not supported") + def _create_llm(self, llm_config: dict, chat=False) -> object: """ Create a large language model instance based on the configuration provided. @@ -101,7 +103,7 @@ class AbstractGraph(ABC): if chat: self._set_model_token(llm_params['model_instance']) return llm_params['model_instance'] - + # Instantiate the language model based on the model name if "gpt-" in llm_params["model"]: try: @@ -178,7 +180,7 @@ class AbstractGraph(ABC): else: raise ValueError( "Model provided by the configuration not supported") - + def _create_default_embedder(self) -> object: """ Create an embedding model instance based on the chosen llm model. @@ -209,7 +211,7 @@ class AbstractGraph(ABC): return BedrockEmbeddings(client=None, model_id=self.llm_model.model_id) else: raise ValueError("Embedding Model missing or not supported") - + def _create_embedder(self, embedder_config: dict) -> object: """ Create an embedding model instance based on the configuration provided. @@ -226,7 +228,7 @@ class AbstractGraph(ABC): if 'model_instance' in embedder_config: return embedder_config['model_instance'] - + # Instantiate the embedding model based on the model name if "openai" in embedder_config["model"]: return OpenAIEmbeddings(api_key=embedder_config["api_key"]) @@ -241,14 +243,14 @@ class AbstractGraph(ABC): except KeyError: raise KeyError("Model not supported") return OllamaEmbeddings(**embedder_config) - + elif "hugging_face" in embedder_config["model"]: try: models_tokens["hugging_face"][embedder_config["model"]] except KeyError: raise KeyError("Model not supported") return HuggingFaceHubEmbeddings(model=embedder_config["model"]) - + elif "bedrock" in embedder_config["model"]: embedder_config["model"] = embedder_config["model"].split("/")[-1] try: @@ -258,7 +260,7 @@ class AbstractGraph(ABC): return BedrockEmbeddings(client=None, model_id=embedder_config["model"]) else: raise ValueError( - "Model provided by the configuration not supported") + "Model provided by the configuration not supported") def get_state(self, key=None) -> dict: """"" @@ -282,7 +284,7 @@ class AbstractGraph(ABC): Returns: dict: The execution information of the graph. """ - + return self.execution_info @abstractmethod @@ -298,3 +300,4 @@ class AbstractGraph(ABC): Abstract method to execute the graph and return the result. """ pass +