diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 1f0bd36..174b4a3 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -24,10 +24,14 @@ jobs: poetry install poetry remove torch poetry run pip install torch --index-url https://download.pytorch.org/whl/cpu - - name: Run benchmark test + - name: Run detection benchmark test run: | poetry run python benchmark/detection.py --max 2 - poetry run python scripts/verify_benchmark_scores.py results/benchmark/doclaynet_bench/results.json + poetry run python scripts/verify_benchmark_scores.py results/benchmark/det_bench/results.json --bench_type detection + - name: Run recognition benchmark test + run: | + poetry run python benchmark/recognition.py --max 2 + poetry run python scripts/verify_benchmark_scores.py results/benchmark/rec_bench/results.json --bench_type recognition diff --git a/README.md b/README.md index 0a69f3c..03ef6ba 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,9 @@ # Surya -Surya is a multilingual document OCR toolkit. It can do: +Surya is for multilingual document OCR. It can do: -- Accurate line-level text detection in any language -- Text recognition in 90+ languages +- Accurate OCR in 90+ languages +- Line-level text detection in any language - Table and chart detection (coming soon) It works on a range of documents (see [usage](#usage) and [benchmarks](#benchmarks) for more details). @@ -39,23 +39,23 @@ Install with: pip install surya-ocr ``` -Model weights will automatically download the first time you run surya. +Model weights will automatically download the first time you run surya. Note that this does not work with the latest version of transformers `4.37+` [yet](https://github.com/huggingface/transformers/issues/28846#issuecomment-1926109135), so you will need to keep `4.36.2`, which is installed with surya. # Usage - Inspect the settings in `surya/settings.py`. You can override any settings with environment variables. -- Your torch device will be automatically detected, but you can override this. For example, `TORCH_DEVICE=cuda`. Note that the `mps` device has a bug (on the [Apple side](https://github.com/pytorch/pytorch/issues/84936)) that may prevent it from working properly. +- Your torch device will be automatically detected, but you can override this. For example, `TORCH_DEVICE=cuda`. For text detection, the `mps` device has a bug (on the [Apple side](https://github.com/pytorch/pytorch/issues/84936)) that may prevent it from working properly. ## OCR (text recognition) -You can detect text lines in an image, pdf, or folder of images/pdfs with the following command. This will write out a json file with the detected text and bboxes, and optionally save images of the reconstructed page. +You can detect text in an image, pdf, or folder of images/pdfs with the following command. This will write out a json file with the detected text and bboxes, and optionally save images of the reconstructed page. ``` surya_ocr DATA_PATH --images --langs hi,en ``` - `DATA_PATH` can be an image, pdf, or folder of images/pdfs -- `--langs` specifies the language(s) to use for OCR. You can comma separate multiple languages. Use the language name or two-letter ISO code from [here](https://en.wikipedia.org/wiki/List_of_ISO_639_language_codes). Surya supports the 90+ languages found in `surya/languages.py`. +- `--langs` specifies the language(s) to use for OCR. You can comma separate multiple languages (I don't recommend using more than `4`). Use the language name or two-letter ISO code from [here](https://en.wikipedia.org/wiki/List_of_ISO_639_language_codes). Surya supports the 90+ languages found in `surya/languages.py`. - `--lang_file` if you want to use a different language for different PDFs/images, you can specify languages here. The format is a JSON dict with the keys being filenames and the values as a list, like `{"file1.pdf": ["en", "hi"], "file2.pdf": ["en"]}`. - `--images` will save images of the pages and detected text lines (optional) - `--results_dir` specifies the directory to save results to instead of the default @@ -158,15 +158,17 @@ If you want to develop surya, you can install it manually: - `git clone https://github.com/VikParuchuri/surya.git` - `cd surya` -- `poetry install` # Installs main and dev dependencies +- `poetry install` - installs main and dev dependencies +- `poetry shell` - activates the virtual environment # Limitations -- Math will not be detected well with the main model. Use `DETECTOR_MODEL_CHECKPOINT=vikp/line_detector_math` for better results. - This is specialized for document OCR. It will likely not work on photos or other images. - It is for printed text, not handwriting. - The model has trained itself to ignore advertisements. - You can find language support for OCR in `surya/languages.py`. Text detection should work with any language. +- Math will not be detected well with the main detector model. Use `DETECTOR_MODEL_CHECKPOINT=vikp/line_detector_math` for better results. + # Benchmarks @@ -207,7 +209,7 @@ Then we calculate precision and recall for the whole dataset. You can benchmark the performance of surya on your machine. - Follow the manual install instructions above. -- `poetry install --group dev` # Installs dev dependencies +- `poetry install --group dev` - installs dev dependencies **Text line detection** @@ -222,10 +224,23 @@ python benchmark/detection.py --max 256 - `--pdf_path` will let you specify a pdf to benchmark instead of the default data - `--results_dir` will let you specify a directory to save results to instead of the default one +**Text recognition** + +This will evaluate surya and optionally tesseract on multilingual pdfs from common crawl. + +``` +python benchmark/recognition.py --max 256 +``` + +- `--max` controls how many images to process for the benchmark +- `--debug` will render images with detected text +- `--results_dir` will let you specify a directory to save results to instead of the default one +- `--tesseract` will run the benchmark with tesseract. You have to run `sudo apt-get install tesseract-ocr-all` to install all tesseract data, and set `TESSDATA_PREFIX` to the path to the tesseract data folder. + # Training -The text detection was trained on 4x A6000s for about 3 days. It used a diverse set of images as training data. It was trained from scratch using a modified segformer architecture that reduces inference RAM requirements. +Text detection was trained on 4x A6000s for 3 days. It used a diverse set of images as training data. It was trained from scratch using a modified segformer architecture that reduces inference RAM requirements. Text recognition was trained on 4x A6000s for 2 weeks. It was trained using a modified donut model (GQA, MoE layer, UTF-16 decoding, layer config changes). diff --git a/benchmark/detection.py b/benchmark/detection.py index 69eea9f..af3b203 100644 --- a/benchmark/detection.py +++ b/benchmark/detection.py @@ -42,7 +42,7 @@ def main(): image_sizes = [img.size for img in images] correct_boxes = get_pdf_lines(args.pdf_path, image_sizes) else: - pathname = "doclaynet_bench" + pathname = "det_bench" # These have already been shuffled randomly, so sampling from the start is fine dataset = datasets.load_dataset(settings.DETECTOR_BENCH_DATASET_NAME, split=f"train[:{args.max}]") images = list(dataset["image"]) diff --git a/benchmark/recognition.py b/benchmark/recognition.py index e5ee9d2..8b92cd8 100644 --- a/benchmark/recognition.py +++ b/benchmark/recognition.py @@ -4,14 +4,18 @@ from collections import defaultdict from benchmark.scoring import overlap_score from surya.model.recognition.model import load_model as load_recognition_model from surya.model.recognition.processor import load_processor as load_recognition_processor -from surya.ocr import run_ocr, run_recognition +from surya.ocr import run_recognition from surya.postprocessing.text import draw_text_on_image from surya.settings import settings -from surya.languages import CODE_TO_LANGUAGE, is_arabic -import arabic_reshaper +from surya.languages import CODE_TO_LANGUAGE +from surya.benchmark.tesseract import tesseract_ocr_parallel, surya_lang_to_tesseract, TESS_CODE_TO_LANGUAGE import os import datasets import json +import time +from tabulate import tabulate + +KEY_LANGUAGES = ["Chinese", "Spanish", "English", "Arabic", "Hindi", "Bengali", "Russian", "Japanese"] def main(): @@ -19,6 +23,7 @@ def main(): parser.add_argument("--results_dir", type=str, help="Path to JSON file with OCR results.", default=os.path.join(settings.RESULT_DIR, "benchmark")) parser.add_argument("--max", type=int, help="Maximum number of pdf pages to OCR.", default=None) parser.add_argument("--debug", action="store_true", help="Run in debug mode.", default=False) + parser.add_argument("--tesseract", action="store_true", help="Run tesseract instead of surya.", default=False) args = parser.parse_args() rec_model = load_recognition_model() @@ -44,25 +49,74 @@ def main(): else: lang_list.append(l) + start = time.time() predictions_by_image = run_recognition(images, lang_list, rec_model, rec_processor, bboxes=bboxes) + surya_time = time.time() - start - image_scores = defaultdict(list) + surya_scores = defaultdict(list) for idx, (pred, ref_text, lang) in enumerate(zip(predictions_by_image, line_text, lang_list)): - if any(is_arabic(l) for l in lang): - ref_text = [arabic_reshaper.reshape(t) for t in ref_text] - pred["text_lines"] = [arabic_reshaper.reshape(t) for t in pred["text_lines"]] image_score = overlap_score(pred["text_lines"], ref_text) for l in lang: - image_scores[CODE_TO_LANGUAGE[l]].append(image_score) + surya_scores[CODE_TO_LANGUAGE[l]].append(image_score) - image_avgs = {l: sum(scores) / len(scores) for l, scores in image_scores.items()} - print(image_avgs) + flat_surya_scores = [s for l in surya_scores for s in surya_scores[l]] + benchmark_stats = { + "surya": { + "avg_score": sum(flat_surya_scores) / len(flat_surya_scores), + "lang_scores": {l: sum(scores) / len(scores) for l, scores in surya_scores.items()}, + "time_per_img": surya_time / len(images) + } + } + + if args.tesseract: + tess_valid = [] + tess_langs = [] + for idx, lang in enumerate(lang_list): + # Tesseract does not support all languages + tess_lang = surya_lang_to_tesseract(lang[0]) + if tess_lang is None: + continue + + tess_valid.append(idx) + tess_langs.append(tess_lang) + + tess_imgs = [images[i] for i in tess_valid] + tess_bboxes = [bboxes[i] for i in tess_valid] + tess_reference = [line_text[i] for i in tess_valid] + start = time.time() + tess_predictions = tesseract_ocr_parallel(tess_imgs, tess_bboxes, tess_langs) + tesseract_time = time.time() - start + + tess_scores = defaultdict(list) + for idx, (pred, ref_text, lang) in enumerate(zip(tess_predictions, tess_reference, tess_langs)): + image_score = overlap_score(pred, ref_text) + tess_scores[TESS_CODE_TO_LANGUAGE[lang]].append(image_score) + + flat_tess_scores = [s for l in tess_scores for s in tess_scores[l]] + benchmark_stats["tesseract"] = { + "avg_score": sum(flat_tess_scores) / len(flat_tess_scores), + "lang_scores": {l: sum(scores) / len(scores) for l, scores in tess_scores.items()}, + "time_per_img": tesseract_time / len(tess_imgs) + } result_path = os.path.join(args.results_dir, "rec_bench") os.makedirs(result_path, exist_ok=True) with open(os.path.join(result_path, "results.json"), "w+") as f: - json.dump(image_scores, f) + json.dump(benchmark_stats, f) + + key_languages = [k for k in KEY_LANGUAGES if k in surya_scores] + table_headers = ["Model", "Time per page (s)", "Avg Score"] + KEY_LANGUAGES + table_data = [ + ["surya", benchmark_stats["surya"]["time_per_img"], benchmark_stats["surya"]["avg_score"]] + [benchmark_stats["surya"]["lang_scores"][l] for l in key_languages], + ] + if args.tesseract: + table_data.append( + ["tesseract", benchmark_stats["tesseract"]["time_per_img"], benchmark_stats["tesseract"]["avg_score"]] + [benchmark_stats["tesseract"]["lang_scores"].get(l, 0) for l in key_languages] + ) + + print(tabulate(table_data, headers=table_headers, tablefmt="github")) + print("Only a few major languages are displayed. See the result path for additional languages.") if args.debug: for idx, (image, pred, ref_text, bbox, lang) in enumerate(zip(images, predictions_by_image, line_text, bboxes, lang_list)): diff --git a/benchmark/scoring.py b/benchmark/scoring.py index 868c880..50bf089 100644 --- a/benchmark/scoring.py +++ b/benchmark/scoring.py @@ -1,8 +1,10 @@ import math +from typing import List + from rapidfuzz import fuzz -def overlap_score(pred_lines, reference_lines): +def overlap_score(pred_lines: List[str], reference_lines: List[str]): line_scores = [] line_weights = [] for i, pred_line in enumerate(pred_lines): diff --git a/scripts/verify_benchmark_scores.py b/scripts/verify_benchmark_scores.py index 7f17b34..5113ffb 100644 --- a/scripts/verify_benchmark_scores.py +++ b/scripts/verify_benchmark_scores.py @@ -2,19 +2,33 @@ import json import argparse -def verify_scores(file_path): +def verify_det(data): + scores = data["metrics"]["surya"] + if scores["precision"] <= 0.9 or scores["recall"] <= 0.9: + raise ValueError("Scores do not meet the required threshold") + + +def verify_rec(data): + scores = data["surya"] + if scores["avg_score"] <= 0.9: + raise ValueError("Scores do not meet the required threshold") + + +def verify_scores(file_path, bench_type): with open(file_path, 'r') as file: data = json.load(file) - scores = data["metrics"]["surya"] - - if scores["precision"] <= 0.9 or scores["recall"] <= 0.9: - print(scores) - raise ValueError("Scores do not meet the required threshold") + if bench_type == "detection": + verify_det(data) + elif bench_type == "recognition": + verify_rec(data) + else: + raise ValueError("Invalid benchmark type") if __name__ == "__main__": parser = argparse.ArgumentParser(description="Verify benchmark scores") parser.add_argument("file_path", type=str, help="Path to the json file") + parser.add_argument("--bench_type", type=str, help="Type of benchmark to verify", default="detection") args = parser.parse_args() - verify_scores(args.file_path) + verify_scores(args.file_path, args.bench_type) diff --git a/surya/benchmark/tesseract.py b/surya/benchmark/tesseract.py index 4019d24..f642d89 100644 --- a/surya/benchmark/tesseract.py +++ b/surya/benchmark/tesseract.py @@ -1,10 +1,49 @@ +from typing import List, Optional + import numpy as np import pytesseract from pytesseract import Output + +from surya.input.processing import slice_bboxes_from_image from surya.settings import settings import os from concurrent.futures import ProcessPoolExecutor -from surya.detection import get_batch_size +from surya.detection import get_batch_size as get_det_batch_size +from surya.recognition import get_batch_size as get_rec_batch_size +from surya.languages import CODE_TO_LANGUAGE + + +def surya_lang_to_tesseract(code: str) -> Optional[str]: + lang_str = CODE_TO_LANGUAGE[code] + try: + tess_lang = TESS_LANGUAGE_TO_CODE[lang_str] + except KeyError: + return None + return tess_lang + + +def tesseract_ocr(img, bboxes, lang: str): + line_imgs = slice_bboxes_from_image(img, bboxes) + config = f'--tessdata-dir "{settings.TESSDATA_PREFIX}"' + lines = [] + for line_img in line_imgs: + line = pytesseract.image_to_string(line_img, lang=lang, config=config) + lines.append(line) + return lines + + +def tesseract_ocr_parallel(imgs, bboxes, langs: List[str]): + tess_parallel_cores = min(len(imgs), get_rec_batch_size()) + cpus = os.cpu_count() + tess_parallel_cores = min(tess_parallel_cores, cpus) + + # Tesseract uses 4 threads per instance + tess_parallel = max(tess_parallel_cores // 4, 1) + + with ProcessPoolExecutor(max_workers=tess_parallel) as executor: + tess_text = executor.map(tesseract_ocr, imgs, bboxes, langs) + tess_text = list(tess_text) + return tess_text def tesseract_bboxes(img): @@ -24,7 +63,7 @@ def tesseract_bboxes(img): def tesseract_parallel(imgs): # Tesseract uses 4 threads per instance - tess_parallel_cores = min(len(imgs), get_batch_size()) + tess_parallel_cores = min(len(imgs), get_det_batch_size()) cpus = os.cpu_count() tess_parallel_cores = min(tess_parallel_cores, cpus) @@ -34,4 +73,106 @@ def tesseract_parallel(imgs): with ProcessPoolExecutor(max_workers=tess_parallel) as executor: tess_bboxes = executor.map(tesseract_bboxes, imgs) tess_bboxes = list(tess_bboxes) - return tess_bboxes \ No newline at end of file + return tess_bboxes + + +TESS_CODE_TO_LANGUAGE = { + "afr": "Afrikaans", + "amh": "Amharic", + "ara": "Arabic", + "asm": "Assamese", + "aze": "Azerbaijani", + "bel": "Belarusian", + "ben": "Bengali", + "bod": "Tibetan", + "bos": "Bosnian", + "bre": "Breton", + "bul": "Bulgarian", + "cat": "Catalan", + "ceb": "Cebuano", + "ces": "Czech", + "chi_sim": "Chinese", + "chr": "Cherokee", + "cym": "Welsh", + "dan": "Danish", + "deu": "German", + "dzo": "Dzongkha", + "ell": "Greek", + "eng": "English", + "epo": "Esperanto", + "est": "Estonian", + "eus": "Basque", + "fas": "Persian", + "fin": "Finnish", + "fra": "French", + "fry": "Western Frisian", + "guj": "Gujarati", + "gla": "Scottish Gaelic", + "gle": "Irish", + "glg": "Galician", + "heb": "Hebrew", + "hin": "Hindi", + "hrv": "Croatian", + "hun": "Hungarian", + "hye": "Armenian", + "iku": "Inuktitut", + "ind": "Indonesian", + "isl": "Icelandic", + "ita": "Italian", + "jav": "Javanese", + "jpn": "Japanese", + "kan": "Kannada", + "kat": "Georgian", + "kaz": "Kazakh", + "khm": "Khmer", + "kir": "Kyrgyz", + "kor": "Korean", + "kur": "Kurdish", + "lao": "Lao", + "lat": "Latin", + "lav": "Latvian", + "lit": "Lithuanian", + "mal": "Malayalam", + "mar": "Marathi", + "mkd": "Macedonian", + "mlt": "Maltese", + "mon": "Mongolian", + "msa": "Malay", + "mya": "Burmese", + "nep": "Nepali", + "nld": "Dutch", + "nor": "Norwegian", + "ori": "Oriya", + "pan": "Punjabi", + "pol": "Polish", + "por": "Portuguese", + "pus": "Pashto", + "ron": "Romanian", + "rus": "Russian", + "san": "Sanskrit", + "sin": "Sinhala", + "slk": "Slovak", + "slv": "Slovenian", + "snd": "Sindhi", + "spa": "Spanish", + "sqi": "Albanian", + "srp": "Serbian", + "swa": "Swahili", + "swe": "Swedish", + "syr": "Syriac", + "tam": "Tamil", + "tel": "Telugu", + "tgk": "Tajik", + "tgl": "Tagalog", + "tha": "Thai", + "tir": "Tigrinya", + "tur": "Turkish", + "uig": "Uyghur", + "ukr": "Ukrainian", + "urd": "Urdu", + "uzb": "Uzbek", + "vie": "Vietnamese", + "yid": "Yiddish" +} + +TESS_LANGUAGE_TO_CODE = {v:k for k,v in TESS_CODE_TO_LANGUAGE.items()} diff --git a/surya/input/load.py b/surya/input/load.py index a04df78..1c2bbbe 100644 --- a/surya/input/load.py +++ b/surya/input/load.py @@ -19,8 +19,8 @@ def load_pdf(pdf_path, max_pages=None, start_page=None): start_page = 0 if max_pages: - assert max_pages + start_page <= last_page and max_pages >= 0, f"Max pages must be between 0 and {last_page}" - last_page = start_page + max_pages + assert max_pages >= 0, f"Max pages must be greater than 0" + last_page = min(start_page + max_pages, last_page) page_indices = list(range(start_page, last_page)) images = get_page_images(doc, page_indices) diff --git a/surya/languages.py b/surya/languages.py index dc98d75..f31f83a 100644 --- a/surya/languages.py +++ b/surya/languages.py @@ -61,7 +61,7 @@ CODE_TO_LANGUAGE = { 'nl': 'Dutch', 'no': 'Norwegian', 'om': 'Oromo', - 'or': 'Odia', + 'or': 'Oriya', 'pa': 'Punjabi', 'pl': 'Polish', 'ps': 'Pashto', diff --git a/surya/settings.py b/surya/settings.py index 769da09..34ef7be 100644 --- a/surya/settings.py +++ b/surya/settings.py @@ -59,6 +59,9 @@ class Settings(BaseSettings): RECOGNITION_FONT_DL_PATH: str = "https://github.com/satbyy/go-noto-universal/releases/download/v7.0/GoNotoKurrent-Regular.ttf" RECOGNITION_BENCH_DATASET_NAME: str = "vikp/rec_bench" + # Tesseract (for benchmarks only) + TESSDATA_PREFIX: Optional[str] = None + @computed_field @property def MODEL_DTYPE(self) -> torch.dtype: