refactor: remove redundant LangChain wrappers
Some checks are pending
/ build (push) Waiting to run

This commit is contained in:
Federico Aguzzi 2024-07-29 21:57:37 +02:00
parent 5007167af1
commit 9275486240
5 changed files with 140 additions and 42 deletions

View File

@ -34,7 +34,8 @@ dependencies = [
"undetected-playwright>=0.3.0", "undetected-playwright>=0.3.0",
"semchunk>=1.0.1", "semchunk>=1.0.1",
"langchain-fireworks>=0.1.3", "langchain-fireworks>=0.1.3",
"langchain-community>=0.2.9" "langchain-community>=0.2.9",
"langchain-huggingface>=0.0.3",
] ]
license = "MIT" license = "MIT"

View File

@ -106,6 +106,8 @@ fastapi-pagination==0.12.26
# via burr # via burr
filelock==3.15.4 filelock==3.15.4
# via huggingface-hub # via huggingface-hub
# via torch
# via transformers
fireworks-ai==0.14.0 fireworks-ai==0.14.0
# via langchain-fireworks # via langchain-fireworks
fonttools==4.53.1 fonttools==4.53.1
@ -117,6 +119,7 @@ frozenlist==1.4.1
# via aiosignal # via aiosignal
fsspec==2024.6.1 fsspec==2024.6.1
# via huggingface-hub # via huggingface-hub
# via torch
furo==2024.5.6 furo==2024.5.6
# via scrapegraphai # via scrapegraphai
gitdb==4.0.11 gitdb==4.0.11
@ -180,6 +183,7 @@ graphviz==0.20.3
# via scrapegraphai # via scrapegraphai
greenlet==3.0.3 greenlet==3.0.3
# via playwright # via playwright
# via sqlalchemy
groq==0.9.0 groq==0.9.0
# via langchain-groq # via langchain-groq
grpc-google-iam-v1==0.13.1 grpc-google-iam-v1==0.13.1
@ -212,7 +216,10 @@ httpx==0.27.0
httpx-sse==0.4.0 httpx-sse==0.4.0
# via fireworks-ai # via fireworks-ai
huggingface-hub==0.24.0 huggingface-hub==0.24.0
# via langchain-huggingface
# via sentence-transformers
# via tokenizers # via tokenizers
# via transformers
idna==3.7 idna==3.7
# via anyio # via anyio
# via email-validator # via email-validator
@ -235,11 +242,14 @@ jinja2==3.1.4
# via fastapi # via fastapi
# via pydeck # via pydeck
# via sphinx # via sphinx
# via torch
jiter==0.5.0 jiter==0.5.0
# via anthropic # via anthropic
jmespath==1.0.1 jmespath==1.0.1
# via boto3 # via boto3
# via botocore # via botocore
joblib==1.4.2
# via scikit-learn
jsonpatch==1.33 jsonpatch==1.33
# via langchain-core # via langchain-core
jsonpointer==3.0.0 jsonpointer==3.0.0
@ -268,6 +278,7 @@ langchain-core==0.2.22
# via langchain-google-genai # via langchain-google-genai
# via langchain-google-vertexai # via langchain-google-vertexai
# via langchain-groq # via langchain-groq
# via langchain-huggingface
# via langchain-nvidia-ai-endpoints # via langchain-nvidia-ai-endpoints
# via langchain-openai # via langchain-openai
# via langchain-text-splitters # via langchain-text-splitters
@ -279,6 +290,8 @@ langchain-google-vertexai==1.0.7
# via scrapegraphai # via scrapegraphai
langchain-groq==0.1.6 langchain-groq==0.1.6
# via scrapegraphai # via scrapegraphai
langchain-huggingface==0.0.3
# via scrapegraphai
langchain-nvidia-ai-endpoints==0.1.6 langchain-nvidia-ai-endpoints==0.1.6
# via scrapegraphai # via scrapegraphai
langchain-openai==0.1.17 langchain-openai==0.1.17
@ -309,6 +322,8 @@ minify-html==0.15.0
# via scrapegraphai # via scrapegraphai
mpire==2.10.2 mpire==2.10.2
# via semchunk # via semchunk
mpmath==1.3.0
# via sympy
multidict==6.0.5 multidict==6.0.5
# via aiohttp # via aiohttp
# via yarl # via yarl
@ -316,6 +331,8 @@ multiprocess==0.70.16
# via mpire # via mpire
mypy-extensions==1.0.0 mypy-extensions==1.0.0
# via typing-inspect # via typing-inspect
networkx==3.2.1
# via torch
numpy==1.26.4 numpy==1.26.4
# via altair # via altair
# via contourpy # via contourpy
@ -327,9 +344,13 @@ numpy==1.26.4
# via pandas # via pandas
# via pyarrow # via pyarrow
# via pydeck # via pydeck
# via scikit-learn
# via scipy
# via sentence-transformers
# via sf-hamilton # via sf-hamilton
# via shapely # via shapely
# via streamlit # via streamlit
# via transformers
openai==1.37.0 openai==1.37.0
# via burr # via burr
# via langchain-fireworks # via langchain-fireworks
@ -348,6 +369,7 @@ packaging==24.1
# via pytest # via pytest
# via sphinx # via sphinx
# via streamlit # via streamlit
# via transformers
pandas==2.2.2 pandas==2.2.2
# via altair # via altair
# via scrapegraphai # via scrapegraphai
@ -357,6 +379,7 @@ pillow==10.4.0
# via fireworks-ai # via fireworks-ai
# via langchain-nvidia-ai-endpoints # via langchain-nvidia-ai-endpoints
# via matplotlib # via matplotlib
# via sentence-transformers
# via streamlit # via streamlit
platformdirs==4.2.2 platformdirs==4.2.2
# via pylint # via pylint
@ -436,12 +459,14 @@ pyyaml==6.0.1
# via langchain # via langchain
# via langchain-community # via langchain-community
# via langchain-core # via langchain-core
# via transformers
# via uvicorn # via uvicorn
referencing==0.35.1 referencing==0.35.1
# via jsonschema # via jsonschema
# via jsonschema-specifications # via jsonschema-specifications
regex==2024.5.15 regex==2024.5.15
# via tiktoken # via tiktoken
# via transformers
requests==2.32.3 requests==2.32.3
# via burr # via burr
# via free-proxy # via free-proxy
@ -456,6 +481,7 @@ requests==2.32.3
# via sphinx # via sphinx
# via streamlit # via streamlit
# via tiktoken # via tiktoken
# via transformers
rich==13.7.1 rich==13.7.1
# via streamlit # via streamlit
# via typer # via typer
@ -466,8 +492,17 @@ rsa==4.9
# via google-auth # via google-auth
s3transfer==0.10.2 s3transfer==0.10.2
# via boto3 # via boto3
safetensors==0.4.3
# via transformers
scikit-learn==1.5.1
# via sentence-transformers
scipy==1.13.1
# via scikit-learn
# via sentence-transformers
semchunk==2.2.0 semchunk==2.2.0
# via scrapegraphai # via scrapegraphai
sentence-transformers==3.0.1
# via langchain-huggingface
sf-hamilton==1.72.1 sf-hamilton==1.72.1
# via burr # via burr
shapely==2.0.5 shapely==2.0.5
@ -513,16 +548,22 @@ starlette==0.37.2
# via fastapi # via fastapi
streamlit==1.36.0 streamlit==1.36.0
# via burr # via burr
sympy==1.13.1
# via torch
tenacity==8.5.0 tenacity==8.5.0
# via langchain # via langchain
# via langchain-community # via langchain-community
# via langchain-core # via langchain-core
# via streamlit # via streamlit
threadpoolctl==3.5.0
# via scikit-learn
tiktoken==0.7.0 tiktoken==0.7.0
# via langchain-openai # via langchain-openai
# via scrapegraphai # via scrapegraphai
tokenizers==0.19.1 tokenizers==0.19.1
# via anthropic # via anthropic
# via langchain-huggingface
# via transformers
toml==0.10.2 toml==0.10.2
# via streamlit # via streamlit
tomli==2.0.1 tomli==2.0.1
@ -532,6 +573,8 @@ tomlkit==0.13.0
# via pylint # via pylint
toolz==0.12.1 toolz==0.12.1
# via altair # via altair
torch==2.2.2
# via sentence-transformers
tornado==6.4.1 tornado==6.4.1
# via streamlit # via streamlit
tqdm==4.66.4 tqdm==4.66.4
@ -541,6 +584,11 @@ tqdm==4.66.4
# via openai # via openai
# via scrapegraphai # via scrapegraphai
# via semchunk # via semchunk
# via sentence-transformers
# via transformers
transformers==4.43.3
# via langchain-huggingface
# via sentence-transformers
typer==0.12.3 typer==0.12.3
# via fastapi-cli # via fastapi-cli
typing-extensions==4.12.2 typing-extensions==4.12.2
@ -562,6 +610,7 @@ typing-extensions==4.12.2
# via sqlalchemy # via sqlalchemy
# via starlette # via starlette
# via streamlit # via streamlit
# via torch
# via typer # via typer
# via typing-inspect # via typing-inspect
# via uvicorn # via uvicorn

View File

@ -63,6 +63,8 @@ faiss-cpu==1.8.0.post1
# via scrapegraphai # via scrapegraphai
filelock==3.15.4 filelock==3.15.4
# via huggingface-hub # via huggingface-hub
# via torch
# via transformers
fireworks-ai==0.14.0 fireworks-ai==0.14.0
# via langchain-fireworks # via langchain-fireworks
free-proxy==1.1.1 free-proxy==1.1.1
@ -72,6 +74,7 @@ frozenlist==1.4.1
# via aiosignal # via aiosignal
fsspec==2024.6.1 fsspec==2024.6.1
# via huggingface-hub # via huggingface-hub
# via torch
google==3.0.0 google==3.0.0
# via scrapegraphai # via scrapegraphai
google-ai-generativelanguage==0.6.6 google-ai-generativelanguage==0.6.6
@ -128,6 +131,7 @@ graphviz==0.20.3
# via scrapegraphai # via scrapegraphai
greenlet==3.0.3 greenlet==3.0.3
# via playwright # via playwright
# via sqlalchemy
groq==0.9.0 groq==0.9.0
# via langchain-groq # via langchain-groq
grpc-google-iam-v1==0.13.1 grpc-google-iam-v1==0.13.1
@ -156,17 +160,24 @@ httpx==0.27.0
httpx-sse==0.4.0 httpx-sse==0.4.0
# via fireworks-ai # via fireworks-ai
huggingface-hub==0.24.0 huggingface-hub==0.24.0
# via langchain-huggingface
# via sentence-transformers
# via tokenizers # via tokenizers
# via transformers
idna==3.7 idna==3.7
# via anyio # via anyio
# via httpx # via httpx
# via requests # via requests
# via yarl # via yarl
jinja2==3.1.4
# via torch
jiter==0.5.0 jiter==0.5.0
# via anthropic # via anthropic
jmespath==1.0.1 jmespath==1.0.1
# via boto3 # via boto3
# via botocore # via botocore
joblib==1.4.2
# via scikit-learn
jsonpatch==1.33 jsonpatch==1.33
# via langchain-core # via langchain-core
jsonpointer==3.0.0 jsonpointer==3.0.0
@ -189,6 +200,7 @@ langchain-core==0.2.22
# via langchain-google-genai # via langchain-google-genai
# via langchain-google-vertexai # via langchain-google-vertexai
# via langchain-groq # via langchain-groq
# via langchain-huggingface
# via langchain-nvidia-ai-endpoints # via langchain-nvidia-ai-endpoints
# via langchain-openai # via langchain-openai
# via langchain-text-splitters # via langchain-text-splitters
@ -200,6 +212,8 @@ langchain-google-vertexai==1.0.7
# via scrapegraphai # via scrapegraphai
langchain-groq==0.1.6 langchain-groq==0.1.6
# via scrapegraphai # via scrapegraphai
langchain-huggingface==0.0.3
# via scrapegraphai
langchain-nvidia-ai-endpoints==0.1.6 langchain-nvidia-ai-endpoints==0.1.6
# via scrapegraphai # via scrapegraphai
langchain-openai==0.1.17 langchain-openai==0.1.17
@ -212,12 +226,16 @@ langsmith==0.1.93
# via langchain-core # via langchain-core
lxml==5.2.2 lxml==5.2.2
# via free-proxy # via free-proxy
markupsafe==2.1.5
# via jinja2
marshmallow==3.21.3 marshmallow==3.21.3
# via dataclasses-json # via dataclasses-json
minify-html==0.15.0 minify-html==0.15.0
# via scrapegraphai # via scrapegraphai
mpire==2.10.2 mpire==2.10.2
# via semchunk # via semchunk
mpmath==1.3.0
# via sympy
multidict==6.0.5 multidict==6.0.5
# via aiohttp # via aiohttp
# via yarl # via yarl
@ -225,13 +243,19 @@ multiprocess==0.70.16
# via mpire # via mpire
mypy-extensions==1.0.0 mypy-extensions==1.0.0
# via typing-inspect # via typing-inspect
networkx==3.2.1
# via torch
numpy==1.26.4 numpy==1.26.4
# via faiss-cpu # via faiss-cpu
# via langchain # via langchain
# via langchain-aws # via langchain-aws
# via langchain-community # via langchain-community
# via pandas # via pandas
# via scikit-learn
# via scipy
# via sentence-transformers
# via shapely # via shapely
# via transformers
openai==1.37.0 openai==1.37.0
# via langchain-fireworks # via langchain-fireworks
# via langchain-openai # via langchain-openai
@ -244,11 +268,13 @@ packaging==24.1
# via huggingface-hub # via huggingface-hub
# via langchain-core # via langchain-core
# via marshmallow # via marshmallow
# via transformers
pandas==2.2.2 pandas==2.2.2
# via scrapegraphai # via scrapegraphai
pillow==10.4.0 pillow==10.4.0
# via fireworks-ai # via fireworks-ai
# via langchain-nvidia-ai-endpoints # via langchain-nvidia-ai-endpoints
# via sentence-transformers
playwright==1.45.0 playwright==1.45.0
# via scrapegraphai # via scrapegraphai
# via undetected-playwright # via undetected-playwright
@ -303,8 +329,10 @@ pyyaml==6.0.1
# via langchain # via langchain
# via langchain-community # via langchain-community
# via langchain-core # via langchain-core
# via transformers
regex==2024.5.15 regex==2024.5.15
# via tiktoken # via tiktoken
# via transformers
requests==2.32.3 requests==2.32.3
# via free-proxy # via free-proxy
# via google-api-core # via google-api-core
@ -316,12 +344,22 @@ requests==2.32.3
# via langchain-fireworks # via langchain-fireworks
# via langsmith # via langsmith
# via tiktoken # via tiktoken
# via transformers
rsa==4.9 rsa==4.9
# via google-auth # via google-auth
s3transfer==0.10.2 s3transfer==0.10.2
# via boto3 # via boto3
safetensors==0.4.3
# via transformers
scikit-learn==1.5.1
# via sentence-transformers
scipy==1.13.1
# via scikit-learn
# via sentence-transformers
semchunk==2.2.0 semchunk==2.2.0
# via scrapegraphai # via scrapegraphai
sentence-transformers==3.0.1
# via langchain-huggingface
shapely==2.0.5 shapely==2.0.5
# via google-cloud-aiplatform # via google-cloud-aiplatform
six==1.16.0 six==1.16.0
@ -337,15 +375,23 @@ soupsieve==2.5
sqlalchemy==2.0.31 sqlalchemy==2.0.31
# via langchain # via langchain
# via langchain-community # via langchain-community
sympy==1.13.1
# via torch
tenacity==8.5.0 tenacity==8.5.0
# via langchain # via langchain
# via langchain-community # via langchain-community
# via langchain-core # via langchain-core
threadpoolctl==3.5.0
# via scikit-learn
tiktoken==0.7.0 tiktoken==0.7.0
# via langchain-openai # via langchain-openai
# via scrapegraphai # via scrapegraphai
tokenizers==0.19.1 tokenizers==0.19.1
# via anthropic # via anthropic
# via langchain-huggingface
# via transformers
torch==2.2.2
# via sentence-transformers
tqdm==4.66.4 tqdm==4.66.4
# via google-generativeai # via google-generativeai
# via huggingface-hub # via huggingface-hub
@ -353,6 +399,11 @@ tqdm==4.66.4
# via openai # via openai
# via scrapegraphai # via scrapegraphai
# via semchunk # via semchunk
# via sentence-transformers
# via transformers
transformers==4.43.3
# via langchain-huggingface
# via sentence-transformers
typing-extensions==4.12.2 typing-extensions==4.12.2
# via anthropic # via anthropic
# via anyio # via anyio
@ -364,6 +415,7 @@ typing-extensions==4.12.2
# via pydantic-core # via pydantic-core
# via pyee # via pyee
# via sqlalchemy # via sqlalchemy
# via torch
# via typing-inspect # via typing-inspect
typing-inspect==0.9.0 typing-inspect==0.9.0
# via dataclasses-json # via dataclasses-json

View File

@ -22,3 +22,4 @@ undetected-playwright>=0.3.0
semchunk>=1.0.1 semchunk>=1.0.1
langchain-fireworks>=0.1.3 langchain-fireworks>=0.1.3
langchain-community>=0.2.9 langchain-community>=0.2.9
langchain-huggingface>=0.0.3

View File

@ -3,33 +3,28 @@ AbstractGraph Module
""" """
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Optional, Union from typing import Optional
import uuid import uuid
from pydantic import BaseModel from pydantic import BaseModel
from langchain_community.chat_models import ChatOllama from langchain_community.chat_models import ChatOllama
from langchain_openai import ChatOpenAI from langchain_openai import ChatOpenAI
from langchain_aws import BedrockEmbeddings from langchain_aws import BedrockEmbeddings, ChatBedrock
from langchain_community.embeddings import HuggingFaceHubEmbeddings, OllamaEmbeddings from langchain_huggingface import ChatHuggingFace, HuggingFaceEmbeddings
from langchain_community.embeddings import OllamaEmbeddings
from langchain_google_genai import GoogleGenerativeAIEmbeddings from langchain_google_genai import GoogleGenerativeAIEmbeddings
from langchain_google_vertexai import VertexAIEmbeddings from langchain_google_vertexai import ChatVertexAI, VertexAIEmbeddings
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings
from langchain_fireworks import FireworksEmbeddings from langchain_fireworks import FireworksEmbeddings, ChatFireworks
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings, ChatOpenAI, AzureChatOpenAI
from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings
from ..helpers import models_tokens from ..helpers import models_tokens
from ..models import ( from ..models import (
Anthropic,
AzureOpenAI,
Bedrock,
Gemini,
Groq,
HuggingFace,
OneApi, OneApi,
Fireworks, Nvidia,
VertexAI, DeepSeek
Nvidia
) )
from ..models.ernie import Ernie from ..models.ernie import Ernie
from langchain.chat_models import init_chat_model from langchain.chat_models import init_chat_model
@ -37,7 +32,6 @@ from langchain.chat_models import init_chat_model
from ..utils.logging import set_verbosity_debug, set_verbosity_warning, set_verbosity_info from ..utils.logging import set_verbosity_debug, set_verbosity_warning, set_verbosity_info
from ..helpers import models_tokens from ..helpers import models_tokens
from ..models import AzureOpenAI, Bedrock, Gemini, Groq, HuggingFace, Anthropic, DeepSeek
class AbstractGraph(ABC): class AbstractGraph(ABC):
@ -181,7 +175,8 @@ class AbstractGraph(ABC):
llm_params["model"] = "/".join(llm_params["model"].split("/")[1:]) llm_params["model"] = "/".join(llm_params["model"].split("/")[1:])
except KeyError as exc: except KeyError as exc:
raise KeyError("Model not supported") from exc raise KeyError("Model not supported") from exc
return Fireworks(llm_params) llm_params["model_provider"] = "fireworks"
return init_chat_model(**llm_params)
elif "azure" in llm_params["model"]: elif "azure" in llm_params["model"]:
# take the model after the last dash # take the model after the last dash
llm_params["model"] = llm_params["model"].split("/")[-1] llm_params["model"] = llm_params["model"].split("/")[-1]
@ -189,7 +184,8 @@ class AbstractGraph(ABC):
self.model_token = models_tokens["azure"][llm_params["model"]] self.model_token = models_tokens["azure"][llm_params["model"]]
except KeyError as exc: except KeyError as exc:
raise KeyError("Model not supported") from exc raise KeyError("Model not supported") from exc
return AzureOpenAI(llm_params) llm_params["model_provider"] = "azure_openai"
return init_chat_model(**llm_params)
elif "nvidia" in llm_params["model"]: elif "nvidia" in llm_params["model"]:
try: try:
self.model_token = models_tokens["nvidia"][llm_params["model"].split("/")[-1]] self.model_token = models_tokens["nvidia"][llm_params["model"].split("/")[-1]]
@ -203,20 +199,23 @@ class AbstractGraph(ABC):
self.model_token = models_tokens["gemini"][llm_params["model"]] self.model_token = models_tokens["gemini"][llm_params["model"]]
except KeyError as exc: except KeyError as exc:
raise KeyError("Model not supported") from exc raise KeyError("Model not supported") from exc
return Gemini(llm_params) llm_params["model_provider"] = "google_genai "
return init_chat_model(**llm_params)
elif llm_params["model"].startswith("claude"): elif llm_params["model"].startswith("claude"):
llm_params["model"] = llm_params["model"].split("/")[-1] llm_params["model"] = llm_params["model"].split("/")[-1]
try: try:
self.model_token = models_tokens["claude"][llm_params["model"]] self.model_token = models_tokens["claude"][llm_params["model"]]
except KeyError as exc: except KeyError as exc:
raise KeyError("Model not supported") from exc raise KeyError("Model not supported") from exc
return Anthropic(llm_params) llm_params["model_provider"] = "anthropic"
return init_chat_model(**llm_params)
elif llm_params["model"].startswith("vertexai"): elif llm_params["model"].startswith("vertexai"):
try: try:
self.model_token = models_tokens["vertexai"][llm_params["model"]] self.model_token = models_tokens["vertexai"][llm_params["model"]]
except KeyError as exc: except KeyError as exc:
raise KeyError("Model not supported") from exc raise KeyError("Model not supported") from exc
return VertexAI(llm_params) llm_params["model_provider"] = "google_vertexai"
return init_chat_model(**llm_params)
elif "ollama" in llm_params["model"]: elif "ollama" in llm_params["model"]:
llm_params["model"] = llm_params["model"].split("ollama/")[-1] llm_params["model"] = llm_params["model"].split("ollama/")[-1]
@ -246,7 +245,8 @@ class AbstractGraph(ABC):
except KeyError: except KeyError:
print("model not found, using default token size (8192)") print("model not found, using default token size (8192)")
self.model_token = 8192 self.model_token = 8192
return HuggingFace(llm_params) llm_params["model_provider"] = "hugging_face"
return init_chat_model(**llm_params)
elif "groq" in llm_params["model"]: elif "groq" in llm_params["model"]:
llm_params["model"] = llm_params["model"].split("/")[-1] llm_params["model"] = llm_params["model"].split("/")[-1]
@ -255,7 +255,8 @@ class AbstractGraph(ABC):
except KeyError: except KeyError:
print("model not found, using default token size (8192)") print("model not found, using default token size (8192)")
self.model_token = 8192 self.model_token = 8192
return Groq(llm_params) llm_params["model_provider"] = "groq"
return init_chat_model(**llm_params)
elif "bedrock" in llm_params["model"]: elif "bedrock" in llm_params["model"]:
llm_params["model"] = llm_params["model"].split("/")[-1] llm_params["model"] = llm_params["model"].split("/")[-1]
model_id = llm_params["model"] model_id = llm_params["model"]
@ -265,22 +266,16 @@ class AbstractGraph(ABC):
except KeyError: except KeyError:
print("model not found, using default token size (8192)") print("model not found, using default token size (8192)")
self.model_token = 8192 self.model_token = 8192
return Bedrock( llm_params["model_provider"] = "bedrock"
{ return init_chat_model(**llm_params)
"client": client,
"model_id": model_id,
"model_kwargs": {
"temperature": llm_params["temperature"],
},
}
)
elif "claude-3-" in llm_params["model"]: elif "claude-3-" in llm_params["model"]:
try: try:
self.model_token = models_tokens["claude"]["claude3"] self.model_token = models_tokens["claude"]["claude3"]
except KeyError: except KeyError:
print("model not found, using default token size (8192)") print("model not found, using default token size (8192)")
self.model_token = 8192 self.model_token = 8192
return Anthropic(llm_params) llm_params["model_provider"] = "anthropic"
return init_chat_model(**llm_params)
elif "deepseek" in llm_params["model"]: elif "deepseek" in llm_params["model"]:
try: try:
self.model_token = models_tokens["deepseek"][llm_params["model"]] self.model_token = models_tokens["deepseek"][llm_params["model"]]
@ -308,7 +303,7 @@ class AbstractGraph(ABC):
Raises: Raises:
ValueError: If the model is not supported. ValueError: If the model is not supported.
""" """
if isinstance(self.llm_model, Gemini): if isinstance(self.llm_model, ChatGoogleGenerativeAI):
return GoogleGenerativeAIEmbeddings( return GoogleGenerativeAIEmbeddings(
google_api_key=llm_config["api_key"], model="models/embedding-001" google_api_key=llm_config["api_key"], model="models/embedding-001"
) )
@ -317,13 +312,13 @@ class AbstractGraph(ABC):
base_url=self.llm_model.openai_api_base) base_url=self.llm_model.openai_api_base)
elif isinstance(self.llm_model, DeepSeek): elif isinstance(self.llm_model, DeepSeek):
return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key) return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key)
elif isinstance(self.llm_model, VertexAI): elif isinstance(self.llm_model, ChatVertexAI):
return VertexAIEmbeddings() return VertexAIEmbeddings()
elif isinstance(self.llm_model, AzureOpenAIEmbeddings): elif isinstance(self.llm_model, AzureOpenAIEmbeddings):
return self.llm_model return self.llm_model
elif isinstance(self.llm_model, AzureOpenAI): elif isinstance(self.llm_model, AzureChatOpenAI):
return AzureOpenAIEmbeddings() return AzureOpenAIEmbeddings()
elif isinstance(self.llm_model, Fireworks): elif isinstance(self.llm_model, ChatFireworks):
return FireworksEmbeddings(model=self.llm_model.model_name) return FireworksEmbeddings(model=self.llm_model.model_name)
elif isinstance(self.llm_model, Nvidia): elif isinstance(self.llm_model, Nvidia):
return NVIDIAEmbeddings(model=self.llm_model.model_name) return NVIDIAEmbeddings(model=self.llm_model.model_name)
@ -335,9 +330,9 @@ class AbstractGraph(ABC):
params.pop("temperature", None) params.pop("temperature", None)
return OllamaEmbeddings(**params) return OllamaEmbeddings(**params)
elif isinstance(self.llm_model, HuggingFace): elif isinstance(self.llm_model, ChatHuggingFace):
return HuggingFaceHubEmbeddings(model=self.llm_model.model) return HuggingFaceEmbeddings(model=self.llm_model.model)
elif isinstance(self.llm_model, Bedrock): elif isinstance(self.llm_model, ChatBedrock):
return BedrockEmbeddings(client=None, model_id=self.llm_model.model_id) return BedrockEmbeddings(client=None, model_id=self.llm_model.model_id)
else: else:
raise ValueError("Embedding Model missing or not supported") raise ValueError("Embedding Model missing or not supported")
@ -384,7 +379,7 @@ class AbstractGraph(ABC):
models_tokens["hugging_face"][embedder_params["model"]] models_tokens["hugging_face"][embedder_params["model"]]
except KeyError as exc: except KeyError as exc:
raise KeyError("Model not supported") from exc raise KeyError("Model not supported") from exc
return HuggingFaceHubEmbeddings(model=embedder_params["model"]) return HuggingFaceEmbeddings(model=embedder_params["model"])
elif "fireworks" in embedder_params["model"]: elif "fireworks" in embedder_params["model"]:
embedder_params["model"] = "/".join(embedder_params["model"].split("/")[1:]) embedder_params["model"] = "/".join(embedder_params["model"].split("/")[1:])
try: try: