diff --git a/scrapegraphai/graphs/abstract_graph.py b/scrapegraphai/graphs/abstract_graph.py index f5922938..31945ec2 100644 --- a/scrapegraphai/graphs/abstract_graph.py +++ b/scrapegraphai/graphs/abstract_graph.py @@ -8,13 +8,9 @@ from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings from langchain_community.embeddings import HuggingFaceHubEmbeddings, OllamaEmbeddings from langchain_google_genai import GoogleGenerativeAIEmbeddings from ..helpers import models_tokens -from ..models import AzureOpenAI, Bedrock, Gemini, Groq, HuggingFace, Ollama, OpenAI, Anthropic, DeepSeek +from ..models import AzureOpenAI, Bedrock, Gemini, Groq, HuggingFace, Ollama, OpenAI, Anthropic from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings -from ..helpers import models_tokens -from ..models import AzureOpenAI, Bedrock, Gemini, Groq, HuggingFace, Ollama, OpenAI, Anthropic, DeepSeek - - class AbstractGraph(ABC): """ Scaffolding class for creating a graph representation and executing it. @@ -22,7 +18,6 @@ class AbstractGraph(ABC): prompt (str): The prompt for the graph. source (str): The source of the graph. config (dict): Configuration parameters for the graph. - schema (str): The schema for the graph output. llm_model: An instance of a language model client, configured for generating answers. embedder_model: An instance of an embedding model client, configured for generating embeddings. @@ -33,7 +28,6 @@ class AbstractGraph(ABC): prompt (str): The prompt for the graph. config (dict): Configuration parameters for the graph. source (str, optional): The source of the graph. - schema (str, optional): The schema for the graph output. Example: >>> class MyGraph(AbstractGraph): @@ -45,21 +39,15 @@ class AbstractGraph(ABC): >>> result = my_graph.run() """ - def __init__(self, prompt: str, config: dict, source: Optional[str] = None, schema: Optional[str] = None): + def __init__(self, prompt: str, config: dict, source: Optional[str] = None): self.prompt = prompt self.source = source self.config = config - self.schema = schema self.llm_model = self._create_llm(config["llm"], chat=True) self.embedder_model = self._create_default_embedder(llm_config=config["llm"] ) if "embeddings" not in config else self._create_embedder( config["embeddings"]) - self.verbose = False if config is None else config.get( - "verbose", False) - self.headless = True if config is None else config.get( - "headless", True) - self.loader_kwargs = config.get("loader_kwargs", {}) # Create the graph self.graph = self._create_graph() @@ -67,20 +55,18 @@ class AbstractGraph(ABC): self.execution_info = None # Set common configuration parameters + self.verbose = False if config is None else config.get( "verbose", False) self.headless = True if config is None else config.get( "headless", True) self.loader_kwargs = config.get("loader_kwargs", {}) - common_params = { - "headless": self.headless, - "verbose": self.verbose, - "loader_kwargs": self.loader_kwargs, - "llm_model": self.llm_model, - "embedder_model": self.embedder_model - } - + common_params = {"headless": self.headless, + + "loader_kwargs": self.loader_kwargs, + "llm_model": self.llm_model, + "embedder_model": self.embedder_model} self.set_common_params(common_params, overwrite=False) def set_common_params(self, params: dict, overwrite=False): @@ -93,7 +79,7 @@ class AbstractGraph(ABC): for node in self.graph.nodes: node.update_config(params, overwrite) - + def _set_model_token(self, llm): if 'Azure' in str(type(llm)): @@ -171,7 +157,7 @@ class AbstractGraph(ABC): raise KeyError("Model not supported") from exc return Anthropic(llm_params) elif "ollama" in llm_params["model"]: - llm_params["model"] = llm_params["model"].split("ollama/")[-1] + llm_params["model"] = llm_params["model"].split("/")[-1] # allow user to set model_tokens in config try: @@ -245,8 +231,6 @@ class AbstractGraph(ABC): model="models/embedding-001") if isinstance(self.llm_model, OpenAI): return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key) - elif isinstance(self.llm_model, DeepSeek): - return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key) elif isinstance(self.llm_model, AzureOpenAIEmbeddings): return self.llm_model elif isinstance(self.llm_model, AzureOpenAI): @@ -282,31 +266,30 @@ class AbstractGraph(ABC): if 'model_instance' in embedder_config: return embedder_config['model_instance'] # Instantiate the embedding model based on the model name - if "openai" in embedder_config["model"].split("/")[0]: + if "openai" in embedder_config["model"]: return OpenAIEmbeddings(api_key=embedder_config["api_key"]) elif "azure" in embedder_config["model"]: return AzureOpenAIEmbeddings() - elif "ollama" in embedder_config["model"].split("/")[0]: - print("ciao") - embedder_config["model"] = embedder_config["model"].split("ollama/")[-1] + elif "ollama" in embedder_config["model"]: + embedder_config["model"] = embedder_config["model"].split("/")[-1] try: models_tokens["ollama"][embedder_config["model"]] except KeyError as exc: raise KeyError("Model not supported") from exc return OllamaEmbeddings(**embedder_config) - elif "hugging_face" in embedder_config["model"].split("/")[0]: + elif "hugging_face" in embedder_config["model"]: try: models_tokens["hugging_face"][embedder_config["model"]] except KeyError as exc: raise KeyError("Model not supported")from exc return HuggingFaceHubEmbeddings(model=embedder_config["model"]) - elif "gemini" in embedder_config["model"].split("/")[0]: + elif "gemini" in embedder_config["model"]: try: models_tokens["gemini"][embedder_config["model"]] except KeyError as exc: raise KeyError("Model not supported")from exc return GoogleGenerativeAIEmbeddings(model=embedder_config["model"]) - elif "bedrock" in embedder_config["model"].split("/")[0]: + elif "bedrock" in embedder_config["model"]: embedder_config["model"] = embedder_config["model"].split("/")[-1] client = embedder_config.get('client', None) try: diff --git a/scrapegraphai/graphs/pdf_scraper_graph.py b/scrapegraphai/graphs/pdf_scraper_graph.py index af9fe7d4..39278ab7 100644 --- a/scrapegraphai/graphs/pdf_scraper_graph.py +++ b/scrapegraphai/graphs/pdf_scraper_graph.py @@ -11,7 +11,7 @@ from ..nodes import ( FetchNode, ParseNode, RAGNode, - GenerateAnswerNode + GenerateAnswerPDFNode ) @@ -48,7 +48,7 @@ class PDFScraperGraph(AbstractGraph): """ def __init__(self, prompt: str, source: str, config: dict, schema: Optional[str] = None): - super().__init__(prompt, config, source, schema) + super().__init__(prompt, config, source) self.input_key = "pdf" if source.endswith("pdf") else "pdf_dir" @@ -64,41 +64,21 @@ class PDFScraperGraph(AbstractGraph): input='pdf | pdf_dir', output=["doc", "link_urls", "img_urls"], ) - parse_node = ParseNode( - input="doc", - output=["parsed_doc"], - node_config={ - "chunk_size": self.model_token, - } - ) - rag_node = RAGNode( - input="user_prompt & (parsed_doc | doc)", - output=["relevant_chunks"], - node_config={ - "llm_model": self.llm_model, - "embedder_model": self.embedder_model, - } - ) - generate_answer_node = GenerateAnswerNode( + generate_answer_node_pdf = GenerateAnswerPDFNode( input="user_prompt & (relevant_chunks | parsed_doc | doc)", output=["answer"], node_config={ "llm_model": self.llm_model, - "schema": self.schema, } ) return BaseGraph( nodes=[ fetch_node, - parse_node, - rag_node, - generate_answer_node, + generate_answer_node_pdf, ], edges=[ - (fetch_node, parse_node), - (parse_node, rag_node), - (rag_node, generate_answer_node) + (fetch_node, generate_answer_node_pdf) ], entry_point=fetch_node ) diff --git a/scrapegraphai/nodes/generate_answer_pdf_node.py b/scrapegraphai/nodes/generate_answer_pdf_node.py index fcad5b5a..b64ca763 100644 --- a/scrapegraphai/nodes/generate_answer_pdf_node.py +++ b/scrapegraphai/nodes/generate_answer_pdf_node.py @@ -49,7 +49,7 @@ class GenerateAnswerPDFNode(BaseNode): node_name (str): name of the node """ super().__init__(node_name, "node", input, output, 2, node_config) - self.llm_model = node_config["llm"] + self.llm_model = node_config["llm_model"] self.verbose = False if node_config is None else node_config.get( "verbose", False)