refactor: remove redundant LangChain wrappers
Some checks are pending
/ build (push) Waiting to run

This commit is contained in:
Federico Aguzzi 2024-07-29 21:57:37 +02:00
parent 5007167af1
commit 9275486240
5 changed files with 140 additions and 42 deletions

View File

@ -34,7 +34,8 @@ dependencies = [
"undetected-playwright>=0.3.0",
"semchunk>=1.0.1",
"langchain-fireworks>=0.1.3",
"langchain-community>=0.2.9"
"langchain-community>=0.2.9",
"langchain-huggingface>=0.0.3",
]
license = "MIT"

View File

@ -106,6 +106,8 @@ fastapi-pagination==0.12.26
# via burr
filelock==3.15.4
# via huggingface-hub
# via torch
# via transformers
fireworks-ai==0.14.0
# via langchain-fireworks
fonttools==4.53.1
@ -117,6 +119,7 @@ frozenlist==1.4.1
# via aiosignal
fsspec==2024.6.1
# via huggingface-hub
# via torch
furo==2024.5.6
# via scrapegraphai
gitdb==4.0.11
@ -180,6 +183,7 @@ graphviz==0.20.3
# via scrapegraphai
greenlet==3.0.3
# via playwright
# via sqlalchemy
groq==0.9.0
# via langchain-groq
grpc-google-iam-v1==0.13.1
@ -212,7 +216,10 @@ httpx==0.27.0
httpx-sse==0.4.0
# via fireworks-ai
huggingface-hub==0.24.0
# via langchain-huggingface
# via sentence-transformers
# via tokenizers
# via transformers
idna==3.7
# via anyio
# via email-validator
@ -235,11 +242,14 @@ jinja2==3.1.4
# via fastapi
# via pydeck
# via sphinx
# via torch
jiter==0.5.0
# via anthropic
jmespath==1.0.1
# via boto3
# via botocore
joblib==1.4.2
# via scikit-learn
jsonpatch==1.33
# via langchain-core
jsonpointer==3.0.0
@ -268,6 +278,7 @@ langchain-core==0.2.22
# via langchain-google-genai
# via langchain-google-vertexai
# via langchain-groq
# via langchain-huggingface
# via langchain-nvidia-ai-endpoints
# via langchain-openai
# via langchain-text-splitters
@ -279,6 +290,8 @@ langchain-google-vertexai==1.0.7
# via scrapegraphai
langchain-groq==0.1.6
# via scrapegraphai
langchain-huggingface==0.0.3
# via scrapegraphai
langchain-nvidia-ai-endpoints==0.1.6
# via scrapegraphai
langchain-openai==0.1.17
@ -309,6 +322,8 @@ minify-html==0.15.0
# via scrapegraphai
mpire==2.10.2
# via semchunk
mpmath==1.3.0
# via sympy
multidict==6.0.5
# via aiohttp
# via yarl
@ -316,6 +331,8 @@ multiprocess==0.70.16
# via mpire
mypy-extensions==1.0.0
# via typing-inspect
networkx==3.2.1
# via torch
numpy==1.26.4
# via altair
# via contourpy
@ -327,9 +344,13 @@ numpy==1.26.4
# via pandas
# via pyarrow
# via pydeck
# via scikit-learn
# via scipy
# via sentence-transformers
# via sf-hamilton
# via shapely
# via streamlit
# via transformers
openai==1.37.0
# via burr
# via langchain-fireworks
@ -348,6 +369,7 @@ packaging==24.1
# via pytest
# via sphinx
# via streamlit
# via transformers
pandas==2.2.2
# via altair
# via scrapegraphai
@ -357,6 +379,7 @@ pillow==10.4.0
# via fireworks-ai
# via langchain-nvidia-ai-endpoints
# via matplotlib
# via sentence-transformers
# via streamlit
platformdirs==4.2.2
# via pylint
@ -436,12 +459,14 @@ pyyaml==6.0.1
# via langchain
# via langchain-community
# via langchain-core
# via transformers
# via uvicorn
referencing==0.35.1
# via jsonschema
# via jsonschema-specifications
regex==2024.5.15
# via tiktoken
# via transformers
requests==2.32.3
# via burr
# via free-proxy
@ -456,6 +481,7 @@ requests==2.32.3
# via sphinx
# via streamlit
# via tiktoken
# via transformers
rich==13.7.1
# via streamlit
# via typer
@ -466,8 +492,17 @@ rsa==4.9
# via google-auth
s3transfer==0.10.2
# via boto3
safetensors==0.4.3
# via transformers
scikit-learn==1.5.1
# via sentence-transformers
scipy==1.13.1
# via scikit-learn
# via sentence-transformers
semchunk==2.2.0
# via scrapegraphai
sentence-transformers==3.0.1
# via langchain-huggingface
sf-hamilton==1.72.1
# via burr
shapely==2.0.5
@ -513,16 +548,22 @@ starlette==0.37.2
# via fastapi
streamlit==1.36.0
# via burr
sympy==1.13.1
# via torch
tenacity==8.5.0
# via langchain
# via langchain-community
# via langchain-core
# via streamlit
threadpoolctl==3.5.0
# via scikit-learn
tiktoken==0.7.0
# via langchain-openai
# via scrapegraphai
tokenizers==0.19.1
# via anthropic
# via langchain-huggingface
# via transformers
toml==0.10.2
# via streamlit
tomli==2.0.1
@ -532,6 +573,8 @@ tomlkit==0.13.0
# via pylint
toolz==0.12.1
# via altair
torch==2.2.2
# via sentence-transformers
tornado==6.4.1
# via streamlit
tqdm==4.66.4
@ -541,6 +584,11 @@ tqdm==4.66.4
# via openai
# via scrapegraphai
# via semchunk
# via sentence-transformers
# via transformers
transformers==4.43.3
# via langchain-huggingface
# via sentence-transformers
typer==0.12.3
# via fastapi-cli
typing-extensions==4.12.2
@ -562,6 +610,7 @@ typing-extensions==4.12.2
# via sqlalchemy
# via starlette
# via streamlit
# via torch
# via typer
# via typing-inspect
# via uvicorn

View File

@ -63,6 +63,8 @@ faiss-cpu==1.8.0.post1
# via scrapegraphai
filelock==3.15.4
# via huggingface-hub
# via torch
# via transformers
fireworks-ai==0.14.0
# via langchain-fireworks
free-proxy==1.1.1
@ -72,6 +74,7 @@ frozenlist==1.4.1
# via aiosignal
fsspec==2024.6.1
# via huggingface-hub
# via torch
google==3.0.0
# via scrapegraphai
google-ai-generativelanguage==0.6.6
@ -128,6 +131,7 @@ graphviz==0.20.3
# via scrapegraphai
greenlet==3.0.3
# via playwright
# via sqlalchemy
groq==0.9.0
# via langchain-groq
grpc-google-iam-v1==0.13.1
@ -156,17 +160,24 @@ httpx==0.27.0
httpx-sse==0.4.0
# via fireworks-ai
huggingface-hub==0.24.0
# via langchain-huggingface
# via sentence-transformers
# via tokenizers
# via transformers
idna==3.7
# via anyio
# via httpx
# via requests
# via yarl
jinja2==3.1.4
# via torch
jiter==0.5.0
# via anthropic
jmespath==1.0.1
# via boto3
# via botocore
joblib==1.4.2
# via scikit-learn
jsonpatch==1.33
# via langchain-core
jsonpointer==3.0.0
@ -189,6 +200,7 @@ langchain-core==0.2.22
# via langchain-google-genai
# via langchain-google-vertexai
# via langchain-groq
# via langchain-huggingface
# via langchain-nvidia-ai-endpoints
# via langchain-openai
# via langchain-text-splitters
@ -200,6 +212,8 @@ langchain-google-vertexai==1.0.7
# via scrapegraphai
langchain-groq==0.1.6
# via scrapegraphai
langchain-huggingface==0.0.3
# via scrapegraphai
langchain-nvidia-ai-endpoints==0.1.6
# via scrapegraphai
langchain-openai==0.1.17
@ -212,12 +226,16 @@ langsmith==0.1.93
# via langchain-core
lxml==5.2.2
# via free-proxy
markupsafe==2.1.5
# via jinja2
marshmallow==3.21.3
# via dataclasses-json
minify-html==0.15.0
# via scrapegraphai
mpire==2.10.2
# via semchunk
mpmath==1.3.0
# via sympy
multidict==6.0.5
# via aiohttp
# via yarl
@ -225,13 +243,19 @@ multiprocess==0.70.16
# via mpire
mypy-extensions==1.0.0
# via typing-inspect
networkx==3.2.1
# via torch
numpy==1.26.4
# via faiss-cpu
# via langchain
# via langchain-aws
# via langchain-community
# via pandas
# via scikit-learn
# via scipy
# via sentence-transformers
# via shapely
# via transformers
openai==1.37.0
# via langchain-fireworks
# via langchain-openai
@ -244,11 +268,13 @@ packaging==24.1
# via huggingface-hub
# via langchain-core
# via marshmallow
# via transformers
pandas==2.2.2
# via scrapegraphai
pillow==10.4.0
# via fireworks-ai
# via langchain-nvidia-ai-endpoints
# via sentence-transformers
playwright==1.45.0
# via scrapegraphai
# via undetected-playwright
@ -303,8 +329,10 @@ pyyaml==6.0.1
# via langchain
# via langchain-community
# via langchain-core
# via transformers
regex==2024.5.15
# via tiktoken
# via transformers
requests==2.32.3
# via free-proxy
# via google-api-core
@ -316,12 +344,22 @@ requests==2.32.3
# via langchain-fireworks
# via langsmith
# via tiktoken
# via transformers
rsa==4.9
# via google-auth
s3transfer==0.10.2
# via boto3
safetensors==0.4.3
# via transformers
scikit-learn==1.5.1
# via sentence-transformers
scipy==1.13.1
# via scikit-learn
# via sentence-transformers
semchunk==2.2.0
# via scrapegraphai
sentence-transformers==3.0.1
# via langchain-huggingface
shapely==2.0.5
# via google-cloud-aiplatform
six==1.16.0
@ -337,15 +375,23 @@ soupsieve==2.5
sqlalchemy==2.0.31
# via langchain
# via langchain-community
sympy==1.13.1
# via torch
tenacity==8.5.0
# via langchain
# via langchain-community
# via langchain-core
threadpoolctl==3.5.0
# via scikit-learn
tiktoken==0.7.0
# via langchain-openai
# via scrapegraphai
tokenizers==0.19.1
# via anthropic
# via langchain-huggingface
# via transformers
torch==2.2.2
# via sentence-transformers
tqdm==4.66.4
# via google-generativeai
# via huggingface-hub
@ -353,6 +399,11 @@ tqdm==4.66.4
# via openai
# via scrapegraphai
# via semchunk
# via sentence-transformers
# via transformers
transformers==4.43.3
# via langchain-huggingface
# via sentence-transformers
typing-extensions==4.12.2
# via anthropic
# via anyio
@ -364,6 +415,7 @@ typing-extensions==4.12.2
# via pydantic-core
# via pyee
# via sqlalchemy
# via torch
# via typing-inspect
typing-inspect==0.9.0
# via dataclasses-json

View File

@ -22,3 +22,4 @@ undetected-playwright>=0.3.0
semchunk>=1.0.1
langchain-fireworks>=0.1.3
langchain-community>=0.2.9
langchain-huggingface>=0.0.3

View File

@ -3,33 +3,28 @@ AbstractGraph Module
"""
from abc import ABC, abstractmethod
from typing import Optional, Union
from typing import Optional
import uuid
from pydantic import BaseModel
from langchain_community.chat_models import ChatOllama
from langchain_openai import ChatOpenAI
from langchain_aws import BedrockEmbeddings
from langchain_community.embeddings import HuggingFaceHubEmbeddings, OllamaEmbeddings
from langchain_aws import BedrockEmbeddings, ChatBedrock
from langchain_huggingface import ChatHuggingFace, HuggingFaceEmbeddings
from langchain_community.embeddings import OllamaEmbeddings
from langchain_google_genai import GoogleGenerativeAIEmbeddings
from langchain_google_vertexai import VertexAIEmbeddings
from langchain_google_vertexai import ChatVertexAI, VertexAIEmbeddings
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings
from langchain_fireworks import FireworksEmbeddings
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
from langchain_fireworks import FireworksEmbeddings, ChatFireworks
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings, ChatOpenAI, AzureChatOpenAI
from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings
from ..helpers import models_tokens
from ..models import (
Anthropic,
AzureOpenAI,
Bedrock,
Gemini,
Groq,
HuggingFace,
OneApi,
Fireworks,
VertexAI,
Nvidia
Nvidia,
DeepSeek
)
from ..models.ernie import Ernie
from langchain.chat_models import init_chat_model
@ -37,7 +32,6 @@ from langchain.chat_models import init_chat_model
from ..utils.logging import set_verbosity_debug, set_verbosity_warning, set_verbosity_info
from ..helpers import models_tokens
from ..models import AzureOpenAI, Bedrock, Gemini, Groq, HuggingFace, Anthropic, DeepSeek
class AbstractGraph(ABC):
@ -181,7 +175,8 @@ class AbstractGraph(ABC):
llm_params["model"] = "/".join(llm_params["model"].split("/")[1:])
except KeyError as exc:
raise KeyError("Model not supported") from exc
return Fireworks(llm_params)
llm_params["model_provider"] = "fireworks"
return init_chat_model(**llm_params)
elif "azure" in llm_params["model"]:
# take the model after the last dash
llm_params["model"] = llm_params["model"].split("/")[-1]
@ -189,7 +184,8 @@ class AbstractGraph(ABC):
self.model_token = models_tokens["azure"][llm_params["model"]]
except KeyError as exc:
raise KeyError("Model not supported") from exc
return AzureOpenAI(llm_params)
llm_params["model_provider"] = "azure_openai"
return init_chat_model(**llm_params)
elif "nvidia" in llm_params["model"]:
try:
self.model_token = models_tokens["nvidia"][llm_params["model"].split("/")[-1]]
@ -203,20 +199,23 @@ class AbstractGraph(ABC):
self.model_token = models_tokens["gemini"][llm_params["model"]]
except KeyError as exc:
raise KeyError("Model not supported") from exc
return Gemini(llm_params)
llm_params["model_provider"] = "google_genai "
return init_chat_model(**llm_params)
elif llm_params["model"].startswith("claude"):
llm_params["model"] = llm_params["model"].split("/")[-1]
try:
self.model_token = models_tokens["claude"][llm_params["model"]]
except KeyError as exc:
raise KeyError("Model not supported") from exc
return Anthropic(llm_params)
llm_params["model_provider"] = "anthropic"
return init_chat_model(**llm_params)
elif llm_params["model"].startswith("vertexai"):
try:
self.model_token = models_tokens["vertexai"][llm_params["model"]]
except KeyError as exc:
raise KeyError("Model not supported") from exc
return VertexAI(llm_params)
llm_params["model_provider"] = "google_vertexai"
return init_chat_model(**llm_params)
elif "ollama" in llm_params["model"]:
llm_params["model"] = llm_params["model"].split("ollama/")[-1]
@ -246,7 +245,8 @@ class AbstractGraph(ABC):
except KeyError:
print("model not found, using default token size (8192)")
self.model_token = 8192
return HuggingFace(llm_params)
llm_params["model_provider"] = "hugging_face"
return init_chat_model(**llm_params)
elif "groq" in llm_params["model"]:
llm_params["model"] = llm_params["model"].split("/")[-1]
@ -255,7 +255,8 @@ class AbstractGraph(ABC):
except KeyError:
print("model not found, using default token size (8192)")
self.model_token = 8192
return Groq(llm_params)
llm_params["model_provider"] = "groq"
return init_chat_model(**llm_params)
elif "bedrock" in llm_params["model"]:
llm_params["model"] = llm_params["model"].split("/")[-1]
model_id = llm_params["model"]
@ -265,22 +266,16 @@ class AbstractGraph(ABC):
except KeyError:
print("model not found, using default token size (8192)")
self.model_token = 8192
return Bedrock(
{
"client": client,
"model_id": model_id,
"model_kwargs": {
"temperature": llm_params["temperature"],
},
}
)
llm_params["model_provider"] = "bedrock"
return init_chat_model(**llm_params)
elif "claude-3-" in llm_params["model"]:
try:
self.model_token = models_tokens["claude"]["claude3"]
except KeyError:
print("model not found, using default token size (8192)")
self.model_token = 8192
return Anthropic(llm_params)
llm_params["model_provider"] = "anthropic"
return init_chat_model(**llm_params)
elif "deepseek" in llm_params["model"]:
try:
self.model_token = models_tokens["deepseek"][llm_params["model"]]
@ -308,7 +303,7 @@ class AbstractGraph(ABC):
Raises:
ValueError: If the model is not supported.
"""
if isinstance(self.llm_model, Gemini):
if isinstance(self.llm_model, ChatGoogleGenerativeAI):
return GoogleGenerativeAIEmbeddings(
google_api_key=llm_config["api_key"], model="models/embedding-001"
)
@ -317,13 +312,13 @@ class AbstractGraph(ABC):
base_url=self.llm_model.openai_api_base)
elif isinstance(self.llm_model, DeepSeek):
return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key)
elif isinstance(self.llm_model, VertexAI):
elif isinstance(self.llm_model, ChatVertexAI):
return VertexAIEmbeddings()
elif isinstance(self.llm_model, AzureOpenAIEmbeddings):
return self.llm_model
elif isinstance(self.llm_model, AzureOpenAI):
elif isinstance(self.llm_model, AzureChatOpenAI):
return AzureOpenAIEmbeddings()
elif isinstance(self.llm_model, Fireworks):
elif isinstance(self.llm_model, ChatFireworks):
return FireworksEmbeddings(model=self.llm_model.model_name)
elif isinstance(self.llm_model, Nvidia):
return NVIDIAEmbeddings(model=self.llm_model.model_name)
@ -335,9 +330,9 @@ class AbstractGraph(ABC):
params.pop("temperature", None)
return OllamaEmbeddings(**params)
elif isinstance(self.llm_model, HuggingFace):
return HuggingFaceHubEmbeddings(model=self.llm_model.model)
elif isinstance(self.llm_model, Bedrock):
elif isinstance(self.llm_model, ChatHuggingFace):
return HuggingFaceEmbeddings(model=self.llm_model.model)
elif isinstance(self.llm_model, ChatBedrock):
return BedrockEmbeddings(client=None, model_id=self.llm_model.model_id)
else:
raise ValueError("Embedding Model missing or not supported")
@ -384,7 +379,7 @@ class AbstractGraph(ABC):
models_tokens["hugging_face"][embedder_params["model"]]
except KeyError as exc:
raise KeyError("Model not supported") from exc
return HuggingFaceHubEmbeddings(model=embedder_params["model"])
return HuggingFaceEmbeddings(model=embedder_params["model"])
elif "fireworks" in embedder_params["model"]:
embedder_params["model"] = "/".join(embedder_params["model"].split("/")[1:])
try: