Merge pull request #48 from VinciGit00/llama_new_models

Llama new models
This commit is contained in:
Marco Perini 2024-04-09 11:10:13 +02:00 committed by GitHub
commit 43401c80fa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 31 additions and 19 deletions

View File

@ -6,6 +6,7 @@ from typing import Optional
from ..models import OpenAI, Gemini, Ollama, AzureOpenAI
from ..helpers import models_tokens
class AbstractGraph(ABC):
"""
Abstract class representing a generic graph-based tool.
@ -19,7 +20,8 @@ class AbstractGraph(ABC):
self.source = source
self.config = config
self.llm_model = self._create_llm(config["llm"])
self.embedder_model = None if "embeddings" not in config else self._create_llm(config["embeddings"])
self.embedder_model = None if "embeddings" not in config else self._create_llm(
config["embeddings"])
self.graph = self._create_graph()
def _create_llm(self, llm_config: dict):
@ -39,7 +41,7 @@ class AbstractGraph(ABC):
except KeyError:
raise ValueError("Model not supported")
return OpenAI(llm_params)
elif "azure" in llm_params["model"]:
# take the model after the last dash
llm_params["model"] = llm_params["model"].split("/")[-1]
@ -48,23 +50,30 @@ class AbstractGraph(ABC):
except KeyError:
raise ValueError("Model not supported")
return AzureOpenAI(llm_params)
elif "gemini" in llm_params["model"]:
try:
self.model_token = models_tokens["gemini"][llm_params["model"]]
except KeyError:
raise ValueError("Model not supported")
return Gemini(llm_params)
elif "ollama" in llm_params["model"]:
# take the model after the last dash
"""
Avaiable models:
- llama2
- mistral
- codellama
- dolphin-mixtral
- mistral-openorca
"""
llm_params["model"] = llm_params["model"].split("/")[-1]
try:
self.model_token = models_tokens["ollama"][llm_params["model"]]
except KeyError:
raise ValueError("Model not supported")
return Ollama(llm_params)
else:
raise ValueError("Model not supported")

View File

@ -1,7 +1,6 @@
"""
Module for creating the smart scraper
"""
from ..models import OpenAI, Gemini
from .base_graph import BaseGraph
from ..nodes import (
FetchNode,

View File

@ -2,7 +2,7 @@
Module for converting text to speach
"""
from scrapegraphai.utils.save_audio_from_bytes import save_audio_from_bytes
from ..models import OpenAI, Gemini, OpenAITextToSpeech
from ..models import OpenAITextToSpeech
from .base_graph import BaseGraph
from ..nodes import (
FetchNode,
@ -27,7 +27,7 @@ class SpeechGraph(AbstractGraph):
super().__init__(prompt, config, source)
self.input_key = "url" if source.startswith("http") else "local_dir"
def _create_graph(self):
"""
Creates the graph of nodes representing the workflow for web scraping and summarization.

View File

@ -21,9 +21,11 @@ models_tokens = {
"gemini-pro": 128000,
},
"ollama":{
"ollama": {
"llama2": 4096,
"mistral": 8192,
"codellama": 16000,
"dolphin-mixtral": 32000,
"mistral-openorca": 32000,
}
}

View File

@ -67,8 +67,8 @@ class FetchNode(BaseNode):
# Fetching data from the state based on the input keys
input_data = [state[key] for key in input_keys]
source = input_data[0]
source = input_data[0]
# if it is a local directory
if not source.startswith("http"):
document = [Document(page_content=source, metadata={

View File

@ -8,8 +8,8 @@ from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import EmbeddingsFilter, DocumentCompressorPipeline
from langchain_community.document_transformers import EmbeddingsRedundantFilter
from langchain_community.vectorstores import FAISS
from ..models import OpenAI, Gemini, Ollama, AzureOpenAI
from langchain_openai import OpenAIEmbeddings, AzureOpenAIEmbeddings
from ..models import OpenAI, Ollama, AzureOpenAI
from langchain_community.embeddings import OllamaEmbeddings
from .base_node import BaseNode
@ -86,16 +86,18 @@ class RAGNode(BaseNode):
embedding_model = self.embedder_model if self.embedder_model else self.llm_model
if isinstance(embedding_model, OpenAI):
embeddings = OpenAIEmbeddings(api_key=embedding_model.openai_api_key)
embeddings = OpenAIEmbeddings(
api_key=embedding_model.openai_api_key)
elif isinstance(embedding_model, AzureOpenAI):
embeddings = AzureOpenAIEmbeddings()
elif isinstance(embedding_model, Ollama):
embeddings = OllamaEmbeddings()
else:
raise ValueError("Embedding Model missing or not supported")
retriever = FAISS.from_documents(chunked_docs, embeddings).as_retriever()
retriever = FAISS.from_documents(
chunked_docs, embeddings).as_retriever()
redundant_filter = EmbeddingsRedundantFilter(embeddings=embeddings)
# similarity_threshold could be set, now k=20
relevant_filter = EmbeddingsFilter(embeddings=embeddings)

View File

@ -94,7 +94,7 @@ class SearchInternetNode(BaseNode):
# Execute the chain to get the search query
search_answer = search_prompt | self.llm_model | output_parser
search_query = search_answer.invoke({"user_prompt": user_prompt})[0]
print(f"Search Query: {search_query}")
# TODO: handle multiple URLs
answer = search_on_web(query=search_query, max_results=1)[0]