mirror of
https://github.com/VinciGit00/Scrapegraph-ai.git
synced 2026-06-23 21:00:30 +08:00
fixed token models, added mistral support
This commit is contained in:
parent
17add20c13
commit
dee1a42629
@ -7,6 +7,7 @@ from dotenv import load_dotenv
|
||||
from scrapegraphai.models import Gemini
|
||||
from scrapegraphai.graphs import BaseGraph
|
||||
from scrapegraphai.nodes import FetchNode, ParseNode, RAGNode, GenerateAnswerNode
|
||||
from scrapegraphai.helpers import models_tokens
|
||||
load_dotenv()
|
||||
|
||||
# ************************************************
|
||||
@ -38,6 +39,7 @@ fetch_node = FetchNode(
|
||||
parse_node = ParseNode(
|
||||
input="doc",
|
||||
output=["parsed_doc"],
|
||||
node_config={"chunk_size": 4096}
|
||||
)
|
||||
rag_node = RAGNode(
|
||||
input="user_prompt & (parsed_doc | doc)",
|
||||
|
||||
@ -7,6 +7,7 @@ from dotenv import load_dotenv
|
||||
from scrapegraphai.models import OpenAI
|
||||
from scrapegraphai.graphs import BaseGraph
|
||||
from scrapegraphai.nodes import FetchNode, ParseNode, RAGNode, GenerateAnswerNode
|
||||
from scrapegraphai.helpers import models_tokens
|
||||
load_dotenv()
|
||||
|
||||
# ************************************************
|
||||
@ -38,6 +39,7 @@ fetch_node = FetchNode(
|
||||
parse_node = ParseNode(
|
||||
input="doc",
|
||||
output=["parsed_doc"],
|
||||
node_config={"chunk_size": 4096}
|
||||
)
|
||||
rag_node = RAGNode(
|
||||
input="user_prompt & (parsed_doc | doc)",
|
||||
|
||||
@ -16,7 +16,7 @@ openai_key = os.getenv("OPENAI_APIKEY")
|
||||
|
||||
graph_config = {
|
||||
"llm": {
|
||||
"model": "ollama/llama2",
|
||||
"model": "ollama/mistral",
|
||||
"temperature": 0,
|
||||
"format": "json", # Ollama needs the format to be specified explicitly
|
||||
},
|
||||
|
||||
@ -4,6 +4,7 @@ Module having abstract class for creating all the graphs
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
from ..models import OpenAI, Gemini, Ollama, AzureOpenAI
|
||||
from ..helpers import models_tokens
|
||||
|
||||
class AbstractGraph(ABC):
|
||||
"""
|
||||
@ -16,8 +17,6 @@ class AbstractGraph(ABC):
|
||||
"""
|
||||
self.prompt = prompt
|
||||
self.source = source
|
||||
self.input_key = "url" if source.startswith(
|
||||
"http") else "local_dir"
|
||||
self.config = config
|
||||
self.llm_model = self._create_llm(config["llm"])
|
||||
self.embedder_model = None if "embeddings" not in config else self._create_llm(config["embeddings"])
|
||||
@ -33,16 +32,39 @@ class AbstractGraph(ABC):
|
||||
}
|
||||
llm_params = {**llm_defaults, **llm_config}
|
||||
|
||||
# Instantiate the language model based on the model name
|
||||
if "gpt-" in llm_params["model"]:
|
||||
try:
|
||||
self.model_token = models_tokens["openai"][llm_params["model"]]
|
||||
except KeyError:
|
||||
raise ValueError("Model not supported")
|
||||
return OpenAI(llm_params)
|
||||
|
||||
elif "azure" in llm_params["model"]:
|
||||
# take the model after the last dash
|
||||
llm_params["model"] = llm_params["model"].split("/")[-1]
|
||||
try:
|
||||
self.model_token = models_tokens["openai"][llm_params["model"]]
|
||||
except KeyError:
|
||||
raise ValueError("Model not supported")
|
||||
return AzureOpenAI(llm_params)
|
||||
|
||||
elif "gemini" in llm_params["model"]:
|
||||
try:
|
||||
self.model_token = models_tokens["gemini"][llm_params["model"]]
|
||||
except KeyError:
|
||||
raise ValueError("Model not supported")
|
||||
return Gemini(llm_params)
|
||||
elif "llama2" in llm_params["model"]:
|
||||
# set model to llama2 if it has a different structure
|
||||
llm_params["model"] = "llama2"
|
||||
|
||||
elif "ollama" in llm_params["model"]:
|
||||
# take the model after the last dash
|
||||
llm_params["model"] = llm_params["model"].split("/")[-1]
|
||||
try:
|
||||
self.model_token = models_tokens["ollama"][llm_params["model"]]
|
||||
except KeyError:
|
||||
raise ValueError("Model not supported")
|
||||
return Ollama(llm_params)
|
||||
|
||||
else:
|
||||
raise ValueError("Model not supported")
|
||||
|
||||
|
||||
@ -32,6 +32,7 @@ class SearchGraph(AbstractGraph):
|
||||
parse_node = ParseNode(
|
||||
input="doc",
|
||||
output=["parsed_doc"],
|
||||
node_config={"chunk_size": self.model_token}
|
||||
)
|
||||
rag_node = RAGNode(
|
||||
input="user_prompt & (parsed_doc | doc)",
|
||||
|
||||
@ -18,6 +18,14 @@ class SmartScraperGraph(AbstractGraph):
|
||||
information from web pages using a natural language model to interpret and answer prompts.
|
||||
"""
|
||||
|
||||
def __init__(self, prompt: str, source: str, config: dict):
|
||||
"""
|
||||
Initializes the SmartScraperGraph with a prompt, source, and configuration.
|
||||
"""
|
||||
super().__init__(prompt, config, source)
|
||||
|
||||
self.input_key = "url" if source.startswith("http") else "local_dir"
|
||||
|
||||
def _create_graph(self):
|
||||
"""
|
||||
Creates the graph of nodes representing the workflow for web scraping.
|
||||
@ -29,6 +37,7 @@ class SmartScraperGraph(AbstractGraph):
|
||||
parse_node = ParseNode(
|
||||
input="doc",
|
||||
output=["parsed_doc"],
|
||||
node_config={"chunk_size": self.model_token}
|
||||
)
|
||||
rag_node = RAGNode(
|
||||
input="user_prompt & (parsed_doc | doc)",
|
||||
|
||||
@ -20,6 +20,14 @@ class SpeechGraph(AbstractGraph):
|
||||
information from web pages, then converting that summary into spoken word via an MP3 file.
|
||||
"""
|
||||
|
||||
def __init__(self, prompt: str, source: str, config: dict):
|
||||
"""
|
||||
Initializes the SmartScraperGraph with a prompt, source, and configuration.
|
||||
"""
|
||||
super().__init__(prompt, config, source)
|
||||
|
||||
self.input_key = "url" if source.startswith("http") else "local_dir"
|
||||
|
||||
def _create_graph(self):
|
||||
"""
|
||||
Creates the graph of nodes representing the workflow for web scraping and summarization.
|
||||
@ -31,6 +39,7 @@ class SpeechGraph(AbstractGraph):
|
||||
parse_node = ParseNode(
|
||||
input="doc",
|
||||
output=["parsed_doc"],
|
||||
node_config={"chunk_size": self.model_token}
|
||||
)
|
||||
rag_node = RAGNode(
|
||||
input="user_prompt & (parsed_doc | doc)",
|
||||
|
||||
@ -2,16 +2,28 @@
|
||||
Models token
|
||||
"""
|
||||
models_tokens = {
|
||||
"gpt-3.5-turbo-0125": 16385,
|
||||
"gpt-3.5-turbo": 4096,
|
||||
"gpt-3.5-turbo-1106": 16385,
|
||||
"gpt-3.5-turbo-instruct": 4096,
|
||||
"gpt-4-0125-preview": 128000,
|
||||
"gpt-4-turbo-preview": 128000,
|
||||
"gpt-4-1106-preview": 128000,
|
||||
"gpt-4-vision-preview": 128000,
|
||||
"gpt-4": 8192,
|
||||
"gpt-4-0613": 8192,
|
||||
"gpt-4-32k": 32768,
|
||||
"gpt-4-32k-0613": 32768,
|
||||
"openai": {
|
||||
"gpt-3.5-turbo-0125": 16385,
|
||||
"gpt-3.5-turbo": 4096,
|
||||
"gpt-3.5-turbo-1106": 16385,
|
||||
"gpt-3.5-turbo-instruct": 4096,
|
||||
"gpt-4-0125-preview": 128000,
|
||||
"gpt-4-turbo-preview": 128000,
|
||||
"gpt-4-1106-preview": 128000,
|
||||
"gpt-4-vision-preview": 128000,
|
||||
"gpt-4": 8192,
|
||||
"gpt-4-0613": 8192,
|
||||
"gpt-4-32k": 32768,
|
||||
"gpt-4-32k-0613": 32768,
|
||||
},
|
||||
|
||||
"gemini": {
|
||||
"gemini-pro": 128000,
|
||||
},
|
||||
|
||||
"ollama":{
|
||||
"llama2": 4096,
|
||||
"mistral": 8192,
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@ -29,7 +29,7 @@ class ParseNode(BaseNode):
|
||||
the specified tags, if provided, and updates the state with the parsed content.
|
||||
"""
|
||||
|
||||
def __init__(self, input: str, output: List[str], node_name: str = "Parse"):
|
||||
def __init__(self, input: str, output: List[str], node_config: dict, node_name: str = "Parse"):
|
||||
"""
|
||||
Initializes the ParseHTMLNode with a node name.
|
||||
Args:
|
||||
@ -38,7 +38,7 @@ class ParseNode(BaseNode):
|
||||
node_name (str): name of the node
|
||||
node_type (str, optional): type of the node
|
||||
"""
|
||||
super().__init__(node_name, "node", input, output, 1)
|
||||
super().__init__(node_name, "node", input, output, 1, node_config)
|
||||
|
||||
def execute(self, state):
|
||||
"""
|
||||
@ -69,7 +69,7 @@ class ParseNode(BaseNode):
|
||||
input_data = [state[key] for key in input_keys]
|
||||
|
||||
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
|
||||
chunk_size=4000,
|
||||
chunk_size=self.node_config.get("chunk_size", 4096),
|
||||
chunk_overlap=0,
|
||||
)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user