mirror of
https://github.com/VinciGit00/Scrapegraph-ai.git
synced 2026-06-23 21:00:30 +08:00
feat: Allow end users to pass model instances for llm and embedding model
This commit is contained in:
parent
40b2a346d5
commit
b86aac2188
63
examples/azure/smart_scraper_azure_openai.py
Normal file
63
examples/azure/smart_scraper_azure_openai.py
Normal file
@ -0,0 +1,63 @@
|
||||
"""
|
||||
Basic example of scraping pipeline using SmartScraper using Azure OpenAI Key
|
||||
"""
|
||||
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
from langchain_openai import AzureChatOpenAI
|
||||
from langchain_openai import AzureOpenAIEmbeddings
|
||||
from scrapegraphai.graphs import SmartScraperGraph
|
||||
from scrapegraphai.utils import prettify_exec_info
|
||||
|
||||
|
||||
## required environment variable in .env
|
||||
# AZURE_OPENAI_ENDPOINT
|
||||
# AZURE_OPENAI_CHAT_DEPLOYMENT_NAME
|
||||
# MODEL_NAME
|
||||
# AZURE_OPENAI_API_KEY
|
||||
# OPENAI_API_TYPE
|
||||
# AZURE_OPENAI_API_VERSION
|
||||
# AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT_NAME
|
||||
load_dotenv()
|
||||
|
||||
|
||||
# ************************************************
|
||||
# Initialize the model instances
|
||||
# ************************************************
|
||||
|
||||
llm_model_instance = AzureChatOpenAI(
|
||||
openai_api_version=os.environ["AZURE_OPENAI_API_VERSION"],
|
||||
azure_deployment=os.environ["AZURE_OPENAI_CHAT_DEPLOYMENT_NAME"]
|
||||
)
|
||||
|
||||
embedder_model_instance = AzureOpenAIEmbeddings(
|
||||
azure_deployment=os.environ["AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT_NAME"],
|
||||
openai_api_version=os.environ["AZURE_OPENAI_API_VERSION"],
|
||||
)
|
||||
|
||||
# ************************************************
|
||||
# Create the SmartScraperGraph instance and run it
|
||||
# ************************************************
|
||||
|
||||
graph_config = {
|
||||
"llm": {"model_instance": llm_model_instance},
|
||||
"embeddings": {"model_instance": embedder_model_instance}
|
||||
}
|
||||
|
||||
smart_scraper_graph = SmartScraperGraph(
|
||||
prompt="List me all the events, with the following fields: company_name, event_name, event_start_date, event_start_time, event_end_date, event_end_time, location, event_mode, event_category, third_party_redirect, no_of_days,
|
||||
time_in_hours, hosted_or_attending, refreshments_type, registration_available, registration_link",
|
||||
# also accepts a string with the already downloaded HTML code
|
||||
source="https://www.hmhco.com/event",
|
||||
config=graph_config
|
||||
)
|
||||
|
||||
result = smart_scraper_graph.run()
|
||||
print(result)
|
||||
|
||||
# ************************************************
|
||||
# Get graph execution info
|
||||
# ************************************************
|
||||
|
||||
graph_exec_info = smart_scraper_graph.get_execution_info()
|
||||
print(prettify_exec_info(graph_exec_info))
|
||||
@ -19,7 +19,7 @@ class AbstractGraph(ABC):
|
||||
self.prompt = prompt
|
||||
self.source = source
|
||||
self.config = config
|
||||
self.llm_model = self._create_llm(config["llm"])
|
||||
self.llm_model = self._create_llm(config["llm"], chat=True)
|
||||
self.embedder_model = self.llm_model if "embeddings" not in config else self._create_llm(
|
||||
config["embeddings"])
|
||||
|
||||
@ -32,7 +32,16 @@ class AbstractGraph(ABC):
|
||||
self.final_state = None
|
||||
self.execution_info = None
|
||||
|
||||
def _create_llm(self, llm_config: dict):
|
||||
def _set_model_token(self, llm):
|
||||
|
||||
if 'Azure' in str(type(llm)):
|
||||
try:
|
||||
self.model_token = models_tokens["azure"][llm.model_name]
|
||||
except KeyError:
|
||||
raise KeyError("Model not supported")
|
||||
|
||||
|
||||
def _create_llm(self, llm_config: dict, chat=False) -> object:
|
||||
"""
|
||||
Creates an instance of the language model (OpenAI or Gemini) based on configuration.
|
||||
"""
|
||||
@ -42,6 +51,12 @@ class AbstractGraph(ABC):
|
||||
}
|
||||
llm_params = {**llm_defaults, **llm_config}
|
||||
|
||||
# If model instance is passed directly instead of the model details
|
||||
if 'model_instance' in llm_params:
|
||||
if chat:
|
||||
self._set_model_token(llm_params['model_instance'])
|
||||
return llm_params['model_instance']
|
||||
|
||||
# Instantiate the language model based on the model name
|
||||
if "gpt-" in llm_params["model"]:
|
||||
try:
|
||||
@ -129,3 +144,4 @@ class AbstractGraph(ABC):
|
||||
Abstract method to execute the graph and return the result.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@ -18,7 +18,9 @@ models_tokens = {
|
||||
"gpt-4-32k": 32768,
|
||||
"gpt-4-32k-0613": 32768,
|
||||
},
|
||||
|
||||
"azure": {
|
||||
"gpt-3.5-turbo": 4096
|
||||
},
|
||||
"gemini": {
|
||||
"gemini-pro": 128000,
|
||||
},
|
||||
@ -45,3 +47,4 @@ models_tokens = {
|
||||
"claude3": 200000
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -92,6 +92,8 @@ class RAGNode(BaseNode):
|
||||
if isinstance(embedding_model, OpenAI):
|
||||
embeddings = OpenAIEmbeddings(
|
||||
api_key=embedding_model.openai_api_key)
|
||||
elif isinstance(embedding_model, AzureOpenAIEmbeddings):
|
||||
embeddings = embedding_model
|
||||
elif isinstance(embedding_model, AzureOpenAI):
|
||||
embeddings = AzureOpenAIEmbeddings()
|
||||
elif isinstance(embedding_model, Ollama):
|
||||
@ -133,3 +135,4 @@ class RAGNode(BaseNode):
|
||||
|
||||
state.update({self.output[0]: compressed_docs})
|
||||
return state
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user