feat: add refactoring of default temperature

This commit is contained in:
Marco Vinciguerra 2024-08-10 11:51:37 +02:00
parent b470d974cf
commit 6c3b37ab00
2 changed files with 5 additions and 3 deletions

View File

@ -14,7 +14,6 @@ graph_config = {
"format": "json", # Ollama needs the format to be specified explicitly
# "base_url": "http://localhost:11434", # set ollama URL arbitrarily
},
"verbose": True,
"headless": False
}

View File

@ -53,6 +53,9 @@ class AbstractGraph(ABC):
def __init__(self, prompt: str, config: dict,
source: Optional[str] = None, schema: Optional[BaseModel] = None):
if config.get("llm").get("temperature") is None:
config["llm"]["temperature"] = 0
self.prompt = prompt
self.source = source
self.config = config
@ -212,7 +215,7 @@ class AbstractGraph(ABC):
print("model not found, using default token size (8192)")
self.model_token = 8192
return ErnieBotChat(llm_params)
if "oneapi" in llm_params["model"]:
# take the model after the last dash
llm_params["model"] = llm_params["model"].split("/")[-1]
@ -221,7 +224,7 @@ class AbstractGraph(ABC):
except KeyError as exc:
raise KeyError("Model not supported") from exc
return OneApi(llm_params)
if "nvidia" in llm_params["model"]:
try:
self.model_token = models_tokens["nvidia"][llm_params["model"].split("/")[-1]]