mirror of
https://github.com/VinciGit00/Scrapegraph-ai.git
synced 2026-06-04 21:01:04 +08:00
feat: add grok integration
This commit is contained in:
parent
ec957a5828
commit
0c476a4a7b
@ -13,7 +13,7 @@ from langchain_core.rate_limiters import InMemoryRateLimiter
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..helpers import models_tokens
|
||||
from ..models import CLoD, DeepSeek, OneApi
|
||||
from ..models import CLoD, DeepSeek, OneApi, XAI
|
||||
from ..utils.logging import set_verbosity_info, set_verbosity_warning
|
||||
|
||||
|
||||
@ -163,6 +163,7 @@ class AbstractGraph(ABC):
|
||||
"fireworks",
|
||||
"clod",
|
||||
"togetherai",
|
||||
"xai",
|
||||
}
|
||||
|
||||
if "/" in llm_params["model"]:
|
||||
@ -217,6 +218,7 @@ class AbstractGraph(ABC):
|
||||
"deepseek",
|
||||
"togetherai",
|
||||
"clod",
|
||||
"xai",
|
||||
}:
|
||||
if llm_params["model_provider"] == "bedrock":
|
||||
llm_params["model_kwargs"] = {
|
||||
@ -242,6 +244,9 @@ class AbstractGraph(ABC):
|
||||
elif model_provider == "oneapi":
|
||||
return OneApi(**llm_params)
|
||||
|
||||
elif model_provider == "xai":
|
||||
return XAI(**llm_params)
|
||||
|
||||
elif model_provider == "togetherai":
|
||||
try:
|
||||
from langchain_together import ChatTogether
|
||||
|
||||
@ -150,7 +150,7 @@ models_tokens = {
|
||||
"llama3-70b-8192": 8192,
|
||||
"mixtral-8x7b-32768": 32768,
|
||||
"gemma-7b-it": 8192,
|
||||
"claude-3-haiku-20240307'": 8192,
|
||||
"claude-3-haiku-20240307": 8192,
|
||||
},
|
||||
"toghetherai": {
|
||||
"meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo": 128000,
|
||||
@ -303,4 +303,7 @@ models_tokens = {
|
||||
"grok-2-latest": 128000,
|
||||
},
|
||||
"togetherai": {"Meta-Llama-3.1-70B-Instruct-Turbo": 128000},
|
||||
"xai": {
|
||||
"grok-1": 8192
|
||||
},
|
||||
}
|
||||
|
||||
@ -7,5 +7,6 @@ from .deepseek import DeepSeek
|
||||
from .oneapi import OneApi
|
||||
from .openai_itt import OpenAIImageToText
|
||||
from .openai_tts import OpenAITextToSpeech
|
||||
from .xai import XAI
|
||||
|
||||
__all__ = ["DeepSeek", "OneApi", "OpenAIImageToText", "OpenAITextToSpeech", "CLoD"]
|
||||
__all__ = ["DeepSeek", "OneApi", "OpenAIImageToText", "OpenAITextToSpeech", "CLoD", "XAI"]
|
||||
|
||||
23
scrapegraphai/models/xai.py
Normal file
23
scrapegraphai/models/xai.py
Normal file
@ -0,0 +1,23 @@
|
||||
"""
|
||||
xAI Grok Module
|
||||
"""
|
||||
from langchain_groq import ChatGroq as LangchainChatGroq
|
||||
|
||||
class XAI(LangchainChatGroq):
|
||||
"""
|
||||
Wrapper for the ChatGroq class from langchain_groq, for use with xAI models.
|
||||
Handles API key mapping from generic 'api_key' to 'groq_api_key' and
|
||||
maps 'model' to 'model_name'.
|
||||
|
||||
Args:
|
||||
llm_config (dict): Configuration parameters for the language model.
|
||||
"""
|
||||
|
||||
def __init__(self, **llm_config):
|
||||
if "api_key" in llm_config and "groq_api_key" not in llm_config:
|
||||
llm_config["groq_api_key"] = llm_config.pop("api_key")
|
||||
|
||||
if "model" in llm_config and "model_name" not in llm_config:
|
||||
llm_config["model_name"] = llm_config.pop("model")
|
||||
|
||||
super().__init__(**llm_config)
|
||||
Loading…
Reference in New Issue
Block a user