mirror of
https://github.com/VinciGit00/Scrapegraph-ai.git
synced 2026-07-01 21:00:48 +08:00
feat: add fireworks integration
This commit is contained in:
parent
79a2f51c34
commit
df0e310829
1
examples/fireworks/.env.example
Normal file
1
examples/fireworks/.env.example
Normal file
@ -0,0 +1 @@
|
||||
FIREWORKS_APIKEY="your fireworks api key"
|
||||
52
examples/fireworks/smart_scraper_fireworks.py
Normal file
52
examples/fireworks/smart_scraper_fireworks.py
Normal file
@ -0,0 +1,52 @@
|
||||
"""
|
||||
Basic example of scraping pipeline using SmartScraper
|
||||
"""
|
||||
|
||||
import os, json
|
||||
from dotenv import load_dotenv
|
||||
from scrapegraphai.graphs import SmartScraperGraph
|
||||
from scrapegraphai.utils import prettify_exec_info
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
# ************************************************
|
||||
# Define the configuration for the graph
|
||||
# ************************************************
|
||||
|
||||
fireworks_api_key = os.getenv("FIREWORKS_APIKEY")
|
||||
|
||||
graph_config = {
|
||||
"llm": {
|
||||
"api_key": fireworks_api_key,
|
||||
"model": "fireworks/accounts/fireworks/models/mixtral-8x7b-instruct"
|
||||
},
|
||||
"embeddings": {
|
||||
"model": "ollama/nomic-embed-text",
|
||||
"temperature": 0,
|
||||
# "base_url": "http://localhost:11434", # set ollama URL arbitrarily
|
||||
},
|
||||
"verbose": True,
|
||||
"headless": False,
|
||||
}
|
||||
|
||||
# ************************************************
|
||||
# Create the SmartScraperGraph instance and run it
|
||||
# ************************************************
|
||||
|
||||
smart_scraper_graph = SmartScraperGraph(
|
||||
prompt="List me all the projects with their description",
|
||||
# also accepts a string with the already downloaded HTML code
|
||||
source="https://perinim.github.io/projects/",
|
||||
config=graph_config,
|
||||
)
|
||||
|
||||
result = smart_scraper_graph.run()
|
||||
print(json.dumps(result, indent=4))
|
||||
|
||||
# ************************************************
|
||||
# Get graph execution info
|
||||
# ************************************************
|
||||
|
||||
graph_exec_info = smart_scraper_graph.get_execution_info()
|
||||
print(prettify_exec_info(graph_exec_info))
|
||||
@ -33,6 +33,7 @@ dependencies = [
|
||||
"google==3.0.0",
|
||||
"undetected-playwright==0.3.0",
|
||||
"semchunk==1.0.1",
|
||||
"langchain-fireworks==0.1.3"
|
||||
]
|
||||
|
||||
license = "MIT"
|
||||
|
||||
@ -13,6 +13,7 @@ aiofiles==23.2.1
|
||||
aiohttp==3.9.5
|
||||
# via langchain
|
||||
# via langchain-community
|
||||
# via langchain-fireworks
|
||||
aiosignal==1.3.1
|
||||
# via aiohttp
|
||||
alabaster==0.7.16
|
||||
@ -93,6 +94,8 @@ fastapi-pagination==0.12.24
|
||||
# via burr
|
||||
filelock==3.14.0
|
||||
# via huggingface-hub
|
||||
fireworks-ai==0.14.0
|
||||
# via langchain-fireworks
|
||||
fonttools==4.52.1
|
||||
# via matplotlib
|
||||
free-proxy==1.1.1
|
||||
@ -158,8 +161,11 @@ httptools==0.6.1
|
||||
httpx==0.27.0
|
||||
# via anthropic
|
||||
# via fastapi
|
||||
# via fireworks-ai
|
||||
# via groq
|
||||
# via openai
|
||||
httpx-sse==0.4.0
|
||||
# via fireworks-ai
|
||||
huggingface-hub==0.23.1
|
||||
# via tokenizers
|
||||
idna==3.7
|
||||
@ -207,10 +213,13 @@ langchain-core==0.1.52
|
||||
# via langchain-anthropic
|
||||
# via langchain-aws
|
||||
# via langchain-community
|
||||
# via langchain-fireworks
|
||||
# via langchain-google-genai
|
||||
# via langchain-groq
|
||||
# via langchain-openai
|
||||
# via langchain-text-splitters
|
||||
langchain-fireworks==0.1.3
|
||||
# via scrapegraphai
|
||||
langchain-google-genai==1.0.3
|
||||
# via scrapegraphai
|
||||
langchain-groq==0.1.3
|
||||
@ -259,6 +268,7 @@ numpy==1.26.4
|
||||
# via streamlit
|
||||
openai==1.30.3
|
||||
# via burr
|
||||
# via langchain-fireworks
|
||||
# via langchain-openai
|
||||
orjson==3.10.3
|
||||
# via fastapi
|
||||
@ -278,6 +288,7 @@ pandas==2.2.2
|
||||
# via sf-hamilton
|
||||
# via streamlit
|
||||
pillow==10.3.0
|
||||
# via fireworks-ai
|
||||
# via matplotlib
|
||||
# via streamlit
|
||||
playwright==1.43.0
|
||||
@ -308,6 +319,7 @@ pydantic==2.7.1
|
||||
# via burr
|
||||
# via fastapi
|
||||
# via fastapi-pagination
|
||||
# via fireworks-ai
|
||||
# via google-generativeai
|
||||
# via groq
|
||||
# via langchain
|
||||
@ -359,6 +371,7 @@ requests==2.32.2
|
||||
# via huggingface-hub
|
||||
# via langchain
|
||||
# via langchain-community
|
||||
# via langchain-fireworks
|
||||
# via langsmith
|
||||
# via sphinx
|
||||
# via streamlit
|
||||
|
||||
@ -11,6 +11,7 @@
|
||||
aiohttp==3.9.5
|
||||
# via langchain
|
||||
# via langchain-community
|
||||
# via langchain-fireworks
|
||||
aiosignal==1.3.1
|
||||
# via aiohttp
|
||||
annotated-types==0.7.0
|
||||
@ -53,6 +54,8 @@ faiss-cpu==1.8.0
|
||||
# via scrapegraphai
|
||||
filelock==3.14.0
|
||||
# via huggingface-hub
|
||||
fireworks-ai==0.14.0
|
||||
# via langchain-fireworks
|
||||
free-proxy==1.1.1
|
||||
# via scrapegraphai
|
||||
frozenlist==1.4.1
|
||||
@ -105,8 +108,11 @@ httplib2==0.22.0
|
||||
# via google-auth-httplib2
|
||||
httpx==0.27.0
|
||||
# via anthropic
|
||||
# via fireworks-ai
|
||||
# via groq
|
||||
# via openai
|
||||
httpx-sse==0.4.0
|
||||
# via fireworks-ai
|
||||
huggingface-hub==0.23.1
|
||||
# via tokenizers
|
||||
idna==3.7
|
||||
@ -137,10 +143,13 @@ langchain-core==0.1.52
|
||||
# via langchain-anthropic
|
||||
# via langchain-aws
|
||||
# via langchain-community
|
||||
# via langchain-fireworks
|
||||
# via langchain-google-genai
|
||||
# via langchain-groq
|
||||
# via langchain-openai
|
||||
# via langchain-text-splitters
|
||||
langchain-fireworks==0.1.3
|
||||
# via scrapegraphai
|
||||
langchain-google-genai==1.0.3
|
||||
# via scrapegraphai
|
||||
langchain-groq==0.1.3
|
||||
@ -171,6 +180,7 @@ numpy==1.26.4
|
||||
# via langchain-community
|
||||
# via pandas
|
||||
openai==1.30.3
|
||||
# via langchain-fireworks
|
||||
# via langchain-openai
|
||||
orjson==3.10.3
|
||||
# via langsmith
|
||||
@ -180,6 +190,8 @@ packaging==23.2
|
||||
# via marshmallow
|
||||
pandas==2.2.2
|
||||
# via scrapegraphai
|
||||
pillow==10.3.0
|
||||
# via fireworks-ai
|
||||
playwright==1.43.0
|
||||
# via scrapegraphai
|
||||
# via undetected-playwright
|
||||
@ -200,6 +212,7 @@ pyasn1-modules==0.4.0
|
||||
# via google-auth
|
||||
pydantic==2.7.1
|
||||
# via anthropic
|
||||
# via fireworks-ai
|
||||
# via google-generativeai
|
||||
# via groq
|
||||
# via langchain
|
||||
@ -232,6 +245,7 @@ requests==2.32.2
|
||||
# via huggingface-hub
|
||||
# via langchain
|
||||
# via langchain-community
|
||||
# via langchain-fireworks
|
||||
# via langsmith
|
||||
# via tiktoken
|
||||
rsa==4.9
|
||||
|
||||
@ -17,4 +17,5 @@ langchain-groq==0.1.3
|
||||
playwright==1.43.0
|
||||
langchain-aws==0.1.2
|
||||
undetected-playwright==0.3.0
|
||||
semchunk==1.0.1
|
||||
semchunk==1.0.1
|
||||
langchain-fireworks==0.1.3
|
||||
|
||||
@ -11,6 +11,7 @@ from langchain_aws import BedrockEmbeddings
|
||||
from langchain_community.embeddings import HuggingFaceHubEmbeddings, OllamaEmbeddings
|
||||
from langchain_google_genai import GoogleGenerativeAIEmbeddings
|
||||
from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings
|
||||
from langchain_fireworks import FireworksEmbeddings
|
||||
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
|
||||
|
||||
from ..helpers import models_tokens
|
||||
@ -23,7 +24,8 @@ from ..models import (
|
||||
HuggingFace,
|
||||
Ollama,
|
||||
OpenAI,
|
||||
OneApi
|
||||
OneApi,
|
||||
Fireworks
|
||||
)
|
||||
from ..models.ernie import Ernie
|
||||
from ..utils.logging import set_verbosity_debug, set_verbosity_warning, set_verbosity_info
|
||||
@ -102,7 +104,7 @@ class AbstractGraph(ABC):
|
||||
"embedder_model": self.embedder_model,
|
||||
"cache_path": self.cache_path,
|
||||
}
|
||||
|
||||
|
||||
self.set_common_params(common_params, overwrite=True)
|
||||
|
||||
# set burr config
|
||||
@ -125,7 +127,7 @@ class AbstractGraph(ABC):
|
||||
|
||||
for node in self.graph.nodes:
|
||||
node.update_config(params, overwrite)
|
||||
|
||||
|
||||
def _create_llm(self, llm_config: dict, chat=False) -> object:
|
||||
"""
|
||||
Create a large language model instance based on the configuration provided.
|
||||
@ -160,8 +162,15 @@ class AbstractGraph(ABC):
|
||||
try:
|
||||
self.model_token = models_tokens["oneapi"][llm_params["model"]]
|
||||
except KeyError as exc:
|
||||
raise KeyError("Model Model not supported") from exc
|
||||
raise KeyError("Model not supported") from exc
|
||||
return OneApi(llm_params)
|
||||
elif "fireworks" in llm_params["model"]:
|
||||
try:
|
||||
self.model_token = models_tokens["fireworks"][llm_params["model"].split("/")[-1]]
|
||||
llm_params["model"] = "/".join(llm_params["model"].split("/")[1:])
|
||||
except KeyError as exc:
|
||||
raise KeyError("Model not supported") from exc
|
||||
return Fireworks(llm_params)
|
||||
elif "azure" in llm_params["model"]:
|
||||
# take the model after the last dash
|
||||
llm_params["model"] = llm_params["model"].split("/")[-1]
|
||||
@ -172,12 +181,14 @@ class AbstractGraph(ABC):
|
||||
return AzureOpenAI(llm_params)
|
||||
|
||||
elif "gemini" in llm_params["model"]:
|
||||
llm_params["model"] = llm_params["model"].split("/")[-1]
|
||||
try:
|
||||
self.model_token = models_tokens["gemini"][llm_params["model"]]
|
||||
except KeyError as exc:
|
||||
raise KeyError("Model not supported") from exc
|
||||
return Gemini(llm_params)
|
||||
elif llm_params["model"].startswith("claude"):
|
||||
llm_params["model"] = llm_params["model"].split("/")[-1]
|
||||
try:
|
||||
self.model_token = models_tokens["claude"][llm_params["model"]]
|
||||
except KeyError as exc:
|
||||
@ -203,6 +214,7 @@ class AbstractGraph(ABC):
|
||||
|
||||
return Ollama(llm_params)
|
||||
elif "hugging_face" in llm_params["model"]:
|
||||
llm_params["model"] = llm_params["model"].split("/")[-1]
|
||||
try:
|
||||
self.model_token = models_tokens["hugging_face"][llm_params["model"]]
|
||||
except KeyError:
|
||||
@ -277,12 +289,13 @@ class AbstractGraph(ABC):
|
||||
if isinstance(self.llm_model, OpenAI):
|
||||
return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key, base_url=self.llm_model.openai_api_base)
|
||||
elif isinstance(self.llm_model, DeepSeek):
|
||||
return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key)
|
||||
|
||||
return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key)
|
||||
elif isinstance(self.llm_model, AzureOpenAIEmbeddings):
|
||||
return self.llm_model
|
||||
elif isinstance(self.llm_model, AzureOpenAI):
|
||||
return AzureOpenAIEmbeddings()
|
||||
elif isinstance(self.llm_model, Fireworks):
|
||||
return FireworksEmbeddings(model=self.llm_model.model_name)
|
||||
elif isinstance(self.llm_model, Ollama):
|
||||
# unwrap the kwargs from the model whihc is a dict
|
||||
params = self.llm_model._lc_kwargs
|
||||
@ -333,6 +346,13 @@ class AbstractGraph(ABC):
|
||||
except KeyError as exc:
|
||||
raise KeyError("Model not supported") from exc
|
||||
return HuggingFaceHubEmbeddings(model=embedder_params["model"])
|
||||
elif "fireworks" in embedder_params["model"]:
|
||||
embedder_params["model"] = "/".join(embedder_params["model"].split("/")[1:])
|
||||
try:
|
||||
models_tokens["fireworks"][embedder_params["model"]]
|
||||
except KeyError as exc:
|
||||
raise KeyError("Model not supported") from exc
|
||||
return FireworksEmbeddings(model=embedder_params["model"])
|
||||
elif "gemini" in embedder_params["model"]:
|
||||
try:
|
||||
models_tokens["gemini"][embedder_params["model"]]
|
||||
|
||||
@ -143,5 +143,10 @@ models_tokens = {
|
||||
"ernie-bot-2-base-en": 4096,
|
||||
"ernie-bot-2-base-en-zh": 4096,
|
||||
"ernie-bot-2-base-zh-en": 4096
|
||||
}
|
||||
},
|
||||
"fireworks": {
|
||||
"llama-v2-7b": 4096,
|
||||
"mixtral-8x7b-instruct": 4096,
|
||||
"nomic-ai/nomic-embed-text-v1.5": 8192
|
||||
},
|
||||
}
|
||||
|
||||
@ -14,3 +14,4 @@ from .bedrock import Bedrock
|
||||
from .anthropic import Anthropic
|
||||
from .deepseek import DeepSeek
|
||||
from .oneapi import OneApi
|
||||
from .fireworks import Fireworks
|
||||
|
||||
33
scrapegraphai/models/fireworks.py
Normal file
33
scrapegraphai/models/fireworks.py
Normal file
@ -0,0 +1,33 @@
|
||||
"""
|
||||
Fireworks Module
|
||||
"""
|
||||
from langchain_fireworks import ChatFireworks
|
||||
|
||||
|
||||
class Fireworks(ChatFireworks):
|
||||
"""
|
||||
Initializes the Fireworks class.
|
||||
|
||||
Args:
|
||||
llm_config (dict): A dictionary containing configuration parameters for the LLM (required).
|
||||
The specific keys and values will depend on the LLM implementation
|
||||
used by the underlying `ChatFireworks` class. Consult its documentation
|
||||
for details.
|
||||
|
||||
Raises:
|
||||
ValueError: If required keys are missing from the llm_config dictionary.
|
||||
"""
|
||||
|
||||
def __init__(self, llm_config: dict):
|
||||
"""
|
||||
Initializes the Fireworks class.
|
||||
|
||||
Args:
|
||||
llm_config (dict): A dictionary containing configuration parameters for the LLM.
|
||||
The specific keys and values will depend on the LLM implementation.
|
||||
|
||||
Raises:
|
||||
ValueError: If required keys are missing from the llm_config dictionary.
|
||||
"""
|
||||
|
||||
super().__init__(**llm_config)
|
||||
Loading…
Reference in New Issue
Block a user