From dee1a4262936464e6683dfc2f09bc8e640bd7bf2 Mon Sep 17 00:00:00 2001 From: "EURAC\\marperini" Date: Mon, 8 Apr 2024 15:21:06 +0200 Subject: [PATCH] fixed token models, added mistral support --- examples/custom_graph_gemini.py | 2 ++ examples/custom_graph_openai.py | 2 ++ examples/search_graph_example.py | 2 +- scrapegraphai/graphs/abstract_graph.py | 32 +++++++++++++++--- scrapegraphai/graphs/search_graph.py | 1 + scrapegraphai/graphs/smart_scraper_graph.py | 9 ++++++ scrapegraphai/graphs/speech_graph.py | 9 ++++++ scrapegraphai/helpers/models_tokens.py | 36 ++++++++++++++------- scrapegraphai/nodes/parse_node.py | 6 ++-- 9 files changed, 78 insertions(+), 21 deletions(-) diff --git a/examples/custom_graph_gemini.py b/examples/custom_graph_gemini.py index 9ecae73b..7d663773 100644 --- a/examples/custom_graph_gemini.py +++ b/examples/custom_graph_gemini.py @@ -7,6 +7,7 @@ from dotenv import load_dotenv from scrapegraphai.models import Gemini from scrapegraphai.graphs import BaseGraph from scrapegraphai.nodes import FetchNode, ParseNode, RAGNode, GenerateAnswerNode +from scrapegraphai.helpers import models_tokens load_dotenv() # ************************************************ @@ -38,6 +39,7 @@ fetch_node = FetchNode( parse_node = ParseNode( input="doc", output=["parsed_doc"], + node_config={"chunk_size": 4096} ) rag_node = RAGNode( input="user_prompt & (parsed_doc | doc)", diff --git a/examples/custom_graph_openai.py b/examples/custom_graph_openai.py index 943838fc..d92cf3dd 100644 --- a/examples/custom_graph_openai.py +++ b/examples/custom_graph_openai.py @@ -7,6 +7,7 @@ from dotenv import load_dotenv from scrapegraphai.models import OpenAI from scrapegraphai.graphs import BaseGraph from scrapegraphai.nodes import FetchNode, ParseNode, RAGNode, GenerateAnswerNode +from scrapegraphai.helpers import models_tokens load_dotenv() # ************************************************ @@ -38,6 +39,7 @@ fetch_node = FetchNode( parse_node = ParseNode( input="doc", output=["parsed_doc"], + node_config={"chunk_size": 4096} ) rag_node = RAGNode( input="user_prompt & (parsed_doc | doc)", diff --git a/examples/search_graph_example.py b/examples/search_graph_example.py index 67fdee13..8624024f 100644 --- a/examples/search_graph_example.py +++ b/examples/search_graph_example.py @@ -16,7 +16,7 @@ openai_key = os.getenv("OPENAI_APIKEY") graph_config = { "llm": { - "model": "ollama/llama2", + "model": "ollama/mistral", "temperature": 0, "format": "json", # Ollama needs the format to be specified explicitly }, diff --git a/scrapegraphai/graphs/abstract_graph.py b/scrapegraphai/graphs/abstract_graph.py index db8bddb7..79eea199 100644 --- a/scrapegraphai/graphs/abstract_graph.py +++ b/scrapegraphai/graphs/abstract_graph.py @@ -4,6 +4,7 @@ Module having abstract class for creating all the graphs from abc import ABC, abstractmethod from typing import Optional from ..models import OpenAI, Gemini, Ollama, AzureOpenAI +from ..helpers import models_tokens class AbstractGraph(ABC): """ @@ -16,8 +17,6 @@ class AbstractGraph(ABC): """ self.prompt = prompt self.source = source - self.input_key = "url" if source.startswith( - "http") else "local_dir" self.config = config self.llm_model = self._create_llm(config["llm"]) self.embedder_model = None if "embeddings" not in config else self._create_llm(config["embeddings"]) @@ -33,16 +32,39 @@ class AbstractGraph(ABC): } llm_params = {**llm_defaults, **llm_config} + # Instantiate the language model based on the model name if "gpt-" in llm_params["model"]: + try: + self.model_token = models_tokens["openai"][llm_params["model"]] + except KeyError: + raise ValueError("Model not supported") return OpenAI(llm_params) + elif "azure" in llm_params["model"]: + # take the model after the last dash + llm_params["model"] = llm_params["model"].split("/")[-1] + try: + self.model_token = models_tokens["openai"][llm_params["model"]] + except KeyError: + raise ValueError("Model not supported") return AzureOpenAI(llm_params) + elif "gemini" in llm_params["model"]: + try: + self.model_token = models_tokens["gemini"][llm_params["model"]] + except KeyError: + raise ValueError("Model not supported") return Gemini(llm_params) - elif "llama2" in llm_params["model"]: - # set model to llama2 if it has a different structure - llm_params["model"] = "llama2" + + elif "ollama" in llm_params["model"]: + # take the model after the last dash + llm_params["model"] = llm_params["model"].split("/")[-1] + try: + self.model_token = models_tokens["ollama"][llm_params["model"]] + except KeyError: + raise ValueError("Model not supported") return Ollama(llm_params) + else: raise ValueError("Model not supported") diff --git a/scrapegraphai/graphs/search_graph.py b/scrapegraphai/graphs/search_graph.py index 57a87868..2f8ef6a8 100644 --- a/scrapegraphai/graphs/search_graph.py +++ b/scrapegraphai/graphs/search_graph.py @@ -32,6 +32,7 @@ class SearchGraph(AbstractGraph): parse_node = ParseNode( input="doc", output=["parsed_doc"], + node_config={"chunk_size": self.model_token} ) rag_node = RAGNode( input="user_prompt & (parsed_doc | doc)", diff --git a/scrapegraphai/graphs/smart_scraper_graph.py b/scrapegraphai/graphs/smart_scraper_graph.py index 56043183..06f7d010 100644 --- a/scrapegraphai/graphs/smart_scraper_graph.py +++ b/scrapegraphai/graphs/smart_scraper_graph.py @@ -18,6 +18,14 @@ class SmartScraperGraph(AbstractGraph): information from web pages using a natural language model to interpret and answer prompts. """ + def __init__(self, prompt: str, source: str, config: dict): + """ + Initializes the SmartScraperGraph with a prompt, source, and configuration. + """ + super().__init__(prompt, config, source) + + self.input_key = "url" if source.startswith("http") else "local_dir" + def _create_graph(self): """ Creates the graph of nodes representing the workflow for web scraping. @@ -29,6 +37,7 @@ class SmartScraperGraph(AbstractGraph): parse_node = ParseNode( input="doc", output=["parsed_doc"], + node_config={"chunk_size": self.model_token} ) rag_node = RAGNode( input="user_prompt & (parsed_doc | doc)", diff --git a/scrapegraphai/graphs/speech_graph.py b/scrapegraphai/graphs/speech_graph.py index d19acea9..28d0cdfa 100644 --- a/scrapegraphai/graphs/speech_graph.py +++ b/scrapegraphai/graphs/speech_graph.py @@ -20,6 +20,14 @@ class SpeechGraph(AbstractGraph): information from web pages, then converting that summary into spoken word via an MP3 file. """ + def __init__(self, prompt: str, source: str, config: dict): + """ + Initializes the SmartScraperGraph with a prompt, source, and configuration. + """ + super().__init__(prompt, config, source) + + self.input_key = "url" if source.startswith("http") else "local_dir" + def _create_graph(self): """ Creates the graph of nodes representing the workflow for web scraping and summarization. @@ -31,6 +39,7 @@ class SpeechGraph(AbstractGraph): parse_node = ParseNode( input="doc", output=["parsed_doc"], + node_config={"chunk_size": self.model_token} ) rag_node = RAGNode( input="user_prompt & (parsed_doc | doc)", diff --git a/scrapegraphai/helpers/models_tokens.py b/scrapegraphai/helpers/models_tokens.py index abdfc4b0..dc286526 100644 --- a/scrapegraphai/helpers/models_tokens.py +++ b/scrapegraphai/helpers/models_tokens.py @@ -2,16 +2,28 @@ Models token """ models_tokens = { - "gpt-3.5-turbo-0125": 16385, - "gpt-3.5-turbo": 4096, - "gpt-3.5-turbo-1106": 16385, - "gpt-3.5-turbo-instruct": 4096, - "gpt-4-0125-preview": 128000, - "gpt-4-turbo-preview": 128000, - "gpt-4-1106-preview": 128000, - "gpt-4-vision-preview": 128000, - "gpt-4": 8192, - "gpt-4-0613": 8192, - "gpt-4-32k": 32768, - "gpt-4-32k-0613": 32768, + "openai": { + "gpt-3.5-turbo-0125": 16385, + "gpt-3.5-turbo": 4096, + "gpt-3.5-turbo-1106": 16385, + "gpt-3.5-turbo-instruct": 4096, + "gpt-4-0125-preview": 128000, + "gpt-4-turbo-preview": 128000, + "gpt-4-1106-preview": 128000, + "gpt-4-vision-preview": 128000, + "gpt-4": 8192, + "gpt-4-0613": 8192, + "gpt-4-32k": 32768, + "gpt-4-32k-0613": 32768, + }, + + "gemini": { + "gemini-pro": 128000, + }, + + "ollama":{ + "llama2": 4096, + "mistral": 8192, + } + } diff --git a/scrapegraphai/nodes/parse_node.py b/scrapegraphai/nodes/parse_node.py index 866487d8..357b39fb 100644 --- a/scrapegraphai/nodes/parse_node.py +++ b/scrapegraphai/nodes/parse_node.py @@ -29,7 +29,7 @@ class ParseNode(BaseNode): the specified tags, if provided, and updates the state with the parsed content. """ - def __init__(self, input: str, output: List[str], node_name: str = "Parse"): + def __init__(self, input: str, output: List[str], node_config: dict, node_name: str = "Parse"): """ Initializes the ParseHTMLNode with a node name. Args: @@ -38,7 +38,7 @@ class ParseNode(BaseNode): node_name (str): name of the node node_type (str, optional): type of the node """ - super().__init__(node_name, "node", input, output, 1) + super().__init__(node_name, "node", input, output, 1, node_config) def execute(self, state): """ @@ -69,7 +69,7 @@ class ParseNode(BaseNode): input_data = [state[key] for key in input_keys] text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder( - chunk_size=4000, + chunk_size=self.node_config.get("chunk_size", 4096), chunk_overlap=0, )