mirror of
https://github.com/VinciGit00/Scrapegraph-ai.git
synced 2026-06-28 21:01:55 +08:00
Merge pull request #680 from LorenzoPaleari/exec-info-enhanced
feat: added Bedrock and Mistral to exec info
This commit is contained in:
commit
95a5ee2d35
@ -5,7 +5,7 @@ import time
|
||||
import warnings
|
||||
from typing import Tuple
|
||||
from ..telemetry import log_graph_execution
|
||||
from ..utils import CustomOpenAiCallbackManager
|
||||
from ..utils import CustomLLMCallbackManager
|
||||
|
||||
class BaseGraph:
|
||||
"""
|
||||
@ -52,7 +52,7 @@ class BaseGraph:
|
||||
self.entry_point = entry_point.node_name
|
||||
self.graph_name = graph_name
|
||||
self.initial_state = {}
|
||||
self.callback_manager = CustomOpenAiCallbackManager()
|
||||
self.callback_manager = CustomLLMCallbackManager()
|
||||
|
||||
if nodes[0].node_name != entry_point.node_name:
|
||||
# raise a warning if the entry point is not the first node in the list
|
||||
@ -108,6 +108,7 @@ class BaseGraph:
|
||||
error_node = None
|
||||
source_type = None
|
||||
llm_model = None
|
||||
llm_model_name = None
|
||||
embedder_model = None
|
||||
source = []
|
||||
prompt = None
|
||||
@ -135,9 +136,11 @@ class BaseGraph:
|
||||
if hasattr(current_node, "llm_model") and llm_model is None:
|
||||
llm_model = current_node.llm_model
|
||||
if hasattr(llm_model, "model_name"):
|
||||
llm_model = llm_model.model_name
|
||||
llm_model_name = llm_model.model_name
|
||||
elif hasattr(llm_model, "model"):
|
||||
llm_model = llm_model.model
|
||||
llm_model_name = llm_model.model
|
||||
elif hasattr(llm_model, "model_id"):
|
||||
llm_model_name = llm_model.model_id
|
||||
|
||||
if hasattr(current_node, "embedder_model") and embedder_model is None:
|
||||
embedder_model = current_node.embedder_model
|
||||
@ -155,7 +158,7 @@ class BaseGraph:
|
||||
except Exception as e:
|
||||
schema = None
|
||||
|
||||
with self.callback_manager.exclusive_get_openai_callback() as cb:
|
||||
with self.callback_manager.exclusive_get_callback(llm_model, llm_model_name) as cb:
|
||||
try:
|
||||
result = current_node.execute(state)
|
||||
except Exception as e:
|
||||
@ -166,7 +169,7 @@ class BaseGraph:
|
||||
source=source,
|
||||
prompt=prompt,
|
||||
schema=schema,
|
||||
llm_model=llm_model,
|
||||
llm_model=llm_model_name,
|
||||
embedder_model=embedder_model,
|
||||
source_type=source_type,
|
||||
execution_time=graph_execution_time,
|
||||
@ -222,7 +225,7 @@ class BaseGraph:
|
||||
source=source,
|
||||
prompt=prompt,
|
||||
schema=schema,
|
||||
llm_model=llm_model,
|
||||
llm_model=llm_model_name,
|
||||
embedder_model=embedder_model,
|
||||
source_type=source_type,
|
||||
content=content,
|
||||
|
||||
@ -17,4 +17,4 @@ from .screenshot_scraping.screenshot_preparation import (take_screenshot,
|
||||
from .screenshot_scraping.text_detection import detect_text
|
||||
from .tokenizer import num_tokens_calculus
|
||||
from .split_text_into_chunks import split_text_into_chunks
|
||||
from .custom_openai_callback import CustomOpenAiCallbackManager
|
||||
from .llm_callback_manager import CustomLLMCallbackManager
|
||||
|
||||
157
scrapegraphai/utils/custom_callback.py
Normal file
157
scrapegraphai/utils/custom_callback.py
Normal file
@ -0,0 +1,157 @@
|
||||
"""
|
||||
Custom callback for LLM token usage statistics.
|
||||
|
||||
This module has been taken and modified from the OpenAI callback manager in langchian-community.
|
||||
https://github.com/langchain-ai/langchain/blob/master/libs/community/langchain_community/callbacks/openai_info.py
|
||||
"""
|
||||
from contextlib import contextmanager
|
||||
import threading
|
||||
from typing import Any, Dict, List, Optional
|
||||
from contextvars import ContextVar
|
||||
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.outputs import ChatGeneration, LLMResult
|
||||
from langchain_core.tracers.context import register_configure_hook
|
||||
|
||||
from .model_costs import MODEL_COST_PER_1K_TOKENS_INPUT, MODEL_COST_PER_1K_TOKENS_OUTPUT
|
||||
|
||||
|
||||
def get_token_cost_for_model(
|
||||
model_name: str, num_tokens: int, is_completion: bool = False
|
||||
) -> float:
|
||||
"""
|
||||
Get the cost in USD for a given model and number of tokens.
|
||||
|
||||
Args:
|
||||
model_name: Name of the model
|
||||
num_tokens: Number of tokens.
|
||||
is_completion: Whether the model is used for completion or not.
|
||||
Defaults to False.
|
||||
|
||||
Returns:
|
||||
Cost in USD.
|
||||
"""
|
||||
if model_name not in MODEL_COST_PER_1K_TOKENS_INPUT:
|
||||
return 0.0
|
||||
if is_completion:
|
||||
return MODEL_COST_PER_1K_TOKENS_OUTPUT[model_name] * (num_tokens / 1000)
|
||||
|
||||
return MODEL_COST_PER_1K_TOKENS_INPUT[model_name] * (num_tokens / 1000)
|
||||
|
||||
|
||||
class CustomCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback Handler that tracks LLMs info."""
|
||||
|
||||
total_tokens: int = 0
|
||||
prompt_tokens: int = 0
|
||||
completion_tokens: int = 0
|
||||
successful_requests: int = 0
|
||||
total_cost: float = 0.0
|
||||
|
||||
def __init__(self, llm_model_name: str) -> None:
|
||||
super().__init__()
|
||||
self._lock = threading.Lock()
|
||||
self.model_name = llm_model_name if llm_model_name else "unknown"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"Tokens Used: {self.total_tokens}\n"
|
||||
f"\tPrompt Tokens: {self.prompt_tokens}\n"
|
||||
f"\tCompletion Tokens: {self.completion_tokens}\n"
|
||||
f"Successful Requests: {self.successful_requests}\n"
|
||||
f"Total Cost (USD): ${self.total_cost}"
|
||||
)
|
||||
|
||||
@property
|
||||
def always_verbose(self) -> bool:
|
||||
"""Whether to call verbose callbacks even if verbose is False."""
|
||||
return True
|
||||
|
||||
def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> None:
|
||||
"""Print out the prompts."""
|
||||
pass
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Print out the token."""
|
||||
pass
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Collect token usage."""
|
||||
# Check for usage_metadata (langchain-core >= 0.2.2)
|
||||
try:
|
||||
generation = response.generations[0][0]
|
||||
except IndexError:
|
||||
generation = None
|
||||
if isinstance(generation, ChatGeneration):
|
||||
try:
|
||||
message = generation.message
|
||||
if isinstance(message, AIMessage):
|
||||
usage_metadata = message.usage_metadata
|
||||
else:
|
||||
usage_metadata = None
|
||||
except AttributeError:
|
||||
usage_metadata = None
|
||||
else:
|
||||
usage_metadata = None
|
||||
if usage_metadata:
|
||||
token_usage = {"total_tokens": usage_metadata["total_tokens"]}
|
||||
completion_tokens = usage_metadata["output_tokens"]
|
||||
prompt_tokens = usage_metadata["input_tokens"]
|
||||
|
||||
|
||||
else:
|
||||
if response.llm_output is None:
|
||||
return None
|
||||
|
||||
if "token_usage" not in response.llm_output:
|
||||
with self._lock:
|
||||
self.successful_requests += 1
|
||||
return None
|
||||
|
||||
# compute tokens and cost for this request
|
||||
token_usage = response.llm_output["token_usage"]
|
||||
completion_tokens = token_usage.get("completion_tokens", 0)
|
||||
prompt_tokens = token_usage.get("prompt_tokens", 0)
|
||||
if self.model_name in MODEL_COST_PER_1K_TOKENS_INPUT:
|
||||
completion_cost = get_token_cost_for_model(
|
||||
self.model_name, completion_tokens, is_completion=True
|
||||
)
|
||||
prompt_cost = get_token_cost_for_model(self.model_name, prompt_tokens)
|
||||
else:
|
||||
completion_cost = 0
|
||||
prompt_cost = 0
|
||||
|
||||
# update shared state behind lock
|
||||
with self._lock:
|
||||
self.total_cost += prompt_cost + completion_cost
|
||||
self.total_tokens += token_usage.get("total_tokens", 0)
|
||||
self.prompt_tokens += prompt_tokens
|
||||
self.completion_tokens += completion_tokens
|
||||
self.successful_requests += 1
|
||||
|
||||
def __copy__(self) -> "CustomCallbackHandler":
|
||||
"""Return a copy of the callback handler."""
|
||||
return self
|
||||
|
||||
def __deepcopy__(self, memo: Any) -> "CustomCallbackHandler":
|
||||
"""Return a deep copy of the callback handler."""
|
||||
return self
|
||||
|
||||
|
||||
custom_callback: ContextVar[Optional[CustomCallbackHandler]] = ContextVar(
|
||||
"custom_callback", default=None
|
||||
)
|
||||
register_configure_hook(custom_callback, True)
|
||||
|
||||
@contextmanager
|
||||
def get_custom_callback(llm_model_name: str):
|
||||
"""
|
||||
Function to get custom callback for LLM token usage statistics.
|
||||
"""
|
||||
cb = CustomCallbackHandler(llm_model_name)
|
||||
custom_callback.set(cb)
|
||||
yield cb
|
||||
custom_callback.set(None)
|
||||
@ -1,17 +0,0 @@
|
||||
import threading
|
||||
from contextlib import contextmanager
|
||||
from langchain_community.callbacks import get_openai_callback
|
||||
|
||||
class CustomOpenAiCallbackManager:
|
||||
_lock = threading.Lock()
|
||||
|
||||
@contextmanager
|
||||
def exclusive_get_openai_callback(self):
|
||||
if CustomOpenAiCallbackManager._lock.acquire(blocking=False):
|
||||
try:
|
||||
with get_openai_callback() as cb:
|
||||
yield cb
|
||||
finally:
|
||||
CustomOpenAiCallbackManager._lock.release()
|
||||
else:
|
||||
yield None
|
||||
38
scrapegraphai/utils/llm_callback_manager.py
Normal file
38
scrapegraphai/utils/llm_callback_manager.py
Normal file
@ -0,0 +1,38 @@
|
||||
"""
|
||||
This module provides a custom callback manager for the LLM models.
|
||||
"""
|
||||
import threading
|
||||
from contextlib import contextmanager
|
||||
from .custom_callback import get_custom_callback
|
||||
|
||||
from langchain_community.callbacks import get_openai_callback
|
||||
from langchain_community.callbacks.manager import get_bedrock_anthropic_callback
|
||||
from langchain_openai import ChatOpenAI, AzureChatOpenAI
|
||||
from langchain_aws import ChatBedrock
|
||||
|
||||
class CustomLLMCallbackManager:
|
||||
_lock = threading.Lock()
|
||||
|
||||
@contextmanager
|
||||
def exclusive_get_callback(self, llm_model, llm_model_name):
|
||||
if CustomLLMCallbackManager._lock.acquire(blocking=False):
|
||||
if isinstance(llm_model, ChatOpenAI) or isinstance(llm_model, AzureChatOpenAI):
|
||||
try:
|
||||
with get_openai_callback() as cb:
|
||||
yield cb
|
||||
finally:
|
||||
CustomLLMCallbackManager._lock.release()
|
||||
elif isinstance(llm_model, ChatBedrock) and llm_model_name is not None and "claude" in llm_model_name:
|
||||
try:
|
||||
with get_bedrock_anthropic_callback() as cb:
|
||||
yield cb
|
||||
finally:
|
||||
CustomLLMCallbackManager._lock.release()
|
||||
else:
|
||||
try:
|
||||
with get_custom_callback(llm_model_name) as cb:
|
||||
yield cb
|
||||
finally:
|
||||
CustomLLMCallbackManager._lock.release()
|
||||
else:
|
||||
yield None
|
||||
105
scrapegraphai/utils/model_costs.py
Normal file
105
scrapegraphai/utils/model_costs.py
Normal file
@ -0,0 +1,105 @@
|
||||
"""
|
||||
This file contains the cost of models per 1k tokens for input and output.
|
||||
The file is on a best effort basis and may not be up to date. Any contributions are welcome.
|
||||
"""
|
||||
MODEL_COST_PER_1K_TOKENS_INPUT = {
|
||||
### MistralAI
|
||||
# General Purpose
|
||||
"open-mistral-nemo": 0.00015,
|
||||
"open-mistral-nemo-2407": 0.00015,
|
||||
"mistral-large": 0.002,
|
||||
"mistral-large-2407": 0.002,
|
||||
"mistral-small": 0.0002,
|
||||
"mistral-small-2409": 0.0002,
|
||||
# Specialist Models
|
||||
"codestral": 0.0002,
|
||||
"codestral-2405": 0.0002,
|
||||
"pixtral-12b": 0.00015,
|
||||
"pixtral-12b-2409": 0.00015,
|
||||
# Legacy Models
|
||||
"open-mistral-7b": 0.00025,
|
||||
"open-mixtral-8x7b": 0.0007,
|
||||
"open-mixtral-8x22b": 0.002,
|
||||
"mistral-small-latest": 0.001,
|
||||
"mistral-medium-latest": 0.00275,
|
||||
|
||||
### Bedrock - not Claude
|
||||
#AI21 Labs
|
||||
"a121.ju-ultra-v1": 0.0188,
|
||||
"a121.ju-mid-v1": 0.0125,
|
||||
"ai21.jamba-instruct-v1:0": 0.0005,
|
||||
# Meta - LLama
|
||||
"meta.llama2-13b-chat-v1": 0.00075,
|
||||
"meta.llama2-70b-chat-v1": 0.00195,
|
||||
"meta.llama3-8b-instruct-v1:0": 0.0003,
|
||||
"meta.llama3-70b-instruct-v1:0": 0.00265,
|
||||
"meta.llama3-1-8b-instruct-v1:0": 0.00022,
|
||||
"meta.llama3-1-70b-instruct-v1:0": 0.00099,
|
||||
"meta.llama3-1-405b-instruct-v1:0": 0.00532,
|
||||
# Cohere - Command
|
||||
"cohere.command-text-v14": 0.0015,
|
||||
"cohere.command-light-text-v14": 0.0003,
|
||||
"cohere.command-r-v1:0": 0.0005,
|
||||
"cohere.command-r-plus-v1:0": 0.003,
|
||||
# Mistral
|
||||
"mistral.mistral-7b-instruct-v0:2": 0.00015,
|
||||
"mistral.mistral-large-2402-v1:0": 0.004,
|
||||
"mistral.mistral-large-2407-v1:0": 0.002,
|
||||
"mistral.mistral-small-2402-v1:0": 0.001,
|
||||
"mistral.mixtral-7x8b-instruct-v0:1": 0.00045,
|
||||
# Amazon - Titan
|
||||
"amazon.titan-text-express-v1": 0.0002,
|
||||
"amazon.titan-text-lite-v1": 0.00015,
|
||||
"amazon.titan-text-premier-v1:0": 0.0005,
|
||||
}
|
||||
|
||||
MODEL_COST_PER_1K_TOKENS_OUTPUT = {
|
||||
### MistralAI
|
||||
# General Purpose
|
||||
"open-mistral-nemo": 0.00015,
|
||||
"open-mistral-nemo-2407": 0.00015,
|
||||
"mistral-large": 0.002,
|
||||
"mistral-large-2407": 0.006,
|
||||
"mistral-small": 0.0002,
|
||||
"mistral-small-2409": 0.0006,
|
||||
# Specialist Models
|
||||
"codestral": 0.0006,
|
||||
"codestral-2405": 0.0006,
|
||||
"pixtral-12b": 0.00015,
|
||||
"pixtral-12b-2409": 0.0006,
|
||||
# Legacy Models
|
||||
"open-mistral-7b": 0.00025,
|
||||
"open-mixtral-8x7b": 0.0007,
|
||||
"open-mixtral-8x22b": 0.006,
|
||||
"mistral-small-latest": 0.003,
|
||||
"mistral-medium-latest": 0.0081,
|
||||
|
||||
### Bedrock - not Claude
|
||||
# AI21 Labs
|
||||
"a121.ju-ultra-v1": 0.0188,
|
||||
"a121.ju-mid-v1": 0.0125,
|
||||
"ai21.jamba-instruct-v1:0": 0.0007,
|
||||
# Meta - LLama
|
||||
"meta.llama2-13b-chat-v1": 0.001,
|
||||
"meta.llama2-70b-chat-v1": 0.00256,
|
||||
"meta.llama3-8b-instruct-v1:0": 0.0006,
|
||||
"meta.llama3-70b-instruct-v1:0": 0.0035,
|
||||
"meta.llama3-1-8b-instruct-v1:0": 0.00022,
|
||||
"meta.llama3-1-70b-instruct-v1:0": 0.00099,
|
||||
"meta.llama3-1-405b-instruct-v1:0": 0.016,
|
||||
# Cohere - Command
|
||||
"cohere.command-text-v14": 0.002,
|
||||
"cohere.command-light-text-v14": 0.0006,
|
||||
"cohere.command-r-v1:0": 0.0015,
|
||||
"cohere.command-r-plus-v1:0": 0.015,
|
||||
# Mistral
|
||||
"mistral.mistral-7b-instruct-v0:2": 0.0002,
|
||||
"mistral.mistral-large-2402-v1:0": 0.012,
|
||||
"mistral.mistral-large-2407-v1:0": 0.006,
|
||||
"mistral.mistral-small-2402-v1:0": 0.003,
|
||||
"mistral.mixtral-7x8b-instruct-v0:1": 0.0007,
|
||||
# Amazon - Titan
|
||||
"amazon.titan-text-express-v1": 0.0006,
|
||||
"amazon.titan-text-lite-v1": 0.0002,
|
||||
"amazon.titan-text-premier-v1:0": 0.0015,
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user