docs(gpt-4o-mini): added new gpt, fixed chromium lazy loading,

add documentation and metrics
This commit is contained in:
Marco Perini 2024-07-20 20:02:26 +02:00 committed by GitHub
parent b4b90b3c12
commit 99dc8497d8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 52 additions and 6 deletions

View File

@ -14,7 +14,7 @@ Some interesting ones are:
- `burr_kwargs`: A dictionary with additional parameters to enable `Burr` graphical user interface.
- `max_images`: The maximum number of images to be analyzed. Useful in `OmniScraperGraph` and `OmniSearchGraph`.
- `cache_path`: The path where the cache files will be saved. If already exists, the cache will be loaded from this path.
- `additional_info`: Add additional text to default prompts defined in the graphs.
.. _Burr:
Burr Integration

View File

@ -27,8 +27,13 @@ Additionally, the following properties are collected:
"llm_model": llm_model_name,
"embedder_model": embedder_model_name,
"source_type": source_type,
"source": source,
"execution_time": execution_time,
"prompt": prompt,
"schema": schema,
"error_node": error_node_name,
"exception": exception,
"response": response,
"total_tokens": total_tokens,
}

View File

@ -82,7 +82,7 @@ class ChromiumLoader(BaseLoader):
context = await browser.new_context()
await Malenia.apply_stealth(context)
page = await context.new_page()
await page.goto(url)
await page.goto(url, wait_until="domcontentloaded")
await page.wait_for_load_state(self.load_state)
results = await page.content() # Simply get the HTML content
logger.info("Content scraped")

View File

@ -106,18 +106,32 @@ class BaseGraph:
source_type = None
llm_model = None
embedder_model = None
source = []
prompt = None
schema = None
while current_node_name:
curr_time = time.time()
current_node = next(node for node in self.nodes if node.node_name == current_node_name)
# check if there is a "source" key in the node config
if current_node.__class__.__name__ == "FetchNode":
# get the second key name of the state dictionary
source_type = list(state.keys())[1]
if state.get("user_prompt", None):
prompt = state["user_prompt"] if type(state["user_prompt"]) == str else None
# quick fix for local_dir source type
if source_type == "local_dir":
source_type = "html_dir"
elif source_type == "url":
if type(state[source_type]) == list:
# iterate through the list of urls and see if they are strings
for url in state[source_type]:
if type(url) == str:
source.append(url)
elif type(state[source_type]) == str:
source.append(state[source_type])
# check if there is an "llm_model" variable in the class
if hasattr(current_node, "llm_model") and llm_model is None:
@ -135,6 +149,16 @@ class BaseGraph:
elif hasattr(embedder_model, "model"):
embedder_model = embedder_model.model
if hasattr(current_node, "node_config"):
if type(current_node.node_config) is dict:
if current_node.node_config.get("schema", None) and schema is None:
if type(current_node.node_config["schema"]) is not dict:
# convert to dict
try:
schema = current_node.node_config["schema"].schema()
except Exception as e:
schema = None
with get_openai_callback() as cb:
try:
result = current_node.execute(state)
@ -144,11 +168,15 @@ class BaseGraph:
graph_execution_time = time.time() - start_time
log_graph_execution(
graph_name=self.graph_name,
source=source,
prompt=prompt,
schema=schema,
llm_model=llm_model,
embedder_model=embedder_model,
source_type=source_type,
execution_time=graph_execution_time,
error_node=error_node
error_node=error_node,
exception=str(e)
)
raise e
node_exec_time = time.time() - curr_time
@ -191,11 +219,16 @@ class BaseGraph:
# Log the graph execution telemetry
graph_execution_time = time.time() - start_time
response = state.get("answer", None) if source_type == "url" else None
log_graph_execution(
graph_name=self.graph_name,
source=source,
prompt=prompt,
schema=schema,
llm_model=llm_model,
embedder_model=embedder_model,
source_type=source_type,
response=response,
execution_time=graph_execution_time,
total_tokens=cb_total["total_tokens"] if cb_total["total_tokens"] > 0 else None,
)

View File

@ -16,9 +16,10 @@ models_tokens = {
"gpt-4-32k": 32768,
"gpt-4-32k-0613": 32768,
"gpt-4o": 128000,
"gpt-4o-mini":128000,
},
"azure": {
"gpt-3.5-turbo-0125": 16385,
"gpt-3.5-turbo-0125": 16385,
"gpt-3.5": 4096,
"gpt-3.5-turbo": 16385,
"gpt-3.5-turbo-1106": 16385,
@ -34,6 +35,7 @@ models_tokens = {
"gpt-4-32k": 32768,
"gpt-4-32k-0613": 32768,
"gpt-4o": 128000,
"gpt-4o-mini":128000,
},
"gemini": {
"gemini-pro": 128000,

View File

@ -126,7 +126,8 @@ class GraphIteratorNode(BaseNode):
for url in urls:
instance = copy.copy(graph_instance)
instance.source = url
if url.startswith("http"):
instance.input_key = "url"
participants.append(instance)
futures = [_async_run(graph) for graph in participants]

View File

@ -156,14 +156,19 @@ def log_event(event: str, properties: Dict[str, any]):
send_event_json(event_json)
def log_graph_execution(graph_name: str, llm_model: str, embedder_model: str, source_type: str, execution_time: float, error_node: str = None, total_tokens: int = None):
def log_graph_execution(graph_name: str, source: str, prompt:str, schema:dict, llm_model: str, embedder_model: str, source_type: str, execution_time: float, response: dict = None, error_node: str = None, exception: str = None, total_tokens: int = None):
properties = {
"graph_name": graph_name,
"source": source,
"prompt": prompt,
"schema": schema,
"llm_model": llm_model,
"embedder_model": embedder_model,
"source_type": source_type,
"response": response,
"execution_time": execution_time,
"error_node": error_node,
"exception": exception,
"total_tokens": total_tokens,
}
log_event("graph_execution", properties)