feat: add grok integration

This commit is contained in:
Marco Vinciguerra 2025-05-30 14:25:24 +02:00
parent ec957a5828
commit 0c476a4a7b
4 changed files with 35 additions and 3 deletions

View File

@ -13,7 +13,7 @@ from langchain_core.rate_limiters import InMemoryRateLimiter
from pydantic import BaseModel
from ..helpers import models_tokens
from ..models import CLoD, DeepSeek, OneApi
from ..models import CLoD, DeepSeek, OneApi, XAI
from ..utils.logging import set_verbosity_info, set_verbosity_warning
@ -163,6 +163,7 @@ class AbstractGraph(ABC):
"fireworks",
"clod",
"togetherai",
"xai",
}
if "/" in llm_params["model"]:
@ -217,6 +218,7 @@ class AbstractGraph(ABC):
"deepseek",
"togetherai",
"clod",
"xai",
}:
if llm_params["model_provider"] == "bedrock":
llm_params["model_kwargs"] = {
@ -242,6 +244,9 @@ class AbstractGraph(ABC):
elif model_provider == "oneapi":
return OneApi(**llm_params)
elif model_provider == "xai":
return XAI(**llm_params)
elif model_provider == "togetherai":
try:
from langchain_together import ChatTogether

View File

@ -150,7 +150,7 @@ models_tokens = {
"llama3-70b-8192": 8192,
"mixtral-8x7b-32768": 32768,
"gemma-7b-it": 8192,
"claude-3-haiku-20240307'": 8192,
"claude-3-haiku-20240307": 8192,
},
"toghetherai": {
"meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo": 128000,
@ -303,4 +303,7 @@ models_tokens = {
"grok-2-latest": 128000,
},
"togetherai": {"Meta-Llama-3.1-70B-Instruct-Turbo": 128000},
"xai": {
"grok-1": 8192
},
}

View File

@ -7,5 +7,6 @@ from .deepseek import DeepSeek
from .oneapi import OneApi
from .openai_itt import OpenAIImageToText
from .openai_tts import OpenAITextToSpeech
from .xai import XAI
__all__ = ["DeepSeek", "OneApi", "OpenAIImageToText", "OpenAITextToSpeech", "CLoD"]
__all__ = ["DeepSeek", "OneApi", "OpenAIImageToText", "OpenAITextToSpeech", "CLoD", "XAI"]

View File

@ -0,0 +1,23 @@
"""
xAI Grok Module
"""
from langchain_groq import ChatGroq as LangchainChatGroq
class XAI(LangchainChatGroq):
"""
Wrapper for the ChatGroq class from langchain_groq, for use with xAI models.
Handles API key mapping from generic 'api_key' to 'groq_api_key' and
maps 'model' to 'model_name'.
Args:
llm_config (dict): Configuration parameters for the language model.
"""
def __init__(self, **llm_config):
if "api_key" in llm_config and "groq_api_key" not in llm_config:
llm_config["groq_api_key"] = llm_config.pop("api_key")
if "model" in llm_config and "model_name" not in llm_config:
llm_config["model_name"] = llm_config.pop("model")
super().__init__(**llm_config)