From df0e3108299071b849d7e055bd11d72764d24f08 Mon Sep 17 00:00:00 2001 From: Marco Vinciguerra Date: Mon, 24 Jun 2024 23:11:28 +0200 Subject: [PATCH] feat: add fireworks integration --- examples/fireworks/.env.example | 1 + examples/fireworks/smart_scraper_fireworks.py | 52 +++++++++++++++++++ pyproject.toml | 1 + requirements-dev.lock | 13 +++++ requirements.lock | 14 +++++ requirements.txt | 3 +- scrapegraphai/graphs/abstract_graph.py | 32 +++++++++--- scrapegraphai/helpers/models_tokens.py | 7 ++- scrapegraphai/models/__init__.py | 1 + scrapegraphai/models/fireworks.py | 33 ++++++++++++ 10 files changed, 149 insertions(+), 8 deletions(-) create mode 100644 examples/fireworks/.env.example create mode 100644 examples/fireworks/smart_scraper_fireworks.py create mode 100644 scrapegraphai/models/fireworks.py diff --git a/examples/fireworks/.env.example b/examples/fireworks/.env.example new file mode 100644 index 00000000..ab200215 --- /dev/null +++ b/examples/fireworks/.env.example @@ -0,0 +1 @@ +FIREWORKS_APIKEY="your fireworks api key" diff --git a/examples/fireworks/smart_scraper_fireworks.py b/examples/fireworks/smart_scraper_fireworks.py new file mode 100644 index 00000000..40071d8f --- /dev/null +++ b/examples/fireworks/smart_scraper_fireworks.py @@ -0,0 +1,52 @@ +""" +Basic example of scraping pipeline using SmartScraper +""" + +import os, json +from dotenv import load_dotenv +from scrapegraphai.graphs import SmartScraperGraph +from scrapegraphai.utils import prettify_exec_info + +load_dotenv() + + +# ************************************************ +# Define the configuration for the graph +# ************************************************ + +fireworks_api_key = os.getenv("FIREWORKS_APIKEY") + +graph_config = { + "llm": { + "api_key": fireworks_api_key, + "model": "fireworks/accounts/fireworks/models/mixtral-8x7b-instruct" + }, + "embeddings": { + "model": "ollama/nomic-embed-text", + "temperature": 0, + # "base_url": "http://localhost:11434", # set ollama URL arbitrarily + }, + "verbose": True, + "headless": False, +} + +# ************************************************ +# Create the SmartScraperGraph instance and run it +# ************************************************ + +smart_scraper_graph = SmartScraperGraph( + prompt="List me all the projects with their description", + # also accepts a string with the already downloaded HTML code + source="https://perinim.github.io/projects/", + config=graph_config, +) + +result = smart_scraper_graph.run() +print(json.dumps(result, indent=4)) + +# ************************************************ +# Get graph execution info +# ************************************************ + +graph_exec_info = smart_scraper_graph.get_execution_info() +print(prettify_exec_info(graph_exec_info)) diff --git a/pyproject.toml b/pyproject.toml index 02114c26..0b296be9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,7 @@ dependencies = [ "google==3.0.0", "undetected-playwright==0.3.0", "semchunk==1.0.1", + "langchain-fireworks==0.1.3" ] license = "MIT" diff --git a/requirements-dev.lock b/requirements-dev.lock index 52c5faa4..963ceaa9 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -13,6 +13,7 @@ aiofiles==23.2.1 aiohttp==3.9.5 # via langchain # via langchain-community + # via langchain-fireworks aiosignal==1.3.1 # via aiohttp alabaster==0.7.16 @@ -93,6 +94,8 @@ fastapi-pagination==0.12.24 # via burr filelock==3.14.0 # via huggingface-hub +fireworks-ai==0.14.0 + # via langchain-fireworks fonttools==4.52.1 # via matplotlib free-proxy==1.1.1 @@ -158,8 +161,11 @@ httptools==0.6.1 httpx==0.27.0 # via anthropic # via fastapi + # via fireworks-ai # via groq # via openai +httpx-sse==0.4.0 + # via fireworks-ai huggingface-hub==0.23.1 # via tokenizers idna==3.7 @@ -207,10 +213,13 @@ langchain-core==0.1.52 # via langchain-anthropic # via langchain-aws # via langchain-community + # via langchain-fireworks # via langchain-google-genai # via langchain-groq # via langchain-openai # via langchain-text-splitters +langchain-fireworks==0.1.3 + # via scrapegraphai langchain-google-genai==1.0.3 # via scrapegraphai langchain-groq==0.1.3 @@ -259,6 +268,7 @@ numpy==1.26.4 # via streamlit openai==1.30.3 # via burr + # via langchain-fireworks # via langchain-openai orjson==3.10.3 # via fastapi @@ -278,6 +288,7 @@ pandas==2.2.2 # via sf-hamilton # via streamlit pillow==10.3.0 + # via fireworks-ai # via matplotlib # via streamlit playwright==1.43.0 @@ -308,6 +319,7 @@ pydantic==2.7.1 # via burr # via fastapi # via fastapi-pagination + # via fireworks-ai # via google-generativeai # via groq # via langchain @@ -359,6 +371,7 @@ requests==2.32.2 # via huggingface-hub # via langchain # via langchain-community + # via langchain-fireworks # via langsmith # via sphinx # via streamlit diff --git a/requirements.lock b/requirements.lock index 1dc6ef4f..a27966ba 100644 --- a/requirements.lock +++ b/requirements.lock @@ -11,6 +11,7 @@ aiohttp==3.9.5 # via langchain # via langchain-community + # via langchain-fireworks aiosignal==1.3.1 # via aiohttp annotated-types==0.7.0 @@ -53,6 +54,8 @@ faiss-cpu==1.8.0 # via scrapegraphai filelock==3.14.0 # via huggingface-hub +fireworks-ai==0.14.0 + # via langchain-fireworks free-proxy==1.1.1 # via scrapegraphai frozenlist==1.4.1 @@ -105,8 +108,11 @@ httplib2==0.22.0 # via google-auth-httplib2 httpx==0.27.0 # via anthropic + # via fireworks-ai # via groq # via openai +httpx-sse==0.4.0 + # via fireworks-ai huggingface-hub==0.23.1 # via tokenizers idna==3.7 @@ -137,10 +143,13 @@ langchain-core==0.1.52 # via langchain-anthropic # via langchain-aws # via langchain-community + # via langchain-fireworks # via langchain-google-genai # via langchain-groq # via langchain-openai # via langchain-text-splitters +langchain-fireworks==0.1.3 + # via scrapegraphai langchain-google-genai==1.0.3 # via scrapegraphai langchain-groq==0.1.3 @@ -171,6 +180,7 @@ numpy==1.26.4 # via langchain-community # via pandas openai==1.30.3 + # via langchain-fireworks # via langchain-openai orjson==3.10.3 # via langsmith @@ -180,6 +190,8 @@ packaging==23.2 # via marshmallow pandas==2.2.2 # via scrapegraphai +pillow==10.3.0 + # via fireworks-ai playwright==1.43.0 # via scrapegraphai # via undetected-playwright @@ -200,6 +212,7 @@ pyasn1-modules==0.4.0 # via google-auth pydantic==2.7.1 # via anthropic + # via fireworks-ai # via google-generativeai # via groq # via langchain @@ -232,6 +245,7 @@ requests==2.32.2 # via huggingface-hub # via langchain # via langchain-community + # via langchain-fireworks # via langsmith # via tiktoken rsa==4.9 diff --git a/requirements.txt b/requirements.txt index 46ae491a..d69066df 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,4 +17,5 @@ langchain-groq==0.1.3 playwright==1.43.0 langchain-aws==0.1.2 undetected-playwright==0.3.0 -semchunk==1.0.1 \ No newline at end of file +semchunk==1.0.1 +langchain-fireworks==0.1.3 diff --git a/scrapegraphai/graphs/abstract_graph.py b/scrapegraphai/graphs/abstract_graph.py index ccd3158a..c04b6efd 100644 --- a/scrapegraphai/graphs/abstract_graph.py +++ b/scrapegraphai/graphs/abstract_graph.py @@ -11,6 +11,7 @@ from langchain_aws import BedrockEmbeddings from langchain_community.embeddings import HuggingFaceHubEmbeddings, OllamaEmbeddings from langchain_google_genai import GoogleGenerativeAIEmbeddings from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings +from langchain_fireworks import FireworksEmbeddings from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings from ..helpers import models_tokens @@ -23,7 +24,8 @@ from ..models import ( HuggingFace, Ollama, OpenAI, - OneApi + OneApi, + Fireworks ) from ..models.ernie import Ernie from ..utils.logging import set_verbosity_debug, set_verbosity_warning, set_verbosity_info @@ -102,7 +104,7 @@ class AbstractGraph(ABC): "embedder_model": self.embedder_model, "cache_path": self.cache_path, } - + self.set_common_params(common_params, overwrite=True) # set burr config @@ -125,7 +127,7 @@ class AbstractGraph(ABC): for node in self.graph.nodes: node.update_config(params, overwrite) - + def _create_llm(self, llm_config: dict, chat=False) -> object: """ Create a large language model instance based on the configuration provided. @@ -160,8 +162,15 @@ class AbstractGraph(ABC): try: self.model_token = models_tokens["oneapi"][llm_params["model"]] except KeyError as exc: - raise KeyError("Model Model not supported") from exc + raise KeyError("Model not supported") from exc return OneApi(llm_params) + elif "fireworks" in llm_params["model"]: + try: + self.model_token = models_tokens["fireworks"][llm_params["model"].split("/")[-1]] + llm_params["model"] = "/".join(llm_params["model"].split("/")[1:]) + except KeyError as exc: + raise KeyError("Model not supported") from exc + return Fireworks(llm_params) elif "azure" in llm_params["model"]: # take the model after the last dash llm_params["model"] = llm_params["model"].split("/")[-1] @@ -172,12 +181,14 @@ class AbstractGraph(ABC): return AzureOpenAI(llm_params) elif "gemini" in llm_params["model"]: + llm_params["model"] = llm_params["model"].split("/")[-1] try: self.model_token = models_tokens["gemini"][llm_params["model"]] except KeyError as exc: raise KeyError("Model not supported") from exc return Gemini(llm_params) elif llm_params["model"].startswith("claude"): + llm_params["model"] = llm_params["model"].split("/")[-1] try: self.model_token = models_tokens["claude"][llm_params["model"]] except KeyError as exc: @@ -203,6 +214,7 @@ class AbstractGraph(ABC): return Ollama(llm_params) elif "hugging_face" in llm_params["model"]: + llm_params["model"] = llm_params["model"].split("/")[-1] try: self.model_token = models_tokens["hugging_face"][llm_params["model"]] except KeyError: @@ -277,12 +289,13 @@ class AbstractGraph(ABC): if isinstance(self.llm_model, OpenAI): return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key, base_url=self.llm_model.openai_api_base) elif isinstance(self.llm_model, DeepSeek): - return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key) - + return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key) elif isinstance(self.llm_model, AzureOpenAIEmbeddings): return self.llm_model elif isinstance(self.llm_model, AzureOpenAI): return AzureOpenAIEmbeddings() + elif isinstance(self.llm_model, Fireworks): + return FireworksEmbeddings(model=self.llm_model.model_name) elif isinstance(self.llm_model, Ollama): # unwrap the kwargs from the model whihc is a dict params = self.llm_model._lc_kwargs @@ -333,6 +346,13 @@ class AbstractGraph(ABC): except KeyError as exc: raise KeyError("Model not supported") from exc return HuggingFaceHubEmbeddings(model=embedder_params["model"]) + elif "fireworks" in embedder_params["model"]: + embedder_params["model"] = "/".join(embedder_params["model"].split("/")[1:]) + try: + models_tokens["fireworks"][embedder_params["model"]] + except KeyError as exc: + raise KeyError("Model not supported") from exc + return FireworksEmbeddings(model=embedder_params["model"]) elif "gemini" in embedder_params["model"]: try: models_tokens["gemini"][embedder_params["model"]] diff --git a/scrapegraphai/helpers/models_tokens.py b/scrapegraphai/helpers/models_tokens.py index 4cc88c04..c9b03f13 100644 --- a/scrapegraphai/helpers/models_tokens.py +++ b/scrapegraphai/helpers/models_tokens.py @@ -143,5 +143,10 @@ models_tokens = { "ernie-bot-2-base-en": 4096, "ernie-bot-2-base-en-zh": 4096, "ernie-bot-2-base-zh-en": 4096 - } + }, + "fireworks": { + "llama-v2-7b": 4096, + "mixtral-8x7b-instruct": 4096, + "nomic-ai/nomic-embed-text-v1.5": 8192 + }, } diff --git a/scrapegraphai/models/__init__.py b/scrapegraphai/models/__init__.py index 0a1ad2af..6c90dc0f 100644 --- a/scrapegraphai/models/__init__.py +++ b/scrapegraphai/models/__init__.py @@ -14,3 +14,4 @@ from .bedrock import Bedrock from .anthropic import Anthropic from .deepseek import DeepSeek from .oneapi import OneApi +from .fireworks import Fireworks diff --git a/scrapegraphai/models/fireworks.py b/scrapegraphai/models/fireworks.py new file mode 100644 index 00000000..445c4846 --- /dev/null +++ b/scrapegraphai/models/fireworks.py @@ -0,0 +1,33 @@ +""" +Fireworks Module +""" +from langchain_fireworks import ChatFireworks + + +class Fireworks(ChatFireworks): + """ + Initializes the Fireworks class. + + Args: + llm_config (dict): A dictionary containing configuration parameters for the LLM (required). + The specific keys and values will depend on the LLM implementation + used by the underlying `ChatFireworks` class. Consult its documentation + for details. + + Raises: + ValueError: If required keys are missing from the llm_config dictionary. + """ + + def __init__(self, llm_config: dict): + """ + Initializes the Fireworks class. + + Args: + llm_config (dict): A dictionary containing configuration parameters for the LLM. + The specific keys and values will depend on the LLM implementation. + + Raises: + ValueError: If required keys are missing from the llm_config dictionary. + """ + + super().__init__(**llm_config)