feat: update exception

Co-Authored-By: Federico Aguzzi <62149513+f-aguzzi@users.noreply.github.com>
This commit is contained in:
Marco Vinciguerra 2024-09-24 09:35:33 +02:00
parent 3b5ee767cb
commit 3876cb7be8

View File

@ -154,12 +154,13 @@ class AbstractGraph(ABC):
try: try:
self.model_token = models_tokens[llm_params["model_provider"]][llm_params["model"]] self.model_token = models_tokens[llm_params["model_provider"]][llm_params["model"]]
except KeyError: except KeyError:
print(f"""Model {llm_params['model_provider']}/{llm_params['model']} not found, print(f"""Model {llm_params['model_provider']}/{llm_params['model']} not found,
using default token size (8192)""") using default token size (8192)""")
self.model_token = 8192 self.model_token = 8192
try: try:
if llm_params["model_provider"] not in {"oneapi","nvidia","ernie","deepseek","togetherai"}: if llm_params["model_provider"] not in \
{"oneapi","nvidia","ernie","deepseek","togetherai"}:
if llm_params["model_provider"] == "bedrock": if llm_params["model_provider"] == "bedrock":
llm_params["model_kwargs"] = { "temperature" : llm_params.pop("temperature") } llm_params["model_kwargs"] = { "temperature" : llm_params.pop("temperature") }
with warnings.catch_warnings(): with warnings.catch_warnings():
@ -195,7 +196,7 @@ class AbstractGraph(ABC):
return ChatNVIDIA(**llm_params) return ChatNVIDIA(**llm_params)
except Exception as e: except Exception as e:
print(f"Error instancing model: {e}") raise Exception(f"Error instancing model: {e}")
def get_state(self, key=None) -> dict: def get_state(self, key=None) -> dict: