refactor(Ollama): integrate new LangChain chat init

This commit is contained in:
Federico Aguzzi 2024-07-29 11:17:47 +02:00
parent ecc5e35d5f
commit d177afb68b
4 changed files with 34 additions and 22 deletions

View File

@ -12,6 +12,7 @@ aiofiles==24.1.0
# via burr
aiohttp==3.9.5
# via langchain
# via langchain-community
# via langchain-fireworks
# via langchain-nvidia-ai-endpoints
aiosignal==1.3.1
@ -179,6 +180,7 @@ graphviz==0.20.3
# via scrapegraphai
greenlet==3.0.3
# via playwright
# via sqlalchemy
groq==0.9.0
# via langchain-groq
grpc-google-iam-v1==0.13.1
@ -262,6 +264,7 @@ langchain-core==0.2.22
# via langchain
# via langchain-anthropic
# via langchain-aws
# via langchain-community
# via langchain-fireworks
# via langchain-google-genai
# via langchain-google-vertexai
@ -285,6 +288,7 @@ langchain-text-splitters==0.2.2
# via langchain
langsmith==0.1.93
# via langchain
# via langchain-community
# via langchain-core
loguru==0.7.2
# via burr
@ -319,6 +323,7 @@ numpy==1.26.4
# via faiss-cpu
# via langchain
# via langchain-aws
# via langchain-community
# via matplotlib
# via pandas
# via pyarrow
@ -339,6 +344,7 @@ packaging==24.1
# via google-cloud-bigquery
# via huggingface-hub
# via langchain-core
# via marshmallow
# via matplotlib
# via pytest
# via sphinx
@ -429,6 +435,7 @@ pytz==2024.1
pyyaml==6.0.1
# via huggingface-hub
# via langchain
# via langchain-community
# via langchain-core
# via uvicorn
referencing==0.35.1
@ -444,6 +451,7 @@ requests==2.32.3
# via google-cloud-storage
# via huggingface-hub
# via langchain
# via langchain-community
# via langchain-fireworks
# via langsmith
# via sphinx
@ -501,12 +509,14 @@ sphinxcontrib-serializinghtml==1.1.10
# via sphinx
sqlalchemy==2.0.31
# via langchain
# via langchain-community
starlette==0.37.2
# via fastapi
streamlit==1.36.0
# via burr
tenacity==8.5.0
# via langchain
# via langchain-community
# via langchain-core
# via streamlit
tiktoken==0.7.0
@ -557,6 +567,7 @@ typing-extensions==4.12.2
# via typing-inspect
# via uvicorn
typing-inspect==0.9.0
# via dataclasses-json
# via sf-hamilton
tzdata==2024.1
# via pandas

View File

@ -10,6 +10,7 @@
-e file:.
aiohttp==3.9.5
# via langchain
# via langchain-community
# via langchain-fireworks
# via langchain-nvidia-ai-endpoints
aiosignal==1.3.1
@ -127,6 +128,7 @@ graphviz==0.20.3
# via scrapegraphai
greenlet==3.0.3
# via playwright
# via sqlalchemy
groq==0.9.0
# via langchain-groq
grpc-google-iam-v1==0.13.1
@ -183,6 +185,7 @@ langchain-core==0.2.22
# via langchain
# via langchain-anthropic
# via langchain-aws
# via langchain-community
# via langchain-fireworks
# via langchain-google-genai
# via langchain-google-vertexai
@ -206,6 +209,7 @@ langchain-text-splitters==0.2.2
# via langchain
langsmith==0.1.93
# via langchain
# via langchain-community
# via langchain-core
lxml==5.2.2
# via free-proxy
@ -226,6 +230,7 @@ numpy==1.26.4
# via faiss-cpu
# via langchain
# via langchain-aws
# via langchain-community
# via pandas
# via shapely
openai==1.37.0
@ -239,6 +244,7 @@ packaging==24.1
# via google-cloud-bigquery
# via huggingface-hub
# via langchain-core
# via marshmallow
pandas==2.2.2
# via scrapegraphai
pillow==10.4.0
@ -296,6 +302,7 @@ pytz==2024.1
pyyaml==6.0.1
# via huggingface-hub
# via langchain
# via langchain-community
# via langchain-core
regex==2024.5.15
# via tiktoken
@ -306,6 +313,7 @@ requests==2.32.3
# via google-cloud-storage
# via huggingface-hub
# via langchain
# via langchain-community
# via langchain-fireworks
# via langsmith
# via tiktoken
@ -332,6 +340,7 @@ sqlalchemy==2.0.31
# via langchain-community
tenacity==8.5.0
# via langchain
# via langchain-community
# via langchain-core
tiktoken==0.7.0
# via langchain-openai
@ -356,6 +365,9 @@ typing-extensions==4.12.2
# via pydantic-core
# via pyee
# via sqlalchemy
# via typing-inspect
typing-inspect==0.9.0
# via dataclasses-json
tzdata==2024.1
# via pandas
undetected-playwright==0.3.0

View File

@ -7,6 +7,8 @@ from typing import Optional, Union
import uuid
from pydantic import BaseModel
from langchain_community.chat_models import ChatOllama
from langchain_aws import BedrockEmbeddings
from langchain_community.embeddings import HuggingFaceHubEmbeddings, OllamaEmbeddings
from langchain_google_genai import GoogleGenerativeAIEmbeddings
@ -19,22 +21,23 @@ from ..helpers import models_tokens
from ..models import (
Anthropic,
AzureOpenAI,
OpenAI,
Bedrock,
Gemini,
Groq,
HuggingFace,
Ollama,
OpenAI,
OneApi,
Fireworks,
VertexAI,
Nvidia
)
from ..models.ernie import Ernie
from langchain.chat_models import init_chat_model
from ..utils.logging import set_verbosity_debug, set_verbosity_warning, set_verbosity_info
from ..helpers import models_tokens
from ..models import AzureOpenAI, Bedrock, Gemini, Groq, HuggingFace, Ollama, OpenAI, Anthropic, DeepSeek
from ..models import AzureOpenAI, OpenAI, Bedrock, Gemini, Groq, HuggingFace, Anthropic, DeepSeek
class AbstractGraph(ABC):
@ -213,8 +216,10 @@ class AbstractGraph(ABC):
except KeyError as exc:
raise KeyError("Model not supported") from exc
return VertexAI(llm_params)
elif "ollama" in llm_params["model"]:
llm_params["model"] = llm_params["model"].split("ollama/")[-1]
llm_params["model_provider"] = "ollama"
# allow user to set model_tokens in config
try:
@ -231,7 +236,8 @@ class AbstractGraph(ABC):
except AttributeError:
self.model_token = 8192
return Ollama(llm_params)
return init_chat_model(**llm_params)
elif "hugging_face" in llm_params["model"]:
llm_params["model"] = llm_params["model"].split("/")[-1]
try:
@ -320,7 +326,7 @@ class AbstractGraph(ABC):
return FireworksEmbeddings(model=self.llm_model.model_name)
elif isinstance(self.llm_model, Nvidia):
return NVIDIAEmbeddings(model=self.llm_model.model_name)
elif isinstance(self.llm_model, Ollama):
elif isinstance(self.llm_model, ChatOllama):
# unwrap the kwargs from the model whihc is a dict
params = self.llm_model._lc_kwargs
# remove streaming and temperature

View File

@ -1,17 +0,0 @@
"""
Ollama Module
"""
from langchain_community.chat_models import ChatOllama
class Ollama(ChatOllama):
"""
A wrapper for the ChatOllama class that provides default configuration
and could be extended with additional methods if needed.
Args:
llm_config (dict): Configuration parameters for the language model.
"""
def __init__(self, llm_config: dict):
super().__init__(**llm_config)