refactor(Ollama): integrate new LangChain chat init

This commit is contained in:
Federico Aguzzi 2024-07-29 11:17:47 +02:00
parent ecc5e35d5f
commit d177afb68b
4 changed files with 34 additions and 22 deletions

View File

@ -12,6 +12,7 @@ aiofiles==24.1.0
# via burr # via burr
aiohttp==3.9.5 aiohttp==3.9.5
# via langchain # via langchain
# via langchain-community
# via langchain-fireworks # via langchain-fireworks
# via langchain-nvidia-ai-endpoints # via langchain-nvidia-ai-endpoints
aiosignal==1.3.1 aiosignal==1.3.1
@ -179,6 +180,7 @@ graphviz==0.20.3
# via scrapegraphai # via scrapegraphai
greenlet==3.0.3 greenlet==3.0.3
# via playwright # via playwright
# via sqlalchemy
groq==0.9.0 groq==0.9.0
# via langchain-groq # via langchain-groq
grpc-google-iam-v1==0.13.1 grpc-google-iam-v1==0.13.1
@ -262,6 +264,7 @@ langchain-core==0.2.22
# via langchain # via langchain
# via langchain-anthropic # via langchain-anthropic
# via langchain-aws # via langchain-aws
# via langchain-community
# via langchain-fireworks # via langchain-fireworks
# via langchain-google-genai # via langchain-google-genai
# via langchain-google-vertexai # via langchain-google-vertexai
@ -285,6 +288,7 @@ langchain-text-splitters==0.2.2
# via langchain # via langchain
langsmith==0.1.93 langsmith==0.1.93
# via langchain # via langchain
# via langchain-community
# via langchain-core # via langchain-core
loguru==0.7.2 loguru==0.7.2
# via burr # via burr
@ -319,6 +323,7 @@ numpy==1.26.4
# via faiss-cpu # via faiss-cpu
# via langchain # via langchain
# via langchain-aws # via langchain-aws
# via langchain-community
# via matplotlib # via matplotlib
# via pandas # via pandas
# via pyarrow # via pyarrow
@ -339,6 +344,7 @@ packaging==24.1
# via google-cloud-bigquery # via google-cloud-bigquery
# via huggingface-hub # via huggingface-hub
# via langchain-core # via langchain-core
# via marshmallow
# via matplotlib # via matplotlib
# via pytest # via pytest
# via sphinx # via sphinx
@ -429,6 +435,7 @@ pytz==2024.1
pyyaml==6.0.1 pyyaml==6.0.1
# via huggingface-hub # via huggingface-hub
# via langchain # via langchain
# via langchain-community
# via langchain-core # via langchain-core
# via uvicorn # via uvicorn
referencing==0.35.1 referencing==0.35.1
@ -444,6 +451,7 @@ requests==2.32.3
# via google-cloud-storage # via google-cloud-storage
# via huggingface-hub # via huggingface-hub
# via langchain # via langchain
# via langchain-community
# via langchain-fireworks # via langchain-fireworks
# via langsmith # via langsmith
# via sphinx # via sphinx
@ -501,12 +509,14 @@ sphinxcontrib-serializinghtml==1.1.10
# via sphinx # via sphinx
sqlalchemy==2.0.31 sqlalchemy==2.0.31
# via langchain # via langchain
# via langchain-community
starlette==0.37.2 starlette==0.37.2
# via fastapi # via fastapi
streamlit==1.36.0 streamlit==1.36.0
# via burr # via burr
tenacity==8.5.0 tenacity==8.5.0
# via langchain # via langchain
# via langchain-community
# via langchain-core # via langchain-core
# via streamlit # via streamlit
tiktoken==0.7.0 tiktoken==0.7.0
@ -557,6 +567,7 @@ typing-extensions==4.12.2
# via typing-inspect # via typing-inspect
# via uvicorn # via uvicorn
typing-inspect==0.9.0 typing-inspect==0.9.0
# via dataclasses-json
# via sf-hamilton # via sf-hamilton
tzdata==2024.1 tzdata==2024.1
# via pandas # via pandas

View File

@ -10,6 +10,7 @@
-e file:. -e file:.
aiohttp==3.9.5 aiohttp==3.9.5
# via langchain # via langchain
# via langchain-community
# via langchain-fireworks # via langchain-fireworks
# via langchain-nvidia-ai-endpoints # via langchain-nvidia-ai-endpoints
aiosignal==1.3.1 aiosignal==1.3.1
@ -127,6 +128,7 @@ graphviz==0.20.3
# via scrapegraphai # via scrapegraphai
greenlet==3.0.3 greenlet==3.0.3
# via playwright # via playwright
# via sqlalchemy
groq==0.9.0 groq==0.9.0
# via langchain-groq # via langchain-groq
grpc-google-iam-v1==0.13.1 grpc-google-iam-v1==0.13.1
@ -183,6 +185,7 @@ langchain-core==0.2.22
# via langchain # via langchain
# via langchain-anthropic # via langchain-anthropic
# via langchain-aws # via langchain-aws
# via langchain-community
# via langchain-fireworks # via langchain-fireworks
# via langchain-google-genai # via langchain-google-genai
# via langchain-google-vertexai # via langchain-google-vertexai
@ -206,6 +209,7 @@ langchain-text-splitters==0.2.2
# via langchain # via langchain
langsmith==0.1.93 langsmith==0.1.93
# via langchain # via langchain
# via langchain-community
# via langchain-core # via langchain-core
lxml==5.2.2 lxml==5.2.2
# via free-proxy # via free-proxy
@ -226,6 +230,7 @@ numpy==1.26.4
# via faiss-cpu # via faiss-cpu
# via langchain # via langchain
# via langchain-aws # via langchain-aws
# via langchain-community
# via pandas # via pandas
# via shapely # via shapely
openai==1.37.0 openai==1.37.0
@ -239,6 +244,7 @@ packaging==24.1
# via google-cloud-bigquery # via google-cloud-bigquery
# via huggingface-hub # via huggingface-hub
# via langchain-core # via langchain-core
# via marshmallow
pandas==2.2.2 pandas==2.2.2
# via scrapegraphai # via scrapegraphai
pillow==10.4.0 pillow==10.4.0
@ -296,6 +302,7 @@ pytz==2024.1
pyyaml==6.0.1 pyyaml==6.0.1
# via huggingface-hub # via huggingface-hub
# via langchain # via langchain
# via langchain-community
# via langchain-core # via langchain-core
regex==2024.5.15 regex==2024.5.15
# via tiktoken # via tiktoken
@ -306,6 +313,7 @@ requests==2.32.3
# via google-cloud-storage # via google-cloud-storage
# via huggingface-hub # via huggingface-hub
# via langchain # via langchain
# via langchain-community
# via langchain-fireworks # via langchain-fireworks
# via langsmith # via langsmith
# via tiktoken # via tiktoken
@ -332,6 +340,7 @@ sqlalchemy==2.0.31
# via langchain-community # via langchain-community
tenacity==8.5.0 tenacity==8.5.0
# via langchain # via langchain
# via langchain-community
# via langchain-core # via langchain-core
tiktoken==0.7.0 tiktoken==0.7.0
# via langchain-openai # via langchain-openai
@ -356,6 +365,9 @@ typing-extensions==4.12.2
# via pydantic-core # via pydantic-core
# via pyee # via pyee
# via sqlalchemy # via sqlalchemy
# via typing-inspect
typing-inspect==0.9.0
# via dataclasses-json
tzdata==2024.1 tzdata==2024.1
# via pandas # via pandas
undetected-playwright==0.3.0 undetected-playwright==0.3.0

View File

@ -7,6 +7,8 @@ from typing import Optional, Union
import uuid import uuid
from pydantic import BaseModel from pydantic import BaseModel
from langchain_community.chat_models import ChatOllama
from langchain_aws import BedrockEmbeddings from langchain_aws import BedrockEmbeddings
from langchain_community.embeddings import HuggingFaceHubEmbeddings, OllamaEmbeddings from langchain_community.embeddings import HuggingFaceHubEmbeddings, OllamaEmbeddings
from langchain_google_genai import GoogleGenerativeAIEmbeddings from langchain_google_genai import GoogleGenerativeAIEmbeddings
@ -19,22 +21,23 @@ from ..helpers import models_tokens
from ..models import ( from ..models import (
Anthropic, Anthropic,
AzureOpenAI, AzureOpenAI,
OpenAI,
Bedrock, Bedrock,
Gemini, Gemini,
Groq, Groq,
HuggingFace, HuggingFace,
Ollama,
OpenAI,
OneApi, OneApi,
Fireworks, Fireworks,
VertexAI, VertexAI,
Nvidia Nvidia
) )
from ..models.ernie import Ernie from ..models.ernie import Ernie
from langchain.chat_models import init_chat_model
from ..utils.logging import set_verbosity_debug, set_verbosity_warning, set_verbosity_info from ..utils.logging import set_verbosity_debug, set_verbosity_warning, set_verbosity_info
from ..helpers import models_tokens from ..helpers import models_tokens
from ..models import AzureOpenAI, Bedrock, Gemini, Groq, HuggingFace, Ollama, OpenAI, Anthropic, DeepSeek from ..models import AzureOpenAI, OpenAI, Bedrock, Gemini, Groq, HuggingFace, Anthropic, DeepSeek
class AbstractGraph(ABC): class AbstractGraph(ABC):
@ -213,8 +216,10 @@ class AbstractGraph(ABC):
except KeyError as exc: except KeyError as exc:
raise KeyError("Model not supported") from exc raise KeyError("Model not supported") from exc
return VertexAI(llm_params) return VertexAI(llm_params)
elif "ollama" in llm_params["model"]: elif "ollama" in llm_params["model"]:
llm_params["model"] = llm_params["model"].split("ollama/")[-1] llm_params["model"] = llm_params["model"].split("ollama/")[-1]
llm_params["model_provider"] = "ollama"
# allow user to set model_tokens in config # allow user to set model_tokens in config
try: try:
@ -231,7 +236,8 @@ class AbstractGraph(ABC):
except AttributeError: except AttributeError:
self.model_token = 8192 self.model_token = 8192
return Ollama(llm_params) return init_chat_model(**llm_params)
elif "hugging_face" in llm_params["model"]: elif "hugging_face" in llm_params["model"]:
llm_params["model"] = llm_params["model"].split("/")[-1] llm_params["model"] = llm_params["model"].split("/")[-1]
try: try:
@ -320,7 +326,7 @@ class AbstractGraph(ABC):
return FireworksEmbeddings(model=self.llm_model.model_name) return FireworksEmbeddings(model=self.llm_model.model_name)
elif isinstance(self.llm_model, Nvidia): elif isinstance(self.llm_model, Nvidia):
return NVIDIAEmbeddings(model=self.llm_model.model_name) return NVIDIAEmbeddings(model=self.llm_model.model_name)
elif isinstance(self.llm_model, Ollama): elif isinstance(self.llm_model, ChatOllama):
# unwrap the kwargs from the model whihc is a dict # unwrap the kwargs from the model whihc is a dict
params = self.llm_model._lc_kwargs params = self.llm_model._lc_kwargs
# remove streaming and temperature # remove streaming and temperature

View File

@ -1,17 +0,0 @@
"""
Ollama Module
"""
from langchain_community.chat_models import ChatOllama
class Ollama(ChatOllama):
"""
A wrapper for the ChatOllama class that provides default configuration
and could be extended with additional methods if needed.
Args:
llm_config (dict): Configuration parameters for the language model.
"""
def __init__(self, llm_config: dict):
super().__init__(**llm_config)