fix: abstract graph

Co-Authored-By: Federico Aguzzi <62149513+f-aguzzi@users.noreply.github.com>
This commit is contained in:
Marco Vinciguerra 2024-08-23 18:43:49 +02:00
parent 20410c9294
commit cf1fada36a
2 changed files with 9 additions and 20 deletions

View File

@ -63,13 +63,10 @@ class AbstractGraph(ABC):
self.cache_path = self.config.get("cache_path", False) self.cache_path = self.config.get("cache_path", False)
self.browser_base = self.config.get("browser_base") self.browser_base = self.config.get("browser_base")
# Create the graph
self.graph = self._create_graph() self.graph = self._create_graph()
self.final_state = None self.final_state = None
self.execution_info = None self.execution_info = None
# Set common configuration parameters
verbose = bool(config and config.get("verbose")) verbose = bool(config and config.get("verbose"))
if verbose: if verbose:
@ -87,12 +84,10 @@ class AbstractGraph(ABC):
self.set_common_params(common_params, overwrite=True) self.set_common_params(common_params, overwrite=True)
# set burr config
self.burr_kwargs = config.get("burr_kwargs", None) self.burr_kwargs = config.get("burr_kwargs", None)
if self.burr_kwargs is not None: if self.burr_kwargs is not None:
self.graph.use_burr = True self.graph.use_burr = True
if "app_instance_id" not in self.burr_kwargs: if "app_instance_id" not in self.burr_kwargs:
# set a random uuid for the app_instance_id to avoid conflicts
self.burr_kwargs["app_instance_id"] = str(uuid.uuid4()) self.burr_kwargs["app_instance_id"] = str(uuid.uuid4())
self.graph.burr_config = self.burr_kwargs self.graph.burr_config = self.burr_kwargs
@ -125,7 +120,6 @@ class AbstractGraph(ABC):
llm_defaults = {"temperature": 0, "streaming": False} llm_defaults = {"temperature": 0, "streaming": False}
llm_params = {**llm_defaults, **llm_config} llm_params = {**llm_defaults, **llm_config}
# If model instance is passed directly instead of the model details
if "model_instance" in llm_params: if "model_instance" in llm_params:
try: try:
self.model_token = llm_params["model_tokens"] self.model_token = llm_params["model_tokens"]
@ -145,18 +139,14 @@ class AbstractGraph(ABC):
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
return init_chat_model(**llm_params) return init_chat_model(**llm_params)
known_models = ["chatgpt","gpt","openai", "azure_openai", "google_genai", known_models = {"chatgpt","gpt","openai", "azure_openai", "google_genai",
"ollama", "oneapi", "nvidia", "groq", "google_vertexai", "ollama", "oneapi", "nvidia", "groq", "google_vertexai",
"bedrock", "mistralai", "hugging_face", "deepseek", "ernie", "fireworks"] "bedrock", "mistralai", "hugging_face", "deepseek", "ernie", "fireworks"}
if llm_params["model"].split("/")[0] not in known_models and llm_params["model"].split("-")[0] not in known_models: if llm_params["model"].split("/")[0] not in known_models and llm_params["model"].split("-")[0] not in known_models:
raise ValueError(f"Model '{llm_params['model']}' is not supported") raise ValueError(f"Model '{llm_params['model']}' is not supported")
try: try:
if "azure" in llm_params["model"]:
model_name = llm_params["model"].split("/")[-1]
return handle_model(model_name, "azure_openai", model_name)
if "fireworks" in llm_params["model"]: if "fireworks" in llm_params["model"]:
model_name = "/".join(llm_params["model"].split("/")[1:]) model_name = "/".join(llm_params["model"].split("/")[1:])
token_key = llm_params["model"].split("/")[-1] token_key = llm_params["model"].split("/")[-1]
@ -207,7 +197,6 @@ class AbstractGraph(ABC):
return ErnieBotChat(llm_params) return ErnieBotChat(llm_params)
elif "oneapi" in llm_params["model"]: elif "oneapi" in llm_params["model"]:
# take the model after the last dash
llm_params["model"] = llm_params["model"].split("/")[-1] llm_params["model"] = llm_params["model"].split("/")[-1]
try: try:
self.model_token = models_tokens["oneapi"][llm_params["model"]] self.model_token = models_tokens["oneapi"][llm_params["model"]]

View File

@ -50,6 +50,11 @@ models_tokens = {
"gemini-1.5-flash-latest": 128000, "gemini-1.5-flash-latest": 128000,
"gemini-1.5-pro-latest": 128000, "gemini-1.5-pro-latest": 128000,
"models/embedding-001": 2048 "models/embedding-001": 2048
},
"google_vertexai": {
"gemini-1.5-flash": 128000,
"gemini-1.5-pro": 128000,
"gemini-1.0-pro": 128000,
}, },
"ollama": { "ollama": {
"command-r": 12800, "command-r": 12800,
@ -96,7 +101,7 @@ models_tokens = {
"oneapi": { "oneapi": {
"qwen-turbo": 6000, "qwen-turbo": 6000,
}, },
"nv dia": { "nvdia": {
"meta/llama3-70b-instruct": 419, "meta/llama3-70b-instruct": 419,
"meta/llama3-8b-instruct": 419, "meta/llama3-8b-instruct": 419,
"nemotron-4-340b-instruct": 1024, "nemotron-4-340b-instruct": 1024,
@ -132,11 +137,6 @@ models_tokens = {
"claude-3-haiku-20240307": 200000, "claude-3-haiku-20240307": 200000,
"claude-3-5-sonnet-20240620": 200000, "claude-3-5-sonnet-20240620": 200000,
}, },
"google_vertexai": {
"gemini-1.5-flash": 128000,
"gemini-1.5-pro": 128000,
"gemini-1.0-pro": 128000,
},
"bedrock": { "bedrock": {
"anthropic.claude-3-haiku-20240307-v1:0": 200000, "anthropic.claude-3-haiku-20240307-v1:0": 200000,
"anthropic.claude-3-sonnet-20240229-v1:0": 200000, "anthropic.claude-3-sonnet-20240229-v1:0": 200000,