Add in tesseract to benchmark

This commit is contained in:
Vik Paruchuri 2024-02-07 11:03:37 -08:00
parent 6371081b76
commit e0785d1ec6
10 changed files with 272 additions and 39 deletions

View File

@ -24,10 +24,14 @@ jobs:
poetry install
poetry remove torch
poetry run pip install torch --index-url https://download.pytorch.org/whl/cpu
- name: Run benchmark test
- name: Run detection benchmark test
run: |
poetry run python benchmark/detection.py --max 2
poetry run python scripts/verify_benchmark_scores.py results/benchmark/doclaynet_bench/results.json
poetry run python scripts/verify_benchmark_scores.py results/benchmark/det_bench/results.json --bench_type detection
- name: Run recognition benchmark test
run: |
poetry run python benchmark/recognition.py --max 2
poetry run python scripts/verify_benchmark_scores.py results/benchmark/rec_bench/results.json --bench_type recognition

View File

@ -1,9 +1,9 @@
# Surya
Surya is a multilingual document OCR toolkit. It can do:
Surya is for multilingual document OCR. It can do:
- Accurate line-level text detection in any language
- Text recognition in 90+ languages
- Accurate OCR in 90+ languages
- Line-level text detection in any language
- Table and chart detection (coming soon)
It works on a range of documents (see [usage](#usage) and [benchmarks](#benchmarks) for more details).
@ -39,23 +39,23 @@ Install with:
pip install surya-ocr
```
Model weights will automatically download the first time you run surya.
Model weights will automatically download the first time you run surya. Note that this does not work with the latest version of transformers `4.37+` [yet](https://github.com/huggingface/transformers/issues/28846#issuecomment-1926109135), so you will need to keep `4.36.2`, which is installed with surya.
# Usage
- Inspect the settings in `surya/settings.py`. You can override any settings with environment variables.
- Your torch device will be automatically detected, but you can override this. For example, `TORCH_DEVICE=cuda`. Note that the `mps` device has a bug (on the [Apple side](https://github.com/pytorch/pytorch/issues/84936)) that may prevent it from working properly.
- Your torch device will be automatically detected, but you can override this. For example, `TORCH_DEVICE=cuda`. For text detection, the `mps` device has a bug (on the [Apple side](https://github.com/pytorch/pytorch/issues/84936)) that may prevent it from working properly.
## OCR (text recognition)
You can detect text lines in an image, pdf, or folder of images/pdfs with the following command. This will write out a json file with the detected text and bboxes, and optionally save images of the reconstructed page.
You can detect text in an image, pdf, or folder of images/pdfs with the following command. This will write out a json file with the detected text and bboxes, and optionally save images of the reconstructed page.
```
surya_ocr DATA_PATH --images --langs hi,en
```
- `DATA_PATH` can be an image, pdf, or folder of images/pdfs
- `--langs` specifies the language(s) to use for OCR. You can comma separate multiple languages. Use the language name or two-letter ISO code from [here](https://en.wikipedia.org/wiki/List_of_ISO_639_language_codes). Surya supports the 90+ languages found in `surya/languages.py`.
- `--langs` specifies the language(s) to use for OCR. You can comma separate multiple languages (I don't recommend using more than `4`). Use the language name or two-letter ISO code from [here](https://en.wikipedia.org/wiki/List_of_ISO_639_language_codes). Surya supports the 90+ languages found in `surya/languages.py`.
- `--lang_file` if you want to use a different language for different PDFs/images, you can specify languages here. The format is a JSON dict with the keys being filenames and the values as a list, like `{"file1.pdf": ["en", "hi"], "file2.pdf": ["en"]}`.
- `--images` will save images of the pages and detected text lines (optional)
- `--results_dir` specifies the directory to save results to instead of the default
@ -158,15 +158,17 @@ If you want to develop surya, you can install it manually:
- `git clone https://github.com/VikParuchuri/surya.git`
- `cd surya`
- `poetry install` # Installs main and dev dependencies
- `poetry install` - installs main and dev dependencies
- `poetry shell` - activates the virtual environment
# Limitations
- Math will not be detected well with the main model. Use `DETECTOR_MODEL_CHECKPOINT=vikp/line_detector_math` for better results.
- This is specialized for document OCR. It will likely not work on photos or other images.
- It is for printed text, not handwriting.
- The model has trained itself to ignore advertisements.
- You can find language support for OCR in `surya/languages.py`. Text detection should work with any language.
- Math will not be detected well with the main detector model. Use `DETECTOR_MODEL_CHECKPOINT=vikp/line_detector_math` for better results.
# Benchmarks
@ -207,7 +209,7 @@ Then we calculate precision and recall for the whole dataset.
You can benchmark the performance of surya on your machine.
- Follow the manual install instructions above.
- `poetry install --group dev` # Installs dev dependencies
- `poetry install --group dev` - installs dev dependencies
**Text line detection**
@ -222,10 +224,23 @@ python benchmark/detection.py --max 256
- `--pdf_path` will let you specify a pdf to benchmark instead of the default data
- `--results_dir` will let you specify a directory to save results to instead of the default one
**Text recognition**
This will evaluate surya and optionally tesseract on multilingual pdfs from common crawl.
```
python benchmark/recognition.py --max 256
```
- `--max` controls how many images to process for the benchmark
- `--debug` will render images with detected text
- `--results_dir` will let you specify a directory to save results to instead of the default one
- `--tesseract` will run the benchmark with tesseract. You have to run `sudo apt-get install tesseract-ocr-all` to install all tesseract data, and set `TESSDATA_PREFIX` to the path to the tesseract data folder.
# Training
The text detection was trained on 4x A6000s for about 3 days. It used a diverse set of images as training data. It was trained from scratch using a modified segformer architecture that reduces inference RAM requirements.
Text detection was trained on 4x A6000s for 3 days. It used a diverse set of images as training data. It was trained from scratch using a modified segformer architecture that reduces inference RAM requirements.
Text recognition was trained on 4x A6000s for 2 weeks. It was trained using a modified donut model (GQA, MoE layer, UTF-16 decoding, layer config changes).

View File

@ -42,7 +42,7 @@ def main():
image_sizes = [img.size for img in images]
correct_boxes = get_pdf_lines(args.pdf_path, image_sizes)
else:
pathname = "doclaynet_bench"
pathname = "det_bench"
# These have already been shuffled randomly, so sampling from the start is fine
dataset = datasets.load_dataset(settings.DETECTOR_BENCH_DATASET_NAME, split=f"train[:{args.max}]")
images = list(dataset["image"])

View File

@ -4,14 +4,18 @@ from collections import defaultdict
from benchmark.scoring import overlap_score
from surya.model.recognition.model import load_model as load_recognition_model
from surya.model.recognition.processor import load_processor as load_recognition_processor
from surya.ocr import run_ocr, run_recognition
from surya.ocr import run_recognition
from surya.postprocessing.text import draw_text_on_image
from surya.settings import settings
from surya.languages import CODE_TO_LANGUAGE, is_arabic
import arabic_reshaper
from surya.languages import CODE_TO_LANGUAGE
from surya.benchmark.tesseract import tesseract_ocr_parallel, surya_lang_to_tesseract, TESS_CODE_TO_LANGUAGE
import os
import datasets
import json
import time
from tabulate import tabulate
KEY_LANGUAGES = ["Chinese", "Spanish", "English", "Arabic", "Hindi", "Bengali", "Russian", "Japanese"]
def main():
@ -19,6 +23,7 @@ def main():
parser.add_argument("--results_dir", type=str, help="Path to JSON file with OCR results.", default=os.path.join(settings.RESULT_DIR, "benchmark"))
parser.add_argument("--max", type=int, help="Maximum number of pdf pages to OCR.", default=None)
parser.add_argument("--debug", action="store_true", help="Run in debug mode.", default=False)
parser.add_argument("--tesseract", action="store_true", help="Run tesseract instead of surya.", default=False)
args = parser.parse_args()
rec_model = load_recognition_model()
@ -44,25 +49,74 @@ def main():
else:
lang_list.append(l)
start = time.time()
predictions_by_image = run_recognition(images, lang_list, rec_model, rec_processor, bboxes=bboxes)
surya_time = time.time() - start
image_scores = defaultdict(list)
surya_scores = defaultdict(list)
for idx, (pred, ref_text, lang) in enumerate(zip(predictions_by_image, line_text, lang_list)):
if any(is_arabic(l) for l in lang):
ref_text = [arabic_reshaper.reshape(t) for t in ref_text]
pred["text_lines"] = [arabic_reshaper.reshape(t) for t in pred["text_lines"]]
image_score = overlap_score(pred["text_lines"], ref_text)
for l in lang:
image_scores[CODE_TO_LANGUAGE[l]].append(image_score)
surya_scores[CODE_TO_LANGUAGE[l]].append(image_score)
image_avgs = {l: sum(scores) / len(scores) for l, scores in image_scores.items()}
print(image_avgs)
flat_surya_scores = [s for l in surya_scores for s in surya_scores[l]]
benchmark_stats = {
"surya": {
"avg_score": sum(flat_surya_scores) / len(flat_surya_scores),
"lang_scores": {l: sum(scores) / len(scores) for l, scores in surya_scores.items()},
"time_per_img": surya_time / len(images)
}
}
if args.tesseract:
tess_valid = []
tess_langs = []
for idx, lang in enumerate(lang_list):
# Tesseract does not support all languages
tess_lang = surya_lang_to_tesseract(lang[0])
if tess_lang is None:
continue
tess_valid.append(idx)
tess_langs.append(tess_lang)
tess_imgs = [images[i] for i in tess_valid]
tess_bboxes = [bboxes[i] for i in tess_valid]
tess_reference = [line_text[i] for i in tess_valid]
start = time.time()
tess_predictions = tesseract_ocr_parallel(tess_imgs, tess_bboxes, tess_langs)
tesseract_time = time.time() - start
tess_scores = defaultdict(list)
for idx, (pred, ref_text, lang) in enumerate(zip(tess_predictions, tess_reference, tess_langs)):
image_score = overlap_score(pred, ref_text)
tess_scores[TESS_CODE_TO_LANGUAGE[lang]].append(image_score)
flat_tess_scores = [s for l in tess_scores for s in tess_scores[l]]
benchmark_stats["tesseract"] = {
"avg_score": sum(flat_tess_scores) / len(flat_tess_scores),
"lang_scores": {l: sum(scores) / len(scores) for l, scores in tess_scores.items()},
"time_per_img": tesseract_time / len(tess_imgs)
}
result_path = os.path.join(args.results_dir, "rec_bench")
os.makedirs(result_path, exist_ok=True)
with open(os.path.join(result_path, "results.json"), "w+") as f:
json.dump(image_scores, f)
json.dump(benchmark_stats, f)
key_languages = [k for k in KEY_LANGUAGES if k in surya_scores]
table_headers = ["Model", "Time per page (s)", "Avg Score"] + KEY_LANGUAGES
table_data = [
["surya", benchmark_stats["surya"]["time_per_img"], benchmark_stats["surya"]["avg_score"]] + [benchmark_stats["surya"]["lang_scores"][l] for l in key_languages],
]
if args.tesseract:
table_data.append(
["tesseract", benchmark_stats["tesseract"]["time_per_img"], benchmark_stats["tesseract"]["avg_score"]] + [benchmark_stats["tesseract"]["lang_scores"].get(l, 0) for l in key_languages]
)
print(tabulate(table_data, headers=table_headers, tablefmt="github"))
print("Only a few major languages are displayed. See the result path for additional languages.")
if args.debug:
for idx, (image, pred, ref_text, bbox, lang) in enumerate(zip(images, predictions_by_image, line_text, bboxes, lang_list)):

View File

@ -1,8 +1,10 @@
import math
from typing import List
from rapidfuzz import fuzz
def overlap_score(pred_lines, reference_lines):
def overlap_score(pred_lines: List[str], reference_lines: List[str]):
line_scores = []
line_weights = []
for i, pred_line in enumerate(pred_lines):

View File

@ -2,19 +2,33 @@ import json
import argparse
def verify_scores(file_path):
def verify_det(data):
scores = data["metrics"]["surya"]
if scores["precision"] <= 0.9 or scores["recall"] <= 0.9:
raise ValueError("Scores do not meet the required threshold")
def verify_rec(data):
scores = data["surya"]
if scores["avg_score"] <= 0.9:
raise ValueError("Scores do not meet the required threshold")
def verify_scores(file_path, bench_type):
with open(file_path, 'r') as file:
data = json.load(file)
scores = data["metrics"]["surya"]
if scores["precision"] <= 0.9 or scores["recall"] <= 0.9:
print(scores)
raise ValueError("Scores do not meet the required threshold")
if bench_type == "detection":
verify_det(data)
elif bench_type == "recognition":
verify_rec(data)
else:
raise ValueError("Invalid benchmark type")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Verify benchmark scores")
parser.add_argument("file_path", type=str, help="Path to the json file")
parser.add_argument("--bench_type", type=str, help="Type of benchmark to verify", default="detection")
args = parser.parse_args()
verify_scores(args.file_path)
verify_scores(args.file_path, args.bench_type)

View File

@ -1,10 +1,49 @@
from typing import List, Optional
import numpy as np
import pytesseract
from pytesseract import Output
from surya.input.processing import slice_bboxes_from_image
from surya.settings import settings
import os
from concurrent.futures import ProcessPoolExecutor
from surya.detection import get_batch_size
from surya.detection import get_batch_size as get_det_batch_size
from surya.recognition import get_batch_size as get_rec_batch_size
from surya.languages import CODE_TO_LANGUAGE
def surya_lang_to_tesseract(code: str) -> Optional[str]:
lang_str = CODE_TO_LANGUAGE[code]
try:
tess_lang = TESS_LANGUAGE_TO_CODE[lang_str]
except KeyError:
return None
return tess_lang
def tesseract_ocr(img, bboxes, lang: str):
line_imgs = slice_bboxes_from_image(img, bboxes)
config = f'--tessdata-dir "{settings.TESSDATA_PREFIX}"'
lines = []
for line_img in line_imgs:
line = pytesseract.image_to_string(line_img, lang=lang, config=config)
lines.append(line)
return lines
def tesseract_ocr_parallel(imgs, bboxes, langs: List[str]):
tess_parallel_cores = min(len(imgs), get_rec_batch_size())
cpus = os.cpu_count()
tess_parallel_cores = min(tess_parallel_cores, cpus)
# Tesseract uses 4 threads per instance
tess_parallel = max(tess_parallel_cores // 4, 1)
with ProcessPoolExecutor(max_workers=tess_parallel) as executor:
tess_text = executor.map(tesseract_ocr, imgs, bboxes, langs)
tess_text = list(tess_text)
return tess_text
def tesseract_bboxes(img):
@ -24,7 +63,7 @@ def tesseract_bboxes(img):
def tesseract_parallel(imgs):
# Tesseract uses 4 threads per instance
tess_parallel_cores = min(len(imgs), get_batch_size())
tess_parallel_cores = min(len(imgs), get_det_batch_size())
cpus = os.cpu_count()
tess_parallel_cores = min(tess_parallel_cores, cpus)
@ -34,4 +73,106 @@ def tesseract_parallel(imgs):
with ProcessPoolExecutor(max_workers=tess_parallel) as executor:
tess_bboxes = executor.map(tesseract_bboxes, imgs)
tess_bboxes = list(tess_bboxes)
return tess_bboxes
return tess_bboxes
TESS_CODE_TO_LANGUAGE = {
"afr": "Afrikaans",
"amh": "Amharic",
"ara": "Arabic",
"asm": "Assamese",
"aze": "Azerbaijani",
"bel": "Belarusian",
"ben": "Bengali",
"bod": "Tibetan",
"bos": "Bosnian",
"bre": "Breton",
"bul": "Bulgarian",
"cat": "Catalan",
"ceb": "Cebuano",
"ces": "Czech",
"chi_sim": "Chinese",
"chr": "Cherokee",
"cym": "Welsh",
"dan": "Danish",
"deu": "German",
"dzo": "Dzongkha",
"ell": "Greek",
"eng": "English",
"epo": "Esperanto",
"est": "Estonian",
"eus": "Basque",
"fas": "Persian",
"fin": "Finnish",
"fra": "French",
"fry": "Western Frisian",
"guj": "Gujarati",
"gla": "Scottish Gaelic",
"gle": "Irish",
"glg": "Galician",
"heb": "Hebrew",
"hin": "Hindi",
"hrv": "Croatian",
"hun": "Hungarian",
"hye": "Armenian",
"iku": "Inuktitut",
"ind": "Indonesian",
"isl": "Icelandic",
"ita": "Italian",
"jav": "Javanese",
"jpn": "Japanese",
"kan": "Kannada",
"kat": "Georgian",
"kaz": "Kazakh",
"khm": "Khmer",
"kir": "Kyrgyz",
"kor": "Korean",
"kur": "Kurdish",
"lao": "Lao",
"lat": "Latin",
"lav": "Latvian",
"lit": "Lithuanian",
"mal": "Malayalam",
"mar": "Marathi",
"mkd": "Macedonian",
"mlt": "Maltese",
"mon": "Mongolian",
"msa": "Malay",
"mya": "Burmese",
"nep": "Nepali",
"nld": "Dutch",
"nor": "Norwegian",
"ori": "Oriya",
"pan": "Punjabi",
"pol": "Polish",
"por": "Portuguese",
"pus": "Pashto",
"ron": "Romanian",
"rus": "Russian",
"san": "Sanskrit",
"sin": "Sinhala",
"slk": "Slovak",
"slv": "Slovenian",
"snd": "Sindhi",
"spa": "Spanish",
"sqi": "Albanian",
"srp": "Serbian",
"swa": "Swahili",
"swe": "Swedish",
"syr": "Syriac",
"tam": "Tamil",
"tel": "Telugu",
"tgk": "Tajik",
"tgl": "Tagalog",
"tha": "Thai",
"tir": "Tigrinya",
"tur": "Turkish",
"uig": "Uyghur",
"ukr": "Ukrainian",
"urd": "Urdu",
"uzb": "Uzbek",
"vie": "Vietnamese",
"yid": "Yiddish"
}
TESS_LANGUAGE_TO_CODE = {v:k for k,v in TESS_CODE_TO_LANGUAGE.items()}

View File

@ -19,8 +19,8 @@ def load_pdf(pdf_path, max_pages=None, start_page=None):
start_page = 0
if max_pages:
assert max_pages + start_page <= last_page and max_pages >= 0, f"Max pages must be between 0 and {last_page}"
last_page = start_page + max_pages
assert max_pages >= 0, f"Max pages must be greater than 0"
last_page = min(start_page + max_pages, last_page)
page_indices = list(range(start_page, last_page))
images = get_page_images(doc, page_indices)

View File

@ -61,7 +61,7 @@ CODE_TO_LANGUAGE = {
'nl': 'Dutch',
'no': 'Norwegian',
'om': 'Oromo',
'or': 'Odia',
'or': 'Oriya',
'pa': 'Punjabi',
'pl': 'Polish',
'ps': 'Pashto',

View File

@ -59,6 +59,9 @@ class Settings(BaseSettings):
RECOGNITION_FONT_DL_PATH: str = "https://github.com/satbyy/go-noto-universal/releases/download/v7.0/GoNotoKurrent-Regular.ttf"
RECOGNITION_BENCH_DATASET_NAME: str = "vikp/rec_bench"
# Tesseract (for benchmarks only)
TESSDATA_PREFIX: Optional[str] = None
@computed_field
@property
def MODEL_DTYPE(self) -> torch.dtype: