From 927548624034b3c30eca60011d216720102d1815 Mon Sep 17 00:00:00 2001 From: Federico Aguzzi <62149513+f-aguzzi@users.noreply.github.com> Date: Mon, 29 Jul 2024 21:57:37 +0200 Subject: [PATCH] refactor: remove redundant LangChain wrappers --- pyproject.toml | 3 +- requirements-dev.lock | 49 ++++++++++++++++ requirements.lock | 52 +++++++++++++++++ requirements.txt | 1 + scrapegraphai/graphs/abstract_graph.py | 77 ++++++++++++-------------- 5 files changed, 140 insertions(+), 42 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b7b0d55d..bee7b61d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,8 @@ dependencies = [ "undetected-playwright>=0.3.0", "semchunk>=1.0.1", "langchain-fireworks>=0.1.3", - "langchain-community>=0.2.9" + "langchain-community>=0.2.9", + "langchain-huggingface>=0.0.3", ] license = "MIT" diff --git a/requirements-dev.lock b/requirements-dev.lock index 2c56f3db..0b3ef491 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -106,6 +106,8 @@ fastapi-pagination==0.12.26 # via burr filelock==3.15.4 # via huggingface-hub + # via torch + # via transformers fireworks-ai==0.14.0 # via langchain-fireworks fonttools==4.53.1 @@ -117,6 +119,7 @@ frozenlist==1.4.1 # via aiosignal fsspec==2024.6.1 # via huggingface-hub + # via torch furo==2024.5.6 # via scrapegraphai gitdb==4.0.11 @@ -180,6 +183,7 @@ graphviz==0.20.3 # via scrapegraphai greenlet==3.0.3 # via playwright + # via sqlalchemy groq==0.9.0 # via langchain-groq grpc-google-iam-v1==0.13.1 @@ -212,7 +216,10 @@ httpx==0.27.0 httpx-sse==0.4.0 # via fireworks-ai huggingface-hub==0.24.0 + # via langchain-huggingface + # via sentence-transformers # via tokenizers + # via transformers idna==3.7 # via anyio # via email-validator @@ -235,11 +242,14 @@ jinja2==3.1.4 # via fastapi # via pydeck # via sphinx + # via torch jiter==0.5.0 # via anthropic jmespath==1.0.1 # via boto3 # via botocore +joblib==1.4.2 + # via scikit-learn jsonpatch==1.33 # via langchain-core jsonpointer==3.0.0 @@ -268,6 +278,7 @@ langchain-core==0.2.22 # via langchain-google-genai # via langchain-google-vertexai # via langchain-groq + # via langchain-huggingface # via langchain-nvidia-ai-endpoints # via langchain-openai # via langchain-text-splitters @@ -279,6 +290,8 @@ langchain-google-vertexai==1.0.7 # via scrapegraphai langchain-groq==0.1.6 # via scrapegraphai +langchain-huggingface==0.0.3 + # via scrapegraphai langchain-nvidia-ai-endpoints==0.1.6 # via scrapegraphai langchain-openai==0.1.17 @@ -309,6 +322,8 @@ minify-html==0.15.0 # via scrapegraphai mpire==2.10.2 # via semchunk +mpmath==1.3.0 + # via sympy multidict==6.0.5 # via aiohttp # via yarl @@ -316,6 +331,8 @@ multiprocess==0.70.16 # via mpire mypy-extensions==1.0.0 # via typing-inspect +networkx==3.2.1 + # via torch numpy==1.26.4 # via altair # via contourpy @@ -327,9 +344,13 @@ numpy==1.26.4 # via pandas # via pyarrow # via pydeck + # via scikit-learn + # via scipy + # via sentence-transformers # via sf-hamilton # via shapely # via streamlit + # via transformers openai==1.37.0 # via burr # via langchain-fireworks @@ -348,6 +369,7 @@ packaging==24.1 # via pytest # via sphinx # via streamlit + # via transformers pandas==2.2.2 # via altair # via scrapegraphai @@ -357,6 +379,7 @@ pillow==10.4.0 # via fireworks-ai # via langchain-nvidia-ai-endpoints # via matplotlib + # via sentence-transformers # via streamlit platformdirs==4.2.2 # via pylint @@ -436,12 +459,14 @@ pyyaml==6.0.1 # via langchain # via langchain-community # via langchain-core + # via transformers # via uvicorn referencing==0.35.1 # via jsonschema # via jsonschema-specifications regex==2024.5.15 # via tiktoken + # via transformers requests==2.32.3 # via burr # via free-proxy @@ -456,6 +481,7 @@ requests==2.32.3 # via sphinx # via streamlit # via tiktoken + # via transformers rich==13.7.1 # via streamlit # via typer @@ -466,8 +492,17 @@ rsa==4.9 # via google-auth s3transfer==0.10.2 # via boto3 +safetensors==0.4.3 + # via transformers +scikit-learn==1.5.1 + # via sentence-transformers +scipy==1.13.1 + # via scikit-learn + # via sentence-transformers semchunk==2.2.0 # via scrapegraphai +sentence-transformers==3.0.1 + # via langchain-huggingface sf-hamilton==1.72.1 # via burr shapely==2.0.5 @@ -513,16 +548,22 @@ starlette==0.37.2 # via fastapi streamlit==1.36.0 # via burr +sympy==1.13.1 + # via torch tenacity==8.5.0 # via langchain # via langchain-community # via langchain-core # via streamlit +threadpoolctl==3.5.0 + # via scikit-learn tiktoken==0.7.0 # via langchain-openai # via scrapegraphai tokenizers==0.19.1 # via anthropic + # via langchain-huggingface + # via transformers toml==0.10.2 # via streamlit tomli==2.0.1 @@ -532,6 +573,8 @@ tomlkit==0.13.0 # via pylint toolz==0.12.1 # via altair +torch==2.2.2 + # via sentence-transformers tornado==6.4.1 # via streamlit tqdm==4.66.4 @@ -541,6 +584,11 @@ tqdm==4.66.4 # via openai # via scrapegraphai # via semchunk + # via sentence-transformers + # via transformers +transformers==4.43.3 + # via langchain-huggingface + # via sentence-transformers typer==0.12.3 # via fastapi-cli typing-extensions==4.12.2 @@ -562,6 +610,7 @@ typing-extensions==4.12.2 # via sqlalchemy # via starlette # via streamlit + # via torch # via typer # via typing-inspect # via uvicorn diff --git a/requirements.lock b/requirements.lock index a943dff1..a9df041e 100644 --- a/requirements.lock +++ b/requirements.lock @@ -63,6 +63,8 @@ faiss-cpu==1.8.0.post1 # via scrapegraphai filelock==3.15.4 # via huggingface-hub + # via torch + # via transformers fireworks-ai==0.14.0 # via langchain-fireworks free-proxy==1.1.1 @@ -72,6 +74,7 @@ frozenlist==1.4.1 # via aiosignal fsspec==2024.6.1 # via huggingface-hub + # via torch google==3.0.0 # via scrapegraphai google-ai-generativelanguage==0.6.6 @@ -128,6 +131,7 @@ graphviz==0.20.3 # via scrapegraphai greenlet==3.0.3 # via playwright + # via sqlalchemy groq==0.9.0 # via langchain-groq grpc-google-iam-v1==0.13.1 @@ -156,17 +160,24 @@ httpx==0.27.0 httpx-sse==0.4.0 # via fireworks-ai huggingface-hub==0.24.0 + # via langchain-huggingface + # via sentence-transformers # via tokenizers + # via transformers idna==3.7 # via anyio # via httpx # via requests # via yarl +jinja2==3.1.4 + # via torch jiter==0.5.0 # via anthropic jmespath==1.0.1 # via boto3 # via botocore +joblib==1.4.2 + # via scikit-learn jsonpatch==1.33 # via langchain-core jsonpointer==3.0.0 @@ -189,6 +200,7 @@ langchain-core==0.2.22 # via langchain-google-genai # via langchain-google-vertexai # via langchain-groq + # via langchain-huggingface # via langchain-nvidia-ai-endpoints # via langchain-openai # via langchain-text-splitters @@ -200,6 +212,8 @@ langchain-google-vertexai==1.0.7 # via scrapegraphai langchain-groq==0.1.6 # via scrapegraphai +langchain-huggingface==0.0.3 + # via scrapegraphai langchain-nvidia-ai-endpoints==0.1.6 # via scrapegraphai langchain-openai==0.1.17 @@ -212,12 +226,16 @@ langsmith==0.1.93 # via langchain-core lxml==5.2.2 # via free-proxy +markupsafe==2.1.5 + # via jinja2 marshmallow==3.21.3 # via dataclasses-json minify-html==0.15.0 # via scrapegraphai mpire==2.10.2 # via semchunk +mpmath==1.3.0 + # via sympy multidict==6.0.5 # via aiohttp # via yarl @@ -225,13 +243,19 @@ multiprocess==0.70.16 # via mpire mypy-extensions==1.0.0 # via typing-inspect +networkx==3.2.1 + # via torch numpy==1.26.4 # via faiss-cpu # via langchain # via langchain-aws # via langchain-community # via pandas + # via scikit-learn + # via scipy + # via sentence-transformers # via shapely + # via transformers openai==1.37.0 # via langchain-fireworks # via langchain-openai @@ -244,11 +268,13 @@ packaging==24.1 # via huggingface-hub # via langchain-core # via marshmallow + # via transformers pandas==2.2.2 # via scrapegraphai pillow==10.4.0 # via fireworks-ai # via langchain-nvidia-ai-endpoints + # via sentence-transformers playwright==1.45.0 # via scrapegraphai # via undetected-playwright @@ -303,8 +329,10 @@ pyyaml==6.0.1 # via langchain # via langchain-community # via langchain-core + # via transformers regex==2024.5.15 # via tiktoken + # via transformers requests==2.32.3 # via free-proxy # via google-api-core @@ -316,12 +344,22 @@ requests==2.32.3 # via langchain-fireworks # via langsmith # via tiktoken + # via transformers rsa==4.9 # via google-auth s3transfer==0.10.2 # via boto3 +safetensors==0.4.3 + # via transformers +scikit-learn==1.5.1 + # via sentence-transformers +scipy==1.13.1 + # via scikit-learn + # via sentence-transformers semchunk==2.2.0 # via scrapegraphai +sentence-transformers==3.0.1 + # via langchain-huggingface shapely==2.0.5 # via google-cloud-aiplatform six==1.16.0 @@ -337,15 +375,23 @@ soupsieve==2.5 sqlalchemy==2.0.31 # via langchain # via langchain-community +sympy==1.13.1 + # via torch tenacity==8.5.0 # via langchain # via langchain-community # via langchain-core +threadpoolctl==3.5.0 + # via scikit-learn tiktoken==0.7.0 # via langchain-openai # via scrapegraphai tokenizers==0.19.1 # via anthropic + # via langchain-huggingface + # via transformers +torch==2.2.2 + # via sentence-transformers tqdm==4.66.4 # via google-generativeai # via huggingface-hub @@ -353,6 +399,11 @@ tqdm==4.66.4 # via openai # via scrapegraphai # via semchunk + # via sentence-transformers + # via transformers +transformers==4.43.3 + # via langchain-huggingface + # via sentence-transformers typing-extensions==4.12.2 # via anthropic # via anyio @@ -364,6 +415,7 @@ typing-extensions==4.12.2 # via pydantic-core # via pyee # via sqlalchemy + # via torch # via typing-inspect typing-inspect==0.9.0 # via dataclasses-json diff --git a/requirements.txt b/requirements.txt index 440bf78a..8f3f5da5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,3 +22,4 @@ undetected-playwright>=0.3.0 semchunk>=1.0.1 langchain-fireworks>=0.1.3 langchain-community>=0.2.9 +langchain-huggingface>=0.0.3 diff --git a/scrapegraphai/graphs/abstract_graph.py b/scrapegraphai/graphs/abstract_graph.py index e1ce18f0..f27d1aee 100644 --- a/scrapegraphai/graphs/abstract_graph.py +++ b/scrapegraphai/graphs/abstract_graph.py @@ -3,33 +3,28 @@ AbstractGraph Module """ from abc import ABC, abstractmethod -from typing import Optional, Union +from typing import Optional import uuid from pydantic import BaseModel from langchain_community.chat_models import ChatOllama from langchain_openai import ChatOpenAI -from langchain_aws import BedrockEmbeddings -from langchain_community.embeddings import HuggingFaceHubEmbeddings, OllamaEmbeddings +from langchain_aws import BedrockEmbeddings, ChatBedrock +from langchain_huggingface import ChatHuggingFace, HuggingFaceEmbeddings +from langchain_community.embeddings import OllamaEmbeddings from langchain_google_genai import GoogleGenerativeAIEmbeddings -from langchain_google_vertexai import VertexAIEmbeddings +from langchain_google_vertexai import ChatVertexAI, VertexAIEmbeddings +from langchain_google_genai import ChatGoogleGenerativeAI from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings -from langchain_fireworks import FireworksEmbeddings -from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings +from langchain_fireworks import FireworksEmbeddings, ChatFireworks +from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings, ChatOpenAI, AzureChatOpenAI from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings from ..helpers import models_tokens from ..models import ( - Anthropic, - AzureOpenAI, - Bedrock, - Gemini, - Groq, - HuggingFace, OneApi, - Fireworks, - VertexAI, - Nvidia + Nvidia, + DeepSeek ) from ..models.ernie import Ernie from langchain.chat_models import init_chat_model @@ -37,7 +32,6 @@ from langchain.chat_models import init_chat_model from ..utils.logging import set_verbosity_debug, set_verbosity_warning, set_verbosity_info from ..helpers import models_tokens -from ..models import AzureOpenAI, Bedrock, Gemini, Groq, HuggingFace, Anthropic, DeepSeek class AbstractGraph(ABC): @@ -181,7 +175,8 @@ class AbstractGraph(ABC): llm_params["model"] = "/".join(llm_params["model"].split("/")[1:]) except KeyError as exc: raise KeyError("Model not supported") from exc - return Fireworks(llm_params) + llm_params["model_provider"] = "fireworks" + return init_chat_model(**llm_params) elif "azure" in llm_params["model"]: # take the model after the last dash llm_params["model"] = llm_params["model"].split("/")[-1] @@ -189,7 +184,8 @@ class AbstractGraph(ABC): self.model_token = models_tokens["azure"][llm_params["model"]] except KeyError as exc: raise KeyError("Model not supported") from exc - return AzureOpenAI(llm_params) + llm_params["model_provider"] = "azure_openai" + return init_chat_model(**llm_params) elif "nvidia" in llm_params["model"]: try: self.model_token = models_tokens["nvidia"][llm_params["model"].split("/")[-1]] @@ -203,20 +199,23 @@ class AbstractGraph(ABC): self.model_token = models_tokens["gemini"][llm_params["model"]] except KeyError as exc: raise KeyError("Model not supported") from exc - return Gemini(llm_params) + llm_params["model_provider"] = "google_genai " + return init_chat_model(**llm_params) elif llm_params["model"].startswith("claude"): llm_params["model"] = llm_params["model"].split("/")[-1] try: self.model_token = models_tokens["claude"][llm_params["model"]] except KeyError as exc: raise KeyError("Model not supported") from exc - return Anthropic(llm_params) + llm_params["model_provider"] = "anthropic" + return init_chat_model(**llm_params) elif llm_params["model"].startswith("vertexai"): try: self.model_token = models_tokens["vertexai"][llm_params["model"]] except KeyError as exc: raise KeyError("Model not supported") from exc - return VertexAI(llm_params) + llm_params["model_provider"] = "google_vertexai" + return init_chat_model(**llm_params) elif "ollama" in llm_params["model"]: llm_params["model"] = llm_params["model"].split("ollama/")[-1] @@ -246,7 +245,8 @@ class AbstractGraph(ABC): except KeyError: print("model not found, using default token size (8192)") self.model_token = 8192 - return HuggingFace(llm_params) + llm_params["model_provider"] = "hugging_face" + return init_chat_model(**llm_params) elif "groq" in llm_params["model"]: llm_params["model"] = llm_params["model"].split("/")[-1] @@ -255,7 +255,8 @@ class AbstractGraph(ABC): except KeyError: print("model not found, using default token size (8192)") self.model_token = 8192 - return Groq(llm_params) + llm_params["model_provider"] = "groq" + return init_chat_model(**llm_params) elif "bedrock" in llm_params["model"]: llm_params["model"] = llm_params["model"].split("/")[-1] model_id = llm_params["model"] @@ -265,22 +266,16 @@ class AbstractGraph(ABC): except KeyError: print("model not found, using default token size (8192)") self.model_token = 8192 - return Bedrock( - { - "client": client, - "model_id": model_id, - "model_kwargs": { - "temperature": llm_params["temperature"], - }, - } - ) + llm_params["model_provider"] = "bedrock" + return init_chat_model(**llm_params) elif "claude-3-" in llm_params["model"]: try: self.model_token = models_tokens["claude"]["claude3"] except KeyError: print("model not found, using default token size (8192)") self.model_token = 8192 - return Anthropic(llm_params) + llm_params["model_provider"] = "anthropic" + return init_chat_model(**llm_params) elif "deepseek" in llm_params["model"]: try: self.model_token = models_tokens["deepseek"][llm_params["model"]] @@ -308,7 +303,7 @@ class AbstractGraph(ABC): Raises: ValueError: If the model is not supported. """ - if isinstance(self.llm_model, Gemini): + if isinstance(self.llm_model, ChatGoogleGenerativeAI): return GoogleGenerativeAIEmbeddings( google_api_key=llm_config["api_key"], model="models/embedding-001" ) @@ -317,13 +312,13 @@ class AbstractGraph(ABC): base_url=self.llm_model.openai_api_base) elif isinstance(self.llm_model, DeepSeek): return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key) - elif isinstance(self.llm_model, VertexAI): + elif isinstance(self.llm_model, ChatVertexAI): return VertexAIEmbeddings() elif isinstance(self.llm_model, AzureOpenAIEmbeddings): return self.llm_model - elif isinstance(self.llm_model, AzureOpenAI): + elif isinstance(self.llm_model, AzureChatOpenAI): return AzureOpenAIEmbeddings() - elif isinstance(self.llm_model, Fireworks): + elif isinstance(self.llm_model, ChatFireworks): return FireworksEmbeddings(model=self.llm_model.model_name) elif isinstance(self.llm_model, Nvidia): return NVIDIAEmbeddings(model=self.llm_model.model_name) @@ -335,9 +330,9 @@ class AbstractGraph(ABC): params.pop("temperature", None) return OllamaEmbeddings(**params) - elif isinstance(self.llm_model, HuggingFace): - return HuggingFaceHubEmbeddings(model=self.llm_model.model) - elif isinstance(self.llm_model, Bedrock): + elif isinstance(self.llm_model, ChatHuggingFace): + return HuggingFaceEmbeddings(model=self.llm_model.model) + elif isinstance(self.llm_model, ChatBedrock): return BedrockEmbeddings(client=None, model_id=self.llm_model.model_id) else: raise ValueError("Embedding Model missing or not supported") @@ -384,7 +379,7 @@ class AbstractGraph(ABC): models_tokens["hugging_face"][embedder_params["model"]] except KeyError as exc: raise KeyError("Model not supported") from exc - return HuggingFaceHubEmbeddings(model=embedder_params["model"]) + return HuggingFaceEmbeddings(model=embedder_params["model"]) elif "fireworks" in embedder_params["model"]: embedder_params["model"] = "/".join(embedder_params["model"].split("/")[1:]) try: