mirror of
https://github.com/VinciGit00/Scrapegraph-ai.git
synced 2026-07-04 21:00:36 +08:00
This commit is contained in:
parent
5007167af1
commit
9275486240
@ -34,7 +34,8 @@ dependencies = [
|
||||
"undetected-playwright>=0.3.0",
|
||||
"semchunk>=1.0.1",
|
||||
"langchain-fireworks>=0.1.3",
|
||||
"langchain-community>=0.2.9"
|
||||
"langchain-community>=0.2.9",
|
||||
"langchain-huggingface>=0.0.3",
|
||||
]
|
||||
|
||||
license = "MIT"
|
||||
|
||||
@ -106,6 +106,8 @@ fastapi-pagination==0.12.26
|
||||
# via burr
|
||||
filelock==3.15.4
|
||||
# via huggingface-hub
|
||||
# via torch
|
||||
# via transformers
|
||||
fireworks-ai==0.14.0
|
||||
# via langchain-fireworks
|
||||
fonttools==4.53.1
|
||||
@ -117,6 +119,7 @@ frozenlist==1.4.1
|
||||
# via aiosignal
|
||||
fsspec==2024.6.1
|
||||
# via huggingface-hub
|
||||
# via torch
|
||||
furo==2024.5.6
|
||||
# via scrapegraphai
|
||||
gitdb==4.0.11
|
||||
@ -180,6 +183,7 @@ graphviz==0.20.3
|
||||
# via scrapegraphai
|
||||
greenlet==3.0.3
|
||||
# via playwright
|
||||
# via sqlalchemy
|
||||
groq==0.9.0
|
||||
# via langchain-groq
|
||||
grpc-google-iam-v1==0.13.1
|
||||
@ -212,7 +216,10 @@ httpx==0.27.0
|
||||
httpx-sse==0.4.0
|
||||
# via fireworks-ai
|
||||
huggingface-hub==0.24.0
|
||||
# via langchain-huggingface
|
||||
# via sentence-transformers
|
||||
# via tokenizers
|
||||
# via transformers
|
||||
idna==3.7
|
||||
# via anyio
|
||||
# via email-validator
|
||||
@ -235,11 +242,14 @@ jinja2==3.1.4
|
||||
# via fastapi
|
||||
# via pydeck
|
||||
# via sphinx
|
||||
# via torch
|
||||
jiter==0.5.0
|
||||
# via anthropic
|
||||
jmespath==1.0.1
|
||||
# via boto3
|
||||
# via botocore
|
||||
joblib==1.4.2
|
||||
# via scikit-learn
|
||||
jsonpatch==1.33
|
||||
# via langchain-core
|
||||
jsonpointer==3.0.0
|
||||
@ -268,6 +278,7 @@ langchain-core==0.2.22
|
||||
# via langchain-google-genai
|
||||
# via langchain-google-vertexai
|
||||
# via langchain-groq
|
||||
# via langchain-huggingface
|
||||
# via langchain-nvidia-ai-endpoints
|
||||
# via langchain-openai
|
||||
# via langchain-text-splitters
|
||||
@ -279,6 +290,8 @@ langchain-google-vertexai==1.0.7
|
||||
# via scrapegraphai
|
||||
langchain-groq==0.1.6
|
||||
# via scrapegraphai
|
||||
langchain-huggingface==0.0.3
|
||||
# via scrapegraphai
|
||||
langchain-nvidia-ai-endpoints==0.1.6
|
||||
# via scrapegraphai
|
||||
langchain-openai==0.1.17
|
||||
@ -309,6 +322,8 @@ minify-html==0.15.0
|
||||
# via scrapegraphai
|
||||
mpire==2.10.2
|
||||
# via semchunk
|
||||
mpmath==1.3.0
|
||||
# via sympy
|
||||
multidict==6.0.5
|
||||
# via aiohttp
|
||||
# via yarl
|
||||
@ -316,6 +331,8 @@ multiprocess==0.70.16
|
||||
# via mpire
|
||||
mypy-extensions==1.0.0
|
||||
# via typing-inspect
|
||||
networkx==3.2.1
|
||||
# via torch
|
||||
numpy==1.26.4
|
||||
# via altair
|
||||
# via contourpy
|
||||
@ -327,9 +344,13 @@ numpy==1.26.4
|
||||
# via pandas
|
||||
# via pyarrow
|
||||
# via pydeck
|
||||
# via scikit-learn
|
||||
# via scipy
|
||||
# via sentence-transformers
|
||||
# via sf-hamilton
|
||||
# via shapely
|
||||
# via streamlit
|
||||
# via transformers
|
||||
openai==1.37.0
|
||||
# via burr
|
||||
# via langchain-fireworks
|
||||
@ -348,6 +369,7 @@ packaging==24.1
|
||||
# via pytest
|
||||
# via sphinx
|
||||
# via streamlit
|
||||
# via transformers
|
||||
pandas==2.2.2
|
||||
# via altair
|
||||
# via scrapegraphai
|
||||
@ -357,6 +379,7 @@ pillow==10.4.0
|
||||
# via fireworks-ai
|
||||
# via langchain-nvidia-ai-endpoints
|
||||
# via matplotlib
|
||||
# via sentence-transformers
|
||||
# via streamlit
|
||||
platformdirs==4.2.2
|
||||
# via pylint
|
||||
@ -436,12 +459,14 @@ pyyaml==6.0.1
|
||||
# via langchain
|
||||
# via langchain-community
|
||||
# via langchain-core
|
||||
# via transformers
|
||||
# via uvicorn
|
||||
referencing==0.35.1
|
||||
# via jsonschema
|
||||
# via jsonschema-specifications
|
||||
regex==2024.5.15
|
||||
# via tiktoken
|
||||
# via transformers
|
||||
requests==2.32.3
|
||||
# via burr
|
||||
# via free-proxy
|
||||
@ -456,6 +481,7 @@ requests==2.32.3
|
||||
# via sphinx
|
||||
# via streamlit
|
||||
# via tiktoken
|
||||
# via transformers
|
||||
rich==13.7.1
|
||||
# via streamlit
|
||||
# via typer
|
||||
@ -466,8 +492,17 @@ rsa==4.9
|
||||
# via google-auth
|
||||
s3transfer==0.10.2
|
||||
# via boto3
|
||||
safetensors==0.4.3
|
||||
# via transformers
|
||||
scikit-learn==1.5.1
|
||||
# via sentence-transformers
|
||||
scipy==1.13.1
|
||||
# via scikit-learn
|
||||
# via sentence-transformers
|
||||
semchunk==2.2.0
|
||||
# via scrapegraphai
|
||||
sentence-transformers==3.0.1
|
||||
# via langchain-huggingface
|
||||
sf-hamilton==1.72.1
|
||||
# via burr
|
||||
shapely==2.0.5
|
||||
@ -513,16 +548,22 @@ starlette==0.37.2
|
||||
# via fastapi
|
||||
streamlit==1.36.0
|
||||
# via burr
|
||||
sympy==1.13.1
|
||||
# via torch
|
||||
tenacity==8.5.0
|
||||
# via langchain
|
||||
# via langchain-community
|
||||
# via langchain-core
|
||||
# via streamlit
|
||||
threadpoolctl==3.5.0
|
||||
# via scikit-learn
|
||||
tiktoken==0.7.0
|
||||
# via langchain-openai
|
||||
# via scrapegraphai
|
||||
tokenizers==0.19.1
|
||||
# via anthropic
|
||||
# via langchain-huggingface
|
||||
# via transformers
|
||||
toml==0.10.2
|
||||
# via streamlit
|
||||
tomli==2.0.1
|
||||
@ -532,6 +573,8 @@ tomlkit==0.13.0
|
||||
# via pylint
|
||||
toolz==0.12.1
|
||||
# via altair
|
||||
torch==2.2.2
|
||||
# via sentence-transformers
|
||||
tornado==6.4.1
|
||||
# via streamlit
|
||||
tqdm==4.66.4
|
||||
@ -541,6 +584,11 @@ tqdm==4.66.4
|
||||
# via openai
|
||||
# via scrapegraphai
|
||||
# via semchunk
|
||||
# via sentence-transformers
|
||||
# via transformers
|
||||
transformers==4.43.3
|
||||
# via langchain-huggingface
|
||||
# via sentence-transformers
|
||||
typer==0.12.3
|
||||
# via fastapi-cli
|
||||
typing-extensions==4.12.2
|
||||
@ -562,6 +610,7 @@ typing-extensions==4.12.2
|
||||
# via sqlalchemy
|
||||
# via starlette
|
||||
# via streamlit
|
||||
# via torch
|
||||
# via typer
|
||||
# via typing-inspect
|
||||
# via uvicorn
|
||||
|
||||
@ -63,6 +63,8 @@ faiss-cpu==1.8.0.post1
|
||||
# via scrapegraphai
|
||||
filelock==3.15.4
|
||||
# via huggingface-hub
|
||||
# via torch
|
||||
# via transformers
|
||||
fireworks-ai==0.14.0
|
||||
# via langchain-fireworks
|
||||
free-proxy==1.1.1
|
||||
@ -72,6 +74,7 @@ frozenlist==1.4.1
|
||||
# via aiosignal
|
||||
fsspec==2024.6.1
|
||||
# via huggingface-hub
|
||||
# via torch
|
||||
google==3.0.0
|
||||
# via scrapegraphai
|
||||
google-ai-generativelanguage==0.6.6
|
||||
@ -128,6 +131,7 @@ graphviz==0.20.3
|
||||
# via scrapegraphai
|
||||
greenlet==3.0.3
|
||||
# via playwright
|
||||
# via sqlalchemy
|
||||
groq==0.9.0
|
||||
# via langchain-groq
|
||||
grpc-google-iam-v1==0.13.1
|
||||
@ -156,17 +160,24 @@ httpx==0.27.0
|
||||
httpx-sse==0.4.0
|
||||
# via fireworks-ai
|
||||
huggingface-hub==0.24.0
|
||||
# via langchain-huggingface
|
||||
# via sentence-transformers
|
||||
# via tokenizers
|
||||
# via transformers
|
||||
idna==3.7
|
||||
# via anyio
|
||||
# via httpx
|
||||
# via requests
|
||||
# via yarl
|
||||
jinja2==3.1.4
|
||||
# via torch
|
||||
jiter==0.5.0
|
||||
# via anthropic
|
||||
jmespath==1.0.1
|
||||
# via boto3
|
||||
# via botocore
|
||||
joblib==1.4.2
|
||||
# via scikit-learn
|
||||
jsonpatch==1.33
|
||||
# via langchain-core
|
||||
jsonpointer==3.0.0
|
||||
@ -189,6 +200,7 @@ langchain-core==0.2.22
|
||||
# via langchain-google-genai
|
||||
# via langchain-google-vertexai
|
||||
# via langchain-groq
|
||||
# via langchain-huggingface
|
||||
# via langchain-nvidia-ai-endpoints
|
||||
# via langchain-openai
|
||||
# via langchain-text-splitters
|
||||
@ -200,6 +212,8 @@ langchain-google-vertexai==1.0.7
|
||||
# via scrapegraphai
|
||||
langchain-groq==0.1.6
|
||||
# via scrapegraphai
|
||||
langchain-huggingface==0.0.3
|
||||
# via scrapegraphai
|
||||
langchain-nvidia-ai-endpoints==0.1.6
|
||||
# via scrapegraphai
|
||||
langchain-openai==0.1.17
|
||||
@ -212,12 +226,16 @@ langsmith==0.1.93
|
||||
# via langchain-core
|
||||
lxml==5.2.2
|
||||
# via free-proxy
|
||||
markupsafe==2.1.5
|
||||
# via jinja2
|
||||
marshmallow==3.21.3
|
||||
# via dataclasses-json
|
||||
minify-html==0.15.0
|
||||
# via scrapegraphai
|
||||
mpire==2.10.2
|
||||
# via semchunk
|
||||
mpmath==1.3.0
|
||||
# via sympy
|
||||
multidict==6.0.5
|
||||
# via aiohttp
|
||||
# via yarl
|
||||
@ -225,13 +243,19 @@ multiprocess==0.70.16
|
||||
# via mpire
|
||||
mypy-extensions==1.0.0
|
||||
# via typing-inspect
|
||||
networkx==3.2.1
|
||||
# via torch
|
||||
numpy==1.26.4
|
||||
# via faiss-cpu
|
||||
# via langchain
|
||||
# via langchain-aws
|
||||
# via langchain-community
|
||||
# via pandas
|
||||
# via scikit-learn
|
||||
# via scipy
|
||||
# via sentence-transformers
|
||||
# via shapely
|
||||
# via transformers
|
||||
openai==1.37.0
|
||||
# via langchain-fireworks
|
||||
# via langchain-openai
|
||||
@ -244,11 +268,13 @@ packaging==24.1
|
||||
# via huggingface-hub
|
||||
# via langchain-core
|
||||
# via marshmallow
|
||||
# via transformers
|
||||
pandas==2.2.2
|
||||
# via scrapegraphai
|
||||
pillow==10.4.0
|
||||
# via fireworks-ai
|
||||
# via langchain-nvidia-ai-endpoints
|
||||
# via sentence-transformers
|
||||
playwright==1.45.0
|
||||
# via scrapegraphai
|
||||
# via undetected-playwright
|
||||
@ -303,8 +329,10 @@ pyyaml==6.0.1
|
||||
# via langchain
|
||||
# via langchain-community
|
||||
# via langchain-core
|
||||
# via transformers
|
||||
regex==2024.5.15
|
||||
# via tiktoken
|
||||
# via transformers
|
||||
requests==2.32.3
|
||||
# via free-proxy
|
||||
# via google-api-core
|
||||
@ -316,12 +344,22 @@ requests==2.32.3
|
||||
# via langchain-fireworks
|
||||
# via langsmith
|
||||
# via tiktoken
|
||||
# via transformers
|
||||
rsa==4.9
|
||||
# via google-auth
|
||||
s3transfer==0.10.2
|
||||
# via boto3
|
||||
safetensors==0.4.3
|
||||
# via transformers
|
||||
scikit-learn==1.5.1
|
||||
# via sentence-transformers
|
||||
scipy==1.13.1
|
||||
# via scikit-learn
|
||||
# via sentence-transformers
|
||||
semchunk==2.2.0
|
||||
# via scrapegraphai
|
||||
sentence-transformers==3.0.1
|
||||
# via langchain-huggingface
|
||||
shapely==2.0.5
|
||||
# via google-cloud-aiplatform
|
||||
six==1.16.0
|
||||
@ -337,15 +375,23 @@ soupsieve==2.5
|
||||
sqlalchemy==2.0.31
|
||||
# via langchain
|
||||
# via langchain-community
|
||||
sympy==1.13.1
|
||||
# via torch
|
||||
tenacity==8.5.0
|
||||
# via langchain
|
||||
# via langchain-community
|
||||
# via langchain-core
|
||||
threadpoolctl==3.5.0
|
||||
# via scikit-learn
|
||||
tiktoken==0.7.0
|
||||
# via langchain-openai
|
||||
# via scrapegraphai
|
||||
tokenizers==0.19.1
|
||||
# via anthropic
|
||||
# via langchain-huggingface
|
||||
# via transformers
|
||||
torch==2.2.2
|
||||
# via sentence-transformers
|
||||
tqdm==4.66.4
|
||||
# via google-generativeai
|
||||
# via huggingface-hub
|
||||
@ -353,6 +399,11 @@ tqdm==4.66.4
|
||||
# via openai
|
||||
# via scrapegraphai
|
||||
# via semchunk
|
||||
# via sentence-transformers
|
||||
# via transformers
|
||||
transformers==4.43.3
|
||||
# via langchain-huggingface
|
||||
# via sentence-transformers
|
||||
typing-extensions==4.12.2
|
||||
# via anthropic
|
||||
# via anyio
|
||||
@ -364,6 +415,7 @@ typing-extensions==4.12.2
|
||||
# via pydantic-core
|
||||
# via pyee
|
||||
# via sqlalchemy
|
||||
# via torch
|
||||
# via typing-inspect
|
||||
typing-inspect==0.9.0
|
||||
# via dataclasses-json
|
||||
|
||||
@ -22,3 +22,4 @@ undetected-playwright>=0.3.0
|
||||
semchunk>=1.0.1
|
||||
langchain-fireworks>=0.1.3
|
||||
langchain-community>=0.2.9
|
||||
langchain-huggingface>=0.0.3
|
||||
|
||||
@ -3,33 +3,28 @@ AbstractGraph Module
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Union
|
||||
from typing import Optional
|
||||
import uuid
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain_community.chat_models import ChatOllama
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
from langchain_aws import BedrockEmbeddings
|
||||
from langchain_community.embeddings import HuggingFaceHubEmbeddings, OllamaEmbeddings
|
||||
from langchain_aws import BedrockEmbeddings, ChatBedrock
|
||||
from langchain_huggingface import ChatHuggingFace, HuggingFaceEmbeddings
|
||||
from langchain_community.embeddings import OllamaEmbeddings
|
||||
from langchain_google_genai import GoogleGenerativeAIEmbeddings
|
||||
from langchain_google_vertexai import VertexAIEmbeddings
|
||||
from langchain_google_vertexai import ChatVertexAI, VertexAIEmbeddings
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings
|
||||
from langchain_fireworks import FireworksEmbeddings
|
||||
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
|
||||
from langchain_fireworks import FireworksEmbeddings, ChatFireworks
|
||||
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings, ChatOpenAI, AzureChatOpenAI
|
||||
from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings
|
||||
from ..helpers import models_tokens
|
||||
from ..models import (
|
||||
Anthropic,
|
||||
AzureOpenAI,
|
||||
Bedrock,
|
||||
Gemini,
|
||||
Groq,
|
||||
HuggingFace,
|
||||
OneApi,
|
||||
Fireworks,
|
||||
VertexAI,
|
||||
Nvidia
|
||||
Nvidia,
|
||||
DeepSeek
|
||||
)
|
||||
from ..models.ernie import Ernie
|
||||
from langchain.chat_models import init_chat_model
|
||||
@ -37,7 +32,6 @@ from langchain.chat_models import init_chat_model
|
||||
from ..utils.logging import set_verbosity_debug, set_verbosity_warning, set_verbosity_info
|
||||
|
||||
from ..helpers import models_tokens
|
||||
from ..models import AzureOpenAI, Bedrock, Gemini, Groq, HuggingFace, Anthropic, DeepSeek
|
||||
|
||||
|
||||
class AbstractGraph(ABC):
|
||||
@ -181,7 +175,8 @@ class AbstractGraph(ABC):
|
||||
llm_params["model"] = "/".join(llm_params["model"].split("/")[1:])
|
||||
except KeyError as exc:
|
||||
raise KeyError("Model not supported") from exc
|
||||
return Fireworks(llm_params)
|
||||
llm_params["model_provider"] = "fireworks"
|
||||
return init_chat_model(**llm_params)
|
||||
elif "azure" in llm_params["model"]:
|
||||
# take the model after the last dash
|
||||
llm_params["model"] = llm_params["model"].split("/")[-1]
|
||||
@ -189,7 +184,8 @@ class AbstractGraph(ABC):
|
||||
self.model_token = models_tokens["azure"][llm_params["model"]]
|
||||
except KeyError as exc:
|
||||
raise KeyError("Model not supported") from exc
|
||||
return AzureOpenAI(llm_params)
|
||||
llm_params["model_provider"] = "azure_openai"
|
||||
return init_chat_model(**llm_params)
|
||||
elif "nvidia" in llm_params["model"]:
|
||||
try:
|
||||
self.model_token = models_tokens["nvidia"][llm_params["model"].split("/")[-1]]
|
||||
@ -203,20 +199,23 @@ class AbstractGraph(ABC):
|
||||
self.model_token = models_tokens["gemini"][llm_params["model"]]
|
||||
except KeyError as exc:
|
||||
raise KeyError("Model not supported") from exc
|
||||
return Gemini(llm_params)
|
||||
llm_params["model_provider"] = "google_genai "
|
||||
return init_chat_model(**llm_params)
|
||||
elif llm_params["model"].startswith("claude"):
|
||||
llm_params["model"] = llm_params["model"].split("/")[-1]
|
||||
try:
|
||||
self.model_token = models_tokens["claude"][llm_params["model"]]
|
||||
except KeyError as exc:
|
||||
raise KeyError("Model not supported") from exc
|
||||
return Anthropic(llm_params)
|
||||
llm_params["model_provider"] = "anthropic"
|
||||
return init_chat_model(**llm_params)
|
||||
elif llm_params["model"].startswith("vertexai"):
|
||||
try:
|
||||
self.model_token = models_tokens["vertexai"][llm_params["model"]]
|
||||
except KeyError as exc:
|
||||
raise KeyError("Model not supported") from exc
|
||||
return VertexAI(llm_params)
|
||||
llm_params["model_provider"] = "google_vertexai"
|
||||
return init_chat_model(**llm_params)
|
||||
|
||||
elif "ollama" in llm_params["model"]:
|
||||
llm_params["model"] = llm_params["model"].split("ollama/")[-1]
|
||||
@ -246,7 +245,8 @@ class AbstractGraph(ABC):
|
||||
except KeyError:
|
||||
print("model not found, using default token size (8192)")
|
||||
self.model_token = 8192
|
||||
return HuggingFace(llm_params)
|
||||
llm_params["model_provider"] = "hugging_face"
|
||||
return init_chat_model(**llm_params)
|
||||
elif "groq" in llm_params["model"]:
|
||||
llm_params["model"] = llm_params["model"].split("/")[-1]
|
||||
|
||||
@ -255,7 +255,8 @@ class AbstractGraph(ABC):
|
||||
except KeyError:
|
||||
print("model not found, using default token size (8192)")
|
||||
self.model_token = 8192
|
||||
return Groq(llm_params)
|
||||
llm_params["model_provider"] = "groq"
|
||||
return init_chat_model(**llm_params)
|
||||
elif "bedrock" in llm_params["model"]:
|
||||
llm_params["model"] = llm_params["model"].split("/")[-1]
|
||||
model_id = llm_params["model"]
|
||||
@ -265,22 +266,16 @@ class AbstractGraph(ABC):
|
||||
except KeyError:
|
||||
print("model not found, using default token size (8192)")
|
||||
self.model_token = 8192
|
||||
return Bedrock(
|
||||
{
|
||||
"client": client,
|
||||
"model_id": model_id,
|
||||
"model_kwargs": {
|
||||
"temperature": llm_params["temperature"],
|
||||
},
|
||||
}
|
||||
)
|
||||
llm_params["model_provider"] = "bedrock"
|
||||
return init_chat_model(**llm_params)
|
||||
elif "claude-3-" in llm_params["model"]:
|
||||
try:
|
||||
self.model_token = models_tokens["claude"]["claude3"]
|
||||
except KeyError:
|
||||
print("model not found, using default token size (8192)")
|
||||
self.model_token = 8192
|
||||
return Anthropic(llm_params)
|
||||
llm_params["model_provider"] = "anthropic"
|
||||
return init_chat_model(**llm_params)
|
||||
elif "deepseek" in llm_params["model"]:
|
||||
try:
|
||||
self.model_token = models_tokens["deepseek"][llm_params["model"]]
|
||||
@ -308,7 +303,7 @@ class AbstractGraph(ABC):
|
||||
Raises:
|
||||
ValueError: If the model is not supported.
|
||||
"""
|
||||
if isinstance(self.llm_model, Gemini):
|
||||
if isinstance(self.llm_model, ChatGoogleGenerativeAI):
|
||||
return GoogleGenerativeAIEmbeddings(
|
||||
google_api_key=llm_config["api_key"], model="models/embedding-001"
|
||||
)
|
||||
@ -317,13 +312,13 @@ class AbstractGraph(ABC):
|
||||
base_url=self.llm_model.openai_api_base)
|
||||
elif isinstance(self.llm_model, DeepSeek):
|
||||
return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key)
|
||||
elif isinstance(self.llm_model, VertexAI):
|
||||
elif isinstance(self.llm_model, ChatVertexAI):
|
||||
return VertexAIEmbeddings()
|
||||
elif isinstance(self.llm_model, AzureOpenAIEmbeddings):
|
||||
return self.llm_model
|
||||
elif isinstance(self.llm_model, AzureOpenAI):
|
||||
elif isinstance(self.llm_model, AzureChatOpenAI):
|
||||
return AzureOpenAIEmbeddings()
|
||||
elif isinstance(self.llm_model, Fireworks):
|
||||
elif isinstance(self.llm_model, ChatFireworks):
|
||||
return FireworksEmbeddings(model=self.llm_model.model_name)
|
||||
elif isinstance(self.llm_model, Nvidia):
|
||||
return NVIDIAEmbeddings(model=self.llm_model.model_name)
|
||||
@ -335,9 +330,9 @@ class AbstractGraph(ABC):
|
||||
params.pop("temperature", None)
|
||||
|
||||
return OllamaEmbeddings(**params)
|
||||
elif isinstance(self.llm_model, HuggingFace):
|
||||
return HuggingFaceHubEmbeddings(model=self.llm_model.model)
|
||||
elif isinstance(self.llm_model, Bedrock):
|
||||
elif isinstance(self.llm_model, ChatHuggingFace):
|
||||
return HuggingFaceEmbeddings(model=self.llm_model.model)
|
||||
elif isinstance(self.llm_model, ChatBedrock):
|
||||
return BedrockEmbeddings(client=None, model_id=self.llm_model.model_id)
|
||||
else:
|
||||
raise ValueError("Embedding Model missing or not supported")
|
||||
@ -384,7 +379,7 @@ class AbstractGraph(ABC):
|
||||
models_tokens["hugging_face"][embedder_params["model"]]
|
||||
except KeyError as exc:
|
||||
raise KeyError("Model not supported") from exc
|
||||
return HuggingFaceHubEmbeddings(model=embedder_params["model"])
|
||||
return HuggingFaceEmbeddings(model=embedder_params["model"])
|
||||
elif "fireworks" in embedder_params["model"]:
|
||||
embedder_params["model"] = "/".join(embedder_params["model"].split("/")[1:])
|
||||
try:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user