feat: Enable end users to pass model instances of HuggingFaceHub

This commit is contained in:
Shubham Kamboj 2024-05-04 21:42:57 +05:30
parent 98dec36c60
commit 7599234ab9
4 changed files with 77 additions and 0 deletions

View File

@ -0,0 +1,63 @@
"""
Basic example of scraping pipeline using SmartScraper using Azure OpenAI Key
"""
import os
from dotenv import load_dotenv
from scrapegraphai.graphs import SmartScraperGraph
from scrapegraphai.utils import prettify_exec_info
from langchain_community.llms import HuggingFaceEndpoint
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
## required environment variable in .env
#HUGGINGFACEHUB_API_TOKEN
load_dotenv()
HUGGINGFACEHUB_API_TOKEN = os.getenv('HUGGINGFACEHUB_API_TOKEN')
# ************************************************
# Initialize the model instances
# ************************************************
repo_id = "mistralai/Mistral-7B-Instruct-v0.2"
llm_model_instance = HuggingFaceEndpoint(
repo_id=repo_id, max_length=128, temperature=0.5, token=HUGGINGFACEHUB_API_TOKEN
)
embedder_model_instance = HuggingFaceInferenceAPIEmbeddings(
api_key=HUGGINGFACEHUB_API_TOKEN, model_name="sentence-transformers/all-MiniLM-l6-v2"
)
# ************************************************
# Create the SmartScraperGraph instance and run it
# ************************************************
graph_config = {
"llm": {"model_instance": llm_model_instance},
"embeddings": {"model_instance": embedder_model_instance}
}
smart_scraper_graph = SmartScraperGraph(
prompt="List me all the events, with the following fields: company_name, event_name, event_start_date, event_start_time, event_end_date, event_end_time, location, event_mode, event_category, third_party_redirect, no_of_days, time_in_hours, hosted_or_attending, refreshments_type, registration_available, registration_link",
# also accepts a string with the already downloaded HTML code
source="https://www.hmhco.com/event",
config=graph_config
)
result = smart_scraper_graph.run()
print(result)
# ************************************************
# Get graph execution info
# ************************************************
graph_exec_info = smart_scraper_graph.get_execution_info()
print(prettify_exec_info(graph_exec_info))

View File

@ -64,6 +64,13 @@ class AbstractGraph(ABC):
self.model_token = models_tokens["azure"][llm.model_name]
except KeyError:
raise KeyError("Model not supported")
elif 'HuggingFaceEndpoint' in str(type(llm)):
if 'mistral' in llm.repo_id:
try:
self.model_token = models_tokens['mistral'][llm.repo_id]
except KeyError:
raise KeyError("Model not supported")
def _create_llm(self, llm_config: dict, chat=False) -> object:

View File

@ -65,5 +65,8 @@ models_tokens = {
"mistral.mistral-large-2402-v1:0": 32768,
"cohere.embed-english-v3": 512,
"cohere.embed-multilingual-v3": 512
},
"mistral": {
"mistralai/Mistral-7B-Instruct-v0.2": 32000
}
}

View File

@ -12,6 +12,7 @@ from langchain_community.embeddings import HuggingFaceHubEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import OllamaEmbeddings
from langchain_openai import OpenAIEmbeddings, AzureOpenAIEmbeddings
from langchain_community.embeddings.huggingface import HuggingFaceInferenceAPIEmbeddings
from ..models import OpenAI, Ollama, AzureOpenAI, HuggingFace, Bedrock
from .base_node import BaseNode
@ -95,6 +96,9 @@ class RAGNode(BaseNode):
api_key=embedding_model.openai_api_key)
elif isinstance(embedding_model, AzureOpenAIEmbeddings):
embeddings = embedding_model
elif isinstance(embedding_model, HuggingFaceInferenceAPIEmbeddings):
embeddings = embedding_model
elif isinstance(embedding_model, AzureOpenAI):
embeddings = AzureOpenAIEmbeddings()
elif isinstance(embedding_model, Ollama):