mirror of
https://github.com/VinciGit00/Scrapegraph-ai.git
synced 2026-06-23 21:00:30 +08:00
feat: Allow end users to pass model instances for llm and embedding model
This commit is contained in:
parent
40b2a346d5
commit
b86aac2188
63
examples/azure/smart_scraper_azure_openai.py
Normal file
63
examples/azure/smart_scraper_azure_openai.py
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
"""
|
||||||
|
Basic example of scraping pipeline using SmartScraper using Azure OpenAI Key
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from langchain_openai import AzureChatOpenAI
|
||||||
|
from langchain_openai import AzureOpenAIEmbeddings
|
||||||
|
from scrapegraphai.graphs import SmartScraperGraph
|
||||||
|
from scrapegraphai.utils import prettify_exec_info
|
||||||
|
|
||||||
|
|
||||||
|
## required environment variable in .env
|
||||||
|
# AZURE_OPENAI_ENDPOINT
|
||||||
|
# AZURE_OPENAI_CHAT_DEPLOYMENT_NAME
|
||||||
|
# MODEL_NAME
|
||||||
|
# AZURE_OPENAI_API_KEY
|
||||||
|
# OPENAI_API_TYPE
|
||||||
|
# AZURE_OPENAI_API_VERSION
|
||||||
|
# AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT_NAME
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
# ************************************************
|
||||||
|
# Initialize the model instances
|
||||||
|
# ************************************************
|
||||||
|
|
||||||
|
llm_model_instance = AzureChatOpenAI(
|
||||||
|
openai_api_version=os.environ["AZURE_OPENAI_API_VERSION"],
|
||||||
|
azure_deployment=os.environ["AZURE_OPENAI_CHAT_DEPLOYMENT_NAME"]
|
||||||
|
)
|
||||||
|
|
||||||
|
embedder_model_instance = AzureOpenAIEmbeddings(
|
||||||
|
azure_deployment=os.environ["AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT_NAME"],
|
||||||
|
openai_api_version=os.environ["AZURE_OPENAI_API_VERSION"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# ************************************************
|
||||||
|
# Create the SmartScraperGraph instance and run it
|
||||||
|
# ************************************************
|
||||||
|
|
||||||
|
graph_config = {
|
||||||
|
"llm": {"model_instance": llm_model_instance},
|
||||||
|
"embeddings": {"model_instance": embedder_model_instance}
|
||||||
|
}
|
||||||
|
|
||||||
|
smart_scraper_graph = SmartScraperGraph(
|
||||||
|
prompt="List me all the events, with the following fields: company_name, event_name, event_start_date, event_start_time, event_end_date, event_end_time, location, event_mode, event_category, third_party_redirect, no_of_days,
|
||||||
|
time_in_hours, hosted_or_attending, refreshments_type, registration_available, registration_link",
|
||||||
|
# also accepts a string with the already downloaded HTML code
|
||||||
|
source="https://www.hmhco.com/event",
|
||||||
|
config=graph_config
|
||||||
|
)
|
||||||
|
|
||||||
|
result = smart_scraper_graph.run()
|
||||||
|
print(result)
|
||||||
|
|
||||||
|
# ************************************************
|
||||||
|
# Get graph execution info
|
||||||
|
# ************************************************
|
||||||
|
|
||||||
|
graph_exec_info = smart_scraper_graph.get_execution_info()
|
||||||
|
print(prettify_exec_info(graph_exec_info))
|
||||||
@ -19,7 +19,7 @@ class AbstractGraph(ABC):
|
|||||||
self.prompt = prompt
|
self.prompt = prompt
|
||||||
self.source = source
|
self.source = source
|
||||||
self.config = config
|
self.config = config
|
||||||
self.llm_model = self._create_llm(config["llm"])
|
self.llm_model = self._create_llm(config["llm"], chat=True)
|
||||||
self.embedder_model = self.llm_model if "embeddings" not in config else self._create_llm(
|
self.embedder_model = self.llm_model if "embeddings" not in config else self._create_llm(
|
||||||
config["embeddings"])
|
config["embeddings"])
|
||||||
|
|
||||||
@ -32,7 +32,16 @@ class AbstractGraph(ABC):
|
|||||||
self.final_state = None
|
self.final_state = None
|
||||||
self.execution_info = None
|
self.execution_info = None
|
||||||
|
|
||||||
def _create_llm(self, llm_config: dict):
|
def _set_model_token(self, llm):
|
||||||
|
|
||||||
|
if 'Azure' in str(type(llm)):
|
||||||
|
try:
|
||||||
|
self.model_token = models_tokens["azure"][llm.model_name]
|
||||||
|
except KeyError:
|
||||||
|
raise KeyError("Model not supported")
|
||||||
|
|
||||||
|
|
||||||
|
def _create_llm(self, llm_config: dict, chat=False) -> object:
|
||||||
"""
|
"""
|
||||||
Creates an instance of the language model (OpenAI or Gemini) based on configuration.
|
Creates an instance of the language model (OpenAI or Gemini) based on configuration.
|
||||||
"""
|
"""
|
||||||
@ -42,6 +51,12 @@ class AbstractGraph(ABC):
|
|||||||
}
|
}
|
||||||
llm_params = {**llm_defaults, **llm_config}
|
llm_params = {**llm_defaults, **llm_config}
|
||||||
|
|
||||||
|
# If model instance is passed directly instead of the model details
|
||||||
|
if 'model_instance' in llm_params:
|
||||||
|
if chat:
|
||||||
|
self._set_model_token(llm_params['model_instance'])
|
||||||
|
return llm_params['model_instance']
|
||||||
|
|
||||||
# Instantiate the language model based on the model name
|
# Instantiate the language model based on the model name
|
||||||
if "gpt-" in llm_params["model"]:
|
if "gpt-" in llm_params["model"]:
|
||||||
try:
|
try:
|
||||||
@ -129,3 +144,4 @@ class AbstractGraph(ABC):
|
|||||||
Abstract method to execute the graph and return the result.
|
Abstract method to execute the graph and return the result.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@ -18,7 +18,9 @@ models_tokens = {
|
|||||||
"gpt-4-32k": 32768,
|
"gpt-4-32k": 32768,
|
||||||
"gpt-4-32k-0613": 32768,
|
"gpt-4-32k-0613": 32768,
|
||||||
},
|
},
|
||||||
|
"azure": {
|
||||||
|
"gpt-3.5-turbo": 4096
|
||||||
|
},
|
||||||
"gemini": {
|
"gemini": {
|
||||||
"gemini-pro": 128000,
|
"gemini-pro": 128000,
|
||||||
},
|
},
|
||||||
@ -45,3 +47,4 @@ models_tokens = {
|
|||||||
"claude3": 200000
|
"claude3": 200000
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -92,6 +92,8 @@ class RAGNode(BaseNode):
|
|||||||
if isinstance(embedding_model, OpenAI):
|
if isinstance(embedding_model, OpenAI):
|
||||||
embeddings = OpenAIEmbeddings(
|
embeddings = OpenAIEmbeddings(
|
||||||
api_key=embedding_model.openai_api_key)
|
api_key=embedding_model.openai_api_key)
|
||||||
|
elif isinstance(embedding_model, AzureOpenAIEmbeddings):
|
||||||
|
embeddings = embedding_model
|
||||||
elif isinstance(embedding_model, AzureOpenAI):
|
elif isinstance(embedding_model, AzureOpenAI):
|
||||||
embeddings = AzureOpenAIEmbeddings()
|
embeddings = AzureOpenAIEmbeddings()
|
||||||
elif isinstance(embedding_model, Ollama):
|
elif isinstance(embedding_model, Ollama):
|
||||||
@ -133,3 +135,4 @@ class RAGNode(BaseNode):
|
|||||||
|
|
||||||
state.update({self.output[0]: compressed_docs})
|
state.update({self.output[0]: compressed_docs})
|
||||||
return state
|
return state
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user