feat: add vertexai integration

This commit is contained in:
Marco Vinciguerra 2024-07-01 12:21:47 +02:00
parent 79a2f51c34
commit 119514bdfc
6 changed files with 41 additions and 9 deletions

View File

@ -16,6 +16,7 @@ dependencies = [
"langchain==0.1.15",
"langchain-openai==0.1.6",
"langchain-google-genai==1.0.3",
"langchain-google-vertexai==1.0.6",
"langchain-groq==0.1.3",
"langchain-aws==0.1.3",
"langchain-anthropic==0.1.11",

View File

@ -1,6 +1,7 @@
langchain==0.1.14
langchain-openai==0.1.1
langchain-google-genai==1.0.1
langchain-google-vertexai==1.0.6
langchain-anthropic==0.1.11
html2text==2020.1.16
faiss-cpu==1.8.0

View File

@ -10,9 +10,9 @@ from pydantic import BaseModel
from langchain_aws import BedrockEmbeddings
from langchain_community.embeddings import HuggingFaceHubEmbeddings, OllamaEmbeddings
from langchain_google_genai import GoogleGenerativeAIEmbeddings
from langchain_google_vertexai import VertexAIEmbeddings
from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
from ..helpers import models_tokens
from ..models import (
Anthropic,
@ -23,7 +23,8 @@ from ..models import (
HuggingFace,
Ollama,
OpenAI,
OneApi
OneApi,
VertexAI
)
from ..models.ernie import Ernie
from ..utils.logging import set_verbosity_debug, set_verbosity_warning, set_verbosity_info
@ -71,7 +72,7 @@ class AbstractGraph(ABC):
self.config = config
self.schema = schema
self.llm_model = self._create_llm(config["llm"], chat=True)
self.embedder_model = self._create_default_embedder(llm_config=config["llm"] ) if "embeddings" not in config else self._create_embedder(
self.embedder_model = self._create_default_embedder(llm_config=config["llm"]) if "embeddings" not in config else self._create_embedder(
config["embeddings"])
self.verbose = False if config is None else config.get(
"verbose", False)
@ -102,7 +103,7 @@ class AbstractGraph(ABC):
"embedder_model": self.embedder_model,
"cache_path": self.cache_path,
}
self.set_common_params(common_params, overwrite=True)
# set burr config
@ -125,7 +126,7 @@ class AbstractGraph(ABC):
for node in self.graph.nodes:
node.update_config(params, overwrite)
def _create_llm(self, llm_config: dict, chat=False) -> object:
"""
Create a large language model instance based on the configuration provided.
@ -170,7 +171,6 @@ class AbstractGraph(ABC):
except KeyError as exc:
raise KeyError("Model not supported") from exc
return AzureOpenAI(llm_params)
elif "gemini" in llm_params["model"]:
try:
self.model_token = models_tokens["gemini"][llm_params["model"]]
@ -183,6 +183,12 @@ class AbstractGraph(ABC):
except KeyError as exc:
raise KeyError("Model not supported") from exc
return Anthropic(llm_params)
elif llm_params["model"].startswith("vertexai"):
try:
self.model_token = models_tokens["vertexai"][llm_params["model"]]
except KeyError as exc:
raise KeyError("Model not supported") from exc
return VertexAI(llm_params)
elif "ollama" in llm_params["model"]:
llm_params["model"] = llm_params["model"].split("ollama/")[-1]
@ -275,10 +281,12 @@ class AbstractGraph(ABC):
google_api_key=llm_config["api_key"], model="models/embedding-001"
)
if isinstance(self.llm_model, OpenAI):
return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key, base_url=self.llm_model.openai_api_base)
return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key,
base_url=self.llm_model.openai_api_base)
elif isinstance(self.llm_model, DeepSeek):
return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key)
return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key)
elif isinstance(self.llm_model, VertexAI):
return VertexAIEmbeddings()
elif isinstance(self.llm_model, AzureOpenAIEmbeddings):
return self.llm_model
elif isinstance(self.llm_model, AzureOpenAI):

View File

@ -75,6 +75,11 @@ models_tokens = {
"claude2.1": 200000,
"claude3": 200000
},
"vertexai": {
"gemini-1.5-flash": 128000,
"gemini-1.5-pro": 128000,
"gemini-1.0-pro": 128000
},
"bedrock": {
"anthropic.claude-3-haiku-20240307-v1:0": 200000,
"anthropic.claude-3-sonnet-20240229-v1:0": 200000,

View File

@ -14,3 +14,4 @@ from .bedrock import Bedrock
from .anthropic import Anthropic
from .deepseek import DeepSeek
from .oneapi import OneApi
from .vertex import VertexAI

View File

@ -0,0 +1,16 @@
"""
VertexAI Module
"""
from langchain_google_vertexai import ChatVertexAI
class VertexAI(ChatVertexAI):
"""
A wrapper for the ChatVertexAI class that provides default configuration
and could be extended with additional methods if needed.
Args:
llm_config (dict): Configuration parameters for the language model.
"""
def __init__(self, llm_config: dict):
super().__init__(**llm_config)