feat: refactoring of the tokenization function

This commit is contained in:
Marco Vinciguerra 2024-09-12 20:21:00 +02:00
parent 4ab26a24a3
commit ec6b164653
4 changed files with 53 additions and 13 deletions

View File

@ -71,6 +71,7 @@ cycler==0.12.1
dataclasses-json==0.6.7 dataclasses-json==0.6.7
# via langchain-community # via langchain-community
dill==0.3.8 dill==0.3.8
# via multiprocess
# via pylint # via pylint
distro==1.9.0 distro==1.9.0
# via openai # via openai
@ -87,6 +88,7 @@ fastapi-pagination==0.12.26
# via burr # via burr
filelock==3.15.4 filelock==3.15.4
# via huggingface-hub # via huggingface-hub
# via transformers
fonttools==4.53.1 fonttools==4.53.1
# via matplotlib # via matplotlib
free-proxy==1.1.1 free-proxy==1.1.1
@ -152,6 +154,7 @@ httpx-sse==0.4.0
# via langchain-mistralai # via langchain-mistralai
huggingface-hub==0.24.5 huggingface-hub==0.24.5
# via tokenizers # via tokenizers
# via transformers
idna==3.7 idna==3.7
# via anyio # via anyio
# via httpx # via httpx
@ -235,9 +238,13 @@ mdurl==0.1.2
# via markdown-it-py # via markdown-it-py
minify-html==0.15.0 minify-html==0.15.0
# via scrapegraphai # via scrapegraphai
mpire==2.10.2
# via semchunk
multidict==6.0.5 multidict==6.0.5
# via aiohttp # via aiohttp
# via yarl # via yarl
multiprocess==0.70.16
# via mpire
mypy-extensions==1.0.0 mypy-extensions==1.0.0
# via typing-inspect # via typing-inspect
narwhals==1.3.0 narwhals==1.3.0
@ -254,6 +261,7 @@ numpy==1.26.4
# via pydeck # via pydeck
# via sf-hamilton # via sf-hamilton
# via streamlit # via streamlit
# via transformers
ollama==0.3.2 ollama==0.3.2
# via langchain-ollama # via langchain-ollama
openai==1.40.3 openai==1.40.3
@ -271,6 +279,7 @@ packaging==24.1
# via pytest # via pytest
# via sphinx # via sphinx
# via streamlit # via streamlit
# via transformers
pandas==2.2.2 pandas==2.2.2
# via scrapegraphai # via scrapegraphai
# via sf-hamilton # via sf-hamilton
@ -320,6 +329,7 @@ pyee==11.1.0
# via playwright # via playwright
pygments==2.18.0 pygments==2.18.0
# via furo # via furo
# via mpire
# via rich # via rich
# via sphinx # via sphinx
pylint==3.2.6 pylint==3.2.6
@ -342,11 +352,13 @@ pyyaml==6.0.2
# via langchain # via langchain
# via langchain-community # via langchain-community
# via langchain-core # via langchain-core
# via transformers
referencing==0.35.1 referencing==0.35.1
# via jsonschema # via jsonschema
# via jsonschema-specifications # via jsonschema-specifications
regex==2024.7.24 regex==2024.7.24
# via tiktoken # via tiktoken
# via transformers
requests==2.32.3 requests==2.32.3
# via burr # via burr
# via free-proxy # via free-proxy
@ -358,6 +370,7 @@ requests==2.32.3
# via sphinx # via sphinx
# via streamlit # via streamlit
# via tiktoken # via tiktoken
# via transformers
rich==13.7.1 rich==13.7.1
# via streamlit # via streamlit
rpds-py==0.20.0 rpds-py==0.20.0
@ -367,6 +380,10 @@ rsa==4.9
# via google-auth # via google-auth
s3transfer==0.10.2 s3transfer==0.10.2
# via boto3 # via boto3
safetensors==0.4.5
# via transformers
semchunk==2.2.0
# via scrapegraphai
sf-hamilton==1.73.1 sf-hamilton==1.73.1
# via burr # via burr
six==1.16.0 six==1.16.0
@ -416,6 +433,7 @@ tiktoken==0.7.0
# via scrapegraphai # via scrapegraphai
tokenizers==0.19.1 tokenizers==0.19.1
# via langchain-mistralai # via langchain-mistralai
# via transformers
toml==0.10.2 toml==0.10.2
# via streamlit # via streamlit
tomli==2.0.1 tomli==2.0.1
@ -428,8 +446,13 @@ tornado==6.4.1
tqdm==4.66.5 tqdm==4.66.5
# via google-generativeai # via google-generativeai
# via huggingface-hub # via huggingface-hub
# via mpire
# via openai # via openai
# via scrapegraphai # via scrapegraphai
# via semchunk
# via transformers
transformers==4.44.2
# via scrapegraphai
typing-extensions==4.12.2 typing-extensions==4.12.2
# via altair # via altair
# via anyio # via anyio

View File

@ -41,6 +41,8 @@ charset-normalizer==3.3.2
# via requests # via requests
dataclasses-json==0.6.7 dataclasses-json==0.6.7
# via langchain-community # via langchain-community
dill==0.3.8
# via multiprocess
distro==1.9.0 distro==1.9.0
# via openai # via openai
exceptiongroup==1.2.2 exceptiongroup==1.2.2
@ -49,6 +51,7 @@ faiss-cpu==1.8.0.post1
# via scrapegraphai # via scrapegraphai
filelock==3.15.4 filelock==3.15.4
# via huggingface-hub # via huggingface-hub
# via transformers
free-proxy==1.1.1 free-proxy==1.1.1
# via scrapegraphai # via scrapegraphai
frozenlist==1.4.1 frozenlist==1.4.1
@ -103,6 +106,7 @@ httpx-sse==0.4.0
# via langchain-mistralai # via langchain-mistralai
huggingface-hub==0.24.1 huggingface-hub==0.24.1
# via tokenizers # via tokenizers
# via transformers
idna==3.7 idna==3.7
# via anyio # via anyio
# via httpx # via httpx
@ -153,9 +157,13 @@ marshmallow==3.21.3
# via dataclasses-json # via dataclasses-json
minify-html==0.15.0 minify-html==0.15.0
# via scrapegraphai # via scrapegraphai
mpire==2.10.2
# via semchunk
multidict==6.0.5 multidict==6.0.5
# via aiohttp # via aiohttp
# via yarl # via yarl
multiprocess==0.70.16
# via mpire
mypy-extensions==1.0.0 mypy-extensions==1.0.0
# via typing-inspect # via typing-inspect
numpy==1.26.4 numpy==1.26.4
@ -164,6 +172,7 @@ numpy==1.26.4
# via langchain-aws # via langchain-aws
# via langchain-community # via langchain-community
# via pandas # via pandas
# via transformers
ollama==0.3.2 ollama==0.3.2
# via langchain-ollama # via langchain-ollama
openai==1.41.0 openai==1.41.0
@ -175,6 +184,7 @@ packaging==24.1
# via huggingface-hub # via huggingface-hub
# via langchain-core # via langchain-core
# via marshmallow # via marshmallow
# via transformers
pandas==2.2.2 pandas==2.2.2
# via scrapegraphai # via scrapegraphai
playwright==1.45.1 playwright==1.45.1
@ -205,6 +215,8 @@ pydantic-core==2.20.1
# via pydantic # via pydantic
pyee==11.1.0 pyee==11.1.0
# via playwright # via playwright
pygments==2.18.0
# via mpire
pyparsing==3.1.2 pyparsing==3.1.2
# via httplib2 # via httplib2
python-dateutil==2.9.0.post0 python-dateutil==2.9.0.post0
@ -219,8 +231,10 @@ pyyaml==6.0.1
# via langchain # via langchain
# via langchain-community # via langchain-community
# via langchain-core # via langchain-core
# via transformers
regex==2024.5.15 regex==2024.5.15
# via tiktoken # via tiktoken
# via transformers
requests==2.32.3 requests==2.32.3
# via free-proxy # via free-proxy
# via google-api-core # via google-api-core
@ -229,10 +243,15 @@ requests==2.32.3
# via langchain-community # via langchain-community
# via langsmith # via langsmith
# via tiktoken # via tiktoken
# via transformers
rsa==4.9 rsa==4.9
# via google-auth # via google-auth
s3transfer==0.10.2 s3transfer==0.10.2
# via boto3 # via boto3
safetensors==0.4.5
# via transformers
semchunk==2.2.0
# via scrapegraphai
six==1.16.0 six==1.16.0
# via python-dateutil # via python-dateutil
sniffio==1.3.1 sniffio==1.3.1
@ -253,11 +272,17 @@ tiktoken==0.7.0
# via scrapegraphai # via scrapegraphai
tokenizers==0.19.1 tokenizers==0.19.1
# via langchain-mistralai # via langchain-mistralai
# via transformers
tqdm==4.66.4 tqdm==4.66.4
# via google-generativeai # via google-generativeai
# via huggingface-hub # via huggingface-hub
# via mpire
# via openai # via openai
# via scrapegraphai # via scrapegraphai
# via semchunk
# via transformers
transformers==4.44.2
# via scrapegraphai
typing-extensions==4.12.2 typing-extensions==4.12.2
# via anyio # via anyio
# via google-generativeai # via google-generativeai

View File

@ -23,7 +23,8 @@ def num_tokens_calculus(string: str, llm_model: BaseChatModel) -> int:
num_tokens_fn = num_tokens_ollama num_tokens_fn = num_tokens_ollama
else: else:
raise NotImplementedError(f"There is no tokenization implementation for model '{llm_model}'") from .tokenizers.tokenizer_openai import num_tokens_openai
num_tokens_fn = num_tokens_openai
num_tokens = num_tokens_fn(string, llm_model) num_tokens = num_tokens_fn(string, llm_model)
return num_tokens return num_tokens

View File

@ -21,17 +21,8 @@ def num_tokens_openai(text: str, llm_model:BaseChatModel) -> int:
logger = get_logger() logger = get_logger()
logger.debug(f"Counting tokens for text of {len(text)} characters") logger.debug(f"Counting tokens for text of {len(text)} characters")
try:
model = llm_model.model_name
except AttributeError:
raise NotImplementedError(f"The model provider you are using ('{llm_model}') "
"does not give us a model name so we cannot identify which encoding to use")
try: encoding = tiktoken.encoding_for_model("gpt-4")
encoding = tiktoken.encoding_for_model(model)
except KeyError:
raise NotImplementedError(f"Tiktoken does not support identifying the encoding for "
"the model '{model}'")
num_tokens = len(encoding.encode(text)) num_tokens = len(encoding.encode(text))
return num_tokens return num_tokens