mirror of
https://github.com/VinciGit00/Scrapegraph-ai.git
synced 2026-07-04 21:00:36 +08:00
refactor(Ollama): integrate new LangChain chat init
This commit is contained in:
parent
ecc5e35d5f
commit
d177afb68b
@ -12,6 +12,7 @@ aiofiles==24.1.0
|
||||
# via burr
|
||||
aiohttp==3.9.5
|
||||
# via langchain
|
||||
# via langchain-community
|
||||
# via langchain-fireworks
|
||||
# via langchain-nvidia-ai-endpoints
|
||||
aiosignal==1.3.1
|
||||
@ -179,6 +180,7 @@ graphviz==0.20.3
|
||||
# via scrapegraphai
|
||||
greenlet==3.0.3
|
||||
# via playwright
|
||||
# via sqlalchemy
|
||||
groq==0.9.0
|
||||
# via langchain-groq
|
||||
grpc-google-iam-v1==0.13.1
|
||||
@ -262,6 +264,7 @@ langchain-core==0.2.22
|
||||
# via langchain
|
||||
# via langchain-anthropic
|
||||
# via langchain-aws
|
||||
# via langchain-community
|
||||
# via langchain-fireworks
|
||||
# via langchain-google-genai
|
||||
# via langchain-google-vertexai
|
||||
@ -285,6 +288,7 @@ langchain-text-splitters==0.2.2
|
||||
# via langchain
|
||||
langsmith==0.1.93
|
||||
# via langchain
|
||||
# via langchain-community
|
||||
# via langchain-core
|
||||
loguru==0.7.2
|
||||
# via burr
|
||||
@ -319,6 +323,7 @@ numpy==1.26.4
|
||||
# via faiss-cpu
|
||||
# via langchain
|
||||
# via langchain-aws
|
||||
# via langchain-community
|
||||
# via matplotlib
|
||||
# via pandas
|
||||
# via pyarrow
|
||||
@ -339,6 +344,7 @@ packaging==24.1
|
||||
# via google-cloud-bigquery
|
||||
# via huggingface-hub
|
||||
# via langchain-core
|
||||
# via marshmallow
|
||||
# via matplotlib
|
||||
# via pytest
|
||||
# via sphinx
|
||||
@ -429,6 +435,7 @@ pytz==2024.1
|
||||
pyyaml==6.0.1
|
||||
# via huggingface-hub
|
||||
# via langchain
|
||||
# via langchain-community
|
||||
# via langchain-core
|
||||
# via uvicorn
|
||||
referencing==0.35.1
|
||||
@ -444,6 +451,7 @@ requests==2.32.3
|
||||
# via google-cloud-storage
|
||||
# via huggingface-hub
|
||||
# via langchain
|
||||
# via langchain-community
|
||||
# via langchain-fireworks
|
||||
# via langsmith
|
||||
# via sphinx
|
||||
@ -501,12 +509,14 @@ sphinxcontrib-serializinghtml==1.1.10
|
||||
# via sphinx
|
||||
sqlalchemy==2.0.31
|
||||
# via langchain
|
||||
# via langchain-community
|
||||
starlette==0.37.2
|
||||
# via fastapi
|
||||
streamlit==1.36.0
|
||||
# via burr
|
||||
tenacity==8.5.0
|
||||
# via langchain
|
||||
# via langchain-community
|
||||
# via langchain-core
|
||||
# via streamlit
|
||||
tiktoken==0.7.0
|
||||
@ -557,6 +567,7 @@ typing-extensions==4.12.2
|
||||
# via typing-inspect
|
||||
# via uvicorn
|
||||
typing-inspect==0.9.0
|
||||
# via dataclasses-json
|
||||
# via sf-hamilton
|
||||
tzdata==2024.1
|
||||
# via pandas
|
||||
|
||||
@ -10,6 +10,7 @@
|
||||
-e file:.
|
||||
aiohttp==3.9.5
|
||||
# via langchain
|
||||
# via langchain-community
|
||||
# via langchain-fireworks
|
||||
# via langchain-nvidia-ai-endpoints
|
||||
aiosignal==1.3.1
|
||||
@ -127,6 +128,7 @@ graphviz==0.20.3
|
||||
# via scrapegraphai
|
||||
greenlet==3.0.3
|
||||
# via playwright
|
||||
# via sqlalchemy
|
||||
groq==0.9.0
|
||||
# via langchain-groq
|
||||
grpc-google-iam-v1==0.13.1
|
||||
@ -183,6 +185,7 @@ langchain-core==0.2.22
|
||||
# via langchain
|
||||
# via langchain-anthropic
|
||||
# via langchain-aws
|
||||
# via langchain-community
|
||||
# via langchain-fireworks
|
||||
# via langchain-google-genai
|
||||
# via langchain-google-vertexai
|
||||
@ -206,6 +209,7 @@ langchain-text-splitters==0.2.2
|
||||
# via langchain
|
||||
langsmith==0.1.93
|
||||
# via langchain
|
||||
# via langchain-community
|
||||
# via langchain-core
|
||||
lxml==5.2.2
|
||||
# via free-proxy
|
||||
@ -226,6 +230,7 @@ numpy==1.26.4
|
||||
# via faiss-cpu
|
||||
# via langchain
|
||||
# via langchain-aws
|
||||
# via langchain-community
|
||||
# via pandas
|
||||
# via shapely
|
||||
openai==1.37.0
|
||||
@ -239,6 +244,7 @@ packaging==24.1
|
||||
# via google-cloud-bigquery
|
||||
# via huggingface-hub
|
||||
# via langchain-core
|
||||
# via marshmallow
|
||||
pandas==2.2.2
|
||||
# via scrapegraphai
|
||||
pillow==10.4.0
|
||||
@ -296,6 +302,7 @@ pytz==2024.1
|
||||
pyyaml==6.0.1
|
||||
# via huggingface-hub
|
||||
# via langchain
|
||||
# via langchain-community
|
||||
# via langchain-core
|
||||
regex==2024.5.15
|
||||
# via tiktoken
|
||||
@ -306,6 +313,7 @@ requests==2.32.3
|
||||
# via google-cloud-storage
|
||||
# via huggingface-hub
|
||||
# via langchain
|
||||
# via langchain-community
|
||||
# via langchain-fireworks
|
||||
# via langsmith
|
||||
# via tiktoken
|
||||
@ -332,6 +340,7 @@ sqlalchemy==2.0.31
|
||||
# via langchain-community
|
||||
tenacity==8.5.0
|
||||
# via langchain
|
||||
# via langchain-community
|
||||
# via langchain-core
|
||||
tiktoken==0.7.0
|
||||
# via langchain-openai
|
||||
@ -356,6 +365,9 @@ typing-extensions==4.12.2
|
||||
# via pydantic-core
|
||||
# via pyee
|
||||
# via sqlalchemy
|
||||
# via typing-inspect
|
||||
typing-inspect==0.9.0
|
||||
# via dataclasses-json
|
||||
tzdata==2024.1
|
||||
# via pandas
|
||||
undetected-playwright==0.3.0
|
||||
|
||||
@ -7,6 +7,8 @@ from typing import Optional, Union
|
||||
import uuid
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain_community.chat_models import ChatOllama
|
||||
|
||||
from langchain_aws import BedrockEmbeddings
|
||||
from langchain_community.embeddings import HuggingFaceHubEmbeddings, OllamaEmbeddings
|
||||
from langchain_google_genai import GoogleGenerativeAIEmbeddings
|
||||
@ -19,22 +21,23 @@ from ..helpers import models_tokens
|
||||
from ..models import (
|
||||
Anthropic,
|
||||
AzureOpenAI,
|
||||
OpenAI,
|
||||
Bedrock,
|
||||
Gemini,
|
||||
Groq,
|
||||
HuggingFace,
|
||||
Ollama,
|
||||
OpenAI,
|
||||
OneApi,
|
||||
Fireworks,
|
||||
VertexAI,
|
||||
Nvidia
|
||||
)
|
||||
from ..models.ernie import Ernie
|
||||
from langchain.chat_models import init_chat_model
|
||||
|
||||
from ..utils.logging import set_verbosity_debug, set_verbosity_warning, set_verbosity_info
|
||||
|
||||
from ..helpers import models_tokens
|
||||
from ..models import AzureOpenAI, Bedrock, Gemini, Groq, HuggingFace, Ollama, OpenAI, Anthropic, DeepSeek
|
||||
from ..models import AzureOpenAI, OpenAI, Bedrock, Gemini, Groq, HuggingFace, Anthropic, DeepSeek
|
||||
|
||||
|
||||
class AbstractGraph(ABC):
|
||||
@ -213,8 +216,10 @@ class AbstractGraph(ABC):
|
||||
except KeyError as exc:
|
||||
raise KeyError("Model not supported") from exc
|
||||
return VertexAI(llm_params)
|
||||
|
||||
elif "ollama" in llm_params["model"]:
|
||||
llm_params["model"] = llm_params["model"].split("ollama/")[-1]
|
||||
llm_params["model_provider"] = "ollama"
|
||||
|
||||
# allow user to set model_tokens in config
|
||||
try:
|
||||
@ -231,7 +236,8 @@ class AbstractGraph(ABC):
|
||||
except AttributeError:
|
||||
self.model_token = 8192
|
||||
|
||||
return Ollama(llm_params)
|
||||
return init_chat_model(**llm_params)
|
||||
|
||||
elif "hugging_face" in llm_params["model"]:
|
||||
llm_params["model"] = llm_params["model"].split("/")[-1]
|
||||
try:
|
||||
@ -320,7 +326,7 @@ class AbstractGraph(ABC):
|
||||
return FireworksEmbeddings(model=self.llm_model.model_name)
|
||||
elif isinstance(self.llm_model, Nvidia):
|
||||
return NVIDIAEmbeddings(model=self.llm_model.model_name)
|
||||
elif isinstance(self.llm_model, Ollama):
|
||||
elif isinstance(self.llm_model, ChatOllama):
|
||||
# unwrap the kwargs from the model whihc is a dict
|
||||
params = self.llm_model._lc_kwargs
|
||||
# remove streaming and temperature
|
||||
|
||||
@ -1,17 +0,0 @@
|
||||
"""
|
||||
Ollama Module
|
||||
"""
|
||||
from langchain_community.chat_models import ChatOllama
|
||||
|
||||
|
||||
class Ollama(ChatOllama):
|
||||
"""
|
||||
A wrapper for the ChatOllama class that provides default configuration
|
||||
and could be extended with additional methods if needed.
|
||||
|
||||
Args:
|
||||
llm_config (dict): Configuration parameters for the language model.
|
||||
"""
|
||||
|
||||
def __init__(self, llm_config: dict):
|
||||
super().__init__(**llm_config)
|
||||
Loading…
Reference in New Issue
Block a user