mirror of
https://github.com/VinciGit00/Scrapegraph-ai.git
synced 2026-07-01 21:00:48 +08:00
feat: add refactoring of default temperature
This commit is contained in:
parent
b470d974cf
commit
6c3b37ab00
@ -14,7 +14,6 @@ graph_config = {
|
||||
"format": "json", # Ollama needs the format to be specified explicitly
|
||||
# "base_url": "http://localhost:11434", # set ollama URL arbitrarily
|
||||
},
|
||||
|
||||
"verbose": True,
|
||||
"headless": False
|
||||
}
|
||||
|
||||
@ -53,6 +53,9 @@ class AbstractGraph(ABC):
|
||||
def __init__(self, prompt: str, config: dict,
|
||||
source: Optional[str] = None, schema: Optional[BaseModel] = None):
|
||||
|
||||
if config.get("llm").get("temperature") is None:
|
||||
config["llm"]["temperature"] = 0
|
||||
|
||||
self.prompt = prompt
|
||||
self.source = source
|
||||
self.config = config
|
||||
@ -212,7 +215,7 @@ class AbstractGraph(ABC):
|
||||
print("model not found, using default token size (8192)")
|
||||
self.model_token = 8192
|
||||
return ErnieBotChat(llm_params)
|
||||
|
||||
|
||||
if "oneapi" in llm_params["model"]:
|
||||
# take the model after the last dash
|
||||
llm_params["model"] = llm_params["model"].split("/")[-1]
|
||||
@ -221,7 +224,7 @@ class AbstractGraph(ABC):
|
||||
except KeyError as exc:
|
||||
raise KeyError("Model not supported") from exc
|
||||
return OneApi(llm_params)
|
||||
|
||||
|
||||
if "nvidia" in llm_params["model"]:
|
||||
try:
|
||||
self.model_token = models_tokens["nvidia"][llm_params["model"].split("/")[-1]]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user