fix: Refactor code to use CustomOpenAiCallbackManager for exclusive access to get_openai_callback

This commit is contained in:
Lorenzo Paleari 2024-09-14 02:06:52 +02:00
parent 063dd1a2ab
commit e657113ebc
No known key found for this signature in database
GPG Key ID: 010F47E3CB681DED
3 changed files with 37 additions and 17 deletions

View File

@ -4,8 +4,8 @@ base_graph module
import time
import warnings
from typing import Tuple
from langchain_community.callbacks import get_openai_callback
from ..telemetry import log_graph_execution
from ..utils import CustomOpenAiCallbackManager
class BaseGraph:
"""
@ -52,6 +52,7 @@ class BaseGraph:
self.entry_point = entry_point.node_name
self.graph_name = graph_name
self.initial_state = {}
self.callback_manager = CustomOpenAiCallbackManager()
if nodes[0].node_name != entry_point.node_name:
# raise a warning if the entry point is not the first node in the list
@ -154,7 +155,7 @@ class BaseGraph:
except Exception as e:
schema = None
with get_openai_callback() as cb:
with self.callback_manager.exclusive_get_openai_callback() as cb:
try:
result = current_node.execute(state)
except Exception as e:
@ -176,23 +177,24 @@ class BaseGraph:
node_exec_time = time.time() - curr_time
total_exec_time += node_exec_time
cb_data = {
"node_name": current_node.node_name,
"total_tokens": cb.total_tokens,
"prompt_tokens": cb.prompt_tokens,
"completion_tokens": cb.completion_tokens,
"successful_requests": cb.successful_requests,
"total_cost_USD": cb.total_cost,
"exec_time": node_exec_time,
}
if cb is not None:
cb_data = {
"node_name": current_node.node_name,
"total_tokens": cb.total_tokens,
"prompt_tokens": cb.prompt_tokens,
"completion_tokens": cb.completion_tokens,
"successful_requests": cb.successful_requests,
"total_cost_USD": cb.total_cost,
"exec_time": node_exec_time,
}
exec_info.append(cb_data)
exec_info.append(cb_data)
cb_total["total_tokens"] += cb_data["total_tokens"]
cb_total["prompt_tokens"] += cb_data["prompt_tokens"]
cb_total["completion_tokens"] += cb_data["completion_tokens"]
cb_total["successful_requests"] += cb_data["successful_requests"]
cb_total["total_cost_USD"] += cb_data["total_cost_USD"]
cb_total["total_tokens"] += cb_data["total_tokens"]
cb_total["prompt_tokens"] += cb_data["prompt_tokens"]
cb_total["completion_tokens"] += cb_data["completion_tokens"]
cb_total["successful_requests"] += cb_data["successful_requests"]
cb_total["total_cost_USD"] += cb_data["total_cost_USD"]
if current_node.node_type == "conditional_node":
current_node_name = result

View File

@ -17,3 +17,4 @@ from .screenshot_scraping.screenshot_preparation import (take_screenshot,
from .screenshot_scraping.text_detection import detect_text
from .tokenizer import num_tokens_calculus
from .split_text_into_chunks import split_text_into_chunks
from .custom_openai_callback import CustomOpenAiCallbackManager

View File

@ -0,0 +1,17 @@
import threading
from contextlib import contextmanager
from langchain_community.callbacks import get_openai_callback
class CustomOpenAiCallbackManager:
_lock = threading.Lock()
@contextmanager
def exclusive_get_openai_callback(self):
if CustomOpenAiCallbackManager._lock.acquire(blocking=False):
try:
with get_openai_callback() as cb:
yield cb
finally:
CustomOpenAiCallbackManager._lock.release()
else:
yield None