feat: Allow end users to pass model instances for llm and embedding model

This commit is contained in:
Shubham Kamboj 2024-05-02 20:09:23 +05:30
parent 40b2a346d5
commit b86aac2188
4 changed files with 88 additions and 3 deletions

View File

@ -0,0 +1,63 @@
"""
Basic example of scraping pipeline using SmartScraper using Azure OpenAI Key
"""
import os
from dotenv import load_dotenv
from langchain_openai import AzureChatOpenAI
from langchain_openai import AzureOpenAIEmbeddings
from scrapegraphai.graphs import SmartScraperGraph
from scrapegraphai.utils import prettify_exec_info
## required environment variable in .env
# AZURE_OPENAI_ENDPOINT
# AZURE_OPENAI_CHAT_DEPLOYMENT_NAME
# MODEL_NAME
# AZURE_OPENAI_API_KEY
# OPENAI_API_TYPE
# AZURE_OPENAI_API_VERSION
# AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT_NAME
load_dotenv()
# ************************************************
# Initialize the model instances
# ************************************************
llm_model_instance = AzureChatOpenAI(
openai_api_version=os.environ["AZURE_OPENAI_API_VERSION"],
azure_deployment=os.environ["AZURE_OPENAI_CHAT_DEPLOYMENT_NAME"]
)
embedder_model_instance = AzureOpenAIEmbeddings(
azure_deployment=os.environ["AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT_NAME"],
openai_api_version=os.environ["AZURE_OPENAI_API_VERSION"],
)
# ************************************************
# Create the SmartScraperGraph instance and run it
# ************************************************
graph_config = {
"llm": {"model_instance": llm_model_instance},
"embeddings": {"model_instance": embedder_model_instance}
}
smart_scraper_graph = SmartScraperGraph(
prompt="List me all the events, with the following fields: company_name, event_name, event_start_date, event_start_time, event_end_date, event_end_time, location, event_mode, event_category, third_party_redirect, no_of_days,
time_in_hours, hosted_or_attending, refreshments_type, registration_available, registration_link",
# also accepts a string with the already downloaded HTML code
source="https://www.hmhco.com/event",
config=graph_config
)
result = smart_scraper_graph.run()
print(result)
# ************************************************
# Get graph execution info
# ************************************************
graph_exec_info = smart_scraper_graph.get_execution_info()
print(prettify_exec_info(graph_exec_info))

View File

@ -19,7 +19,7 @@ class AbstractGraph(ABC):
self.prompt = prompt
self.source = source
self.config = config
self.llm_model = self._create_llm(config["llm"])
self.llm_model = self._create_llm(config["llm"], chat=True)
self.embedder_model = self.llm_model if "embeddings" not in config else self._create_llm(
config["embeddings"])
@ -32,7 +32,16 @@ class AbstractGraph(ABC):
self.final_state = None
self.execution_info = None
def _create_llm(self, llm_config: dict):
def _set_model_token(self, llm):
if 'Azure' in str(type(llm)):
try:
self.model_token = models_tokens["azure"][llm.model_name]
except KeyError:
raise KeyError("Model not supported")
def _create_llm(self, llm_config: dict, chat=False) -> object:
"""
Creates an instance of the language model (OpenAI or Gemini) based on configuration.
"""
@ -42,6 +51,12 @@ class AbstractGraph(ABC):
}
llm_params = {**llm_defaults, **llm_config}
# If model instance is passed directly instead of the model details
if 'model_instance' in llm_params:
if chat:
self._set_model_token(llm_params['model_instance'])
return llm_params['model_instance']
# Instantiate the language model based on the model name
if "gpt-" in llm_params["model"]:
try:
@ -129,3 +144,4 @@ class AbstractGraph(ABC):
Abstract method to execute the graph and return the result.
"""
pass

View File

@ -18,7 +18,9 @@ models_tokens = {
"gpt-4-32k": 32768,
"gpt-4-32k-0613": 32768,
},
"azure": {
"gpt-3.5-turbo": 4096
},
"gemini": {
"gemini-pro": 128000,
},
@ -45,3 +47,4 @@ models_tokens = {
"claude3": 200000
}
}

View File

@ -92,6 +92,8 @@ class RAGNode(BaseNode):
if isinstance(embedding_model, OpenAI):
embeddings = OpenAIEmbeddings(
api_key=embedding_model.openai_api_key)
elif isinstance(embedding_model, AzureOpenAIEmbeddings):
embeddings = embedding_model
elif isinstance(embedding_model, AzureOpenAI):
embeddings = AzureOpenAIEmbeddings()
elif isinstance(embedding_model, Ollama):
@ -133,3 +135,4 @@ class RAGNode(BaseNode):
state.update({self.output[0]: compressed_docs})
return state