mirror of
https://github.com/VinciGit00/Scrapegraph-ai.git
synced 2026-06-25 21:11:11 +08:00
Merge pull request #48 from VinciGit00/llama_new_models
Llama new models
This commit is contained in:
commit
43401c80fa
@ -6,6 +6,7 @@ from typing import Optional
|
||||
from ..models import OpenAI, Gemini, Ollama, AzureOpenAI
|
||||
from ..helpers import models_tokens
|
||||
|
||||
|
||||
class AbstractGraph(ABC):
|
||||
"""
|
||||
Abstract class representing a generic graph-based tool.
|
||||
@ -19,7 +20,8 @@ class AbstractGraph(ABC):
|
||||
self.source = source
|
||||
self.config = config
|
||||
self.llm_model = self._create_llm(config["llm"])
|
||||
self.embedder_model = None if "embeddings" not in config else self._create_llm(config["embeddings"])
|
||||
self.embedder_model = None if "embeddings" not in config else self._create_llm(
|
||||
config["embeddings"])
|
||||
self.graph = self._create_graph()
|
||||
|
||||
def _create_llm(self, llm_config: dict):
|
||||
@ -39,7 +41,7 @@ class AbstractGraph(ABC):
|
||||
except KeyError:
|
||||
raise ValueError("Model not supported")
|
||||
return OpenAI(llm_params)
|
||||
|
||||
|
||||
elif "azure" in llm_params["model"]:
|
||||
# take the model after the last dash
|
||||
llm_params["model"] = llm_params["model"].split("/")[-1]
|
||||
@ -48,23 +50,30 @@ class AbstractGraph(ABC):
|
||||
except KeyError:
|
||||
raise ValueError("Model not supported")
|
||||
return AzureOpenAI(llm_params)
|
||||
|
||||
|
||||
elif "gemini" in llm_params["model"]:
|
||||
try:
|
||||
self.model_token = models_tokens["gemini"][llm_params["model"]]
|
||||
except KeyError:
|
||||
raise ValueError("Model not supported")
|
||||
return Gemini(llm_params)
|
||||
|
||||
|
||||
elif "ollama" in llm_params["model"]:
|
||||
# take the model after the last dash
|
||||
"""
|
||||
Avaiable models:
|
||||
- llama2
|
||||
- mistral
|
||||
- codellama
|
||||
- dolphin-mixtral
|
||||
- mistral-openorca
|
||||
"""
|
||||
llm_params["model"] = llm_params["model"].split("/")[-1]
|
||||
try:
|
||||
self.model_token = models_tokens["ollama"][llm_params["model"]]
|
||||
except KeyError:
|
||||
raise ValueError("Model not supported")
|
||||
return Ollama(llm_params)
|
||||
|
||||
|
||||
else:
|
||||
raise ValueError("Model not supported")
|
||||
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
"""
|
||||
Module for creating the smart scraper
|
||||
"""
|
||||
from ..models import OpenAI, Gemini
|
||||
from .base_graph import BaseGraph
|
||||
from ..nodes import (
|
||||
FetchNode,
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
Module for converting text to speach
|
||||
"""
|
||||
from scrapegraphai.utils.save_audio_from_bytes import save_audio_from_bytes
|
||||
from ..models import OpenAI, Gemini, OpenAITextToSpeech
|
||||
from ..models import OpenAITextToSpeech
|
||||
from .base_graph import BaseGraph
|
||||
from ..nodes import (
|
||||
FetchNode,
|
||||
@ -27,7 +27,7 @@ class SpeechGraph(AbstractGraph):
|
||||
super().__init__(prompt, config, source)
|
||||
|
||||
self.input_key = "url" if source.startswith("http") else "local_dir"
|
||||
|
||||
|
||||
def _create_graph(self):
|
||||
"""
|
||||
Creates the graph of nodes representing the workflow for web scraping and summarization.
|
||||
|
||||
@ -21,9 +21,11 @@ models_tokens = {
|
||||
"gemini-pro": 128000,
|
||||
},
|
||||
|
||||
"ollama":{
|
||||
"ollama": {
|
||||
"llama2": 4096,
|
||||
"mistral": 8192,
|
||||
"codellama": 16000,
|
||||
"dolphin-mixtral": 32000,
|
||||
"mistral-openorca": 32000,
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@ -67,8 +67,8 @@ class FetchNode(BaseNode):
|
||||
# Fetching data from the state based on the input keys
|
||||
input_data = [state[key] for key in input_keys]
|
||||
|
||||
source = input_data[0]
|
||||
|
||||
source = input_data[0]
|
||||
|
||||
# if it is a local directory
|
||||
if not source.startswith("http"):
|
||||
document = [Document(page_content=source, metadata={
|
||||
|
||||
@ -8,8 +8,8 @@ from langchain.retrievers import ContextualCompressionRetriever
|
||||
from langchain.retrievers.document_compressors import EmbeddingsFilter, DocumentCompressorPipeline
|
||||
from langchain_community.document_transformers import EmbeddingsRedundantFilter
|
||||
from langchain_community.vectorstores import FAISS
|
||||
from ..models import OpenAI, Gemini, Ollama, AzureOpenAI
|
||||
from langchain_openai import OpenAIEmbeddings, AzureOpenAIEmbeddings
|
||||
from ..models import OpenAI, Ollama, AzureOpenAI
|
||||
from langchain_community.embeddings import OllamaEmbeddings
|
||||
from .base_node import BaseNode
|
||||
|
||||
@ -86,16 +86,18 @@ class RAGNode(BaseNode):
|
||||
embedding_model = self.embedder_model if self.embedder_model else self.llm_model
|
||||
|
||||
if isinstance(embedding_model, OpenAI):
|
||||
embeddings = OpenAIEmbeddings(api_key=embedding_model.openai_api_key)
|
||||
embeddings = OpenAIEmbeddings(
|
||||
api_key=embedding_model.openai_api_key)
|
||||
elif isinstance(embedding_model, AzureOpenAI):
|
||||
embeddings = AzureOpenAIEmbeddings()
|
||||
elif isinstance(embedding_model, Ollama):
|
||||
embeddings = OllamaEmbeddings()
|
||||
else:
|
||||
raise ValueError("Embedding Model missing or not supported")
|
||||
|
||||
retriever = FAISS.from_documents(chunked_docs, embeddings).as_retriever()
|
||||
|
||||
|
||||
retriever = FAISS.from_documents(
|
||||
chunked_docs, embeddings).as_retriever()
|
||||
|
||||
redundant_filter = EmbeddingsRedundantFilter(embeddings=embeddings)
|
||||
# similarity_threshold could be set, now k=20
|
||||
relevant_filter = EmbeddingsFilter(embeddings=embeddings)
|
||||
|
||||
@ -94,7 +94,7 @@ class SearchInternetNode(BaseNode):
|
||||
# Execute the chain to get the search query
|
||||
search_answer = search_prompt | self.llm_model | output_parser
|
||||
search_query = search_answer.invoke({"user_prompt": user_prompt})[0]
|
||||
|
||||
|
||||
print(f"Search Query: {search_query}")
|
||||
# TODO: handle multiple URLs
|
||||
answer = search_on_web(query=search_query, max_results=1)[0]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user