fix(AbstractGraph): Bedrock init issues

Closes #633
This commit is contained in:
Federico Aguzzi 2024-09-05 10:19:47 +02:00
parent 50c9c6bd8c
commit 63a5d18486
2 changed files with 5 additions and 1 deletions

View File

@ -128,7 +128,7 @@ class AbstractGraph(ABC):
return llm_params["model_instance"]
known_providers = {"openai", "azure_openai", "google_genai", "google_vertexai",
"ollama", "oneapi", "nvidia", "groq", "anthropic" "bedrock", "mistralai",
"ollama", "oneapi", "nvidia", "groq", "anthropic", "bedrock", "mistralai",
"hugging_face", "deepseek", "ernie", "fireworks", "togetherai"}
split_model_provider = llm_params["model"].split("/", 1)
@ -146,6 +146,8 @@ class AbstractGraph(ABC):
try:
if llm_params["model_provider"] not in {"oneapi", "nvidia", "ernie", "deepseek", "togetherai"}:
if llm_params["model_provider"] == "bedrock":
llm_params["model_kwargs"] = { "temperature" : llm_params.pop("temperature") }
with warnings.catch_warnings():
warnings.simplefilter("ignore")
return init_chat_model(**llm_params)

View File

@ -12,6 +12,7 @@ from scrapegraphai.models import OneApi, DeepSeek
from langchain_openai import ChatOpenAI, AzureChatOpenAI
from langchain_ollama import ChatOllama
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_aws import ChatBedrock
@ -71,6 +72,7 @@ class TestAbstractGraph:
({"model": "ollama/llama2"}, ChatOllama),
({"model": "oneapi/qwen-turbo", "api_key": "oneapi-api-key"}, OneApi),
({"model": "deepseek/deepseek-coder", "api_key": "deepseek-api-key"}, DeepSeek),
({"model": "bedrock/anthropic.claude-3-sonnet-20240229-v1:0", "region_name": "IDK"}, ChatBedrock),
])
def test_create_llm(self, llm_config, expected_model):