Scrapegraph-ai/scrapegraphai/utils/custom_callback.py
2025-01-06 15:10:35 +01:00

158 lines
5.3 KiB
Python

"""
Custom callback for LLM token usage statistics.
This module has been taken and modified from the OpenAI callback manager in langchian-community.
https://github.com/langchain-ai/langchain/blob/master/libs/community/langchain_community/callbacks/openai_info.py
"""
import threading
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Any, Dict, List, Optional
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.messages import AIMessage
from langchain_core.outputs import ChatGeneration, LLMResult
from langchain_core.tracers.context import register_configure_hook
from .model_costs import MODEL_COST_PER_1K_TOKENS_INPUT, MODEL_COST_PER_1K_TOKENS_OUTPUT
def get_token_cost_for_model(
model_name: str, num_tokens: int, is_completion: bool = False
) -> float:
"""
Get the cost in USD for a given model and number of tokens.
Args:
model_name: Name of the model
num_tokens: Number of tokens.
is_completion: Whether the model is used for completion or not.
Defaults to False.
Returns:
Cost in USD.
"""
if model_name not in MODEL_COST_PER_1K_TOKENS_INPUT:
return 0.0
if is_completion:
return MODEL_COST_PER_1K_TOKENS_OUTPUT[model_name] * (num_tokens / 1000)
return MODEL_COST_PER_1K_TOKENS_INPUT[model_name] * (num_tokens / 1000)
class CustomCallbackHandler(BaseCallbackHandler):
"""Callback Handler that tracks LLMs info."""
total_tokens: int = 0
prompt_tokens: int = 0
completion_tokens: int = 0
successful_requests: int = 0
total_cost: float = 0.0
def __init__(self, llm_model_name: str) -> None:
super().__init__()
self._lock = threading.Lock()
self.model_name = llm_model_name if llm_model_name else "unknown"
def __repr__(self) -> str:
return (
f"Tokens Used: {self.total_tokens}\n"
f"\tPrompt Tokens: {self.prompt_tokens}\n"
f"\tCompletion Tokens: {self.completion_tokens}\n"
f"Successful Requests: {self.successful_requests}\n"
f"Total Cost (USD): ${self.total_cost}"
)
@property
def always_verbose(self) -> bool:
"""Whether to call verbose callbacks even if verbose is False."""
return True
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Print out the prompts."""
pass
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Print out the token."""
pass
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Collect token usage."""
# Check for usage_metadata (langchain-core >= 0.2.2)
try:
generation = response.generations[0][0]
except IndexError:
generation = None
if isinstance(generation, ChatGeneration):
try:
message = generation.message
if isinstance(message, AIMessage):
usage_metadata = message.usage_metadata
else:
usage_metadata = None
except AttributeError:
usage_metadata = None
else:
usage_metadata = None
if usage_metadata:
token_usage = {"total_tokens": usage_metadata["total_tokens"]}
completion_tokens = usage_metadata["output_tokens"]
prompt_tokens = usage_metadata["input_tokens"]
else:
if response.llm_output is None:
return None
if "token_usage" not in response.llm_output:
with self._lock:
self.successful_requests += 1
return None
# compute tokens and cost for this request
token_usage = response.llm_output["token_usage"]
completion_tokens = token_usage.get("completion_tokens", 0)
prompt_tokens = token_usage.get("prompt_tokens", 0)
if self.model_name in MODEL_COST_PER_1K_TOKENS_INPUT:
completion_cost = get_token_cost_for_model(
self.model_name, completion_tokens, is_completion=True
)
prompt_cost = get_token_cost_for_model(self.model_name, prompt_tokens)
else:
completion_cost = 0
prompt_cost = 0
# update shared state behind lock
with self._lock:
self.total_cost += prompt_cost + completion_cost
self.total_tokens += token_usage.get("total_tokens", 0)
self.prompt_tokens += prompt_tokens
self.completion_tokens += completion_tokens
self.successful_requests += 1
def __copy__(self) -> "CustomCallbackHandler":
"""Return a copy of the callback handler."""
return self
def __deepcopy__(self, memo: Any) -> "CustomCallbackHandler":
"""Return a deep copy of the callback handler."""
return self
custom_callback: ContextVar[Optional[CustomCallbackHandler]] = ContextVar(
"custom_callback", default=None
)
register_configure_hook(custom_callback, True)
@contextmanager
def get_custom_callback(llm_model_name: str):
"""
Function to get custom callback for LLM token usage statistics.
"""
cb = CustomCallbackHandler(llm_model_name)
custom_callback.set(cb)
yield cb
custom_callback.set(None)