From d177afb68be036465ede1f567d2562b145d77d36 Mon Sep 17 00:00:00 2001 From: Federico Aguzzi <62149513+f-aguzzi@users.noreply.github.com> Date: Mon, 29 Jul 2024 11:17:47 +0200 Subject: [PATCH] refactor(Ollama): integrate new LangChain chat init --- requirements-dev.lock | 11 +++++++++++ requirements.lock | 12 ++++++++++++ scrapegraphai/graphs/abstract_graph.py | 16 +++++++++++----- scrapegraphai/models/ollama.py | 17 ----------------- 4 files changed, 34 insertions(+), 22 deletions(-) delete mode 100644 scrapegraphai/models/ollama.py diff --git a/requirements-dev.lock b/requirements-dev.lock index 405395c4..bce18810 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -12,6 +12,7 @@ aiofiles==24.1.0 # via burr aiohttp==3.9.5 # via langchain + # via langchain-community # via langchain-fireworks # via langchain-nvidia-ai-endpoints aiosignal==1.3.1 @@ -179,6 +180,7 @@ graphviz==0.20.3 # via scrapegraphai greenlet==3.0.3 # via playwright + # via sqlalchemy groq==0.9.0 # via langchain-groq grpc-google-iam-v1==0.13.1 @@ -262,6 +264,7 @@ langchain-core==0.2.22 # via langchain # via langchain-anthropic # via langchain-aws + # via langchain-community # via langchain-fireworks # via langchain-google-genai # via langchain-google-vertexai @@ -285,6 +288,7 @@ langchain-text-splitters==0.2.2 # via langchain langsmith==0.1.93 # via langchain + # via langchain-community # via langchain-core loguru==0.7.2 # via burr @@ -319,6 +323,7 @@ numpy==1.26.4 # via faiss-cpu # via langchain # via langchain-aws + # via langchain-community # via matplotlib # via pandas # via pyarrow @@ -339,6 +344,7 @@ packaging==24.1 # via google-cloud-bigquery # via huggingface-hub # via langchain-core + # via marshmallow # via matplotlib # via pytest # via sphinx @@ -429,6 +435,7 @@ pytz==2024.1 pyyaml==6.0.1 # via huggingface-hub # via langchain + # via langchain-community # via langchain-core # via uvicorn referencing==0.35.1 @@ -444,6 +451,7 @@ requests==2.32.3 # via google-cloud-storage # via huggingface-hub # via langchain + # via langchain-community # via langchain-fireworks # via langsmith # via sphinx @@ -501,12 +509,14 @@ sphinxcontrib-serializinghtml==1.1.10 # via sphinx sqlalchemy==2.0.31 # via langchain + # via langchain-community starlette==0.37.2 # via fastapi streamlit==1.36.0 # via burr tenacity==8.5.0 # via langchain + # via langchain-community # via langchain-core # via streamlit tiktoken==0.7.0 @@ -557,6 +567,7 @@ typing-extensions==4.12.2 # via typing-inspect # via uvicorn typing-inspect==0.9.0 + # via dataclasses-json # via sf-hamilton tzdata==2024.1 # via pandas diff --git a/requirements.lock b/requirements.lock index 9d0602db..aa03fd14 100644 --- a/requirements.lock +++ b/requirements.lock @@ -10,6 +10,7 @@ -e file:. aiohttp==3.9.5 # via langchain + # via langchain-community # via langchain-fireworks # via langchain-nvidia-ai-endpoints aiosignal==1.3.1 @@ -127,6 +128,7 @@ graphviz==0.20.3 # via scrapegraphai greenlet==3.0.3 # via playwright + # via sqlalchemy groq==0.9.0 # via langchain-groq grpc-google-iam-v1==0.13.1 @@ -183,6 +185,7 @@ langchain-core==0.2.22 # via langchain # via langchain-anthropic # via langchain-aws + # via langchain-community # via langchain-fireworks # via langchain-google-genai # via langchain-google-vertexai @@ -206,6 +209,7 @@ langchain-text-splitters==0.2.2 # via langchain langsmith==0.1.93 # via langchain + # via langchain-community # via langchain-core lxml==5.2.2 # via free-proxy @@ -226,6 +230,7 @@ numpy==1.26.4 # via faiss-cpu # via langchain # via langchain-aws + # via langchain-community # via pandas # via shapely openai==1.37.0 @@ -239,6 +244,7 @@ packaging==24.1 # via google-cloud-bigquery # via huggingface-hub # via langchain-core + # via marshmallow pandas==2.2.2 # via scrapegraphai pillow==10.4.0 @@ -296,6 +302,7 @@ pytz==2024.1 pyyaml==6.0.1 # via huggingface-hub # via langchain + # via langchain-community # via langchain-core regex==2024.5.15 # via tiktoken @@ -306,6 +313,7 @@ requests==2.32.3 # via google-cloud-storage # via huggingface-hub # via langchain + # via langchain-community # via langchain-fireworks # via langsmith # via tiktoken @@ -332,6 +340,7 @@ sqlalchemy==2.0.31 # via langchain-community tenacity==8.5.0 # via langchain + # via langchain-community # via langchain-core tiktoken==0.7.0 # via langchain-openai @@ -356,6 +365,9 @@ typing-extensions==4.12.2 # via pydantic-core # via pyee # via sqlalchemy + # via typing-inspect +typing-inspect==0.9.0 + # via dataclasses-json tzdata==2024.1 # via pandas undetected-playwright==0.3.0 diff --git a/scrapegraphai/graphs/abstract_graph.py b/scrapegraphai/graphs/abstract_graph.py index 91396ae0..f1c9ff92 100644 --- a/scrapegraphai/graphs/abstract_graph.py +++ b/scrapegraphai/graphs/abstract_graph.py @@ -7,6 +7,8 @@ from typing import Optional, Union import uuid from pydantic import BaseModel +from langchain_community.chat_models import ChatOllama + from langchain_aws import BedrockEmbeddings from langchain_community.embeddings import HuggingFaceHubEmbeddings, OllamaEmbeddings from langchain_google_genai import GoogleGenerativeAIEmbeddings @@ -19,22 +21,23 @@ from ..helpers import models_tokens from ..models import ( Anthropic, AzureOpenAI, + OpenAI, Bedrock, Gemini, Groq, HuggingFace, - Ollama, - OpenAI, OneApi, Fireworks, VertexAI, Nvidia ) from ..models.ernie import Ernie +from langchain.chat_models import init_chat_model + from ..utils.logging import set_verbosity_debug, set_verbosity_warning, set_verbosity_info from ..helpers import models_tokens -from ..models import AzureOpenAI, Bedrock, Gemini, Groq, HuggingFace, Ollama, OpenAI, Anthropic, DeepSeek +from ..models import AzureOpenAI, OpenAI, Bedrock, Gemini, Groq, HuggingFace, Anthropic, DeepSeek class AbstractGraph(ABC): @@ -213,8 +216,10 @@ class AbstractGraph(ABC): except KeyError as exc: raise KeyError("Model not supported") from exc return VertexAI(llm_params) + elif "ollama" in llm_params["model"]: llm_params["model"] = llm_params["model"].split("ollama/")[-1] + llm_params["model_provider"] = "ollama" # allow user to set model_tokens in config try: @@ -231,7 +236,8 @@ class AbstractGraph(ABC): except AttributeError: self.model_token = 8192 - return Ollama(llm_params) + return init_chat_model(**llm_params) + elif "hugging_face" in llm_params["model"]: llm_params["model"] = llm_params["model"].split("/")[-1] try: @@ -320,7 +326,7 @@ class AbstractGraph(ABC): return FireworksEmbeddings(model=self.llm_model.model_name) elif isinstance(self.llm_model, Nvidia): return NVIDIAEmbeddings(model=self.llm_model.model_name) - elif isinstance(self.llm_model, Ollama): + elif isinstance(self.llm_model, ChatOllama): # unwrap the kwargs from the model whihc is a dict params = self.llm_model._lc_kwargs # remove streaming and temperature diff --git a/scrapegraphai/models/ollama.py b/scrapegraphai/models/ollama.py deleted file mode 100644 index 4bf48178..00000000 --- a/scrapegraphai/models/ollama.py +++ /dev/null @@ -1,17 +0,0 @@ -""" -Ollama Module -""" -from langchain_community.chat_models import ChatOllama - - -class Ollama(ChatOllama): - """ - A wrapper for the ChatOllama class that provides default configuration - and could be extended with additional methods if needed. - - Args: - llm_config (dict): Configuration parameters for the language model. - """ - - def __init__(self, llm_config: dict): - super().__init__(**llm_config)