mirror of
https://github.com/VinciGit00/Scrapegraph-ai.git
synced 2026-07-01 21:00:48 +08:00
fix: Refactor code to use CustomOpenAiCallbackManager for exclusive access to get_openai_callback
This commit is contained in:
parent
063dd1a2ab
commit
e657113ebc
@ -4,8 +4,8 @@ base_graph module
|
||||
import time
|
||||
import warnings
|
||||
from typing import Tuple
|
||||
from langchain_community.callbacks import get_openai_callback
|
||||
from ..telemetry import log_graph_execution
|
||||
from ..utils import CustomOpenAiCallbackManager
|
||||
|
||||
class BaseGraph:
|
||||
"""
|
||||
@ -52,6 +52,7 @@ class BaseGraph:
|
||||
self.entry_point = entry_point.node_name
|
||||
self.graph_name = graph_name
|
||||
self.initial_state = {}
|
||||
self.callback_manager = CustomOpenAiCallbackManager()
|
||||
|
||||
if nodes[0].node_name != entry_point.node_name:
|
||||
# raise a warning if the entry point is not the first node in the list
|
||||
@ -154,7 +155,7 @@ class BaseGraph:
|
||||
except Exception as e:
|
||||
schema = None
|
||||
|
||||
with get_openai_callback() as cb:
|
||||
with self.callback_manager.exclusive_get_openai_callback() as cb:
|
||||
try:
|
||||
result = current_node.execute(state)
|
||||
except Exception as e:
|
||||
@ -176,23 +177,24 @@ class BaseGraph:
|
||||
node_exec_time = time.time() - curr_time
|
||||
total_exec_time += node_exec_time
|
||||
|
||||
cb_data = {
|
||||
"node_name": current_node.node_name,
|
||||
"total_tokens": cb.total_tokens,
|
||||
"prompt_tokens": cb.prompt_tokens,
|
||||
"completion_tokens": cb.completion_tokens,
|
||||
"successful_requests": cb.successful_requests,
|
||||
"total_cost_USD": cb.total_cost,
|
||||
"exec_time": node_exec_time,
|
||||
}
|
||||
if cb is not None:
|
||||
cb_data = {
|
||||
"node_name": current_node.node_name,
|
||||
"total_tokens": cb.total_tokens,
|
||||
"prompt_tokens": cb.prompt_tokens,
|
||||
"completion_tokens": cb.completion_tokens,
|
||||
"successful_requests": cb.successful_requests,
|
||||
"total_cost_USD": cb.total_cost,
|
||||
"exec_time": node_exec_time,
|
||||
}
|
||||
|
||||
exec_info.append(cb_data)
|
||||
exec_info.append(cb_data)
|
||||
|
||||
cb_total["total_tokens"] += cb_data["total_tokens"]
|
||||
cb_total["prompt_tokens"] += cb_data["prompt_tokens"]
|
||||
cb_total["completion_tokens"] += cb_data["completion_tokens"]
|
||||
cb_total["successful_requests"] += cb_data["successful_requests"]
|
||||
cb_total["total_cost_USD"] += cb_data["total_cost_USD"]
|
||||
cb_total["total_tokens"] += cb_data["total_tokens"]
|
||||
cb_total["prompt_tokens"] += cb_data["prompt_tokens"]
|
||||
cb_total["completion_tokens"] += cb_data["completion_tokens"]
|
||||
cb_total["successful_requests"] += cb_data["successful_requests"]
|
||||
cb_total["total_cost_USD"] += cb_data["total_cost_USD"]
|
||||
|
||||
if current_node.node_type == "conditional_node":
|
||||
current_node_name = result
|
||||
|
||||
@ -17,3 +17,4 @@ from .screenshot_scraping.screenshot_preparation import (take_screenshot,
|
||||
from .screenshot_scraping.text_detection import detect_text
|
||||
from .tokenizer import num_tokens_calculus
|
||||
from .split_text_into_chunks import split_text_into_chunks
|
||||
from .custom_openai_callback import CustomOpenAiCallbackManager
|
||||
|
||||
17
scrapegraphai/utils/custom_openai_callback.py
Normal file
17
scrapegraphai/utils/custom_openai_callback.py
Normal file
@ -0,0 +1,17 @@
|
||||
import threading
|
||||
from contextlib import contextmanager
|
||||
from langchain_community.callbacks import get_openai_callback
|
||||
|
||||
class CustomOpenAiCallbackManager:
|
||||
_lock = threading.Lock()
|
||||
|
||||
@contextmanager
|
||||
def exclusive_get_openai_callback(self):
|
||||
if CustomOpenAiCallbackManager._lock.acquire(blocking=False):
|
||||
try:
|
||||
with get_openai_callback() as cb:
|
||||
yield cb
|
||||
finally:
|
||||
CustomOpenAiCallbackManager._lock.release()
|
||||
else:
|
||||
yield None
|
||||
Loading…
Reference in New Issue
Block a user