Add in tests

This commit is contained in:
Vik Paruchuri 2024-12-19 11:15:03 -05:00
parent 2281aec8b9
commit cd795a71c0
8 changed files with 114 additions and 9 deletions

View File

@ -14,16 +14,10 @@ jobs:
uses: actions/setup-python@v4
with:
python-version: 3.11
- name: Install apt dependencies
run: |
sudo apt-get update
sudo apt-get install -y tesseract-ocr tesseract-ocr-eng
- name: Install python dependencies
run: |
pip install poetry
poetry install
poetry remove torch
poetry run pip install torch --index-url https://download.pytorch.org/whl/cpu
- name: Run detection benchmark test
run: |
poetry run python benchmark/detection.py --max 2

26
.github/workflows/ci.yml vendored Normal file
View File

@ -0,0 +1,26 @@
name: Integration test
on: [push]
env:
TORCH_DEVICE: "cpu"
jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.11
uses: actions/setup-python@v4
with:
python-version: 3.11
- name: Install apt dependencies
run: |
sudo apt-get update
sudo apt-get install -y tesseract-ocr tesseract-ocr-eng
- name: Install python dependencies
run: |
pip install poetry
poetry install
- name: Run tests
run: poetry run pytest

50
poetry.lock generated
View File

@ -1235,6 +1235,17 @@ files = [
[package.extras]
all = ["flake8 (>=7.1.1)", "mypy (>=1.11.2)", "pytest (>=8.3.2)", "ruff (>=0.6.2)"]
[[package]]
name = "iniconfig"
version = "2.0.0"
description = "brain-dead simple config-ini parsing"
optional = false
python-versions = ">=3.7"
files = [
{file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"},
{file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"},
]
[[package]]
name = "ipykernel"
version = "6.29.5"
@ -2692,6 +2703,21 @@ files = [
greenlet = "3.1.1"
pyee = "12.0.0"
[[package]]
name = "pluggy"
version = "1.5.0"
description = "plugin and hook calling mechanisms for python"
optional = false
python-versions = ">=3.8"
files = [
{file = "pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669"},
{file = "pluggy-1.5.0.tar.gz", hash = "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1"},
]
[package.extras]
dev = ["pre-commit", "tox"]
testing = ["pytest", "pytest-benchmark"]
[[package]]
name = "prometheus-client"
version = "0.21.1"
@ -3207,6 +3233,28 @@ files = [
packaging = ">=21.3"
Pillow = ">=8.0.0"
[[package]]
name = "pytest"
version = "8.3.4"
description = "pytest: simple powerful testing with Python"
optional = false
python-versions = ">=3.8"
files = [
{file = "pytest-8.3.4-py3-none-any.whl", hash = "sha256:50e16d954148559c9a74109af1eaf0c945ba2d8f30f0a3d3335edde19788b6f6"},
{file = "pytest-8.3.4.tar.gz", hash = "sha256:965370d062bce11e73868e0335abac31b4d3de0e82f4007408d242b4f8610761"},
]
[package.dependencies]
colorama = {version = "*", markers = "sys_platform == \"win32\""}
exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""}
iniconfig = "*"
packaging = "*"
pluggy = ">=1.5,<2"
tomli = {version = ">=1", markers = "python_version < \"3.11\""}
[package.extras]
dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"]
[[package]]
name = "python-dateutil"
version = "2.9.0.post0"
@ -4925,4 +4973,4 @@ propcache = ">=0.2.0"
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
content-hash = "e5dfcdc29e7912fe8cedfb2af75c0edfe5a911de4600890f75922f501b440046"
content-hash = "dd035c4c1f7634ad4fc809b9a11bad6d9c936eb2ce5c992830f68835f69fea12"

View File

@ -45,6 +45,7 @@ rapidfuzz = "^3.6.1"
arabic-reshaper = "^3.0.0"
streamlit = "^1.31.0"
playwright = "^1.41.2"
pytest = "^8.3.4"
[tool.poetry.scripts]
surya_detect = "detect_text:main"

7
pytest.ini Normal file
View File

@ -0,0 +1,7 @@
[pytest]
testpaths=tests
pythonpath=.
filterwarnings =
ignore::UserWarning
ignore::PendingDeprecationWarning
ignore::DeprecationWarning

View File

@ -1,8 +1,6 @@
from typing import List, Optional
import numpy as np
import pytesseract
from pytesseract import Output
from tqdm import tqdm
from surya.input.processing import slice_bboxes_from_image
@ -24,6 +22,7 @@ def surya_lang_to_tesseract(code: str) -> Optional[str]:
def tesseract_ocr(img, bboxes, lang: str):
import pytesseract
line_imgs = slice_bboxes_from_image(img, bboxes)
config = f'--tessdata-dir "{settings.TESSDATA_PREFIX}"'
lines = []
@ -50,6 +49,8 @@ def tesseract_ocr_parallel(imgs, bboxes, langs: List[str], cpus=None):
def tesseract_bboxes(img):
import pytesseract
from pytesseract import Output
arr_img = np.asarray(img, dtype=np.uint8)
ocr = pytesseract.image_to_data(arr_img, output_type=Output.DICT)

10
tests/conftest.py Normal file
View File

@ -0,0 +1,10 @@
import pytest
from surya.model.ocr_error.model import load_model as load_ocr_error_model, load_tokenizer as load_ocr_error_processor
@pytest.fixture(scope="session")
def ocr_error_model():
ocr_error_m = load_ocr_error_model()
ocr_error_p = load_ocr_error_processor()
ocr_error_m.processor = ocr_error_p
yield ocr_error_m
del ocr_error_m

18
tests/test_ocr_errors.py Normal file
View File

@ -0,0 +1,18 @@
from surya.ocr_error import batch_ocr_error_detection
def test_garbled_text(ocr_error_model):
text = """"
; dh vksj ls mifLFkr vf/koDrk % Jh vfuy dqekj
2. vfHk;qDr dh vksj ls mifLFkr vf/koDrk % Jh iznhi d
""".strip()
results = batch_ocr_error_detection([text], ocr_error_model, ocr_error_model.processor)
assert results.labels[0] == "bad"
def test_good_text(ocr_error_model):
text = """"
There are professions more harmful than industrial design, but only a very few of them.
""".strip()
results = batch_ocr_error_detection([text], ocr_error_model, ocr_error_model.processor)
assert results.labels[0] == "good"