feat: add fireworks integration

This commit is contained in:
Marco Vinciguerra 2024-06-24 23:11:28 +02:00
parent 79a2f51c34
commit df0e310829
10 changed files with 149 additions and 8 deletions

View File

@ -0,0 +1 @@
FIREWORKS_APIKEY="your fireworks api key"

View File

@ -0,0 +1,52 @@
"""
Basic example of scraping pipeline using SmartScraper
"""
import os, json
from dotenv import load_dotenv
from scrapegraphai.graphs import SmartScraperGraph
from scrapegraphai.utils import prettify_exec_info
load_dotenv()
# ************************************************
# Define the configuration for the graph
# ************************************************
fireworks_api_key = os.getenv("FIREWORKS_APIKEY")
graph_config = {
"llm": {
"api_key": fireworks_api_key,
"model": "fireworks/accounts/fireworks/models/mixtral-8x7b-instruct"
},
"embeddings": {
"model": "ollama/nomic-embed-text",
"temperature": 0,
# "base_url": "http://localhost:11434", # set ollama URL arbitrarily
},
"verbose": True,
"headless": False,
}
# ************************************************
# Create the SmartScraperGraph instance and run it
# ************************************************
smart_scraper_graph = SmartScraperGraph(
prompt="List me all the projects with their description",
# also accepts a string with the already downloaded HTML code
source="https://perinim.github.io/projects/",
config=graph_config,
)
result = smart_scraper_graph.run()
print(json.dumps(result, indent=4))
# ************************************************
# Get graph execution info
# ************************************************
graph_exec_info = smart_scraper_graph.get_execution_info()
print(prettify_exec_info(graph_exec_info))

View File

@ -33,6 +33,7 @@ dependencies = [
"google==3.0.0",
"undetected-playwright==0.3.0",
"semchunk==1.0.1",
"langchain-fireworks==0.1.3"
]
license = "MIT"

View File

@ -13,6 +13,7 @@ aiofiles==23.2.1
aiohttp==3.9.5
# via langchain
# via langchain-community
# via langchain-fireworks
aiosignal==1.3.1
# via aiohttp
alabaster==0.7.16
@ -93,6 +94,8 @@ fastapi-pagination==0.12.24
# via burr
filelock==3.14.0
# via huggingface-hub
fireworks-ai==0.14.0
# via langchain-fireworks
fonttools==4.52.1
# via matplotlib
free-proxy==1.1.1
@ -158,8 +161,11 @@ httptools==0.6.1
httpx==0.27.0
# via anthropic
# via fastapi
# via fireworks-ai
# via groq
# via openai
httpx-sse==0.4.0
# via fireworks-ai
huggingface-hub==0.23.1
# via tokenizers
idna==3.7
@ -207,10 +213,13 @@ langchain-core==0.1.52
# via langchain-anthropic
# via langchain-aws
# via langchain-community
# via langchain-fireworks
# via langchain-google-genai
# via langchain-groq
# via langchain-openai
# via langchain-text-splitters
langchain-fireworks==0.1.3
# via scrapegraphai
langchain-google-genai==1.0.3
# via scrapegraphai
langchain-groq==0.1.3
@ -259,6 +268,7 @@ numpy==1.26.4
# via streamlit
openai==1.30.3
# via burr
# via langchain-fireworks
# via langchain-openai
orjson==3.10.3
# via fastapi
@ -278,6 +288,7 @@ pandas==2.2.2
# via sf-hamilton
# via streamlit
pillow==10.3.0
# via fireworks-ai
# via matplotlib
# via streamlit
playwright==1.43.0
@ -308,6 +319,7 @@ pydantic==2.7.1
# via burr
# via fastapi
# via fastapi-pagination
# via fireworks-ai
# via google-generativeai
# via groq
# via langchain
@ -359,6 +371,7 @@ requests==2.32.2
# via huggingface-hub
# via langchain
# via langchain-community
# via langchain-fireworks
# via langsmith
# via sphinx
# via streamlit

View File

@ -11,6 +11,7 @@
aiohttp==3.9.5
# via langchain
# via langchain-community
# via langchain-fireworks
aiosignal==1.3.1
# via aiohttp
annotated-types==0.7.0
@ -53,6 +54,8 @@ faiss-cpu==1.8.0
# via scrapegraphai
filelock==3.14.0
# via huggingface-hub
fireworks-ai==0.14.0
# via langchain-fireworks
free-proxy==1.1.1
# via scrapegraphai
frozenlist==1.4.1
@ -105,8 +108,11 @@ httplib2==0.22.0
# via google-auth-httplib2
httpx==0.27.0
# via anthropic
# via fireworks-ai
# via groq
# via openai
httpx-sse==0.4.0
# via fireworks-ai
huggingface-hub==0.23.1
# via tokenizers
idna==3.7
@ -137,10 +143,13 @@ langchain-core==0.1.52
# via langchain-anthropic
# via langchain-aws
# via langchain-community
# via langchain-fireworks
# via langchain-google-genai
# via langchain-groq
# via langchain-openai
# via langchain-text-splitters
langchain-fireworks==0.1.3
# via scrapegraphai
langchain-google-genai==1.0.3
# via scrapegraphai
langchain-groq==0.1.3
@ -171,6 +180,7 @@ numpy==1.26.4
# via langchain-community
# via pandas
openai==1.30.3
# via langchain-fireworks
# via langchain-openai
orjson==3.10.3
# via langsmith
@ -180,6 +190,8 @@ packaging==23.2
# via marshmallow
pandas==2.2.2
# via scrapegraphai
pillow==10.3.0
# via fireworks-ai
playwright==1.43.0
# via scrapegraphai
# via undetected-playwright
@ -200,6 +212,7 @@ pyasn1-modules==0.4.0
# via google-auth
pydantic==2.7.1
# via anthropic
# via fireworks-ai
# via google-generativeai
# via groq
# via langchain
@ -232,6 +245,7 @@ requests==2.32.2
# via huggingface-hub
# via langchain
# via langchain-community
# via langchain-fireworks
# via langsmith
# via tiktoken
rsa==4.9

View File

@ -17,4 +17,5 @@ langchain-groq==0.1.3
playwright==1.43.0
langchain-aws==0.1.2
undetected-playwright==0.3.0
semchunk==1.0.1
semchunk==1.0.1
langchain-fireworks==0.1.3

View File

@ -11,6 +11,7 @@ from langchain_aws import BedrockEmbeddings
from langchain_community.embeddings import HuggingFaceHubEmbeddings, OllamaEmbeddings
from langchain_google_genai import GoogleGenerativeAIEmbeddings
from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings
from langchain_fireworks import FireworksEmbeddings
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
from ..helpers import models_tokens
@ -23,7 +24,8 @@ from ..models import (
HuggingFace,
Ollama,
OpenAI,
OneApi
OneApi,
Fireworks
)
from ..models.ernie import Ernie
from ..utils.logging import set_verbosity_debug, set_verbosity_warning, set_verbosity_info
@ -102,7 +104,7 @@ class AbstractGraph(ABC):
"embedder_model": self.embedder_model,
"cache_path": self.cache_path,
}
self.set_common_params(common_params, overwrite=True)
# set burr config
@ -125,7 +127,7 @@ class AbstractGraph(ABC):
for node in self.graph.nodes:
node.update_config(params, overwrite)
def _create_llm(self, llm_config: dict, chat=False) -> object:
"""
Create a large language model instance based on the configuration provided.
@ -160,8 +162,15 @@ class AbstractGraph(ABC):
try:
self.model_token = models_tokens["oneapi"][llm_params["model"]]
except KeyError as exc:
raise KeyError("Model Model not supported") from exc
raise KeyError("Model not supported") from exc
return OneApi(llm_params)
elif "fireworks" in llm_params["model"]:
try:
self.model_token = models_tokens["fireworks"][llm_params["model"].split("/")[-1]]
llm_params["model"] = "/".join(llm_params["model"].split("/")[1:])
except KeyError as exc:
raise KeyError("Model not supported") from exc
return Fireworks(llm_params)
elif "azure" in llm_params["model"]:
# take the model after the last dash
llm_params["model"] = llm_params["model"].split("/")[-1]
@ -172,12 +181,14 @@ class AbstractGraph(ABC):
return AzureOpenAI(llm_params)
elif "gemini" in llm_params["model"]:
llm_params["model"] = llm_params["model"].split("/")[-1]
try:
self.model_token = models_tokens["gemini"][llm_params["model"]]
except KeyError as exc:
raise KeyError("Model not supported") from exc
return Gemini(llm_params)
elif llm_params["model"].startswith("claude"):
llm_params["model"] = llm_params["model"].split("/")[-1]
try:
self.model_token = models_tokens["claude"][llm_params["model"]]
except KeyError as exc:
@ -203,6 +214,7 @@ class AbstractGraph(ABC):
return Ollama(llm_params)
elif "hugging_face" in llm_params["model"]:
llm_params["model"] = llm_params["model"].split("/")[-1]
try:
self.model_token = models_tokens["hugging_face"][llm_params["model"]]
except KeyError:
@ -277,12 +289,13 @@ class AbstractGraph(ABC):
if isinstance(self.llm_model, OpenAI):
return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key, base_url=self.llm_model.openai_api_base)
elif isinstance(self.llm_model, DeepSeek):
return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key)
return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key)
elif isinstance(self.llm_model, AzureOpenAIEmbeddings):
return self.llm_model
elif isinstance(self.llm_model, AzureOpenAI):
return AzureOpenAIEmbeddings()
elif isinstance(self.llm_model, Fireworks):
return FireworksEmbeddings(model=self.llm_model.model_name)
elif isinstance(self.llm_model, Ollama):
# unwrap the kwargs from the model whihc is a dict
params = self.llm_model._lc_kwargs
@ -333,6 +346,13 @@ class AbstractGraph(ABC):
except KeyError as exc:
raise KeyError("Model not supported") from exc
return HuggingFaceHubEmbeddings(model=embedder_params["model"])
elif "fireworks" in embedder_params["model"]:
embedder_params["model"] = "/".join(embedder_params["model"].split("/")[1:])
try:
models_tokens["fireworks"][embedder_params["model"]]
except KeyError as exc:
raise KeyError("Model not supported") from exc
return FireworksEmbeddings(model=embedder_params["model"])
elif "gemini" in embedder_params["model"]:
try:
models_tokens["gemini"][embedder_params["model"]]

View File

@ -143,5 +143,10 @@ models_tokens = {
"ernie-bot-2-base-en": 4096,
"ernie-bot-2-base-en-zh": 4096,
"ernie-bot-2-base-zh-en": 4096
}
},
"fireworks": {
"llama-v2-7b": 4096,
"mixtral-8x7b-instruct": 4096,
"nomic-ai/nomic-embed-text-v1.5": 8192
},
}

View File

@ -14,3 +14,4 @@ from .bedrock import Bedrock
from .anthropic import Anthropic
from .deepseek import DeepSeek
from .oneapi import OneApi
from .fireworks import Fireworks

View File

@ -0,0 +1,33 @@
"""
Fireworks Module
"""
from langchain_fireworks import ChatFireworks
class Fireworks(ChatFireworks):
"""
Initializes the Fireworks class.
Args:
llm_config (dict): A dictionary containing configuration parameters for the LLM (required).
The specific keys and values will depend on the LLM implementation
used by the underlying `ChatFireworks` class. Consult its documentation
for details.
Raises:
ValueError: If required keys are missing from the llm_config dictionary.
"""
def __init__(self, llm_config: dict):
"""
Initializes the Fireworks class.
Args:
llm_config (dict): A dictionary containing configuration parameters for the LLM.
The specific keys and values will depend on the LLM implementation.
Raises:
ValueError: If required keys are missing from the llm_config dictionary.
"""
super().__init__(**llm_config)