fixed token models, added mistral support

This commit is contained in:
EURAC\marperini 2024-04-08 15:21:06 +02:00
parent 17add20c13
commit dee1a42629
9 changed files with 78 additions and 21 deletions

View File

@ -7,6 +7,7 @@ from dotenv import load_dotenv
from scrapegraphai.models import Gemini
from scrapegraphai.graphs import BaseGraph
from scrapegraphai.nodes import FetchNode, ParseNode, RAGNode, GenerateAnswerNode
from scrapegraphai.helpers import models_tokens
load_dotenv()
# ************************************************
@ -38,6 +39,7 @@ fetch_node = FetchNode(
parse_node = ParseNode(
input="doc",
output=["parsed_doc"],
node_config={"chunk_size": 4096}
)
rag_node = RAGNode(
input="user_prompt & (parsed_doc | doc)",

View File

@ -7,6 +7,7 @@ from dotenv import load_dotenv
from scrapegraphai.models import OpenAI
from scrapegraphai.graphs import BaseGraph
from scrapegraphai.nodes import FetchNode, ParseNode, RAGNode, GenerateAnswerNode
from scrapegraphai.helpers import models_tokens
load_dotenv()
# ************************************************
@ -38,6 +39,7 @@ fetch_node = FetchNode(
parse_node = ParseNode(
input="doc",
output=["parsed_doc"],
node_config={"chunk_size": 4096}
)
rag_node = RAGNode(
input="user_prompt & (parsed_doc | doc)",

View File

@ -16,7 +16,7 @@ openai_key = os.getenv("OPENAI_APIKEY")
graph_config = {
"llm": {
"model": "ollama/llama2",
"model": "ollama/mistral",
"temperature": 0,
"format": "json", # Ollama needs the format to be specified explicitly
},

View File

@ -4,6 +4,7 @@ Module having abstract class for creating all the graphs
from abc import ABC, abstractmethod
from typing import Optional
from ..models import OpenAI, Gemini, Ollama, AzureOpenAI
from ..helpers import models_tokens
class AbstractGraph(ABC):
"""
@ -16,8 +17,6 @@ class AbstractGraph(ABC):
"""
self.prompt = prompt
self.source = source
self.input_key = "url" if source.startswith(
"http") else "local_dir"
self.config = config
self.llm_model = self._create_llm(config["llm"])
self.embedder_model = None if "embeddings" not in config else self._create_llm(config["embeddings"])
@ -33,16 +32,39 @@ class AbstractGraph(ABC):
}
llm_params = {**llm_defaults, **llm_config}
# Instantiate the language model based on the model name
if "gpt-" in llm_params["model"]:
try:
self.model_token = models_tokens["openai"][llm_params["model"]]
except KeyError:
raise ValueError("Model not supported")
return OpenAI(llm_params)
elif "azure" in llm_params["model"]:
# take the model after the last dash
llm_params["model"] = llm_params["model"].split("/")[-1]
try:
self.model_token = models_tokens["openai"][llm_params["model"]]
except KeyError:
raise ValueError("Model not supported")
return AzureOpenAI(llm_params)
elif "gemini" in llm_params["model"]:
try:
self.model_token = models_tokens["gemini"][llm_params["model"]]
except KeyError:
raise ValueError("Model not supported")
return Gemini(llm_params)
elif "llama2" in llm_params["model"]:
# set model to llama2 if it has a different structure
llm_params["model"] = "llama2"
elif "ollama" in llm_params["model"]:
# take the model after the last dash
llm_params["model"] = llm_params["model"].split("/")[-1]
try:
self.model_token = models_tokens["ollama"][llm_params["model"]]
except KeyError:
raise ValueError("Model not supported")
return Ollama(llm_params)
else:
raise ValueError("Model not supported")

View File

@ -32,6 +32,7 @@ class SearchGraph(AbstractGraph):
parse_node = ParseNode(
input="doc",
output=["parsed_doc"],
node_config={"chunk_size": self.model_token}
)
rag_node = RAGNode(
input="user_prompt & (parsed_doc | doc)",

View File

@ -18,6 +18,14 @@ class SmartScraperGraph(AbstractGraph):
information from web pages using a natural language model to interpret and answer prompts.
"""
def __init__(self, prompt: str, source: str, config: dict):
"""
Initializes the SmartScraperGraph with a prompt, source, and configuration.
"""
super().__init__(prompt, config, source)
self.input_key = "url" if source.startswith("http") else "local_dir"
def _create_graph(self):
"""
Creates the graph of nodes representing the workflow for web scraping.
@ -29,6 +37,7 @@ class SmartScraperGraph(AbstractGraph):
parse_node = ParseNode(
input="doc",
output=["parsed_doc"],
node_config={"chunk_size": self.model_token}
)
rag_node = RAGNode(
input="user_prompt & (parsed_doc | doc)",

View File

@ -20,6 +20,14 @@ class SpeechGraph(AbstractGraph):
information from web pages, then converting that summary into spoken word via an MP3 file.
"""
def __init__(self, prompt: str, source: str, config: dict):
"""
Initializes the SmartScraperGraph with a prompt, source, and configuration.
"""
super().__init__(prompt, config, source)
self.input_key = "url" if source.startswith("http") else "local_dir"
def _create_graph(self):
"""
Creates the graph of nodes representing the workflow for web scraping and summarization.
@ -31,6 +39,7 @@ class SpeechGraph(AbstractGraph):
parse_node = ParseNode(
input="doc",
output=["parsed_doc"],
node_config={"chunk_size": self.model_token}
)
rag_node = RAGNode(
input="user_prompt & (parsed_doc | doc)",

View File

@ -2,16 +2,28 @@
Models token
"""
models_tokens = {
"gpt-3.5-turbo-0125": 16385,
"gpt-3.5-turbo": 4096,
"gpt-3.5-turbo-1106": 16385,
"gpt-3.5-turbo-instruct": 4096,
"gpt-4-0125-preview": 128000,
"gpt-4-turbo-preview": 128000,
"gpt-4-1106-preview": 128000,
"gpt-4-vision-preview": 128000,
"gpt-4": 8192,
"gpt-4-0613": 8192,
"gpt-4-32k": 32768,
"gpt-4-32k-0613": 32768,
"openai": {
"gpt-3.5-turbo-0125": 16385,
"gpt-3.5-turbo": 4096,
"gpt-3.5-turbo-1106": 16385,
"gpt-3.5-turbo-instruct": 4096,
"gpt-4-0125-preview": 128000,
"gpt-4-turbo-preview": 128000,
"gpt-4-1106-preview": 128000,
"gpt-4-vision-preview": 128000,
"gpt-4": 8192,
"gpt-4-0613": 8192,
"gpt-4-32k": 32768,
"gpt-4-32k-0613": 32768,
},
"gemini": {
"gemini-pro": 128000,
},
"ollama":{
"llama2": 4096,
"mistral": 8192,
}
}

View File

@ -29,7 +29,7 @@ class ParseNode(BaseNode):
the specified tags, if provided, and updates the state with the parsed content.
"""
def __init__(self, input: str, output: List[str], node_name: str = "Parse"):
def __init__(self, input: str, output: List[str], node_config: dict, node_name: str = "Parse"):
"""
Initializes the ParseHTMLNode with a node name.
Args:
@ -38,7 +38,7 @@ class ParseNode(BaseNode):
node_name (str): name of the node
node_type (str, optional): type of the node
"""
super().__init__(node_name, "node", input, output, 1)
super().__init__(node_name, "node", input, output, 1, node_config)
def execute(self, state):
"""
@ -69,7 +69,7 @@ class ParseNode(BaseNode):
input_data = [state[key] for key in input_keys]
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
chunk_size=4000,
chunk_size=self.node_config.get("chunk_size", 4096),
chunk_overlap=0,
)