mirror of
https://github.com/VinciGit00/Scrapegraph-ai.git
synced 2026-07-01 21:00:48 +08:00
fix: Refactor code to use CustomOpenAiCallbackManager for exclusive access to get_openai_callback
This commit is contained in:
parent
063dd1a2ab
commit
e657113ebc
@ -4,8 +4,8 @@ base_graph module
|
|||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
from langchain_community.callbacks import get_openai_callback
|
|
||||||
from ..telemetry import log_graph_execution
|
from ..telemetry import log_graph_execution
|
||||||
|
from ..utils import CustomOpenAiCallbackManager
|
||||||
|
|
||||||
class BaseGraph:
|
class BaseGraph:
|
||||||
"""
|
"""
|
||||||
@ -52,6 +52,7 @@ class BaseGraph:
|
|||||||
self.entry_point = entry_point.node_name
|
self.entry_point = entry_point.node_name
|
||||||
self.graph_name = graph_name
|
self.graph_name = graph_name
|
||||||
self.initial_state = {}
|
self.initial_state = {}
|
||||||
|
self.callback_manager = CustomOpenAiCallbackManager()
|
||||||
|
|
||||||
if nodes[0].node_name != entry_point.node_name:
|
if nodes[0].node_name != entry_point.node_name:
|
||||||
# raise a warning if the entry point is not the first node in the list
|
# raise a warning if the entry point is not the first node in the list
|
||||||
@ -154,7 +155,7 @@ class BaseGraph:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
schema = None
|
schema = None
|
||||||
|
|
||||||
with get_openai_callback() as cb:
|
with self.callback_manager.exclusive_get_openai_callback() as cb:
|
||||||
try:
|
try:
|
||||||
result = current_node.execute(state)
|
result = current_node.execute(state)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -176,23 +177,24 @@ class BaseGraph:
|
|||||||
node_exec_time = time.time() - curr_time
|
node_exec_time = time.time() - curr_time
|
||||||
total_exec_time += node_exec_time
|
total_exec_time += node_exec_time
|
||||||
|
|
||||||
cb_data = {
|
if cb is not None:
|
||||||
"node_name": current_node.node_name,
|
cb_data = {
|
||||||
"total_tokens": cb.total_tokens,
|
"node_name": current_node.node_name,
|
||||||
"prompt_tokens": cb.prompt_tokens,
|
"total_tokens": cb.total_tokens,
|
||||||
"completion_tokens": cb.completion_tokens,
|
"prompt_tokens": cb.prompt_tokens,
|
||||||
"successful_requests": cb.successful_requests,
|
"completion_tokens": cb.completion_tokens,
|
||||||
"total_cost_USD": cb.total_cost,
|
"successful_requests": cb.successful_requests,
|
||||||
"exec_time": node_exec_time,
|
"total_cost_USD": cb.total_cost,
|
||||||
}
|
"exec_time": node_exec_time,
|
||||||
|
}
|
||||||
|
|
||||||
exec_info.append(cb_data)
|
exec_info.append(cb_data)
|
||||||
|
|
||||||
cb_total["total_tokens"] += cb_data["total_tokens"]
|
cb_total["total_tokens"] += cb_data["total_tokens"]
|
||||||
cb_total["prompt_tokens"] += cb_data["prompt_tokens"]
|
cb_total["prompt_tokens"] += cb_data["prompt_tokens"]
|
||||||
cb_total["completion_tokens"] += cb_data["completion_tokens"]
|
cb_total["completion_tokens"] += cb_data["completion_tokens"]
|
||||||
cb_total["successful_requests"] += cb_data["successful_requests"]
|
cb_total["successful_requests"] += cb_data["successful_requests"]
|
||||||
cb_total["total_cost_USD"] += cb_data["total_cost_USD"]
|
cb_total["total_cost_USD"] += cb_data["total_cost_USD"]
|
||||||
|
|
||||||
if current_node.node_type == "conditional_node":
|
if current_node.node_type == "conditional_node":
|
||||||
current_node_name = result
|
current_node_name = result
|
||||||
|
|||||||
@ -17,3 +17,4 @@ from .screenshot_scraping.screenshot_preparation import (take_screenshot,
|
|||||||
from .screenshot_scraping.text_detection import detect_text
|
from .screenshot_scraping.text_detection import detect_text
|
||||||
from .tokenizer import num_tokens_calculus
|
from .tokenizer import num_tokens_calculus
|
||||||
from .split_text_into_chunks import split_text_into_chunks
|
from .split_text_into_chunks import split_text_into_chunks
|
||||||
|
from .custom_openai_callback import CustomOpenAiCallbackManager
|
||||||
|
|||||||
17
scrapegraphai/utils/custom_openai_callback.py
Normal file
17
scrapegraphai/utils/custom_openai_callback.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
import threading
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from langchain_community.callbacks import get_openai_callback
|
||||||
|
|
||||||
|
class CustomOpenAiCallbackManager:
|
||||||
|
_lock = threading.Lock()
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def exclusive_get_openai_callback(self):
|
||||||
|
if CustomOpenAiCallbackManager._lock.acquire(blocking=False):
|
||||||
|
try:
|
||||||
|
with get_openai_callback() as cb:
|
||||||
|
yield cb
|
||||||
|
finally:
|
||||||
|
CustomOpenAiCallbackManager._lock.release()
|
||||||
|
else:
|
||||||
|
yield None
|
||||||
Loading…
Reference in New Issue
Block a user