mirror of
https://github.com/VinciGit00/Scrapegraph-ai.git
synced 2026-07-04 21:00:36 +08:00
fix: removed tokenizer
This commit is contained in:
parent
58b11334d3
commit
a18471688f
@ -6,7 +6,6 @@ from langchain_openai import ChatOpenAI
|
|||||||
from langchain_ollama import ChatOllama
|
from langchain_ollama import ChatOllama
|
||||||
from langchain_mistralai import ChatMistralAI
|
from langchain_mistralai import ChatMistralAI
|
||||||
from langchain_core.language_models.chat_models import BaseChatModel
|
from langchain_core.language_models.chat_models import BaseChatModel
|
||||||
from transformers import GPT2TokenizerFast
|
|
||||||
|
|
||||||
def num_tokens_calculus(string: str, llm_model: BaseChatModel) -> int:
|
def num_tokens_calculus(string: str, llm_model: BaseChatModel) -> int:
|
||||||
"""
|
"""
|
||||||
@ -24,13 +23,6 @@ def num_tokens_calculus(string: str, llm_model: BaseChatModel) -> int:
|
|||||||
from .tokenizers.tokenizer_ollama import num_tokens_ollama
|
from .tokenizers.tokenizer_ollama import num_tokens_ollama
|
||||||
num_tokens_fn = num_tokens_ollama
|
num_tokens_fn = num_tokens_ollama
|
||||||
|
|
||||||
elif isinstance(llm_model, GPT2TokenizerFast):
|
|
||||||
def num_tokens_gpt2(text: str, model: BaseChatModel) -> int:
|
|
||||||
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
|
|
||||||
tokens = tokenizer.encode(text)
|
|
||||||
return len(tokens)
|
|
||||||
num_tokens_fn = num_tokens_gpt2
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
from .tokenizers.tokenizer_openai import num_tokens_openai
|
from .tokenizers.tokenizer_openai import num_tokens_openai
|
||||||
num_tokens_fn = num_tokens_openai
|
num_tokens_fn = num_tokens_openai
|
||||||
|
|||||||
@ -3,7 +3,6 @@ Tokenization utilities for Ollama models
|
|||||||
"""
|
"""
|
||||||
from langchain_core.language_models.chat_models import BaseChatModel
|
from langchain_core.language_models.chat_models import BaseChatModel
|
||||||
from ..logging import get_logger
|
from ..logging import get_logger
|
||||||
from transformers import GPT2TokenizerFast
|
|
||||||
|
|
||||||
def num_tokens_ollama(text: str, llm_model:BaseChatModel) -> int:
|
def num_tokens_ollama(text: str, llm_model:BaseChatModel) -> int:
|
||||||
"""
|
"""
|
||||||
@ -22,12 +21,8 @@ def num_tokens_ollama(text: str, llm_model:BaseChatModel) -> int:
|
|||||||
|
|
||||||
logger.debug(f"Counting tokens for text of {len(text)} characters")
|
logger.debug(f"Counting tokens for text of {len(text)} characters")
|
||||||
|
|
||||||
if isinstance(llm_model, GPT2TokenizerFast):
|
|
||||||
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
|
|
||||||
tokens = tokenizer.encode(text)
|
|
||||||
return len(tokens)
|
|
||||||
|
|
||||||
# Use langchain token count implementation
|
# Use langchain token count implementation
|
||||||
# NB: https://github.com/ollama/ollama/issues/1716#issuecomment-2074265507
|
# NB: https://github.com/ollama/ollama/issues/1716#issuecomment-2074265507
|
||||||
tokens = llm_model.get_num_tokens(text)
|
tokens = llm_model.get_num_tokens(text)
|
||||||
return tokens
|
return tokens
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user