mirror of
https://github.com/VinciGit00/Scrapegraph-ai.git
synced 2026-06-25 21:11:11 +08:00
feat: Enable end users to pass model instances of HuggingFaceHub
This commit is contained in:
parent
98dec36c60
commit
7599234ab9
63
examples/huggingfacehub/smart_scraper_huggingfacehub.py
Normal file
63
examples/huggingfacehub/smart_scraper_huggingfacehub.py
Normal file
@ -0,0 +1,63 @@
|
||||
"""
|
||||
Basic example of scraping pipeline using SmartScraper using Azure OpenAI Key
|
||||
"""
|
||||
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
from scrapegraphai.graphs import SmartScraperGraph
|
||||
from scrapegraphai.utils import prettify_exec_info
|
||||
from langchain_community.llms import HuggingFaceEndpoint
|
||||
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
|
||||
|
||||
|
||||
|
||||
|
||||
## required environment variable in .env
|
||||
#HUGGINGFACEHUB_API_TOKEN
|
||||
load_dotenv()
|
||||
|
||||
HUGGINGFACEHUB_API_TOKEN = os.getenv('HUGGINGFACEHUB_API_TOKEN')
|
||||
# ************************************************
|
||||
# Initialize the model instances
|
||||
# ************************************************
|
||||
|
||||
repo_id = "mistralai/Mistral-7B-Instruct-v0.2"
|
||||
|
||||
llm_model_instance = HuggingFaceEndpoint(
|
||||
repo_id=repo_id, max_length=128, temperature=0.5, token=HUGGINGFACEHUB_API_TOKEN
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
embedder_model_instance = HuggingFaceInferenceAPIEmbeddings(
|
||||
api_key=HUGGINGFACEHUB_API_TOKEN, model_name="sentence-transformers/all-MiniLM-l6-v2"
|
||||
)
|
||||
|
||||
# ************************************************
|
||||
# Create the SmartScraperGraph instance and run it
|
||||
# ************************************************
|
||||
|
||||
graph_config = {
|
||||
"llm": {"model_instance": llm_model_instance},
|
||||
"embeddings": {"model_instance": embedder_model_instance}
|
||||
}
|
||||
|
||||
smart_scraper_graph = SmartScraperGraph(
|
||||
prompt="List me all the events, with the following fields: company_name, event_name, event_start_date, event_start_time, event_end_date, event_end_time, location, event_mode, event_category, third_party_redirect, no_of_days, time_in_hours, hosted_or_attending, refreshments_type, registration_available, registration_link",
|
||||
# also accepts a string with the already downloaded HTML code
|
||||
source="https://www.hmhco.com/event",
|
||||
config=graph_config
|
||||
)
|
||||
|
||||
result = smart_scraper_graph.run()
|
||||
print(result)
|
||||
|
||||
# ************************************************
|
||||
# Get graph execution info
|
||||
# ************************************************
|
||||
|
||||
graph_exec_info = smart_scraper_graph.get_execution_info()
|
||||
print(prettify_exec_info(graph_exec_info))
|
||||
|
||||
|
||||
@ -64,6 +64,13 @@ class AbstractGraph(ABC):
|
||||
self.model_token = models_tokens["azure"][llm.model_name]
|
||||
except KeyError:
|
||||
raise KeyError("Model not supported")
|
||||
|
||||
elif 'HuggingFaceEndpoint' in str(type(llm)):
|
||||
if 'mistral' in llm.repo_id:
|
||||
try:
|
||||
self.model_token = models_tokens['mistral'][llm.repo_id]
|
||||
except KeyError:
|
||||
raise KeyError("Model not supported")
|
||||
|
||||
|
||||
def _create_llm(self, llm_config: dict, chat=False) -> object:
|
||||
|
||||
@ -65,5 +65,8 @@ models_tokens = {
|
||||
"mistral.mistral-large-2402-v1:0": 32768,
|
||||
"cohere.embed-english-v3": 512,
|
||||
"cohere.embed-multilingual-v3": 512
|
||||
},
|
||||
"mistral": {
|
||||
"mistralai/Mistral-7B-Instruct-v0.2": 32000
|
||||
}
|
||||
}
|
||||
|
||||
@ -12,6 +12,7 @@ from langchain_community.embeddings import HuggingFaceHubEmbeddings
|
||||
from langchain_community.vectorstores import FAISS
|
||||
from langchain_community.embeddings import OllamaEmbeddings
|
||||
from langchain_openai import OpenAIEmbeddings, AzureOpenAIEmbeddings
|
||||
from langchain_community.embeddings.huggingface import HuggingFaceInferenceAPIEmbeddings
|
||||
|
||||
from ..models import OpenAI, Ollama, AzureOpenAI, HuggingFace, Bedrock
|
||||
from .base_node import BaseNode
|
||||
@ -95,6 +96,9 @@ class RAGNode(BaseNode):
|
||||
api_key=embedding_model.openai_api_key)
|
||||
elif isinstance(embedding_model, AzureOpenAIEmbeddings):
|
||||
embeddings = embedding_model
|
||||
elif isinstance(embedding_model, HuggingFaceInferenceAPIEmbeddings):
|
||||
embeddings = embedding_model
|
||||
|
||||
elif isinstance(embedding_model, AzureOpenAI):
|
||||
embeddings = AzureOpenAIEmbeddings()
|
||||
elif isinstance(embedding_model, Ollama):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user