diff --git a/requirements-dev.lock b/requirements-dev.lock index fd04d800..66a0ec32 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -71,6 +71,7 @@ cycler==0.12.1 dataclasses-json==0.6.7 # via langchain-community dill==0.3.8 + # via multiprocess # via pylint distro==1.9.0 # via openai @@ -87,6 +88,7 @@ fastapi-pagination==0.12.26 # via burr filelock==3.15.4 # via huggingface-hub + # via transformers fonttools==4.53.1 # via matplotlib free-proxy==1.1.1 @@ -152,6 +154,7 @@ httpx-sse==0.4.0 # via langchain-mistralai huggingface-hub==0.24.5 # via tokenizers + # via transformers idna==3.7 # via anyio # via httpx @@ -235,9 +238,13 @@ mdurl==0.1.2 # via markdown-it-py minify-html==0.15.0 # via scrapegraphai +mpire==2.10.2 + # via semchunk multidict==6.0.5 # via aiohttp # via yarl +multiprocess==0.70.16 + # via mpire mypy-extensions==1.0.0 # via typing-inspect narwhals==1.3.0 @@ -254,6 +261,7 @@ numpy==1.26.4 # via pydeck # via sf-hamilton # via streamlit + # via transformers ollama==0.3.2 # via langchain-ollama openai==1.40.3 @@ -271,6 +279,7 @@ packaging==24.1 # via pytest # via sphinx # via streamlit + # via transformers pandas==2.2.2 # via scrapegraphai # via sf-hamilton @@ -320,6 +329,7 @@ pyee==11.1.0 # via playwright pygments==2.18.0 # via furo + # via mpire # via rich # via sphinx pylint==3.2.6 @@ -342,11 +352,13 @@ pyyaml==6.0.2 # via langchain # via langchain-community # via langchain-core + # via transformers referencing==0.35.1 # via jsonschema # via jsonschema-specifications regex==2024.7.24 # via tiktoken + # via transformers requests==2.32.3 # via burr # via free-proxy @@ -358,6 +370,7 @@ requests==2.32.3 # via sphinx # via streamlit # via tiktoken + # via transformers rich==13.7.1 # via streamlit rpds-py==0.20.0 @@ -367,6 +380,10 @@ rsa==4.9 # via google-auth s3transfer==0.10.2 # via boto3 +safetensors==0.4.5 + # via transformers +semchunk==2.2.0 + # via scrapegraphai sf-hamilton==1.73.1 # via burr six==1.16.0 @@ -416,6 +433,7 @@ tiktoken==0.7.0 # via scrapegraphai tokenizers==0.19.1 # via langchain-mistralai + # via transformers toml==0.10.2 # via streamlit tomli==2.0.1 @@ -428,8 +446,13 @@ tornado==6.4.1 tqdm==4.66.5 # via google-generativeai # via huggingface-hub + # via mpire # via openai # via scrapegraphai + # via semchunk + # via transformers +transformers==4.44.2 + # via scrapegraphai typing-extensions==4.12.2 # via altair # via anyio diff --git a/requirements.lock b/requirements.lock index b34c9290..f29ac340 100644 --- a/requirements.lock +++ b/requirements.lock @@ -41,6 +41,8 @@ charset-normalizer==3.3.2 # via requests dataclasses-json==0.6.7 # via langchain-community +dill==0.3.8 + # via multiprocess distro==1.9.0 # via openai exceptiongroup==1.2.2 @@ -49,6 +51,7 @@ faiss-cpu==1.8.0.post1 # via scrapegraphai filelock==3.15.4 # via huggingface-hub + # via transformers free-proxy==1.1.1 # via scrapegraphai frozenlist==1.4.1 @@ -103,6 +106,7 @@ httpx-sse==0.4.0 # via langchain-mistralai huggingface-hub==0.24.1 # via tokenizers + # via transformers idna==3.7 # via anyio # via httpx @@ -153,9 +157,13 @@ marshmallow==3.21.3 # via dataclasses-json minify-html==0.15.0 # via scrapegraphai +mpire==2.10.2 + # via semchunk multidict==6.0.5 # via aiohttp # via yarl +multiprocess==0.70.16 + # via mpire mypy-extensions==1.0.0 # via typing-inspect numpy==1.26.4 @@ -164,6 +172,7 @@ numpy==1.26.4 # via langchain-aws # via langchain-community # via pandas + # via transformers ollama==0.3.2 # via langchain-ollama openai==1.41.0 @@ -175,6 +184,7 @@ packaging==24.1 # via huggingface-hub # via langchain-core # via marshmallow + # via transformers pandas==2.2.2 # via scrapegraphai playwright==1.45.1 @@ -205,6 +215,8 @@ pydantic-core==2.20.1 # via pydantic pyee==11.1.0 # via playwright +pygments==2.18.0 + # via mpire pyparsing==3.1.2 # via httplib2 python-dateutil==2.9.0.post0 @@ -219,8 +231,10 @@ pyyaml==6.0.1 # via langchain # via langchain-community # via langchain-core + # via transformers regex==2024.5.15 # via tiktoken + # via transformers requests==2.32.3 # via free-proxy # via google-api-core @@ -229,10 +243,15 @@ requests==2.32.3 # via langchain-community # via langsmith # via tiktoken + # via transformers rsa==4.9 # via google-auth s3transfer==0.10.2 # via boto3 +safetensors==0.4.5 + # via transformers +semchunk==2.2.0 + # via scrapegraphai six==1.16.0 # via python-dateutil sniffio==1.3.1 @@ -253,11 +272,17 @@ tiktoken==0.7.0 # via scrapegraphai tokenizers==0.19.1 # via langchain-mistralai + # via transformers tqdm==4.66.4 # via google-generativeai # via huggingface-hub + # via mpire # via openai # via scrapegraphai + # via semchunk + # via transformers +transformers==4.44.2 + # via scrapegraphai typing-extensions==4.12.2 # via anyio # via google-generativeai diff --git a/scrapegraphai/utils/tokenizer.py b/scrapegraphai/utils/tokenizer.py index 5ed94250..2e20a244 100644 --- a/scrapegraphai/utils/tokenizer.py +++ b/scrapegraphai/utils/tokenizer.py @@ -23,7 +23,8 @@ def num_tokens_calculus(string: str, llm_model: BaseChatModel) -> int: num_tokens_fn = num_tokens_ollama else: - raise NotImplementedError(f"There is no tokenization implementation for model '{llm_model}'") - + from .tokenizers.tokenizer_openai import num_tokens_openai + num_tokens_fn = num_tokens_openai + num_tokens = num_tokens_fn(string, llm_model) return num_tokens diff --git a/scrapegraphai/utils/tokenizers/tokenizer_openai.py b/scrapegraphai/utils/tokenizers/tokenizer_openai.py index ef70aa28..ede53905 100644 --- a/scrapegraphai/utils/tokenizers/tokenizer_openai.py +++ b/scrapegraphai/utils/tokenizers/tokenizer_openai.py @@ -21,17 +21,8 @@ def num_tokens_openai(text: str, llm_model:BaseChatModel) -> int: logger = get_logger() logger.debug(f"Counting tokens for text of {len(text)} characters") - try: - model = llm_model.model_name - except AttributeError: - raise NotImplementedError(f"The model provider you are using ('{llm_model}') " - "does not give us a model name so we cannot identify which encoding to use") - try: - encoding = tiktoken.encoding_for_model(model) - except KeyError: - raise NotImplementedError(f"Tiktoken does not support identifying the encoding for " - "the model '{model}'") - + encoding = tiktoken.encoding_for_model("gpt-4") + num_tokens = len(encoding.encode(text)) return num_tokens