feat: refactoring of the tokenization function

This commit is contained in:
Marco Vinciguerra 2024-09-12 20:21:00 +02:00
parent 4ab26a24a3
commit ec6b164653
4 changed files with 53 additions and 13 deletions

View File

@ -71,6 +71,7 @@ cycler==0.12.1
dataclasses-json==0.6.7
# via langchain-community
dill==0.3.8
# via multiprocess
# via pylint
distro==1.9.0
# via openai
@ -87,6 +88,7 @@ fastapi-pagination==0.12.26
# via burr
filelock==3.15.4
# via huggingface-hub
# via transformers
fonttools==4.53.1
# via matplotlib
free-proxy==1.1.1
@ -152,6 +154,7 @@ httpx-sse==0.4.0
# via langchain-mistralai
huggingface-hub==0.24.5
# via tokenizers
# via transformers
idna==3.7
# via anyio
# via httpx
@ -235,9 +238,13 @@ mdurl==0.1.2
# via markdown-it-py
minify-html==0.15.0
# via scrapegraphai
mpire==2.10.2
# via semchunk
multidict==6.0.5
# via aiohttp
# via yarl
multiprocess==0.70.16
# via mpire
mypy-extensions==1.0.0
# via typing-inspect
narwhals==1.3.0
@ -254,6 +261,7 @@ numpy==1.26.4
# via pydeck
# via sf-hamilton
# via streamlit
# via transformers
ollama==0.3.2
# via langchain-ollama
openai==1.40.3
@ -271,6 +279,7 @@ packaging==24.1
# via pytest
# via sphinx
# via streamlit
# via transformers
pandas==2.2.2
# via scrapegraphai
# via sf-hamilton
@ -320,6 +329,7 @@ pyee==11.1.0
# via playwright
pygments==2.18.0
# via furo
# via mpire
# via rich
# via sphinx
pylint==3.2.6
@ -342,11 +352,13 @@ pyyaml==6.0.2
# via langchain
# via langchain-community
# via langchain-core
# via transformers
referencing==0.35.1
# via jsonschema
# via jsonschema-specifications
regex==2024.7.24
# via tiktoken
# via transformers
requests==2.32.3
# via burr
# via free-proxy
@ -358,6 +370,7 @@ requests==2.32.3
# via sphinx
# via streamlit
# via tiktoken
# via transformers
rich==13.7.1
# via streamlit
rpds-py==0.20.0
@ -367,6 +380,10 @@ rsa==4.9
# via google-auth
s3transfer==0.10.2
# via boto3
safetensors==0.4.5
# via transformers
semchunk==2.2.0
# via scrapegraphai
sf-hamilton==1.73.1
# via burr
six==1.16.0
@ -416,6 +433,7 @@ tiktoken==0.7.0
# via scrapegraphai
tokenizers==0.19.1
# via langchain-mistralai
# via transformers
toml==0.10.2
# via streamlit
tomli==2.0.1
@ -428,8 +446,13 @@ tornado==6.4.1
tqdm==4.66.5
# via google-generativeai
# via huggingface-hub
# via mpire
# via openai
# via scrapegraphai
# via semchunk
# via transformers
transformers==4.44.2
# via scrapegraphai
typing-extensions==4.12.2
# via altair
# via anyio

View File

@ -41,6 +41,8 @@ charset-normalizer==3.3.2
# via requests
dataclasses-json==0.6.7
# via langchain-community
dill==0.3.8
# via multiprocess
distro==1.9.0
# via openai
exceptiongroup==1.2.2
@ -49,6 +51,7 @@ faiss-cpu==1.8.0.post1
# via scrapegraphai
filelock==3.15.4
# via huggingface-hub
# via transformers
free-proxy==1.1.1
# via scrapegraphai
frozenlist==1.4.1
@ -103,6 +106,7 @@ httpx-sse==0.4.0
# via langchain-mistralai
huggingface-hub==0.24.1
# via tokenizers
# via transformers
idna==3.7
# via anyio
# via httpx
@ -153,9 +157,13 @@ marshmallow==3.21.3
# via dataclasses-json
minify-html==0.15.0
# via scrapegraphai
mpire==2.10.2
# via semchunk
multidict==6.0.5
# via aiohttp
# via yarl
multiprocess==0.70.16
# via mpire
mypy-extensions==1.0.0
# via typing-inspect
numpy==1.26.4
@ -164,6 +172,7 @@ numpy==1.26.4
# via langchain-aws
# via langchain-community
# via pandas
# via transformers
ollama==0.3.2
# via langchain-ollama
openai==1.41.0
@ -175,6 +184,7 @@ packaging==24.1
# via huggingface-hub
# via langchain-core
# via marshmallow
# via transformers
pandas==2.2.2
# via scrapegraphai
playwright==1.45.1
@ -205,6 +215,8 @@ pydantic-core==2.20.1
# via pydantic
pyee==11.1.0
# via playwright
pygments==2.18.0
# via mpire
pyparsing==3.1.2
# via httplib2
python-dateutil==2.9.0.post0
@ -219,8 +231,10 @@ pyyaml==6.0.1
# via langchain
# via langchain-community
# via langchain-core
# via transformers
regex==2024.5.15
# via tiktoken
# via transformers
requests==2.32.3
# via free-proxy
# via google-api-core
@ -229,10 +243,15 @@ requests==2.32.3
# via langchain-community
# via langsmith
# via tiktoken
# via transformers
rsa==4.9
# via google-auth
s3transfer==0.10.2
# via boto3
safetensors==0.4.5
# via transformers
semchunk==2.2.0
# via scrapegraphai
six==1.16.0
# via python-dateutil
sniffio==1.3.1
@ -253,11 +272,17 @@ tiktoken==0.7.0
# via scrapegraphai
tokenizers==0.19.1
# via langchain-mistralai
# via transformers
tqdm==4.66.4
# via google-generativeai
# via huggingface-hub
# via mpire
# via openai
# via scrapegraphai
# via semchunk
# via transformers
transformers==4.44.2
# via scrapegraphai
typing-extensions==4.12.2
# via anyio
# via google-generativeai

View File

@ -23,7 +23,8 @@ def num_tokens_calculus(string: str, llm_model: BaseChatModel) -> int:
num_tokens_fn = num_tokens_ollama
else:
raise NotImplementedError(f"There is no tokenization implementation for model '{llm_model}'")
from .tokenizers.tokenizer_openai import num_tokens_openai
num_tokens_fn = num_tokens_openai
num_tokens = num_tokens_fn(string, llm_model)
return num_tokens

View File

@ -21,17 +21,8 @@ def num_tokens_openai(text: str, llm_model:BaseChatModel) -> int:
logger = get_logger()
logger.debug(f"Counting tokens for text of {len(text)} characters")
try:
model = llm_model.model_name
except AttributeError:
raise NotImplementedError(f"The model provider you are using ('{llm_model}') "
"does not give us a model name so we cannot identify which encoding to use")
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
raise NotImplementedError(f"Tiktoken does not support identifying the encoding for "
"the model '{model}'")
encoding = tiktoken.encoding_for_model("gpt-4")
num_tokens = len(encoding.encode(text))
return num_tokens