feat: add fireworks integration

This commit is contained in:
Marco Vinciguerra 2024-06-24 23:11:28 +02:00
parent 79a2f51c34
commit df0e310829
10 changed files with 149 additions and 8 deletions

View File

@ -0,0 +1 @@
FIREWORKS_APIKEY="your fireworks api key"

View File

@ -0,0 +1,52 @@
"""
Basic example of scraping pipeline using SmartScraper
"""
import os, json
from dotenv import load_dotenv
from scrapegraphai.graphs import SmartScraperGraph
from scrapegraphai.utils import prettify_exec_info
load_dotenv()
# ************************************************
# Define the configuration for the graph
# ************************************************
fireworks_api_key = os.getenv("FIREWORKS_APIKEY")
graph_config = {
"llm": {
"api_key": fireworks_api_key,
"model": "fireworks/accounts/fireworks/models/mixtral-8x7b-instruct"
},
"embeddings": {
"model": "ollama/nomic-embed-text",
"temperature": 0,
# "base_url": "http://localhost:11434", # set ollama URL arbitrarily
},
"verbose": True,
"headless": False,
}
# ************************************************
# Create the SmartScraperGraph instance and run it
# ************************************************
smart_scraper_graph = SmartScraperGraph(
prompt="List me all the projects with their description",
# also accepts a string with the already downloaded HTML code
source="https://perinim.github.io/projects/",
config=graph_config,
)
result = smart_scraper_graph.run()
print(json.dumps(result, indent=4))
# ************************************************
# Get graph execution info
# ************************************************
graph_exec_info = smart_scraper_graph.get_execution_info()
print(prettify_exec_info(graph_exec_info))

View File

@ -33,6 +33,7 @@ dependencies = [
"google==3.0.0", "google==3.0.0",
"undetected-playwright==0.3.0", "undetected-playwright==0.3.0",
"semchunk==1.0.1", "semchunk==1.0.1",
"langchain-fireworks==0.1.3"
] ]
license = "MIT" license = "MIT"

View File

@ -13,6 +13,7 @@ aiofiles==23.2.1
aiohttp==3.9.5 aiohttp==3.9.5
# via langchain # via langchain
# via langchain-community # via langchain-community
# via langchain-fireworks
aiosignal==1.3.1 aiosignal==1.3.1
# via aiohttp # via aiohttp
alabaster==0.7.16 alabaster==0.7.16
@ -93,6 +94,8 @@ fastapi-pagination==0.12.24
# via burr # via burr
filelock==3.14.0 filelock==3.14.0
# via huggingface-hub # via huggingface-hub
fireworks-ai==0.14.0
# via langchain-fireworks
fonttools==4.52.1 fonttools==4.52.1
# via matplotlib # via matplotlib
free-proxy==1.1.1 free-proxy==1.1.1
@ -158,8 +161,11 @@ httptools==0.6.1
httpx==0.27.0 httpx==0.27.0
# via anthropic # via anthropic
# via fastapi # via fastapi
# via fireworks-ai
# via groq # via groq
# via openai # via openai
httpx-sse==0.4.0
# via fireworks-ai
huggingface-hub==0.23.1 huggingface-hub==0.23.1
# via tokenizers # via tokenizers
idna==3.7 idna==3.7
@ -207,10 +213,13 @@ langchain-core==0.1.52
# via langchain-anthropic # via langchain-anthropic
# via langchain-aws # via langchain-aws
# via langchain-community # via langchain-community
# via langchain-fireworks
# via langchain-google-genai # via langchain-google-genai
# via langchain-groq # via langchain-groq
# via langchain-openai # via langchain-openai
# via langchain-text-splitters # via langchain-text-splitters
langchain-fireworks==0.1.3
# via scrapegraphai
langchain-google-genai==1.0.3 langchain-google-genai==1.0.3
# via scrapegraphai # via scrapegraphai
langchain-groq==0.1.3 langchain-groq==0.1.3
@ -259,6 +268,7 @@ numpy==1.26.4
# via streamlit # via streamlit
openai==1.30.3 openai==1.30.3
# via burr # via burr
# via langchain-fireworks
# via langchain-openai # via langchain-openai
orjson==3.10.3 orjson==3.10.3
# via fastapi # via fastapi
@ -278,6 +288,7 @@ pandas==2.2.2
# via sf-hamilton # via sf-hamilton
# via streamlit # via streamlit
pillow==10.3.0 pillow==10.3.0
# via fireworks-ai
# via matplotlib # via matplotlib
# via streamlit # via streamlit
playwright==1.43.0 playwright==1.43.0
@ -308,6 +319,7 @@ pydantic==2.7.1
# via burr # via burr
# via fastapi # via fastapi
# via fastapi-pagination # via fastapi-pagination
# via fireworks-ai
# via google-generativeai # via google-generativeai
# via groq # via groq
# via langchain # via langchain
@ -359,6 +371,7 @@ requests==2.32.2
# via huggingface-hub # via huggingface-hub
# via langchain # via langchain
# via langchain-community # via langchain-community
# via langchain-fireworks
# via langsmith # via langsmith
# via sphinx # via sphinx
# via streamlit # via streamlit

View File

@ -11,6 +11,7 @@
aiohttp==3.9.5 aiohttp==3.9.5
# via langchain # via langchain
# via langchain-community # via langchain-community
# via langchain-fireworks
aiosignal==1.3.1 aiosignal==1.3.1
# via aiohttp # via aiohttp
annotated-types==0.7.0 annotated-types==0.7.0
@ -53,6 +54,8 @@ faiss-cpu==1.8.0
# via scrapegraphai # via scrapegraphai
filelock==3.14.0 filelock==3.14.0
# via huggingface-hub # via huggingface-hub
fireworks-ai==0.14.0
# via langchain-fireworks
free-proxy==1.1.1 free-proxy==1.1.1
# via scrapegraphai # via scrapegraphai
frozenlist==1.4.1 frozenlist==1.4.1
@ -105,8 +108,11 @@ httplib2==0.22.0
# via google-auth-httplib2 # via google-auth-httplib2
httpx==0.27.0 httpx==0.27.0
# via anthropic # via anthropic
# via fireworks-ai
# via groq # via groq
# via openai # via openai
httpx-sse==0.4.0
# via fireworks-ai
huggingface-hub==0.23.1 huggingface-hub==0.23.1
# via tokenizers # via tokenizers
idna==3.7 idna==3.7
@ -137,10 +143,13 @@ langchain-core==0.1.52
# via langchain-anthropic # via langchain-anthropic
# via langchain-aws # via langchain-aws
# via langchain-community # via langchain-community
# via langchain-fireworks
# via langchain-google-genai # via langchain-google-genai
# via langchain-groq # via langchain-groq
# via langchain-openai # via langchain-openai
# via langchain-text-splitters # via langchain-text-splitters
langchain-fireworks==0.1.3
# via scrapegraphai
langchain-google-genai==1.0.3 langchain-google-genai==1.0.3
# via scrapegraphai # via scrapegraphai
langchain-groq==0.1.3 langchain-groq==0.1.3
@ -171,6 +180,7 @@ numpy==1.26.4
# via langchain-community # via langchain-community
# via pandas # via pandas
openai==1.30.3 openai==1.30.3
# via langchain-fireworks
# via langchain-openai # via langchain-openai
orjson==3.10.3 orjson==3.10.3
# via langsmith # via langsmith
@ -180,6 +190,8 @@ packaging==23.2
# via marshmallow # via marshmallow
pandas==2.2.2 pandas==2.2.2
# via scrapegraphai # via scrapegraphai
pillow==10.3.0
# via fireworks-ai
playwright==1.43.0 playwright==1.43.0
# via scrapegraphai # via scrapegraphai
# via undetected-playwright # via undetected-playwright
@ -200,6 +212,7 @@ pyasn1-modules==0.4.0
# via google-auth # via google-auth
pydantic==2.7.1 pydantic==2.7.1
# via anthropic # via anthropic
# via fireworks-ai
# via google-generativeai # via google-generativeai
# via groq # via groq
# via langchain # via langchain
@ -232,6 +245,7 @@ requests==2.32.2
# via huggingface-hub # via huggingface-hub
# via langchain # via langchain
# via langchain-community # via langchain-community
# via langchain-fireworks
# via langsmith # via langsmith
# via tiktoken # via tiktoken
rsa==4.9 rsa==4.9

View File

@ -17,4 +17,5 @@ langchain-groq==0.1.3
playwright==1.43.0 playwright==1.43.0
langchain-aws==0.1.2 langchain-aws==0.1.2
undetected-playwright==0.3.0 undetected-playwright==0.3.0
semchunk==1.0.1 semchunk==1.0.1
langchain-fireworks==0.1.3

View File

@ -11,6 +11,7 @@ from langchain_aws import BedrockEmbeddings
from langchain_community.embeddings import HuggingFaceHubEmbeddings, OllamaEmbeddings from langchain_community.embeddings import HuggingFaceHubEmbeddings, OllamaEmbeddings
from langchain_google_genai import GoogleGenerativeAIEmbeddings from langchain_google_genai import GoogleGenerativeAIEmbeddings
from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings
from langchain_fireworks import FireworksEmbeddings
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
from ..helpers import models_tokens from ..helpers import models_tokens
@ -23,7 +24,8 @@ from ..models import (
HuggingFace, HuggingFace,
Ollama, Ollama,
OpenAI, OpenAI,
OneApi OneApi,
Fireworks
) )
from ..models.ernie import Ernie from ..models.ernie import Ernie
from ..utils.logging import set_verbosity_debug, set_verbosity_warning, set_verbosity_info from ..utils.logging import set_verbosity_debug, set_verbosity_warning, set_verbosity_info
@ -102,7 +104,7 @@ class AbstractGraph(ABC):
"embedder_model": self.embedder_model, "embedder_model": self.embedder_model,
"cache_path": self.cache_path, "cache_path": self.cache_path,
} }
self.set_common_params(common_params, overwrite=True) self.set_common_params(common_params, overwrite=True)
# set burr config # set burr config
@ -125,7 +127,7 @@ class AbstractGraph(ABC):
for node in self.graph.nodes: for node in self.graph.nodes:
node.update_config(params, overwrite) node.update_config(params, overwrite)
def _create_llm(self, llm_config: dict, chat=False) -> object: def _create_llm(self, llm_config: dict, chat=False) -> object:
""" """
Create a large language model instance based on the configuration provided. Create a large language model instance based on the configuration provided.
@ -160,8 +162,15 @@ class AbstractGraph(ABC):
try: try:
self.model_token = models_tokens["oneapi"][llm_params["model"]] self.model_token = models_tokens["oneapi"][llm_params["model"]]
except KeyError as exc: except KeyError as exc:
raise KeyError("Model Model not supported") from exc raise KeyError("Model not supported") from exc
return OneApi(llm_params) return OneApi(llm_params)
elif "fireworks" in llm_params["model"]:
try:
self.model_token = models_tokens["fireworks"][llm_params["model"].split("/")[-1]]
llm_params["model"] = "/".join(llm_params["model"].split("/")[1:])
except KeyError as exc:
raise KeyError("Model not supported") from exc
return Fireworks(llm_params)
elif "azure" in llm_params["model"]: elif "azure" in llm_params["model"]:
# take the model after the last dash # take the model after the last dash
llm_params["model"] = llm_params["model"].split("/")[-1] llm_params["model"] = llm_params["model"].split("/")[-1]
@ -172,12 +181,14 @@ class AbstractGraph(ABC):
return AzureOpenAI(llm_params) return AzureOpenAI(llm_params)
elif "gemini" in llm_params["model"]: elif "gemini" in llm_params["model"]:
llm_params["model"] = llm_params["model"].split("/")[-1]
try: try:
self.model_token = models_tokens["gemini"][llm_params["model"]] self.model_token = models_tokens["gemini"][llm_params["model"]]
except KeyError as exc: except KeyError as exc:
raise KeyError("Model not supported") from exc raise KeyError("Model not supported") from exc
return Gemini(llm_params) return Gemini(llm_params)
elif llm_params["model"].startswith("claude"): elif llm_params["model"].startswith("claude"):
llm_params["model"] = llm_params["model"].split("/")[-1]
try: try:
self.model_token = models_tokens["claude"][llm_params["model"]] self.model_token = models_tokens["claude"][llm_params["model"]]
except KeyError as exc: except KeyError as exc:
@ -203,6 +214,7 @@ class AbstractGraph(ABC):
return Ollama(llm_params) return Ollama(llm_params)
elif "hugging_face" in llm_params["model"]: elif "hugging_face" in llm_params["model"]:
llm_params["model"] = llm_params["model"].split("/")[-1]
try: try:
self.model_token = models_tokens["hugging_face"][llm_params["model"]] self.model_token = models_tokens["hugging_face"][llm_params["model"]]
except KeyError: except KeyError:
@ -277,12 +289,13 @@ class AbstractGraph(ABC):
if isinstance(self.llm_model, OpenAI): if isinstance(self.llm_model, OpenAI):
return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key, base_url=self.llm_model.openai_api_base) return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key, base_url=self.llm_model.openai_api_base)
elif isinstance(self.llm_model, DeepSeek): elif isinstance(self.llm_model, DeepSeek):
return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key) return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key)
elif isinstance(self.llm_model, AzureOpenAIEmbeddings): elif isinstance(self.llm_model, AzureOpenAIEmbeddings):
return self.llm_model return self.llm_model
elif isinstance(self.llm_model, AzureOpenAI): elif isinstance(self.llm_model, AzureOpenAI):
return AzureOpenAIEmbeddings() return AzureOpenAIEmbeddings()
elif isinstance(self.llm_model, Fireworks):
return FireworksEmbeddings(model=self.llm_model.model_name)
elif isinstance(self.llm_model, Ollama): elif isinstance(self.llm_model, Ollama):
# unwrap the kwargs from the model whihc is a dict # unwrap the kwargs from the model whihc is a dict
params = self.llm_model._lc_kwargs params = self.llm_model._lc_kwargs
@ -333,6 +346,13 @@ class AbstractGraph(ABC):
except KeyError as exc: except KeyError as exc:
raise KeyError("Model not supported") from exc raise KeyError("Model not supported") from exc
return HuggingFaceHubEmbeddings(model=embedder_params["model"]) return HuggingFaceHubEmbeddings(model=embedder_params["model"])
elif "fireworks" in embedder_params["model"]:
embedder_params["model"] = "/".join(embedder_params["model"].split("/")[1:])
try:
models_tokens["fireworks"][embedder_params["model"]]
except KeyError as exc:
raise KeyError("Model not supported") from exc
return FireworksEmbeddings(model=embedder_params["model"])
elif "gemini" in embedder_params["model"]: elif "gemini" in embedder_params["model"]:
try: try:
models_tokens["gemini"][embedder_params["model"]] models_tokens["gemini"][embedder_params["model"]]

View File

@ -143,5 +143,10 @@ models_tokens = {
"ernie-bot-2-base-en": 4096, "ernie-bot-2-base-en": 4096,
"ernie-bot-2-base-en-zh": 4096, "ernie-bot-2-base-en-zh": 4096,
"ernie-bot-2-base-zh-en": 4096 "ernie-bot-2-base-zh-en": 4096
} },
"fireworks": {
"llama-v2-7b": 4096,
"mixtral-8x7b-instruct": 4096,
"nomic-ai/nomic-embed-text-v1.5": 8192
},
} }

View File

@ -14,3 +14,4 @@ from .bedrock import Bedrock
from .anthropic import Anthropic from .anthropic import Anthropic
from .deepseek import DeepSeek from .deepseek import DeepSeek
from .oneapi import OneApi from .oneapi import OneApi
from .fireworks import Fireworks

View File

@ -0,0 +1,33 @@
"""
Fireworks Module
"""
from langchain_fireworks import ChatFireworks
class Fireworks(ChatFireworks):
"""
Initializes the Fireworks class.
Args:
llm_config (dict): A dictionary containing configuration parameters for the LLM (required).
The specific keys and values will depend on the LLM implementation
used by the underlying `ChatFireworks` class. Consult its documentation
for details.
Raises:
ValueError: If required keys are missing from the llm_config dictionary.
"""
def __init__(self, llm_config: dict):
"""
Initializes the Fireworks class.
Args:
llm_config (dict): A dictionary containing configuration parameters for the LLM.
The specific keys and values will depend on the LLM implementation.
Raises:
ValueError: If required keys are missing from the llm_config dictionary.
"""
super().__init__(**llm_config)