mirror of
https://github.com/VinciGit00/Scrapegraph-ai.git
synced 2026-06-28 21:01:55 +08:00
parent
50c9c6bd8c
commit
63a5d18486
@ -128,7 +128,7 @@ class AbstractGraph(ABC):
|
||||
return llm_params["model_instance"]
|
||||
|
||||
known_providers = {"openai", "azure_openai", "google_genai", "google_vertexai",
|
||||
"ollama", "oneapi", "nvidia", "groq", "anthropic" "bedrock", "mistralai",
|
||||
"ollama", "oneapi", "nvidia", "groq", "anthropic", "bedrock", "mistralai",
|
||||
"hugging_face", "deepseek", "ernie", "fireworks", "togetherai"}
|
||||
|
||||
split_model_provider = llm_params["model"].split("/", 1)
|
||||
@ -146,6 +146,8 @@ class AbstractGraph(ABC):
|
||||
|
||||
try:
|
||||
if llm_params["model_provider"] not in {"oneapi", "nvidia", "ernie", "deepseek", "togetherai"}:
|
||||
if llm_params["model_provider"] == "bedrock":
|
||||
llm_params["model_kwargs"] = { "temperature" : llm_params.pop("temperature") }
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
return init_chat_model(**llm_params)
|
||||
|
||||
@ -12,6 +12,7 @@ from scrapegraphai.models import OneApi, DeepSeek
|
||||
from langchain_openai import ChatOpenAI, AzureChatOpenAI
|
||||
from langchain_ollama import ChatOllama
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
from langchain_aws import ChatBedrock
|
||||
|
||||
|
||||
|
||||
@ -71,6 +72,7 @@ class TestAbstractGraph:
|
||||
({"model": "ollama/llama2"}, ChatOllama),
|
||||
({"model": "oneapi/qwen-turbo", "api_key": "oneapi-api-key"}, OneApi),
|
||||
({"model": "deepseek/deepseek-coder", "api_key": "deepseek-api-key"}, DeepSeek),
|
||||
({"model": "bedrock/anthropic.claude-3-sonnet-20240229-v1:0", "region_name": "IDK"}, ChatBedrock),
|
||||
])
|
||||
|
||||
def test_create_llm(self, llm_config, expected_model):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user