Merge pull request #196 from VikParuchuri/dev
Table recognition, better layout
9
.github/workflows/tests.yml
vendored
@ -36,10 +36,11 @@ jobs:
|
||||
run: |
|
||||
poetry run python benchmark/layout.py --max 5
|
||||
poetry run python scripts/verify_benchmark_scores.py results/benchmark/layout_bench/results.json --bench_type layout
|
||||
- name: Run ordering benchmark text
|
||||
- name: Run ordering benchmark
|
||||
run: |
|
||||
poetry run python benchmark/ordering.py --max 5
|
||||
poetry run python scripts/verify_benchmark_scores.py results/benchmark/order_bench/results.json --bench_type ordering
|
||||
|
||||
|
||||
|
||||
- name: Run table recognition benchmark
|
||||
run: |
|
||||
poetry run python benchmark/table_recognition.py --max 5
|
||||
poetry run python scripts/verify_benchmark_scores.py results/benchmark/table_rec_bench/results.json --bench_type table_recognition
|
||||
108
README.md
@ -6,16 +6,23 @@ Surya is a document OCR toolkit that does:
|
||||
- Line-level text detection in any language
|
||||
- Layout analysis (table, image, header, etc detection)
|
||||
- Reading order detection
|
||||
- Table recognition (detecting rows/columns)
|
||||
|
||||
It works on a range of documents (see [usage](#usage) and [benchmarks](#benchmarks) for more details).
|
||||
|
||||
|
||||
| Detection | OCR |
|
||||
|:----------------------------------------------------------------:|:-----------------------------------------------------------------------:|
|
||||
|  |  |
|
||||
| <img src="static/images/excerpt.png" width="500px"/> | <img src="static/images/excerpt_text.png" width="500px"/> |
|
||||
|
||||
| Layout | Reading Order |
|
||||
|:------------------------------------------------------------------:|:--------------------------------------------------------------------------:|
|
||||
|  |  |
|
||||
| <img src="static/images/excerpt_layout.png" width="500px"/> | <img src="static/images/excerpt_reading.jpg" width="500px"/> |
|
||||
|
||||
| Table Recognition | |
|
||||
|:-----------------------------------------:|:----------------:|
|
||||
| <img src="static/images/table_rec.png" width="500px"/> | <img width="500px"/> |
|
||||
|
||||
|
||||
Surya is named for the [Hindu sun god](https://en.wikipedia.org/wiki/Surya), who has universal vision.
|
||||
|
||||
@ -25,19 +32,19 @@ Surya is named for the [Hindu sun god](https://en.wikipedia.org/wiki/Surya), who
|
||||
|
||||
## Examples
|
||||
|
||||
| Name | Detection | OCR | Layout | Order |
|
||||
|------------------|:-----------------------------------:|-----------------------------------------:|-------------------------------------------:|--------------------------------------------:|
|
||||
| Japanese | [Image](static/images/japanese.jpg) | [Image](static/images/japanese_text.jpg) | [Image](static/images/japanese_layout.jpg) | [Image](static/images/japanese_reading.jpg) |
|
||||
| Chinese | [Image](static/images/chinese.jpg) | [Image](static/images/chinese_text.jpg) | [Image](static/images/chinese_layout.jpg) | [Image](static/images/chinese_reading.jpg) |
|
||||
| Hindi | [Image](static/images/hindi.jpg) | [Image](static/images/hindi_text.jpg) | [Image](static/images/hindi_layout.jpg) | [Image](static/images/hindi_reading.jpg) |
|
||||
| Arabic | [Image](static/images/arabic.jpg) | [Image](static/images/arabic_text.jpg) | [Image](static/images/arabic_layout.jpg) | [Image](static/images/arabic_reading.jpg) |
|
||||
| Chinese + Hindi | [Image](static/images/chi_hind.jpg) | [Image](static/images/chi_hind_text.jpg) | [Image](static/images/chi_hind_layout.jpg) | [Image](static/images/chi_hind_reading.jpg) |
|
||||
| Presentation | [Image](static/images/pres.png) | [Image](static/images/pres_text.jpg) | [Image](static/images/pres_layout.jpg) | [Image](static/images/pres_reading.jpg) |
|
||||
| Scientific Paper | [Image](static/images/paper.jpg) | [Image](static/images/paper_text.jpg) | [Image](static/images/paper_layout.jpg) | [Image](static/images/paper_reading.jpg) |
|
||||
| Scanned Document | [Image](static/images/scanned.png) | [Image](static/images/scanned_text.jpg) | [Image](static/images/scanned_layout.jpg) | [Image](static/images/scanned_reading.jpg) |
|
||||
| New York Times | [Image](static/images/nyt.jpg) | [Image](static/images/nyt_text.jpg) | [Image](static/images/nyt_layout.jpg) | [Image](static/images/nyt_order.jpg) |
|
||||
| Scanned Form | [Image](static/images/funsd.png) | [Image](static/images/funsd_text.jpg) | [Image](static/images/funsd_layout.jpg) | [Image](static/images/funsd_reading.jpg) |
|
||||
| Textbook | [Image](static/images/textbook.jpg) | [Image](static/images/textbook_text.jpg) | [Image](static/images/textbook_layout.jpg) | [Image](static/images/textbook_order.jpg) |
|
||||
| Name | Detection | OCR | Layout | Order | Table Rec |
|
||||
|------------------|:-----------------------------------:|-----------------------------------------:|-------------------------------------------:|--------------------------------------------:|---------------------------------------------:|
|
||||
| Japanese | [Image](static/images/japanese.jpg) | [Image](static/images/japanese_text.jpg) | [Image](static/images/japanese_layout.jpg) | [Image](static/images/japanese_reading.jpg) | [Image](static/images/japanese_tablerec.png) |
|
||||
| Chinese | [Image](static/images/chinese.jpg) | [Image](static/images/chinese_text.jpg) | [Image](static/images/chinese_layout.jpg) | [Image](static/images/chinese_reading.jpg) | |
|
||||
| Hindi | [Image](static/images/hindi.jpg) | [Image](static/images/hindi_text.jpg) | [Image](static/images/hindi_layout.jpg) | [Image](static/images/hindi_reading.jpg) | |
|
||||
| Arabic | [Image](static/images/arabic.jpg) | [Image](static/images/arabic_text.jpg) | [Image](static/images/arabic_layout.jpg) | [Image](static/images/arabic_reading.jpg) | |
|
||||
| Chinese + Hindi | [Image](static/images/chi_hind.jpg) | [Image](static/images/chi_hind_text.jpg) | [Image](static/images/chi_hind_layout.jpg) | [Image](static/images/chi_hind_reading.jpg) | |
|
||||
| Presentation | [Image](static/images/pres.png) | [Image](static/images/pres_text.jpg) | [Image](static/images/pres_layout.jpg) | [Image](static/images/pres_reading.jpg) | [Image](static/images/pres_tablerec.png) |
|
||||
| Scientific Paper | [Image](static/images/paper.jpg) | [Image](static/images/paper_text.jpg) | [Image](static/images/paper_layout.jpg) | [Image](static/images/paper_reading.jpg) | [Image](static/images/paper_tablerec.png) |
|
||||
| Scanned Document | [Image](static/images/scanned.png) | [Image](static/images/scanned_text.jpg) | [Image](static/images/scanned_layout.jpg) | [Image](static/images/scanned_reading.jpg) | [Image](static/images/scanned_tablerec.png) |
|
||||
| New York Times | [Image](static/images/nyt.jpg) | [Image](static/images/nyt_text.jpg) | [Image](static/images/nyt_layout.jpg) | [Image](static/images/nyt_order.jpg) | |
|
||||
| Scanned Form | [Image](static/images/funsd.png) | [Image](static/images/funsd_text.jpg) | [Image](static/images/funsd_layout.jpg) | [Image](static/images/funsd_reading.jpg) | [Image](static/images/scanned_tablerec2.png) |
|
||||
| Textbook | [Image](static/images/textbook.jpg) | [Image](static/images/textbook_text.jpg) | [Image](static/images/textbook_layout.jpg) | [Image](static/images/textbook_order.jpg) | |
|
||||
|
||||
# Hosted API
|
||||
|
||||
@ -272,6 +279,43 @@ processor = load_processor()
|
||||
order_predictions = batch_ordering([image], [bboxes], model, processor)
|
||||
```
|
||||
|
||||
## Table Recognition
|
||||
|
||||
This command will write out a json file with the detected table cells and row/column ids, along with row/column bounding boxes.
|
||||
|
||||
```shell
|
||||
surya_table DATA_PATH
|
||||
```
|
||||
|
||||
- `DATA_PATH` can be an image, pdf, or folder of images/pdfs
|
||||
- `--images` will save images of the pages and detected table cells + rows and columns (optional)
|
||||
- `--max` specifies the maximum number of pages to process if you don't want to process everything
|
||||
- `--results_dir` specifies the directory to save results to instead of the default
|
||||
- `--detect_boxes` specifies if cells should be detected. By default, they're pulled out of the PDF, but this is not always possible.
|
||||
- `--skip_table_detection` tells table recognition not to detect tables first. Use this if your image is already cropped to a table.
|
||||
|
||||
The `results.json` file will contain a json dictionary where the keys are the input filenames without extensions. Each value will be a list of dictionaries, one per page of the input document. Each page dictionary contains:
|
||||
|
||||
- `cells` - detected table cells
|
||||
- `bbox` - the axis-aligned rectangle for the text line in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner.
|
||||
- `row_id` - the id of the row this cell belongs to.
|
||||
- `col_id` - the id of the column this cell belongs to.
|
||||
- `text` - if text could be pulled out of the pdf, the text of this cell.
|
||||
- `rows` - detected table rows
|
||||
- `bbox` - the bounding box of the table row
|
||||
- `row_id` - the id of the row
|
||||
- `cols` - detected table columns
|
||||
- `bbox` - the bounding box of the table column
|
||||
- `col_id`- the id of the column
|
||||
- `page` - the page number in the file
|
||||
- `table_idx` - the index of the table on the page (sorted in vertical order)
|
||||
- `image_bbox` - the bbox for the image in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner. All line bboxes will be contained within this bbox.
|
||||
|
||||
**Performance tips**
|
||||
|
||||
Setting the `TABLE_REC_BATCH_SIZE` env var properly will make a big difference when using a GPU. Each batch item will use `150MB` of VRAM, so very high batch sizes are possible. The default is a batch size `64`, which will use about 10GB of VRAM. Depending on your CPU core count, it might help, too - the default CPU batch size is `8`.
|
||||
|
||||
|
||||
# Limitations
|
||||
|
||||
- This is specialized for document OCR. It will likely not work on photos or other images.
|
||||
@ -381,10 +425,23 @@ I benchmarked the layout analysis on [Publaynet](https://github.com/ibm-aur-nlp/
|
||||
|
||||
**Methodology**
|
||||
|
||||
I benchmarked the layout analysis on the layout dataset from [here](https://www.icst.pku.edu.cn/cpdp/sjzy/), which was not in the training data. Unfortunately, this dataset is fairly noisy, and not all the labels are correct. It was very hard to find a dataset annotated with reading order and also layout information. I wanted to avoid using a cloud service for the ground truth.
|
||||
I benchmarked the reading order on the layout dataset from [here](https://www.icst.pku.edu.cn/cpdp/sjzy/), which was not in the training data. Unfortunately, this dataset is fairly noisy, and not all the labels are correct. It was very hard to find a dataset annotated with reading order and also layout information. I wanted to avoid using a cloud service for the ground truth.
|
||||
|
||||
The accuracy is computed by finding if each pair of layout boxes is in the correct order, then taking the % that are correct.
|
||||
|
||||
## Table Recognition
|
||||
|
||||
| Model | Row Intersection | Col Intersection | Time Per Image |
|
||||
|-------------------|------------------|------------------|------------------|
|
||||
| Surya | 0.97 | 0.93 | 0.03 |
|
||||
| Table transformer | 0.72 | 0.84 | 0.02 |
|
||||
|
||||
Higher is better for intersection, which the percentage of the actual row/column overlapped by the predictions.
|
||||
|
||||
**Methodology**
|
||||
|
||||
The benchmark uses a subset of [Fintabnet](https://developer.ibm.com/exchanges/data/all/fintabnet/) from IBM. It has labeled rows and columns. After table recognition is run, the predicted rows and columns are compared to the ground truth. There is an additional penalty for predicting too many or too few rows/columns.
|
||||
|
||||
## Running your own benchmarks
|
||||
|
||||
You can benchmark the performance of surya on your machine.
|
||||
@ -396,7 +453,7 @@ You can benchmark the performance of surya on your machine.
|
||||
|
||||
This will evaluate tesseract and surya for text line detection across a randomly sampled set of images from [doclaynet](https://huggingface.co/datasets/vikp/doclaynet_bench).
|
||||
|
||||
```
|
||||
```shell
|
||||
python benchmark/detection.py --max 256
|
||||
```
|
||||
|
||||
@ -409,7 +466,7 @@ python benchmark/detection.py --max 256
|
||||
|
||||
This will evaluate surya and optionally tesseract on multilingual pdfs from common crawl (with synthetic data for missing languages).
|
||||
|
||||
```
|
||||
```shell
|
||||
python benchmark/recognition.py --tesseract
|
||||
```
|
||||
|
||||
@ -425,7 +482,7 @@ python benchmark/recognition.py --tesseract
|
||||
|
||||
This will evaluate surya on the publaynet dataset.
|
||||
|
||||
```
|
||||
```shell
|
||||
python benchmark/layout.py
|
||||
```
|
||||
|
||||
@ -435,7 +492,7 @@ python benchmark/layout.py
|
||||
|
||||
**Reading Order**
|
||||
|
||||
```
|
||||
```shell
|
||||
python benchmark/ordering.py
|
||||
```
|
||||
|
||||
@ -443,6 +500,17 @@ python benchmark/ordering.py
|
||||
- `--debug` will render images with detected text
|
||||
- `--results_dir` will let you specify a directory to save results to instead of the default one
|
||||
|
||||
**Table Recognition**
|
||||
|
||||
```shell
|
||||
python benchmark/table_recognition.py --max 1024 --tatr
|
||||
```
|
||||
|
||||
- `--max` controls how many images to process for the benchmark
|
||||
- `--debug` will render images with detected text
|
||||
- `--results_dir` will let you specify a directory to save results to instead of the default one
|
||||
- `--tatr` specifies whether to also run table transformer
|
||||
|
||||
# Training
|
||||
|
||||
Text detection was trained on 4x A6000s for 3 days. It used a diverse set of images as training data. It was trained from scratch using a modified efficientvit architecture for semantic segmentation.
|
||||
|
||||
143
benchmark/table_recognition.py
Normal file
@ -0,0 +1,143 @@
|
||||
import argparse
|
||||
import collections
|
||||
import copy
|
||||
import json
|
||||
|
||||
from tabulate import tabulate
|
||||
|
||||
from surya.input.processing import convert_if_not_rgb
|
||||
from surya.model.table_rec.model import load_model
|
||||
from surya.model.table_rec.processor import load_processor
|
||||
from surya.tables import batch_table_recognition, get_batch_size
|
||||
from surya.settings import settings
|
||||
from surya.benchmark.metrics import rank_accuracy, penalized_iou_score
|
||||
from surya.benchmark.tatr import load_tatr, batch_inference_tatr
|
||||
import os
|
||||
import time
|
||||
import datasets
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Benchmark surya table recognition model.")
|
||||
parser.add_argument("--results_dir", type=str, help="Path to JSON file with benchmark results.", default=os.path.join(settings.RESULT_DIR, "benchmark"))
|
||||
parser.add_argument("--max", type=int, help="Maximum number of images to run benchmark on.", default=None)
|
||||
parser.add_argument("--tatr", action="store_true", help="Run table transformer.", default=False)
|
||||
args = parser.parse_args()
|
||||
|
||||
model = load_model()
|
||||
processor = load_processor()
|
||||
|
||||
pathname = "table_rec_bench"
|
||||
# These have already been shuffled randomly, so sampling from the start is fine
|
||||
split = "train"
|
||||
if args.max is not None:
|
||||
split = f"train[:{args.max}]"
|
||||
dataset = datasets.load_dataset(settings.TABLE_REC_BENCH_DATASET_NAME, split=split)
|
||||
images = list(dataset["image"])
|
||||
images = convert_if_not_rgb(images)
|
||||
bboxes = list(dataset["bboxes"])
|
||||
|
||||
start = time.time()
|
||||
bboxes = [[{"bbox": b, "text": None} for b in bb] for bb in bboxes]
|
||||
table_rec_predictions = batch_table_recognition(images, bboxes, model, processor)
|
||||
surya_time = time.time() - start
|
||||
|
||||
folder_name = os.path.basename(pathname).split(".")[0]
|
||||
result_path = os.path.join(args.results_dir, folder_name)
|
||||
os.makedirs(result_path, exist_ok=True)
|
||||
|
||||
page_metrics = collections.OrderedDict()
|
||||
mean_col_iou = 0
|
||||
mean_row_iou = 0
|
||||
for idx, pred in enumerate(table_rec_predictions):
|
||||
row = dataset[idx]
|
||||
pred_row_boxes = [p.bbox for p in pred.rows]
|
||||
pred_col_bboxes = [p.bbox for p in pred.cols]
|
||||
actual_row_bboxes = row["rows"]
|
||||
actual_col_bboxes = row["cols"]
|
||||
row_score = penalized_iou_score(pred_row_boxes, actual_row_bboxes)
|
||||
col_score = penalized_iou_score(pred_col_bboxes, actual_col_bboxes)
|
||||
page_results = {
|
||||
"row_score": row_score,
|
||||
"col_score": col_score,
|
||||
"row_count": len(actual_row_bboxes),
|
||||
"col_count": len(actual_col_bboxes)
|
||||
}
|
||||
|
||||
mean_col_iou += col_score
|
||||
mean_row_iou += row_score
|
||||
|
||||
page_metrics[idx] = page_results
|
||||
|
||||
mean_col_iou /= len(table_rec_predictions)
|
||||
mean_row_iou /= len(table_rec_predictions)
|
||||
|
||||
out_data = {"surya": {
|
||||
"time": surya_time,
|
||||
"mean_row_iou": mean_row_iou,
|
||||
"mean_col_iou": mean_col_iou,
|
||||
"page_metrics": page_metrics
|
||||
}}
|
||||
|
||||
if args.tatr:
|
||||
tatr_model = load_tatr()
|
||||
start = time.time()
|
||||
tatr_predictions = batch_inference_tatr(tatr_model, images, 1)
|
||||
tatr_time = time.time() - start
|
||||
|
||||
page_metrics = collections.OrderedDict()
|
||||
mean_col_iou = 0
|
||||
mean_row_iou = 0
|
||||
for idx, pred in enumerate(tatr_predictions):
|
||||
row = dataset[idx]
|
||||
pred_row_boxes = [p["bbox"] for p in pred["rows"]]
|
||||
pred_col_bboxes = [p["bbox"] for p in pred["cols"]]
|
||||
actual_row_bboxes = row["rows"]
|
||||
actual_col_bboxes = row["cols"]
|
||||
row_score = penalized_iou_score(pred_row_boxes, actual_row_bboxes)
|
||||
col_score = penalized_iou_score(pred_col_bboxes, actual_col_bboxes)
|
||||
page_results = {
|
||||
"row_score": row_score,
|
||||
"col_score": col_score,
|
||||
"row_count": len(actual_row_bboxes),
|
||||
"col_count": len(actual_col_bboxes)
|
||||
}
|
||||
|
||||
mean_col_iou += col_score
|
||||
mean_row_iou += row_score
|
||||
|
||||
page_metrics[idx] = page_results
|
||||
|
||||
mean_col_iou /= len(tatr_predictions)
|
||||
mean_row_iou /= len(tatr_predictions)
|
||||
|
||||
out_data["tatr"] = {
|
||||
"time": tatr_time,
|
||||
"mean_row_iou": mean_row_iou,
|
||||
"mean_col_iou": mean_col_iou,
|
||||
"page_metrics": page_metrics
|
||||
}
|
||||
|
||||
|
||||
with open(os.path.join(result_path, "results.json"), "w+") as f:
|
||||
json.dump(out_data, f, indent=4)
|
||||
|
||||
table = [
|
||||
["Model", "Row Intersection", "Col Intersection", "Time Per Image"],
|
||||
["Surya", f"{out_data['surya']['mean_row_iou']:.2f}", f"{out_data['surya']['mean_col_iou']:.2f}",
|
||||
f"{surya_time / len(images):.2f}"],
|
||||
]
|
||||
|
||||
if args.tatr:
|
||||
table.append(["Table transformer", f"{out_data['tatr']['mean_row_iou']:.2f}", f"{out_data['tatr']['mean_col_iou']:.2f}",
|
||||
f"{tatr_time / len(images):.2f}"])
|
||||
|
||||
print(tabulate(table, headers="firstrow", tablefmt="github"))
|
||||
|
||||
print("Intersection is the average of the intersection % between each actual row/column, and the predictions. With penalties for too many/few predictions.")
|
||||
print("Note that table transformers is unbatched, since the example code in the repo is unbatched.")
|
||||
print(f"Wrote results to {result_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -1,3 +1,4 @@
|
||||
import pypdfium2 # Causes a warning if not the top import
|
||||
import argparse
|
||||
import copy
|
||||
import json
|
||||
@ -27,10 +28,10 @@ def main():
|
||||
det_processor = load_processor()
|
||||
|
||||
if os.path.isdir(args.input_path):
|
||||
images, names = load_from_folder(args.input_path, args.max)
|
||||
images, names, _ = load_from_folder(args.input_path, args.max)
|
||||
folder_name = os.path.basename(args.input_path)
|
||||
else:
|
||||
images, names = load_from_file(args.input_path, args.max)
|
||||
images, names, _ = load_from_file(args.input_path, args.max)
|
||||
folder_name = os.path.basename(args.input_path).split(".")[0]
|
||||
|
||||
line_predictions = batch_text_detection(images, det_model, det_processor)
|
||||
|
||||
@ -28,10 +28,10 @@ def main():
|
||||
processor = load_processor(checkpoint=checkpoint)
|
||||
|
||||
if os.path.isdir(args.input_path):
|
||||
images, names = load_from_folder(args.input_path, args.max)
|
||||
images, names, _ = load_from_folder(args.input_path, args.max)
|
||||
folder_name = os.path.basename(args.input_path)
|
||||
else:
|
||||
images, names = load_from_file(args.input_path, args.max)
|
||||
images, names, _ = load_from_file(args.input_path, args.max)
|
||||
folder_name = os.path.basename(args.input_path).split(".")[0]
|
||||
|
||||
start = time.time()
|
||||
|
||||
85
ocr_app.py
@ -4,21 +4,27 @@ from typing import List
|
||||
import pypdfium2
|
||||
import streamlit as st
|
||||
from surya.detection import batch_text_detection
|
||||
from surya.input.pdflines import get_page_text_lines, get_table_blocks
|
||||
from surya.layout import batch_layout_detection
|
||||
from surya.model.detection.model import load_model, load_processor
|
||||
from surya.model.recognition.model import load_model as load_rec_model
|
||||
from surya.model.recognition.processor import load_processor as load_rec_processor
|
||||
from surya.model.ordering.processor import load_processor as load_order_processor
|
||||
from surya.model.ordering.model import load_model as load_order_model
|
||||
from surya.model.table_rec.model import load_model as load_table_model
|
||||
from surya.model.table_rec.processor import load_processor as load_table_processor
|
||||
from surya.ordering import batch_ordering
|
||||
from surya.postprocessing.heatmap import draw_polys_on_image
|
||||
from surya.postprocessing.heatmap import draw_polys_on_image, draw_bboxes_on_image
|
||||
from surya.ocr import run_ocr
|
||||
from surya.postprocessing.text import draw_text_on_image
|
||||
from PIL import Image
|
||||
from surya.languages import CODE_TO_LANGUAGE
|
||||
from surya.input.langs import replace_lang_with_code
|
||||
from surya.schema import OCRResult, TextDetectionResult, LayoutResult, OrderResult
|
||||
from surya.schema import OCRResult, TextDetectionResult, LayoutResult, OrderResult, TableResult
|
||||
from surya.settings import settings
|
||||
from surya.tables import batch_table_recognition
|
||||
from surya.postprocessing.util import rescale_bboxes, rescale_bbox
|
||||
|
||||
|
||||
@st.cache_resource()
|
||||
def load_det_cached():
|
||||
@ -40,6 +46,11 @@ def load_order_cached():
|
||||
return load_order_model(), load_order_processor()
|
||||
|
||||
|
||||
@st.cache_resource()
|
||||
def load_table_cached():
|
||||
return load_table_model(), load_table_processor()
|
||||
|
||||
|
||||
def text_detection(img) -> (Image.Image, TextDetectionResult):
|
||||
pred = batch_text_detection([img], det_model, det_processor)[0]
|
||||
polygons = [p.polygon for p in pred.bboxes]
|
||||
@ -52,7 +63,7 @@ def layout_detection(img) -> (Image.Image, LayoutResult):
|
||||
pred = batch_layout_detection([img], layout_model, layout_processor, [det_pred])[0]
|
||||
polygons = [p.polygon for p in pred.bboxes]
|
||||
labels = [p.label for p in pred.bboxes]
|
||||
layout_img = draw_polys_on_image(polygons, img.copy(), labels=labels)
|
||||
layout_img = draw_polys_on_image(polygons, img.copy(), labels=labels, label_font_size=18)
|
||||
return layout_img, pred
|
||||
|
||||
|
||||
@ -62,14 +73,56 @@ def order_detection(img) -> (Image.Image, OrderResult):
|
||||
pred = batch_ordering([img], [bboxes], order_model, order_processor)[0]
|
||||
polys = [l.polygon for l in pred.bboxes]
|
||||
positions = [str(l.position) for l in pred.bboxes]
|
||||
order_img = draw_polys_on_image(polys, img.copy(), labels=positions, label_font_size=20)
|
||||
order_img = draw_polys_on_image(polys, img.copy(), labels=positions, label_font_size=18)
|
||||
return order_img, pred
|
||||
|
||||
|
||||
def table_recognition(img, highres_img, filepath, page_idx: int, use_pdf_boxes: bool, skip_table_detection: bool) -> (Image.Image, List[TableResult]):
|
||||
if skip_table_detection:
|
||||
layout_tables = [(0, 0, highres_img.size[0], highres_img.size[1])]
|
||||
table_imgs = [highres_img]
|
||||
else:
|
||||
_, layout_pred = layout_detection(img)
|
||||
layout_tables_lowres = [l.bbox for l in layout_pred.bboxes if l.label == "Table"]
|
||||
table_imgs = []
|
||||
layout_tables = []
|
||||
for tb in layout_tables_lowres:
|
||||
highres_bbox = rescale_bbox(tb, img.size, highres_img.size)
|
||||
table_imgs.append(
|
||||
highres_img.crop(highres_bbox)
|
||||
)
|
||||
layout_tables.append(highres_bbox)
|
||||
|
||||
page_text = get_page_text_lines(filepath, [page_idx], [highres_img.size])[0]
|
||||
table_bboxes = get_table_blocks(layout_tables, page_text, highres_img.size)
|
||||
|
||||
if not use_pdf_boxes or any(len(tb) == 0 for tb in table_bboxes):
|
||||
det_results = batch_text_detection(table_imgs, det_model, det_processor)
|
||||
table_bboxes = [[{"bbox": tb.bbox, "text": None} for tb in det_result.bboxes] for det_result in det_results]
|
||||
|
||||
table_preds = batch_table_recognition(table_imgs, table_bboxes, table_model, table_processor)
|
||||
table_img = img.copy()
|
||||
|
||||
for results, table_bbox in zip(table_preds, layout_tables):
|
||||
adjusted_bboxes = []
|
||||
labels = []
|
||||
|
||||
for item in results.cells:
|
||||
adjusted_bboxes.append([
|
||||
(item.bbox[0] + table_bbox[0]),
|
||||
(item.bbox[1] + table_bbox[1]),
|
||||
(item.bbox[2] + table_bbox[0]),
|
||||
(item.bbox[3] + table_bbox[1])
|
||||
])
|
||||
labels.append(f"{item.row_id} / {item.col_id}")
|
||||
table_img = draw_bboxes_on_image(adjusted_bboxes, highres_img, labels=labels, label_font_size=18)
|
||||
return table_img, table_preds
|
||||
|
||||
|
||||
# Function for OCR
|
||||
def ocr(img, langs: List[str]) -> (Image.Image, OCRResult):
|
||||
def ocr(img, highres_img, langs: List[str]) -> (Image.Image, OCRResult):
|
||||
replace_lang_with_code(langs)
|
||||
img_pred = run_ocr([img], [langs], det_model, det_processor, rec_model, rec_processor)[0]
|
||||
img_pred = run_ocr([img], [langs], det_model, det_processor, rec_model, rec_processor, highres_images=[highres_img])[0]
|
||||
|
||||
bboxes = [l.bbox for l in img_pred.text_lines]
|
||||
text = [l.text for l in img_pred.text_lines]
|
||||
@ -83,7 +136,7 @@ def open_pdf(pdf_file):
|
||||
|
||||
|
||||
@st.cache_data()
|
||||
def get_page_image(pdf_file, page_num, dpi=96):
|
||||
def get_page_image(pdf_file, page_num, dpi=settings.IMAGE_DPI):
|
||||
doc = open_pdf(pdf_file)
|
||||
renderer = doc.render(
|
||||
pypdfium2.PdfBitmap.to_pil,
|
||||
@ -108,6 +161,7 @@ det_model, det_processor = load_det_cached()
|
||||
rec_model, rec_processor = load_rec_cached()
|
||||
layout_model, layout_processor = load_layout_cached()
|
||||
order_model, order_processor = load_order_cached()
|
||||
table_model, table_processor = load_table_cached()
|
||||
|
||||
|
||||
st.markdown("""
|
||||
@ -136,14 +190,20 @@ if "pdf" in filetype:
|
||||
page_count = page_count(in_file)
|
||||
page_number = st.sidebar.number_input(f"Page number out of {page_count}:", min_value=1, value=1, max_value=page_count)
|
||||
|
||||
pil_image = get_page_image(in_file, page_number)
|
||||
pil_image = get_page_image(in_file, page_number, settings.IMAGE_DPI)
|
||||
pil_image_highres = get_page_image(in_file, page_number, dpi=settings.IMAGE_DPI_HIGHRES)
|
||||
else:
|
||||
pil_image = Image.open(in_file).convert("RGB")
|
||||
pil_image_highres = pil_image
|
||||
page_number = None
|
||||
|
||||
text_det = st.sidebar.button("Run Text Detection")
|
||||
text_rec = st.sidebar.button("Run OCR")
|
||||
layout_det = st.sidebar.button("Run Layout Analysis")
|
||||
order_det = st.sidebar.button("Run Reading Order")
|
||||
table_rec = st.sidebar.button("Run Table Rec")
|
||||
use_pdf_boxes = st.sidebar.checkbox("PDF table boxes", value=True, help="Table recognition only: Use the bounding boxes from the PDF file vs text detection model.")
|
||||
skip_table_detection = st.sidebar.checkbox("Skip table detection", value=False, help="Table recognition only: Skip table detection and treat the whole image/page as a table.")
|
||||
|
||||
if pil_image is None:
|
||||
st.stop()
|
||||
@ -165,7 +225,7 @@ if layout_det:
|
||||
|
||||
# Run OCR
|
||||
if text_rec:
|
||||
rec_img, pred = ocr(pil_image, languages)
|
||||
rec_img, pred = ocr(pil_image, pil_image_highres, languages)
|
||||
with col1:
|
||||
st.image(rec_img, caption="OCR Result", use_column_width=True)
|
||||
json_tab, text_tab = st.tabs(["JSON", "Text Lines (for debugging)"])
|
||||
@ -180,5 +240,12 @@ if order_det:
|
||||
st.image(order_img, caption="Reading Order", use_column_width=True)
|
||||
st.json(pred.model_dump(), expanded=True)
|
||||
|
||||
|
||||
if table_rec:
|
||||
table_img, pred = table_recognition(pil_image, pil_image_highres, in_file, page_number - 1 if page_number else None, use_pdf_boxes, skip_table_detection)
|
||||
with col1:
|
||||
st.image(table_img, caption="Table Recognition", use_column_width=True)
|
||||
st.json([p.model_dump() for p in pred], expanded=True)
|
||||
|
||||
with col2:
|
||||
st.image(pil_image, caption="Uploaded Image", use_column_width=True)
|
||||
10
ocr_text.py
@ -30,10 +30,12 @@ def main():
|
||||
args = parser.parse_args()
|
||||
|
||||
if os.path.isdir(args.input_path):
|
||||
images, names = load_from_folder(args.input_path, args.max, args.start_page)
|
||||
images, names, _ = load_from_folder(args.input_path, args.max, args.start_page)
|
||||
highres_images, _, _ = load_from_folder(args.input_path, args.max, args.start_page, settings.IMAGE_DPI_HIGHRES)
|
||||
folder_name = os.path.basename(args.input_path)
|
||||
else:
|
||||
images, names = load_from_file(args.input_path, args.max, args.start_page)
|
||||
images, names, _ = load_from_file(args.input_path, args.max, args.start_page)
|
||||
highres_images, _, _ = load_from_file(args.input_path, args.max, args.start_page, settings.IMAGE_DPI_HIGHRES)
|
||||
folder_name = os.path.basename(args.input_path).split(".")[0]
|
||||
|
||||
if args.lang_file:
|
||||
@ -60,7 +62,7 @@ def main():
|
||||
os.makedirs(result_path, exist_ok=True)
|
||||
|
||||
start = time.time()
|
||||
predictions_by_image = run_ocr(images, image_langs, det_model, det_processor, rec_model, rec_processor)
|
||||
predictions_by_image = run_ocr(images, image_langs, det_model, det_processor, rec_model, rec_processor, highres_images=highres_images)
|
||||
if args.debug:
|
||||
print(f"OCR took {time.time() - start:.2f} seconds")
|
||||
max_chars = max([len(l.text) for p in predictions_by_image for l in p.text_lines])
|
||||
@ -70,7 +72,7 @@ def main():
|
||||
for idx, (name, image, pred, langs) in enumerate(zip(names, images, predictions_by_image, image_langs)):
|
||||
bboxes = [l.bbox for l in pred.text_lines]
|
||||
pred_text = [l.text for l in pred.text_lines]
|
||||
page_image = draw_text_on_image(bboxes, pred_text, image.size, langs, has_math="_math" in langs)
|
||||
page_image = draw_text_on_image(bboxes, pred_text, image.size, langs, has_math="_math" in langs if langs else False)
|
||||
page_image.save(os.path.join(result_path, f"{name}_{idx}_text.png"))
|
||||
|
||||
out_preds = defaultdict(list)
|
||||
|
||||
123
poetry.lock
generated
@ -605,6 +605,23 @@ files = [
|
||||
{file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "coloredlogs"
|
||||
version = "15.0.1"
|
||||
description = "Colored terminal output for Python's logging module"
|
||||
optional = false
|
||||
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
|
||||
files = [
|
||||
{file = "coloredlogs-15.0.1-py2.py3-none-any.whl", hash = "sha256:612ee75c546f53e92e70049c9dbfcc18c935a2b9a53b66085ce9ef6a6e5c0934"},
|
||||
{file = "coloredlogs-15.0.1.tar.gz", hash = "sha256:7c991aa71a4577af2f82600d8f8f3a89f936baeaf9b50a9c197da014e5bf16b0"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
humanfriendly = ">=9.1"
|
||||
|
||||
[package.extras]
|
||||
cron = ["capturer (>=2.4)"]
|
||||
|
||||
[[package]]
|
||||
name = "comm"
|
||||
version = "0.2.2"
|
||||
@ -803,6 +820,17 @@ files = [
|
||||
{file = "filetype-1.2.0.tar.gz", hash = "sha256:66b56cd6474bf41d8c54660347d37afcc3f7d1970648de365c102ef77548aadb"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "flatbuffers"
|
||||
version = "24.3.25"
|
||||
description = "The FlatBuffers serialization format for Python"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "flatbuffers-24.3.25-py2.py3-none-any.whl", hash = "sha256:8dbdec58f935f3765e4f7f3cf635ac3a77f83568138d6a2311f524ec96364812"},
|
||||
{file = "flatbuffers-24.3.25.tar.gz", hash = "sha256:de2ec5b203f21441716617f38443e0a8ebf3d25bf0d9c0bb0ce68fa00ad546a4"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fqdn"
|
||||
version = "1.5.1"
|
||||
@ -1148,6 +1176,20 @@ testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gr
|
||||
torch = ["safetensors", "torch"]
|
||||
typing = ["types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "humanfriendly"
|
||||
version = "10.0"
|
||||
description = "Human friendly output for text interfaces using Python"
|
||||
optional = false
|
||||
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
|
||||
files = [
|
||||
{file = "humanfriendly-10.0-py2.py3-none-any.whl", hash = "sha256:1697e1a8a8f550fd43c2865cd84542fc175a61dcb779b6fee18cf6b6ccba1477"},
|
||||
{file = "humanfriendly-10.0.tar.gz", hash = "sha256:6b0b831ce8f15f7300721aa49829fc4e83921a9a301cc7f606be6686a2288ddc"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
pyreadline3 = {version = "*", markers = "sys_platform == \"win32\" and python_version >= \"3.8\""}
|
||||
|
||||
[[package]]
|
||||
name = "idna"
|
||||
version = "3.7"
|
||||
@ -2296,6 +2338,48 @@ files = [
|
||||
{file = "nvidia_nvtx_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:65f4d98982b31b60026e0e6de73fbdfc09d08a96f4656dd3665ca616a11e1e82"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "onnxruntime"
|
||||
version = "1.19.2"
|
||||
description = "ONNX Runtime is a runtime accelerator for Machine Learning models"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "onnxruntime-1.19.2-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:84fa57369c06cadd3c2a538ae2a26d76d583e7c34bdecd5769d71ca5c0fc750e"},
|
||||
{file = "onnxruntime-1.19.2-cp310-cp310-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:bdc471a66df0c1cdef774accef69e9f2ca168c851ab5e4f2f3341512c7ef4666"},
|
||||
{file = "onnxruntime-1.19.2-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e3a4ce906105d99ebbe817f536d50a91ed8a4d1592553f49b3c23c4be2560ae6"},
|
||||
{file = "onnxruntime-1.19.2-cp310-cp310-win32.whl", hash = "sha256:4b3d723cc154c8ddeb9f6d0a8c0d6243774c6b5930847cc83170bfe4678fafb3"},
|
||||
{file = "onnxruntime-1.19.2-cp310-cp310-win_amd64.whl", hash = "sha256:17ed7382d2c58d4b7354fb2b301ff30b9bf308a1c7eac9546449cd122d21cae5"},
|
||||
{file = "onnxruntime-1.19.2-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:d863e8acdc7232d705d49e41087e10b274c42f09e259016a46f32c34e06dc4fd"},
|
||||
{file = "onnxruntime-1.19.2-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c1dfe4f660a71b31caa81fc298a25f9612815215a47b286236e61d540350d7b6"},
|
||||
{file = "onnxruntime-1.19.2-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a36511dc07c5c964b916697e42e366fa43c48cdb3d3503578d78cef30417cb84"},
|
||||
{file = "onnxruntime-1.19.2-cp311-cp311-win32.whl", hash = "sha256:50cbb8dc69d6befad4746a69760e5b00cc3ff0a59c6c3fb27f8afa20e2cab7e7"},
|
||||
{file = "onnxruntime-1.19.2-cp311-cp311-win_amd64.whl", hash = "sha256:1c3e5d415b78337fa0b1b75291e9ea9fb2a4c1f148eb5811e7212fed02cfffa8"},
|
||||
{file = "onnxruntime-1.19.2-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:68e7051bef9cfefcbb858d2d2646536829894d72a4130c24019219442b1dd2ed"},
|
||||
{file = "onnxruntime-1.19.2-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d2d366fbcc205ce68a8a3bde2185fd15c604d9645888703785b61ef174265168"},
|
||||
{file = "onnxruntime-1.19.2-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:477b93df4db467e9cbf34051662a4b27c18e131fa1836e05974eae0d6e4cf29b"},
|
||||
{file = "onnxruntime-1.19.2-cp312-cp312-win32.whl", hash = "sha256:9a174073dc5608fad05f7cf7f320b52e8035e73d80b0a23c80f840e5a97c0147"},
|
||||
{file = "onnxruntime-1.19.2-cp312-cp312-win_amd64.whl", hash = "sha256:190103273ea4507638ffc31d66a980594b237874b65379e273125150eb044857"},
|
||||
{file = "onnxruntime-1.19.2-cp38-cp38-macosx_11_0_universal2.whl", hash = "sha256:636bc1d4cc051d40bc52e1f9da87fbb9c57d9d47164695dfb1c41646ea51ea66"},
|
||||
{file = "onnxruntime-1.19.2-cp38-cp38-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5bd8b875757ea941cbcfe01582970cc299893d1b65bd56731e326a8333f638a3"},
|
||||
{file = "onnxruntime-1.19.2-cp38-cp38-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b2046fc9560f97947bbc1acbe4c6d48585ef0f12742744307d3364b131ac5778"},
|
||||
{file = "onnxruntime-1.19.2-cp38-cp38-win32.whl", hash = "sha256:31c12840b1cde4ac1f7d27d540c44e13e34f2345cf3642762d2a3333621abb6a"},
|
||||
{file = "onnxruntime-1.19.2-cp38-cp38-win_amd64.whl", hash = "sha256:016229660adea180e9a32ce218b95f8f84860a200f0f13b50070d7d90e92956c"},
|
||||
{file = "onnxruntime-1.19.2-cp39-cp39-macosx_11_0_universal2.whl", hash = "sha256:006c8d326835c017a9e9f74c9c77ebb570a71174a1e89fe078b29a557d9c3848"},
|
||||
{file = "onnxruntime-1.19.2-cp39-cp39-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:df2a94179a42d530b936f154615b54748239c2908ee44f0d722cb4df10670f68"},
|
||||
{file = "onnxruntime-1.19.2-cp39-cp39-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fae4b4de45894b9ce7ae418c5484cbf0341db6813effec01bb2216091c52f7fb"},
|
||||
{file = "onnxruntime-1.19.2-cp39-cp39-win32.whl", hash = "sha256:dc5430f473e8706fff837ae01323be9dcfddd3ea471c900a91fa7c9b807ec5d3"},
|
||||
{file = "onnxruntime-1.19.2-cp39-cp39-win_amd64.whl", hash = "sha256:38475e29a95c5f6c62c2c603d69fc7d4c6ccbf4df602bd567b86ae1138881c49"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
coloredlogs = "*"
|
||||
flatbuffers = "*"
|
||||
numpy = ">=1.21.6"
|
||||
packaging = "*"
|
||||
protobuf = "*"
|
||||
sympy = "*"
|
||||
|
||||
[[package]]
|
||||
name = "opencv-python"
|
||||
version = "4.10.0.84"
|
||||
@ -2315,11 +2399,11 @@ files = [
|
||||
[package.dependencies]
|
||||
numpy = [
|
||||
{version = ">=1.21.0", markers = "python_version == \"3.9\" and platform_system == \"Darwin\" and platform_machine == \"arm64\""},
|
||||
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
|
||||
{version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""},
|
||||
{version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""},
|
||||
{version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""},
|
||||
{version = ">=1.19.3", markers = "platform_system == \"Linux\" and platform_machine == \"aarch64\" and python_version >= \"3.8\" and python_version < \"3.10\" or python_version > \"3.9\" and python_version < \"3.10\" or python_version >= \"3.9\" and platform_system != \"Darwin\" and python_version < \"3.10\" or python_version >= \"3.9\" and platform_machine != \"arm64\" and python_version < \"3.10\""},
|
||||
{version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""},
|
||||
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -2385,8 +2469,8 @@ files = [
|
||||
[package.dependencies]
|
||||
numpy = [
|
||||
{version = ">=1.22.4", markers = "python_version < \"3.11\""},
|
||||
{version = ">=1.23.2", markers = "python_version == \"3.11\""},
|
||||
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
|
||||
{version = ">=1.23.2", markers = "python_version == \"3.11\""},
|
||||
]
|
||||
python-dateutil = ">=2.8.2"
|
||||
pytz = ">=2020.1"
|
||||
@ -2443,6 +2527,23 @@ files = [
|
||||
qa = ["flake8 (==5.0.4)", "mypy (==0.971)", "types-setuptools (==67.2.0.1)"]
|
||||
testing = ["docopt", "pytest"]
|
||||
|
||||
[[package]]
|
||||
name = "pdftext"
|
||||
version = "0.3.12"
|
||||
description = "Extract structured text from pdfs quickly"
|
||||
optional = false
|
||||
python-versions = "!=2.7.*,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,!=3.7.*,!=3.8.*,>=3.9"
|
||||
files = [
|
||||
{file = "pdftext-0.3.12-py3-none-any.whl", hash = "sha256:4903349d5be23984dcb417d6d37611fec9595cf65b535fc54c609f8bc702b847"},
|
||||
{file = "pdftext-0.3.12.tar.gz", hash = "sha256:a82cccff535c97f806f8f32708bf336e19d1005b7b4a8e08cc44a37522a5cd62"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
onnxruntime = ">=1.19.2,<2.0.0"
|
||||
pydantic = ">=2.7.1,<3.0.0"
|
||||
pydantic-settings = ">=2.2.1,<3.0.0"
|
||||
pypdfium2 = ">=4.29.0,<5.0.0"
|
||||
|
||||
[[package]]
|
||||
name = "pexpect"
|
||||
version = "4.9.0"
|
||||
@ -3016,6 +3117,20 @@ files = [
|
||||
{file = "pypdfium2-4.30.0.tar.gz", hash = "sha256:48b5b7e5566665bc1015b9d69c1ebabe21f6aee468b509531c3c8318eeee2e16"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyreadline3"
|
||||
version = "3.5.4"
|
||||
description = "A python implementation of GNU readline."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "pyreadline3-3.5.4-py3-none-any.whl", hash = "sha256:eaf8e6cc3c49bcccf145fc6067ba8643d1df34d604a1ec0eccbf7a18e6d3fae6"},
|
||||
{file = "pyreadline3-3.5.4.tar.gz", hash = "sha256:8d57d53039a1c75adba8e50dd3d992b28143480816187ea5efbd5c78e6c885b7"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
dev = ["build", "flake8", "mypy", "pytest", "twine"]
|
||||
|
||||
[[package]]
|
||||
name = "pytesseract"
|
||||
version = "0.3.10"
|
||||
@ -4819,4 +4934,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools",
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.9,<3.13,!=3.9.7"
|
||||
content-hash = "d250e5223075069c0561f95e970624731feb7ddc20f1bc7b8ef6dd826a8f3085"
|
||||
content-hash = "58776d15cd2b20d1a735106b587c2595e18fd521dc91ef732192610e8e93d8ff"
|
||||
|
||||
@ -1,12 +1,12 @@
|
||||
[tool.poetry]
|
||||
name = "surya-ocr"
|
||||
version = "0.5.0"
|
||||
description = "OCR, layout, reading order, and line detection in 90+ languages"
|
||||
version = "0.6.0"
|
||||
description = "OCR, layout, reading order, and table recognition in 90+ languages"
|
||||
authors = ["Vik Paruchuri <vik.paruchuri@gmail.com>"]
|
||||
readme = "README.md"
|
||||
license = "GPL-3.0-or-later"
|
||||
repository = "https://github.com/VikParuchuri/surya"
|
||||
keywords = ["ocr", "pdf", "text detection", "text recognition"]
|
||||
keywords = ["ocr", "pdf", "text detection", "text recognition", "tables"]
|
||||
packages = [
|
||||
{include = "surya"}
|
||||
]
|
||||
@ -17,6 +17,7 @@ include = [
|
||||
"run_ocr_app.py",
|
||||
"detect_layout.py",
|
||||
"reading_order.py",
|
||||
"table_recognition.py"
|
||||
]
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
@ -32,6 +33,7 @@ opencv-python = "^4.9.0.80"
|
||||
tabulate = "^0.9.0"
|
||||
filetype = "^1.2.0"
|
||||
ftfy = "^6.1.3"
|
||||
pdftext = "^0.3.12"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
jupyter = "^1.0.0"
|
||||
@ -50,6 +52,7 @@ surya_ocr = "ocr_text:main"
|
||||
surya_layout = "detect_layout:main"
|
||||
surya_gui = "run_ocr_app:run_app"
|
||||
surya_order = "reading_order:main"
|
||||
surya_table = "table_recognition:main"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
|
||||
@ -33,10 +33,10 @@ def main():
|
||||
det_processor = load_det_processor()
|
||||
|
||||
if os.path.isdir(args.input_path):
|
||||
images, names = load_from_folder(args.input_path, args.max)
|
||||
images, names, _ = load_from_folder(args.input_path, args.max)
|
||||
folder_name = os.path.basename(args.input_path)
|
||||
else:
|
||||
images, names = load_from_file(args.input_path, args.max)
|
||||
images, names, _ = load_from_file(args.input_path, args.max)
|
||||
folder_name = os.path.basename(args.input_path).split(".")[0]
|
||||
|
||||
line_predictions = batch_text_detection(images, det_model, det_processor)
|
||||
|
||||
@ -27,6 +27,14 @@ def verify_order(data):
|
||||
raise ValueError("Scores do not meet the required threshold")
|
||||
|
||||
|
||||
def verify_table_rec(data):
|
||||
row_score = data["surya"]["mean_row_iou"]
|
||||
col_score = data["surya"]["mean_col_iou"]
|
||||
|
||||
if row_score < 0.75 or col_score < 0.75:
|
||||
raise ValueError("Scores do not meet the required threshold")
|
||||
|
||||
|
||||
def verify_scores(file_path, bench_type):
|
||||
with open(file_path, 'r') as file:
|
||||
data = json.load(file)
|
||||
@ -39,6 +47,8 @@ def verify_scores(file_path, bench_type):
|
||||
verify_layout(data)
|
||||
elif bench_type == "ordering":
|
||||
verify_order(data)
|
||||
elif bench_type == "table_recognition":
|
||||
verify_table_rec(data)
|
||||
else:
|
||||
raise ValueError("Invalid benchmark type")
|
||||
|
||||
|
||||
BIN
static/images/benchmark_tablerec_acc.png
Normal file
|
After Width: | Height: | Size: 25 KiB |
BIN
static/images/benchmark_tablerec_speed.png
Normal file
|
After Width: | Height: | Size: 22 KiB |
BIN
static/images/japanese_tablerec.png
Normal file
|
After Width: | Height: | Size: 351 KiB |
BIN
static/images/paper_tablerec.png
Normal file
|
After Width: | Height: | Size: 2.1 MiB |
BIN
static/images/pres_tablerec.png
Normal file
|
After Width: | Height: | Size: 934 KiB |
BIN
static/images/scanned_tablerec.png
Normal file
|
After Width: | Height: | Size: 711 KiB |
BIN
static/images/scanned_tablerec2.png
Normal file
|
After Width: | Height: | Size: 1.6 MiB |
BIN
static/images/table_rec.png
Normal file
|
After Width: | Height: | Size: 2.0 MiB |
@ -4,6 +4,7 @@ from itertools import repeat
|
||||
import numpy as np
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
|
||||
|
||||
def intersection_area(box1, box2):
|
||||
x_left = max(box1[0], box2[0])
|
||||
y_top = max(box1[1], box2[1])
|
||||
@ -15,6 +16,59 @@ def intersection_area(box1, box2):
|
||||
|
||||
return (x_right - x_left) * (y_bottom - y_top)
|
||||
|
||||
def box_area(box):
|
||||
return (box[2] - box[0]) * (box[3] - box[1])
|
||||
|
||||
|
||||
def calculate_iou(box1, box2, box1_only=False):
|
||||
intersection = intersection_area(box1, box2)
|
||||
union = box_area(box1)
|
||||
if not box1_only:
|
||||
union += box_area(box2) - intersection
|
||||
|
||||
if union == 0:
|
||||
return 0
|
||||
return intersection / union
|
||||
|
||||
|
||||
def match_boxes(preds, references):
|
||||
num_actual = len(references)
|
||||
num_predicted = len(preds)
|
||||
|
||||
iou_matrix = np.zeros((num_actual, num_predicted))
|
||||
for i, actual in enumerate(references):
|
||||
for j, pred in enumerate(preds):
|
||||
iou_matrix[i, j] = calculate_iou(actual, pred, box1_only=True)
|
||||
|
||||
sorted_indices = np.argsort(iou_matrix, axis=None)[::-1]
|
||||
sorted_ious = iou_matrix.flatten()[sorted_indices]
|
||||
actual_indices, predicted_indices = np.unravel_index(sorted_indices, iou_matrix.shape)
|
||||
|
||||
assigned_actual = set()
|
||||
assigned_pred = set()
|
||||
|
||||
matches = []
|
||||
for idx, iou in zip(zip(actual_indices, predicted_indices), sorted_ious):
|
||||
i, j = idx
|
||||
if i not in assigned_actual and j not in assigned_pred:
|
||||
iou_val = iou_matrix[i, j]
|
||||
if iou_val > .95: # Account for rounding on box edges
|
||||
iou_val = 1.0
|
||||
matches.append((i, j, iou_val))
|
||||
assigned_actual.add(i)
|
||||
assigned_pred.add(j)
|
||||
|
||||
unassigned_actual = set(range(num_actual)) - assigned_actual
|
||||
unassigned_pred = set(range(num_predicted)) - assigned_pred
|
||||
matches.extend([(i, None, -1.0) for i in unassigned_actual])
|
||||
matches.extend([(None, j, 0.0) for j in unassigned_pred])
|
||||
|
||||
return matches
|
||||
|
||||
def penalized_iou_score(preds, references):
|
||||
matches = match_boxes(preds, references)
|
||||
iou = sum([match[2] for match in matches]) / len(matches)
|
||||
return iou
|
||||
|
||||
def intersection_pixels(box1, box2):
|
||||
x_left = max(box1[0], box2[0])
|
||||
|
||||
117
surya/benchmark/tatr.py
Normal file
@ -0,0 +1,117 @@
|
||||
import torch
|
||||
from transformers import DetrFeatureExtractor, AutoModelForObjectDetection
|
||||
from surya.settings import settings
|
||||
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
|
||||
class MaxResize(object):
|
||||
def __init__(self, max_size=800):
|
||||
self.max_size = max_size
|
||||
|
||||
def __call__(self, image):
|
||||
width, height = image.size
|
||||
current_max_size = max(width, height)
|
||||
scale = self.max_size / current_max_size
|
||||
resized_image = image.resize((int(round(scale * width)), int(round(scale * height))))
|
||||
|
||||
return resized_image
|
||||
|
||||
|
||||
def to_tensor(image):
|
||||
# Convert PIL Image to NumPy array
|
||||
np_image = np.array(image).astype(np.float32)
|
||||
|
||||
# Rearrange dimensions to [C, H, W] format
|
||||
np_image = np_image.transpose((2, 0, 1))
|
||||
|
||||
# Normalize to [0.0, 1.0]
|
||||
np_image /= 255.0
|
||||
|
||||
return torch.from_numpy(np_image)
|
||||
|
||||
|
||||
def normalize(tensor, mean, std):
|
||||
for t, m, s in zip(tensor, mean, std):
|
||||
t.sub_(m).div_(s)
|
||||
return tensor
|
||||
|
||||
|
||||
def structure_transform(image):
|
||||
image = MaxResize(1000)(image)
|
||||
tensor = to_tensor(image)
|
||||
normalized_tensor = normalize(tensor, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||
return normalized_tensor
|
||||
|
||||
|
||||
def box_cxcywh_to_xyxy(x):
|
||||
x_c, y_c, w, h = x.unbind(-1)
|
||||
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
|
||||
return torch.stack(b, dim=1)
|
||||
|
||||
|
||||
def rescale_bboxes(out_bbox, size):
|
||||
width, height = size
|
||||
boxes = box_cxcywh_to_xyxy(out_bbox)
|
||||
boxes = boxes * torch.tensor([width, height, width, height], dtype=torch.float32)
|
||||
return boxes
|
||||
|
||||
|
||||
def outputs_to_objects(outputs, img_sizes, id2label):
|
||||
m = outputs.logits.softmax(-1).max(-1)
|
||||
batch_labels = list(m.indices.detach().cpu().numpy())
|
||||
batch_scores = list(m.values.detach().cpu().numpy())
|
||||
batch_bboxes = outputs['pred_boxes'].detach().cpu()
|
||||
|
||||
batch_objects = []
|
||||
for i in range(len(img_sizes)):
|
||||
pred_bboxes = [elem.tolist() for elem in rescale_bboxes(batch_bboxes[i], img_sizes[i])]
|
||||
pred_scores = batch_scores[i]
|
||||
pred_labels = batch_labels[i]
|
||||
|
||||
objects = []
|
||||
for label, score, bbox in zip(pred_labels, pred_scores, pred_bboxes):
|
||||
class_label = id2label[int(label)]
|
||||
if not class_label == 'no object':
|
||||
objects.append({
|
||||
'label': class_label,
|
||||
'score': float(score),
|
||||
'bbox': [float(elem) for elem in bbox]}
|
||||
)
|
||||
|
||||
rows = []
|
||||
cols = []
|
||||
for i, cell in enumerate(objects):
|
||||
if cell["label"] == "table column":
|
||||
cols.append(cell)
|
||||
|
||||
if cell["label"] == "table row":
|
||||
rows.append(cell)
|
||||
batch_objects.append({
|
||||
"rows": rows,
|
||||
"cols": cols
|
||||
})
|
||||
|
||||
return batch_objects
|
||||
|
||||
|
||||
def load_tatr():
|
||||
return AutoModelForObjectDetection.from_pretrained("microsoft/table-transformer-structure-recognition-v1.1-all").to(settings.TORCH_DEVICE_MODEL)
|
||||
|
||||
|
||||
def batch_inference_tatr(model, images, batch_size):
|
||||
device = model.device
|
||||
rows_cols = []
|
||||
for i in range(0, len(images), batch_size):
|
||||
batch_images = images[i:i + batch_size]
|
||||
pixel_values = torch.stack([structure_transform(img) for img in batch_images], dim=0).to(device)
|
||||
|
||||
# forward pass
|
||||
with torch.no_grad():
|
||||
outputs = model(pixel_values)
|
||||
|
||||
id2label = model.config.id2label
|
||||
id2label[len(model.config.id2label)] = "no object"
|
||||
rows_cols.extend(outputs_to_objects(outputs, [img.size for img in batch_images], id2label))
|
||||
return rows_cols
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import List, Tuple
|
||||
from typing import List, Tuple, Generator
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
@ -26,14 +26,17 @@ def get_batch_size():
|
||||
return batch_size
|
||||
|
||||
|
||||
def batch_detection(images: List, model: EfficientViTForSemanticSegmentation, processor, batch_size=None) -> Tuple[List[List[np.ndarray]], List[Tuple[int, int]]]:
|
||||
def batch_detection(
|
||||
images: List,
|
||||
model: EfficientViTForSemanticSegmentation,
|
||||
processor,
|
||||
batch_size=None
|
||||
) -> Generator[Tuple[List[List[np.ndarray]], List[Tuple[int, int]]], None, None]:
|
||||
assert all([isinstance(image, Image.Image) for image in images])
|
||||
if batch_size is None:
|
||||
batch_size = get_batch_size()
|
||||
heatmap_count = model.config.num_labels
|
||||
|
||||
images = [image.convert("RGB") for image in images] # also copies the images
|
||||
|
||||
orig_sizes = [image.size for image in images]
|
||||
splits_per_image = [get_total_splits(size, processor) for size in orig_sizes]
|
||||
|
||||
@ -52,10 +55,9 @@ def batch_detection(images: List, model: EfficientViTForSemanticSegmentation, pr
|
||||
if len(current_batch) > 0:
|
||||
batches.append(current_batch)
|
||||
|
||||
all_preds = []
|
||||
for batch_idx in tqdm(range(len(batches)), desc="Detecting bboxes"):
|
||||
batch_image_idxs = batches[batch_idx]
|
||||
batch_images = convert_if_not_rgb([images[j] for j in batch_image_idxs])
|
||||
batch_images = [images[j].convert("RGB") for j in batch_image_idxs]
|
||||
|
||||
split_index = []
|
||||
split_heights = []
|
||||
@ -98,11 +100,7 @@ def batch_detection(images: List, model: EfficientViTForSemanticSegmentation, pr
|
||||
heatmaps[k] = np.vstack([heatmaps[k], pred_heatmaps[k]])
|
||||
preds[idx] = heatmaps
|
||||
|
||||
all_preds.extend(preds)
|
||||
|
||||
assert len(all_preds) == len(images)
|
||||
assert all([len(pred) == heatmap_count for pred in all_preds])
|
||||
return all_preds, orig_sizes
|
||||
yield preds, [orig_sizes[j] for j in batch_image_idxs]
|
||||
|
||||
|
||||
def parallel_get_lines(preds, orig_sizes):
|
||||
@ -125,16 +123,21 @@ def parallel_get_lines(preds, orig_sizes):
|
||||
|
||||
|
||||
def batch_text_detection(images: List, model, processor, batch_size=None) -> List[TextDetectionResult]:
|
||||
preds, orig_sizes = batch_detection(images, model, processor, batch_size=batch_size)
|
||||
detection_generator = batch_detection(images, model, processor, batch_size=batch_size)
|
||||
|
||||
results = []
|
||||
if settings.IN_STREAMLIT or len(images) < settings.DETECTOR_MIN_PARALLEL_THRESH: # Ensures we don't parallelize with streamlit, or with very few images
|
||||
for i in range(len(images)):
|
||||
result = parallel_get_lines(preds[i], orig_sizes[i])
|
||||
results.append(result)
|
||||
else:
|
||||
max_workers = min(settings.DETECTOR_POSTPROCESSING_CPU_WORKERS, len(images))
|
||||
max_workers = min(settings.DETECTOR_POSTPROCESSING_CPU_WORKERS, len(images))
|
||||
parallelize = not settings.IN_STREAMLIT and len(images) >= settings.DETECTOR_MIN_PARALLEL_THRESH
|
||||
|
||||
if parallelize:
|
||||
with ProcessPoolExecutor(max_workers=max_workers) as executor:
|
||||
results = list(executor.map(parallel_get_lines, preds, orig_sizes))
|
||||
for preds, orig_sizes in detection_generator:
|
||||
batch_results = list(executor.map(parallel_get_lines, preds, orig_sizes))
|
||||
results.extend(batch_results)
|
||||
else:
|
||||
for preds, orig_sizes in detection_generator:
|
||||
for pred, orig_size in zip(preds, orig_sizes):
|
||||
results.append(parallel_get_lines(pred, orig_size))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@ -1,17 +1,19 @@
|
||||
import PIL
|
||||
|
||||
from surya.input.processing import open_pdf, get_page_images
|
||||
from surya.settings import settings
|
||||
import os
|
||||
import filetype
|
||||
from PIL import Image
|
||||
import json
|
||||
|
||||
|
||||
|
||||
def get_name_from_path(path):
|
||||
return os.path.basename(path).split(".")[0]
|
||||
|
||||
|
||||
def load_pdf(pdf_path, max_pages=None, start_page=None):
|
||||
def load_pdf(pdf_path, max_pages=None, start_page=None, dpi=settings.IMAGE_DPI, load_text_lines=False):
|
||||
doc = open_pdf(pdf_path)
|
||||
last_page = len(doc)
|
||||
|
||||
@ -25,47 +27,58 @@ def load_pdf(pdf_path, max_pages=None, start_page=None):
|
||||
last_page = min(start_page + max_pages, last_page)
|
||||
|
||||
page_indices = list(range(start_page, last_page))
|
||||
images = get_page_images(doc, page_indices)
|
||||
images = get_page_images(doc, page_indices, dpi=dpi)
|
||||
text_lines = None
|
||||
if load_text_lines:
|
||||
from surya.input.pdflines import get_page_text_lines # Putting import here because pypdfium2 causes warnings if its not the top import
|
||||
text_lines = get_page_text_lines(
|
||||
pdf_path,
|
||||
page_indices,
|
||||
[i.size for i in images]
|
||||
)
|
||||
doc.close()
|
||||
names = [get_name_from_path(pdf_path) for _ in page_indices]
|
||||
return images, names
|
||||
return images, names, text_lines
|
||||
|
||||
|
||||
def load_image(image_path):
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
name = get_name_from_path(image_path)
|
||||
return [image], [name]
|
||||
return [image], [name], [None]
|
||||
|
||||
|
||||
def load_from_file(input_path, max_pages=None, start_page=None):
|
||||
def load_from_file(input_path, max_pages=None, start_page=None, dpi=settings.IMAGE_DPI, load_text_lines=False):
|
||||
input_type = filetype.guess(input_path)
|
||||
if input_type.extension == "pdf":
|
||||
return load_pdf(input_path, max_pages, start_page)
|
||||
return load_pdf(input_path, max_pages, start_page, dpi=dpi, load_text_lines=load_text_lines)
|
||||
else:
|
||||
return load_image(input_path)
|
||||
|
||||
|
||||
def load_from_folder(folder_path, max_pages=None, start_page=None):
|
||||
def load_from_folder(folder_path, max_pages=None, start_page=None, dpi=settings.IMAGE_DPI, load_text_lines=False):
|
||||
image_paths = [os.path.join(folder_path, image_name) for image_name in os.listdir(folder_path) if not image_name.startswith(".")]
|
||||
image_paths = [ip for ip in image_paths if not os.path.isdir(ip)]
|
||||
|
||||
images = []
|
||||
names = []
|
||||
text_lines = []
|
||||
for path in image_paths:
|
||||
extension = filetype.guess(path)
|
||||
if extension and extension.extension == "pdf":
|
||||
image, name = load_pdf(path, max_pages, start_page)
|
||||
image, name, text_line = load_pdf(path, max_pages, start_page, dpi=dpi, load_text_lines=load_text_lines)
|
||||
images.extend(image)
|
||||
names.extend(name)
|
||||
text_lines.extend(text_line)
|
||||
else:
|
||||
try:
|
||||
image, name = load_image(path)
|
||||
image, name, text_line = load_image(path)
|
||||
images.extend(image)
|
||||
names.extend(name)
|
||||
text_lines.extend(text_line)
|
||||
except PIL.UnidentifiedImageError:
|
||||
print(f"Could not load image {path}")
|
||||
continue
|
||||
return images, names
|
||||
return images, names, text_lines
|
||||
|
||||
|
||||
def load_lang_file(lang_path, names):
|
||||
|
||||
86
surya/input/pdflines.py
Normal file
@ -0,0 +1,86 @@
|
||||
from pdftext.extraction import dictionary_output
|
||||
|
||||
from surya.postprocessing.text import sort_text_lines
|
||||
from surya.schema import PolygonBox
|
||||
|
||||
|
||||
def get_page_text_lines(filepath: str, page_idxs: list, out_sizes: list) -> list:
|
||||
assert len(page_idxs) == len(out_sizes)
|
||||
pages_text = dictionary_output(filepath, sort=False, page_range=page_idxs, keep_chars=True)
|
||||
for full_text, out_size in zip(pages_text, out_sizes):
|
||||
width = full_text["width"]
|
||||
height = full_text["height"]
|
||||
text_w_scale = out_size[0] / width
|
||||
text_h_scale = out_size[1] / height
|
||||
for block in full_text["blocks"]:
|
||||
for line in block["lines"]:
|
||||
line["bbox"] = [line["bbox"][0] * text_w_scale, line["bbox"][1] * text_h_scale,
|
||||
line["bbox"][2] * text_w_scale, line["bbox"][3] * text_h_scale]
|
||||
for span in line["spans"]:
|
||||
for char in span["chars"]:
|
||||
char["bbox"] = [char["bbox"][0] * text_w_scale, char["bbox"][1] * text_h_scale,
|
||||
char["bbox"][2] * text_w_scale, char["bbox"][3] * text_h_scale]
|
||||
return pages_text
|
||||
|
||||
|
||||
def get_table_blocks(tables: list, full_text: dict, img_size: list, table_thresh=.8):
|
||||
# Returns coordinates relative to input table, not full image
|
||||
table_texts = []
|
||||
for table in tables:
|
||||
table_poly = PolygonBox(polygon=[
|
||||
[table[0], table[1]],
|
||||
[table[2], table[1]],
|
||||
[table[2], table[3]],
|
||||
[table[0], table[3]]
|
||||
])
|
||||
table_text = []
|
||||
rotation = full_text["rotation"]
|
||||
for block in full_text["blocks"]:
|
||||
for line in block["lines"]:
|
||||
line_poly = PolygonBox(polygon=[
|
||||
[line["bbox"][0], line["bbox"][1]],
|
||||
[line["bbox"][2], line["bbox"][1]],
|
||||
[line["bbox"][2], line["bbox"][3]],
|
||||
[line["bbox"][0], line["bbox"][3]]
|
||||
])
|
||||
if line_poly.intersection_pct(table_poly) < table_thresh:
|
||||
continue
|
||||
curr_span = None
|
||||
curr_box = None
|
||||
for span in line["spans"]:
|
||||
for char in span["chars"]:
|
||||
same_span = False
|
||||
if curr_span:
|
||||
if rotation == 90:
|
||||
same_span = (char["bbox"][0] - curr_box[0]) / img_size[0] < 0.01 and abs(char["bbox"][1] - curr_box[3]) / img_size[1] < 0.01
|
||||
elif rotation == 180:
|
||||
same_span = (char["bbox"][2] - curr_box[0]) / img_size[0] < 0.01 and (char["bbox"][1] - curr_box[1]) / img_size[1] < 0.01
|
||||
elif rotation == 270:
|
||||
same_span = (char["bbox"][0] - curr_box[0]) / img_size[0] < 0.01 and abs(char["bbox"][3] - curr_box[1]) / img_size[1] < 0.01
|
||||
else:
|
||||
same_span = (char["bbox"][0] - curr_box[2]) / img_size[0] < 0.01 and (char["bbox"][1] - curr_box[1]) / img_size[1] < 0.01
|
||||
|
||||
if curr_span is None:
|
||||
curr_span = char["char"]
|
||||
curr_box = char["bbox"]
|
||||
elif same_span:
|
||||
curr_span += char["char"]
|
||||
curr_box = [min(curr_box[0], char["bbox"][0]), min(curr_box[1], char["bbox"][1]),
|
||||
max(curr_box[2], char["bbox"][2]), max(curr_box[3], char["bbox"][3])]
|
||||
else:
|
||||
table_text.append({"text": curr_span, "bbox": curr_box})
|
||||
curr_span = char["char"]
|
||||
curr_box = char["bbox"]
|
||||
if curr_span is not None:
|
||||
table_text.append({"text": curr_span, "bbox": curr_box})
|
||||
# Adjust to be relative to input table
|
||||
for item in table_text:
|
||||
item["bbox"] = [
|
||||
item["bbox"][0] - table[0],
|
||||
item["bbox"][1] - table[1],
|
||||
item["bbox"][2] - table[0],
|
||||
item["bbox"][3] - table[1]
|
||||
]
|
||||
table_text = sort_text_lines(table_text)
|
||||
table_texts.append(table_text)
|
||||
return table_texts
|
||||
@ -115,4 +115,4 @@ def slice_and_pad_poly(image_array: np.array, coordinates):
|
||||
cropped_polygon[mask == 0] = settings.RECOGNITION_PAD_VALUE
|
||||
rectangle_image = Image.fromarray(cropped_polygon)
|
||||
|
||||
return rectangle_image
|
||||
return rectangle_image
|
||||
|
||||
@ -12,7 +12,7 @@ from surya.settings import settings
|
||||
|
||||
def get_regions_from_detection_result(detection_result: TextDetectionResult, heatmaps: List[np.ndarray], orig_size, id2label, segment_assignment, vertical_line_width=20) -> List[LayoutBox]:
|
||||
logits = np.stack(heatmaps, axis=0)
|
||||
vertical_line_bboxes = [line for line in detection_result.vertical_lines]
|
||||
vertical_line_bboxes = detection_result.vertical_lines
|
||||
line_bboxes = detection_result.bboxes
|
||||
|
||||
# Scale back to processor size
|
||||
@ -38,6 +38,8 @@ def get_regions_from_detection_result(detection_result: TextDetectionResult, hea
|
||||
detected_boxes = []
|
||||
for heatmap_idx in range(1, len(id2label)): # Skip the blank class
|
||||
heatmap = logits[heatmap_idx]
|
||||
if np.max(heatmap) < settings.DETECTOR_BLANK_THRESHOLD:
|
||||
continue
|
||||
bboxes = get_detected_boxes(heatmap)
|
||||
bboxes = [bbox for bbox in bboxes if bbox.area > 25]
|
||||
for bb in bboxes:
|
||||
@ -89,14 +91,12 @@ def get_regions_from_detection_result(detection_result: TextDetectionResult, hea
|
||||
max_x = max(max_x, max_x_box)
|
||||
max_y = max(max_y, max_y_box)
|
||||
|
||||
bbox.polygon[0][0] = min_x
|
||||
bbox.polygon[0][1] = min_y
|
||||
bbox.polygon[1][0] = max_x
|
||||
bbox.polygon[1][1] = min_y
|
||||
bbox.polygon[2][0] = max_x
|
||||
bbox.polygon[2][1] = max_y
|
||||
bbox.polygon[3][0] = min_x
|
||||
bbox.polygon[3][1] = max_y
|
||||
bbox.polygon = [
|
||||
[min_x, min_y],
|
||||
[max_x, min_y],
|
||||
[max_x, max_y],
|
||||
[min_x, max_y]
|
||||
]
|
||||
|
||||
if bbox_idx in box_lines and bbox.label in ["Picture"]:
|
||||
bbox.label = "Figure"
|
||||
@ -104,17 +104,18 @@ def get_regions_from_detection_result(detection_result: TextDetectionResult, hea
|
||||
new_boxes.append(bbox)
|
||||
|
||||
# Merge tables together (sometimes one column is detected as a separate table)
|
||||
for i in range(5): # Up to 5 rounds of merging
|
||||
mergeable_types = ["Table", "Picture", "Figure"]
|
||||
for ftype in mergeable_types:
|
||||
to_remove = set()
|
||||
for bbox_idx, bbox in enumerate(new_boxes):
|
||||
if bbox.label != "Table" or bbox_idx in to_remove:
|
||||
if bbox.label != ftype or bbox_idx in to_remove:
|
||||
continue
|
||||
|
||||
for bbox_idx2, bbox2 in enumerate(new_boxes):
|
||||
if bbox2.label != "Table" or bbox_idx2 in to_remove or bbox_idx == bbox_idx2:
|
||||
if bbox2.label != ftype or bbox_idx2 in to_remove or bbox_idx == bbox_idx2:
|
||||
continue
|
||||
|
||||
if bbox.intersection_pct(bbox2) > 0:
|
||||
if bbox.intersection_pct(bbox2, x_margin=.25) > .1:
|
||||
bbox.merge(bbox2)
|
||||
to_remove.add(bbox_idx2)
|
||||
|
||||
@ -151,10 +152,14 @@ def get_regions(heatmaps: List[np.ndarray], orig_size, id2label, segment_assignm
|
||||
heatmap = heatmaps[i]
|
||||
assert heatmap.shape == segment_assignment.shape
|
||||
heatmap[segment_assignment != i] = 0 # zero out where another segment is
|
||||
|
||||
# Skip processing empty labels
|
||||
if np.max(heatmap) < settings.DETECTOR_BLANK_THRESHOLD:
|
||||
continue
|
||||
|
||||
bbox = get_and_clean_boxes(heatmap, list(reversed(heatmap.shape)), orig_size)
|
||||
for bb in bbox:
|
||||
bboxes.append(LayoutBox(polygon=bb.polygon, label=id2label[i]))
|
||||
heatmaps.append(heatmap)
|
||||
|
||||
bboxes = keep_largest_boxes(bboxes)
|
||||
return bboxes
|
||||
@ -182,23 +187,43 @@ def parallel_get_regions(heatmaps: List[np.ndarray], orig_size, id2label, detect
|
||||
|
||||
|
||||
def batch_layout_detection(images: List, model, processor, detection_results: Optional[List[TextDetectionResult]] = None, batch_size=None) -> List[LayoutResult]:
|
||||
preds, orig_sizes = batch_detection(images, model, processor, batch_size=batch_size)
|
||||
layout_generator = batch_detection(images, model, processor, batch_size=batch_size)
|
||||
id2label = model.config.id2label
|
||||
|
||||
results = []
|
||||
if settings.IN_STREAMLIT or len(images) < settings.DETECTOR_MIN_PARALLEL_THRESH: # Ensures we don't parallelize with streamlit or too few images
|
||||
for i in range(len(images)):
|
||||
result = parallel_get_regions(preds[i], orig_sizes[i], id2label, detection_results[i] if detection_results else None)
|
||||
results.append(result)
|
||||
else:
|
||||
futures = []
|
||||
max_workers = min(settings.DETECTOR_POSTPROCESSING_CPU_WORKERS, len(images))
|
||||
with ProcessPoolExecutor(max_workers=max_workers) as executor:
|
||||
for i in range(len(images)):
|
||||
future = executor.submit(parallel_get_regions, preds[i], orig_sizes[i], id2label, detection_results[i] if detection_results else None)
|
||||
futures.append(future)
|
||||
max_workers = min(settings.DETECTOR_POSTPROCESSING_CPU_WORKERS, len(images))
|
||||
parallelize = not settings.IN_STREAMLIT and len(images) >= settings.DETECTOR_MIN_PARALLEL_THRESH
|
||||
|
||||
for future in futures:
|
||||
results.append(future.result())
|
||||
if parallelize:
|
||||
with ProcessPoolExecutor(max_workers=max_workers) as executor:
|
||||
img_idx = 0
|
||||
for preds, orig_sizes in layout_generator:
|
||||
futures = []
|
||||
for pred, orig_size in zip(preds, orig_sizes):
|
||||
future = executor.submit(
|
||||
parallel_get_regions,
|
||||
pred,
|
||||
orig_size,
|
||||
id2label,
|
||||
detection_results[img_idx] if detection_results else None
|
||||
)
|
||||
|
||||
futures.append(future)
|
||||
img_idx += 1
|
||||
|
||||
for future in futures:
|
||||
results.append(future.result())
|
||||
else:
|
||||
img_idx = 0
|
||||
for preds, orig_sizes in layout_generator:
|
||||
for pred, orig_size in zip(preds, orig_sizes):
|
||||
results.append(parallel_get_regions(
|
||||
pred,
|
||||
orig_size,
|
||||
id2label,
|
||||
detection_results[img_idx] if detection_results else None
|
||||
))
|
||||
|
||||
img_idx += 1
|
||||
|
||||
return results
|
||||
260
surya/model/table_rec/config.py
Normal file
@ -0,0 +1,260 @@
|
||||
from transformers import PretrainedConfig
|
||||
from surya.settings import settings
|
||||
|
||||
BOX_DIM = 1024
|
||||
SPECIAL_TOKENS = 7
|
||||
MAX_ROWS = 384
|
||||
|
||||
|
||||
class SuryaTableRecConfig(PretrainedConfig):
|
||||
model_type = "vision-encoder-decoder"
|
||||
is_composition = True
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
encoder_config = kwargs.pop("encoder")
|
||||
decoder_config = kwargs.pop("decoder")
|
||||
text_enc_config = kwargs.pop("text_encoder")
|
||||
|
||||
self.encoder = encoder_config
|
||||
self.decoder = decoder_config
|
||||
self.text_encoder = text_enc_config
|
||||
self.is_encoder_decoder = True
|
||||
|
||||
if isinstance(decoder_config, dict):
|
||||
self.decoder_start_token_id = decoder_config["bos_token_id"]
|
||||
self.pad_token_id = decoder_config["pad_token_id"]
|
||||
self.eos_token_id = decoder_config["eos_token_id"]
|
||||
else:
|
||||
self.decoder_start_token_id = decoder_config.bos_token_id
|
||||
self.pad_token_id = decoder_config.pad_token_id
|
||||
self.eos_token_id = decoder_config.eos_token_id
|
||||
|
||||
|
||||
class DonutSwinTableRecConfig(PretrainedConfig):
|
||||
model_type = "donut-swin"
|
||||
|
||||
attribute_map = {
|
||||
"num_attention_heads": "num_heads",
|
||||
"num_hidden_layers": "num_layers",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_size=(settings.TABLE_REC_IMAGE_SIZE["width"], settings.TABLE_REC_IMAGE_SIZE["height"]),
|
||||
patch_size=4,
|
||||
num_channels=3,
|
||||
embed_dim=128,
|
||||
depths=[2, 2, 14, 2],
|
||||
num_heads=[4, 8, 16, 32],
|
||||
num_kv_heads=[4, 8, 16, 32],
|
||||
window_size=8,
|
||||
mlp_ratio=4.0,
|
||||
qkv_bias=True,
|
||||
hidden_dropout_prob=0.0,
|
||||
attention_probs_dropout_prob=0.0,
|
||||
drop_path_rate=0.1,
|
||||
hidden_act="gelu",
|
||||
use_absolute_embeddings=True,
|
||||
initializer_range=0.02,
|
||||
layer_norm_eps=1e-5,
|
||||
encoder_length=1024,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.image_size = image_size
|
||||
self.patch_size = patch_size
|
||||
self.num_channels = num_channels
|
||||
self.embed_dim = embed_dim
|
||||
self.depths = depths
|
||||
self.num_layers = len(depths)
|
||||
self.num_heads = num_heads
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.window_size = window_size
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.qkv_bias = qkv_bias
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.drop_path_rate = drop_path_rate
|
||||
self.hidden_act = hidden_act
|
||||
self.use_absolute_embeddings = use_absolute_embeddings
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.initializer_range = initializer_range
|
||||
# we set the hidden_size attribute in order to make Swin work with VisionEncoderDecoderModel
|
||||
# this indicates the channel dimension after the last stage of the model
|
||||
self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))
|
||||
self.encoder_length = encoder_length
|
||||
|
||||
|
||||
class SuryaTableRecDecoderConfig(PretrainedConfig):
|
||||
model_type = "surya_tablerec"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_hidden_layers=3,
|
||||
vocab_size=settings.TABLE_REC_MAX_ROWS + SPECIAL_TOKENS,
|
||||
hidden_size=512,
|
||||
intermediate_size=4 * 512,
|
||||
encoder_hidden_size=1024,
|
||||
num_attention_heads=8,
|
||||
lru_width=None,
|
||||
attention_window_size=16,
|
||||
conv1d_width=4,
|
||||
logits_soft_cap=30.0,
|
||||
rms_norm_eps=1e-6,
|
||||
use_cache=True,
|
||||
pad_token_id=0,
|
||||
eos_token_id=1,
|
||||
bos_token_id=2,
|
||||
hidden_activation="gelu_pytorch_tanh",
|
||||
rope_theta=10000.0,
|
||||
block_types=("attention",),
|
||||
cross_attn_layers=(0, 1, 2, 3),
|
||||
encoder_cross_attn_layers=(0, 1, 2, 3),
|
||||
self_attn_layers=(0, 1, 2, 3),
|
||||
global_attn_layers=(0, 1, 2, 3),
|
||||
attention_dropout=0.0,
|
||||
num_key_value_heads=4,
|
||||
attention_bias=False,
|
||||
w_init_variance_scale=0.01,
|
||||
init_std=0.02,
|
||||
tie_word_embeddings=False,
|
||||
aux_heads=0, # How many n-token-ahead heads to add
|
||||
causal=True,
|
||||
max_classes=2 + SPECIAL_TOKENS,
|
||||
max_width=1024 + SPECIAL_TOKENS,
|
||||
max_height=1024 + SPECIAL_TOKENS,
|
||||
out_box_size=1024,
|
||||
**kwargs,
|
||||
):
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.lru_width = lru_width if lru_width is not None else hidden_size
|
||||
self.attention_window_size = attention_window_size
|
||||
self.conv1d_width = conv1d_width
|
||||
self.logits_soft_cap = logits_soft_cap
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.block_types = list(block_types)
|
||||
self.hidden_activation = hidden_activation
|
||||
self.head_dim = self.hidden_size // self.num_attention_heads
|
||||
self.num_key_value_heads = num_key_value_heads if num_key_value_heads is not None else num_attention_heads
|
||||
if self.num_key_value_heads > self.num_attention_heads:
|
||||
raise ValueError("The number of `num_key_value_heads` must be smaller than `num_attention_heads`")
|
||||
self.cross_attn_layers = cross_attn_layers
|
||||
self.self_attn_layers = self_attn_layers
|
||||
self.global_attn_layers = global_attn_layers
|
||||
self.attention_dropout = attention_dropout
|
||||
self.attention_bias = attention_bias
|
||||
self.w_init_variance_scale = w_init_variance_scale
|
||||
self.final_w_init_variance_scale = 2.0 / self.num_hidden_layers
|
||||
self.init_std = init_std
|
||||
self.tie_word_embeddings = tie_word_embeddings
|
||||
self.aux_heads = aux_heads
|
||||
self.encoder_hidden_size=encoder_hidden_size
|
||||
self.causal = causal
|
||||
self.encoder_cross_attn_layers = encoder_cross_attn_layers
|
||||
self.max_classes = max_classes
|
||||
self.max_width = max_width
|
||||
self.max_height = max_height
|
||||
self.out_box_size = out_box_size
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
def layers_block_type(self):
|
||||
return (self.block_types * 100)[: self.num_hidden_layers]
|
||||
|
||||
|
||||
class SuryaTableRecTextEncoderConfig(PretrainedConfig):
|
||||
model_type = "surya_tablerec"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_hidden_layers=4,
|
||||
vocab_size=settings.TABLE_REC_MAX_ROWS + SPECIAL_TOKENS,
|
||||
hidden_size=1024,
|
||||
intermediate_size=4 * 1024,
|
||||
encoder_hidden_size=1024,
|
||||
num_attention_heads=16,
|
||||
lru_width=None,
|
||||
attention_window_size=16,
|
||||
conv1d_width=4,
|
||||
logits_soft_cap=30.0,
|
||||
rms_norm_eps=1e-6,
|
||||
use_cache=True,
|
||||
pad_token_id=0,
|
||||
eos_token_id=1,
|
||||
bos_token_id=2,
|
||||
hidden_activation="gelu_pytorch_tanh",
|
||||
rope_theta=10000.0,
|
||||
block_types=("attention",),
|
||||
cross_attn_layers=(0, 1, 2, 3, 4, 5),
|
||||
self_attn_layers=(0, 1, 2, 3, 4, 5),
|
||||
global_attn_layers=(0, 1, 2, 3, 4, 5),
|
||||
attention_dropout=0.0,
|
||||
num_key_value_heads=16,
|
||||
attention_bias=False,
|
||||
w_init_variance_scale=0.01,
|
||||
init_std=0.02,
|
||||
tie_word_embeddings=False,
|
||||
causal=False,
|
||||
max_width=BOX_DIM + SPECIAL_TOKENS,
|
||||
max_height=BOX_DIM + SPECIAL_TOKENS,
|
||||
max_position_embeddings=1024,
|
||||
**kwargs,
|
||||
):
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.lru_width = lru_width if lru_width is not None else hidden_size
|
||||
self.attention_window_size = attention_window_size
|
||||
self.conv1d_width = conv1d_width
|
||||
self.logits_soft_cap = logits_soft_cap
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.block_types = list(block_types)
|
||||
self.hidden_activation = hidden_activation
|
||||
self.head_dim = self.hidden_size // self.num_attention_heads
|
||||
self.num_key_value_heads = num_key_value_heads if num_key_value_heads is not None else num_attention_heads
|
||||
if self.num_key_value_heads > self.num_attention_heads:
|
||||
raise ValueError("The number of `num_key_value_heads` must be smaller than `num_attention_heads`")
|
||||
self.cross_attn_layers = cross_attn_layers
|
||||
self.self_attn_layers = self_attn_layers
|
||||
self.global_attn_layers = global_attn_layers
|
||||
self.attention_dropout = attention_dropout
|
||||
self.attention_bias = attention_bias
|
||||
self.w_init_variance_scale = w_init_variance_scale
|
||||
self.final_w_init_variance_scale = 2.0 / self.num_hidden_layers
|
||||
self.init_std = init_std
|
||||
self.tie_word_embeddings = tie_word_embeddings
|
||||
self.encoder_hidden_size = encoder_hidden_size
|
||||
self.causal = causal
|
||||
self.max_width = max_width
|
||||
self.max_height = max_height
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
def layers_block_type(self):
|
||||
return (self.block_types * 100)[: self.num_hidden_layers]
|
||||
795
surya/model/table_rec/decoder.py
Normal file
@ -0,0 +1,795 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from transformers.utils import ModelOutput
|
||||
|
||||
from surya.model.table_rec.config import SuryaTableRecDecoderConfig, SuryaTableRecTextEncoderConfig
|
||||
from transformers import PreTrainedModel
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from transformers.modeling_outputs import BaseModelOutputWithNoAttention, CausalLMOutput
|
||||
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
|
||||
|
||||
from surya.settings import settings
|
||||
|
||||
_MAX_SQRT_GRADIENT = 1000.0
|
||||
|
||||
@dataclass
|
||||
class TableRecModelOutput(ModelOutput):
|
||||
bbox_logits: torch.Tensor
|
||||
class_logits: torch.Tensor | None = None
|
||||
hidden_states: torch.Tensor | None = None
|
||||
|
||||
|
||||
class SuryaTableRecDecoderRMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.zeros(dim))
|
||||
|
||||
def _norm(self, x):
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
output = self._norm(x.float())
|
||||
# Llama does x.to(float16) * w whilst SuryaTableRecDecoder is (x * w).to(float16)
|
||||
# See https://github.com/huggingface/transformers/pull/29402
|
||||
output = output * (1.0 + self.weight.float())
|
||||
return output.type_as(x)
|
||||
|
||||
def extra_repr(self):
|
||||
return f"{tuple(self.weight.shape)}, eps={self.eps}"
|
||||
|
||||
|
||||
ALL_LAYERNORM_LAYERS.append(SuryaTableRecDecoderRMSNorm)
|
||||
|
||||
|
||||
class SuryaTableRecDecoderRotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim, base=10000, device=None):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.base = base
|
||||
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
|
||||
self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
|
||||
|
||||
@torch.no_grad()
|
||||
# Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding.forward with Gemma->SuryaTableRecDecoder
|
||||
def forward(self, x, position_ids, seq_len=None):
|
||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||
self.inv_freq.to(x.device)
|
||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||
|
||||
Args:
|
||||
q (`torch.Tensor`): The query tensor.
|
||||
k (`torch.Tensor`): The key tensor.
|
||||
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
||||
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
||||
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
||||
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
||||
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
||||
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
||||
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
||||
Returns:
|
||||
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
||||
"""
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.repeat_kv
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||
"""
|
||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
class SuryaTableRecDecoderSdpaCrossAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper
|
||||
Modified for GQA
|
||||
"""
|
||||
|
||||
def __init__(self, config: SuryaTableRecDecoderConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.head_dim = config.head_dim
|
||||
self.num_key_value_heads = config.num_key_value_heads
|
||||
self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads
|
||||
|
||||
self.q_proj = nn.Linear(self.hidden_size, self.num_attention_heads * self.head_dim, bias=config.attention_bias)
|
||||
self.k_proj = nn.Linear(self.config.encoder_hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
||||
self.v_proj = nn.Linear(self.config.encoder_hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
||||
self.o_proj = nn.Linear(self.num_attention_heads * self.head_dim, self.hidden_size, bias=True)
|
||||
self.rotary_emb = SuryaTableRecDecoderRotaryEmbedding(
|
||||
self.head_dim,
|
||||
base=config.rope_theta,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
use_cache: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
# Encoder attention mask currently ignored
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
_, v_len, _ = encoder_hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
query_states = query_states.view(bsz, q_len, self.num_attention_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
if self.key_states is None:
|
||||
key_states = self.k_proj(encoder_hidden_states)
|
||||
value_states = self.v_proj(encoder_hidden_states)
|
||||
key_states = key_states.view(bsz, v_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, v_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
if use_cache:
|
||||
self._update_cache(key_states, value_states)
|
||||
else:
|
||||
key_states = self.key_states
|
||||
value_states = self.value_states
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states.contiguous(),
|
||||
key_states.contiguous(),
|
||||
value_states.contiguous(),
|
||||
attn_mask=None,
|
||||
dropout_p=self.attention_dropout if self.training else 0.0,
|
||||
scale=self.head_dim**-0.5,
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
return attn_output
|
||||
|
||||
def _setup_cache(self, batch_size, device, dtype=None):
|
||||
# Setup initial caches
|
||||
self.value_states = None
|
||||
self.key_states = None
|
||||
|
||||
@torch.no_grad()
|
||||
def _update_cache(self, key_states, value_states, **cache_kwargs):
|
||||
self.value_states = value_states
|
||||
self.key_states = key_states
|
||||
|
||||
|
||||
class SuryaTableRecDecoderSdpaAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, config: SuryaTableRecDecoderConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.head_dim = config.head_dim
|
||||
self.num_key_value_heads = config.num_key_value_heads
|
||||
self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads
|
||||
|
||||
self.q_proj = nn.Linear(self.hidden_size, self.num_attention_heads * self.head_dim, bias=config.attention_bias)
|
||||
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
||||
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
||||
self.o_proj = nn.Linear(self.num_attention_heads * self.head_dim, self.hidden_size, bias=True)
|
||||
self.rotary_emb = SuryaTableRecDecoderRotaryEmbedding(
|
||||
self.head_dim,
|
||||
base=config.rope_theta,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
use_cache: bool = False,
|
||||
window_attn: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
# Final is bsz, num_attention_heads, seq_len, head_dim
|
||||
query_states = query_states.view(bsz, q_len, self.num_attention_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if use_cache and hasattr(self, "key_states"):
|
||||
cache_kwargs = {"cache_position": cache_position, "window_attn": window_attn}
|
||||
key_states, value_states = self._update_cache(key_states, value_states, **cache_kwargs)
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
causal_mask = attention_mask
|
||||
if attention_mask is not None:
|
||||
# Mask is batch, head, seq_len, kv_len
|
||||
causal_mask = causal_mask[:, :, :, :key_states.shape[-2]]
|
||||
current_cache_position = cache_position[-1].item() if cache_position is not None else None
|
||||
if current_cache_position and settings.RECOGNITION_STATIC_CACHE:
|
||||
# Mask out future cache positions
|
||||
position_mask = torch.ones_like(causal_mask, dtype=torch.bool, device=causal_mask.device)
|
||||
position_mask[:, :, :, :current_cache_position + 1] = False
|
||||
causal_mask = torch.where(position_mask, torch.finfo(causal_mask.dtype).min, causal_mask)
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states.contiguous(),
|
||||
key_states.contiguous(),
|
||||
value_states.contiguous(),
|
||||
attn_mask=causal_mask,
|
||||
dropout_p=self.attention_dropout if self.training else 0.0,
|
||||
scale=self.head_dim**-0.5,
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
return attn_output
|
||||
|
||||
def _setup_cache(self, batch_size, device, dtype=None):
|
||||
if dtype is None and self.config.torch_dtype is not None:
|
||||
dtype = self.config.torch_dtype
|
||||
dtype = dtype if dtype is not None else torch.float32
|
||||
|
||||
# Setup initial caches
|
||||
self.value_states = None
|
||||
self.key_states = None
|
||||
|
||||
if settings.RECOGNITION_STATIC_CACHE:
|
||||
cache_shape = (batch_size, self.num_key_value_heads, settings.RECOGNITION_MAX_TOKENS, self.head_dim)
|
||||
self.value_states = torch.zeros(cache_shape, dtype=dtype, device=device)
|
||||
self.key_states = torch.zeros(cache_shape, dtype=dtype, device=device)
|
||||
|
||||
def _update_static_cache(self, key_states, value_states, **cache_kwargs):
|
||||
cache_position = cache_kwargs.get("cache_position")
|
||||
k_out, v_out = self.key_states.to(key_states.device), self.value_states.to(value_states.device)
|
||||
|
||||
k_out[:, :, cache_position] = key_states.to(k_out.dtype)
|
||||
v_out[:, :, cache_position] = value_states.to(v_out.dtype)
|
||||
|
||||
self.key_states, self.value_states = k_out, v_out
|
||||
return k_out, v_out
|
||||
|
||||
def _update_dynamic_cache(self, key_states, value_states, **cache_kwargs):
|
||||
k_out = key_states
|
||||
if self.key_states is not None:
|
||||
k_out = torch.cat([self.key_states, key_states], dim=2)
|
||||
|
||||
v_out = value_states
|
||||
if self.value_states is not None:
|
||||
v_out = torch.cat([self.value_states, value_states], dim=2)
|
||||
|
||||
self.key_states, self.value_states = k_out, v_out
|
||||
return k_out, v_out
|
||||
|
||||
@torch.no_grad()
|
||||
def _update_cache(self, key_states, value_states, **cache_kwargs):
|
||||
if settings.RECOGNITION_STATIC_CACHE:
|
||||
return self._update_static_cache(key_states, value_states, **cache_kwargs)
|
||||
|
||||
return self._update_dynamic_cache(key_states, value_states, **cache_kwargs)
|
||||
|
||||
|
||||
class SuryaTableRecDecoderMlp(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||||
if config.hidden_activation is None:
|
||||
config.hidden_activation = "gelu_pytorch_tanh"
|
||||
hidden_activation = config.hidden_activation
|
||||
self.act_fn = ACT2FN[hidden_activation]
|
||||
|
||||
def forward(self, x):
|
||||
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
class SuryaTableRecDecoderLayer(nn.Module):
|
||||
def __init__(self, config, layer_idx):
|
||||
super().__init__()
|
||||
super().__init__()
|
||||
self.cross_pre_norm = SuryaTableRecDecoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.temporal_pre_norm = SuryaTableRecDecoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
self.temporal_block = None
|
||||
if layer_idx in config.self_attn_layers:
|
||||
self.temporal_block = SuryaTableRecDecoderSdpaAttention(config)
|
||||
|
||||
self.cross_attn_block = None
|
||||
if layer_idx in config.cross_attn_layers:
|
||||
self.cross_attn_block = SuryaTableRecDecoderSdpaCrossAttention(config)
|
||||
|
||||
self.window_attn = layer_idx not in config.global_attn_layers
|
||||
self.channel_pre_norm = SuryaTableRecDecoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.mlp_block = SuryaTableRecDecoderMlp(config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
activations: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor = None,
|
||||
encoder_attention_mask: torch.Tensor = None,
|
||||
cache_position: torch.Tensor = None,
|
||||
use_cache: bool = None,
|
||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||
raw_activations = activations
|
||||
|
||||
if self.cross_attn_block is not None:
|
||||
# Do cross-attention on encoder outputs
|
||||
cross_attn_inputs = self.cross_pre_norm(activations)
|
||||
cross_attn_path = self.cross_attn_block(
|
||||
cross_attn_inputs, encoder_hidden_states, attention_mask, encoder_attention_mask, use_cache=use_cache
|
||||
)
|
||||
cross_attn_output = cross_attn_path + raw_activations
|
||||
else:
|
||||
cross_attn_output = raw_activations
|
||||
|
||||
if self.temporal_block is not None:
|
||||
inputs_normalized = self.temporal_pre_norm(cross_attn_output) # RMSNorm introduces slight slight differences
|
||||
hidden_states = self.temporal_block(
|
||||
inputs_normalized, position_ids, attention_mask, cache_position=cache_position, use_cache=use_cache, window_attn=self.window_attn
|
||||
)
|
||||
|
||||
residual = hidden_states + raw_activations
|
||||
else:
|
||||
residual = cross_attn_output
|
||||
|
||||
hidden_states = self.channel_pre_norm(residual)
|
||||
hidden_states = self.mlp_block(hidden_states)
|
||||
|
||||
hidden_states = hidden_states + residual
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SuryaTableRecDecoderPreTrainedModel(PreTrainedModel):
|
||||
config_class = SuryaTableRecDecoderConfig
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["SuryaTableRecDecoderLayer"]
|
||||
_skip_keys_device_placement = ["cache"]
|
||||
_supports_flash_attn_2 = False
|
||||
_supports_sdpa = False # we can't compare with eager for now
|
||||
_supports_cache_class = True
|
||||
_supports_quantized_cache = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
if isinstance(module, SuryaTableRecDecoderSdpaAttention):
|
||||
torch.nn.init.normal_(module.q_proj.weight, mean=0.0, std=self.config.init_std)
|
||||
torch.nn.init.normal_(module.k_proj.weight, mean=0.0, std=self.config.init_std)
|
||||
torch.nn.init.normal_(module.v_proj.weight, mean=0.0, std=self.config.init_std)
|
||||
|
||||
torch.nn.init.normal_(module.o_proj.weight, mean=0.0, std=self.config.init_std)
|
||||
elif isinstance(module, nn.Linear):
|
||||
torch.nn.init.normal_(module.weight, mean=0.0, std=self.config.init_std)
|
||||
if getattr(module, "bias", None) is not None:
|
||||
torch.nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.init_std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
def _setup_cache(self, config, batch, device, dtype):
|
||||
layers = getattr(self, "model", self).layers
|
||||
for layer in layers:
|
||||
if layer.temporal_block:
|
||||
layer.temporal_block._setup_cache(batch, device, dtype)
|
||||
if layer.cross_attn_block:
|
||||
layer.cross_attn_block._setup_cache(batch, device, dtype)
|
||||
|
||||
def reset_cache(self, batch, device, dtype):
|
||||
pass
|
||||
|
||||
def _tie_weights(self):
|
||||
pass
|
||||
|
||||
def tie_weights(self):
|
||||
pass
|
||||
|
||||
|
||||
class LabelEmbedding(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.vocab_size = config.vocab_size
|
||||
self.x1_embed = nn.Embedding(config.max_width, config.hidden_size)
|
||||
self.y1_embed = nn.Embedding(config.max_height, config.hidden_size)
|
||||
self.x2_embed = nn.Embedding(config.max_width, config.hidden_size)
|
||||
self.y2_embed = nn.Embedding(config.max_height, config.hidden_size)
|
||||
self.w_embed = nn.Embedding(config.max_width, config.hidden_size)
|
||||
self.h_embed = nn.Embedding(config.max_height, config.hidden_size)
|
||||
self.cx_embed = nn.Embedding(config.max_width, config.hidden_size)
|
||||
self.cy_embed = nn.Embedding(config.max_height, config.hidden_size)
|
||||
self.class_embed = nn.Embedding(config.max_classes, config.hidden_size)
|
||||
self.max_width = config.max_width
|
||||
self.max_height = config.max_height
|
||||
self.max_classes = config.max_classes
|
||||
|
||||
def forward(self, labels: torch.LongTensor, input_box_counts: torch.LongTensor):
|
||||
cx, cy, w, h, class_ = labels.to(torch.long).unbind(dim=-1)
|
||||
# Shape is (batch_size, num_boxes/seq len, d_model)
|
||||
x1 = (cx - w // 2).long()
|
||||
y1 = (cy - h // 2).long()
|
||||
x2 = (cx + w // 2).long()
|
||||
y2 = (cy + h // 2).long()
|
||||
x1 = torch.clamp(x1, 0, self.max_width - 1)
|
||||
y1 = torch.clamp(y1, 0, self.max_height - 1)
|
||||
x2 = torch.clamp(x2, 0, self.max_width - 1)
|
||||
y2 = torch.clamp(y2, 0, self.max_height - 1)
|
||||
|
||||
class_ = torch.clamp(class_, 0, self.max_classes - 1).long()
|
||||
|
||||
w = torch.clamp(w, 0, self.max_width - 1).long()
|
||||
h = torch.clamp(h, 0, self.max_height - 1).long()
|
||||
cx = torch.clamp(cx, 0, self.max_width - 1).long()
|
||||
cy = torch.clamp(cy, 0, self.max_height - 1).long()
|
||||
|
||||
coord_embeds = self.x1_embed(x1) + self.y1_embed(y1) + self.x2_embed(x2) + self.y2_embed(y2)
|
||||
class_embeds = self.class_embed(class_)
|
||||
embedded = coord_embeds + self.w_embed(w) + self.h_embed(h) + self.cx_embed(cx) + self.cy_embed(cy) + class_embeds
|
||||
|
||||
return embedded
|
||||
|
||||
|
||||
class BboxEmbedding(nn.Module):
|
||||
def __init__(self, config, embed_positions=False):
|
||||
super().__init__()
|
||||
self.x1_embed = nn.Embedding(config.max_width, config.hidden_size)
|
||||
self.y1_embed = nn.Embedding(config.max_height, config.hidden_size)
|
||||
self.x2_embed = nn.Embedding(config.max_width, config.hidden_size)
|
||||
self.y2_embed = nn.Embedding(config.max_height, config.hidden_size)
|
||||
self.w_embed = nn.Embedding(config.max_width, config.hidden_size)
|
||||
self.h_embed = nn.Embedding(config.max_height, config.hidden_size)
|
||||
self.cx_embed = nn.Embedding(config.max_width, config.hidden_size)
|
||||
self.cy_embed = nn.Embedding(config.max_height, config.hidden_size)
|
||||
self.box_pos_embed = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
||||
self.max_width = config.max_width
|
||||
self.max_height = config.max_height
|
||||
self.embed_positions = embed_positions
|
||||
|
||||
def forward(self, boxes: torch.LongTensor, input_box_counts: torch.LongTensor):
|
||||
x1, y1, x2, y2 = boxes.unbind(dim=-1)
|
||||
x1 = torch.clamp(x1, 0, self.max_width - 1).long()
|
||||
y1 = torch.clamp(y1, 0, self.max_height - 1).long()
|
||||
x2 = torch.clamp(x2, 0, self.max_width - 1).long()
|
||||
y2 = torch.clamp(y2, 0, self.max_height - 1).long()
|
||||
|
||||
# Shape is (batch_size, num_boxes/seq len, d_model)
|
||||
w = x2 - x1
|
||||
h = y2 - y1
|
||||
# Center x and y in torch long tensors
|
||||
cx = (x1 + x2) / 2
|
||||
cy = (y1 + y2) / 2
|
||||
cx = cx.long()
|
||||
cy = cy.long()
|
||||
|
||||
w = torch.clamp(w, 0, self.max_width - 1).long()
|
||||
h = torch.clamp(h, 0, self.max_height - 1).long()
|
||||
cx = torch.clamp(cx, 0, self.max_width - 1).long()
|
||||
cy = torch.clamp(cy, 0, self.max_height - 1).long()
|
||||
|
||||
coord_embeds = self.x1_embed(x1) + self.y1_embed(y1) + self.x2_embed(x2) + self.y2_embed(y2)
|
||||
embedded = coord_embeds + self.w_embed(w) + self.h_embed(h) + self.cx_embed(cx) + self.cy_embed(cy)
|
||||
|
||||
# Add in positional embeddings for the boxes and labels
|
||||
if self.embed_positions:
|
||||
for j in range(embedded.shape[0]):
|
||||
box_start = input_box_counts[j, 0]
|
||||
box_end = input_box_counts[j, 1] - 1 # Skip the sep token
|
||||
box_count = box_end - box_start
|
||||
embedded[j, box_start:box_end] = embedded[j, box_start:box_end] + self.box_pos_embed.weight[:box_count]
|
||||
|
||||
return embedded
|
||||
|
||||
|
||||
class SuryaTableRecDecoderModel(SuryaTableRecDecoderPreTrainedModel):
|
||||
"""
|
||||
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`SuryaTableRecDecoderDecoderLayer`]
|
||||
|
||||
Args:
|
||||
config: SuryaTableRecDecoderConfig
|
||||
"""
|
||||
|
||||
def __init__(self, config: SuryaTableRecDecoderConfig, embed_labels=False, embed_positions=True):
|
||||
super().__init__(config)
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
self.causal = config.causal
|
||||
|
||||
if embed_labels:
|
||||
self.embed_tokens = LabelEmbedding(config)
|
||||
else:
|
||||
self.embed_tokens = BboxEmbedding(config, embed_positions=embed_positions)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[SuryaTableRecDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||
)
|
||||
self.final_norm = SuryaTableRecDecoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
self.register_buffer(
|
||||
"normalizer", torch.tensor(self.config.hidden_size**0.5, dtype=torch.float32), persistent=False
|
||||
)
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel.get_input_embeddings
|
||||
def get_input_embeddings(self):
|
||||
return self.embed_tokens
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel.set_input_embeddings
|
||||
def set_input_embeddings(self, value):
|
||||
self.embed_tokens = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
input_boxes_counts: torch.LongTensor = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
prefill: bool = False
|
||||
) -> Union[Tuple, BaseModelOutputWithNoAttention]:
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if self.gradient_checkpointing and self.training and use_cache:
|
||||
use_cache = False
|
||||
|
||||
inputs_embeds = self.embed_tokens(input_ids, input_boxes_counts)
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
if use_cache and prefill:
|
||||
self._setup_cache(self.config, hidden_states.shape[0], hidden_states.device, hidden_states.dtype)
|
||||
|
||||
if cache_position is None:
|
||||
cache_position = torch.arange(hidden_states.shape[1], device=hidden_states.device)
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
|
||||
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
for i, residual_block in enumerate(self.layers):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
if self.gradient_checkpointing and self.training:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
residual_block.__call__, hidden_states, position_ids, causal_mask, encoder_hidden_states, encoder_attention_mask, cache_position, use_cache
|
||||
)
|
||||
else:
|
||||
hidden_states = residual_block(hidden_states, position_ids, causal_mask, encoder_hidden_states, encoder_attention_mask, cache_position, use_cache)
|
||||
|
||||
hidden_states = self.final_norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)
|
||||
|
||||
return BaseModelOutputWithNoAttention(
|
||||
last_hidden_state=hidden_states,
|
||||
hidden_states=all_hidden_states,
|
||||
)
|
||||
|
||||
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
|
||||
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
|
||||
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
|
||||
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
|
||||
# Ignore copy
|
||||
def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
|
||||
if not self.causal:
|
||||
return None
|
||||
|
||||
dtype, device = input_tensor.dtype, input_tensor.device
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
sequence_length = input_tensor.shape[1]
|
||||
target_length = max(settings.TABLE_REC_MAX_BOXES, sequence_length)
|
||||
|
||||
diagonal = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
||||
causal_mask = diagonal
|
||||
if sequence_length != 1:
|
||||
# Select the upper triangular part of the matrix, but unmask current token (the diagonal)
|
||||
# triu will be the min_dtype, everything else is 0 (attended to)
|
||||
causal_mask = torch.triu(diagonal, diagonal=1)
|
||||
|
||||
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
if attention_mask.dim() == 2:
|
||||
# Mask positions in the causal mask that are masked in the attention mask
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
|
||||
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
|
||||
|
||||
if attention_mask is not None and attention_mask.device.type == "cuda":
|
||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
class SuryaTableRecDecoder(SuryaTableRecDecoderPreTrainedModel):
|
||||
_tied_weights_keys = None
|
||||
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(config)
|
||||
self.model = SuryaTableRecDecoderModel(config, embed_labels=True, embed_positions=False)
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.bbox_head = nn.Linear(config.hidden_size, config.max_width * 4, bias=False)
|
||||
self.class_head = nn.Linear(config.hidden_size, config.max_classes, bias=False)
|
||||
self.max_width = config.max_width
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.embed_tokens = value
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def set_decoder(self, decoder):
|
||||
self.model = decoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model
|
||||
|
||||
# Ignore copy
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
prefill: bool = False,
|
||||
**kwargs
|
||||
) -> Union[Tuple, TableRecModelOutput]:
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
cache_position=cache_position,
|
||||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
use_cache=use_cache,
|
||||
output_hidden_states=True,
|
||||
return_dict=True,
|
||||
prefill=prefill,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
bbox_logits = self.bbox_head(hidden_states)
|
||||
class_logits = self.class_head(hidden_states)
|
||||
bsz, seq_len = class_logits.shape[:2]
|
||||
bbox_logits = bbox_logits.view(bsz, seq_len, 4, self.max_width)
|
||||
|
||||
return TableRecModelOutput(
|
||||
bbox_logits=bbox_logits,
|
||||
class_logits=class_logits,
|
||||
hidden_states=hidden_states,
|
||||
)
|
||||
@dataclass
|
||||
class TextEncoderOutput(CausalLMOutput):
|
||||
hidden_states: torch.FloatTensor = None
|
||||
|
||||
|
||||
class SuryaTableRecTextEncoder(SuryaTableRecDecoderPreTrainedModel):
|
||||
_tied_weights_keys = None
|
||||
config_class = SuryaTableRecTextEncoderConfig
|
||||
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(config)
|
||||
self.model = SuryaTableRecDecoderModel(config, embed_labels=False, embed_positions=True)
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.embed_tokens = value
|
||||
|
||||
def set_decoder(self, decoder):
|
||||
self.model = decoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model
|
||||
|
||||
# Ignore copy
|
||||
def forward(
|
||||
self,
|
||||
input_boxes: Optional[torch.LongTensor] = None,
|
||||
input_boxes_counts: Optional[torch.LongTensor] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
**kwargs
|
||||
) -> Union[Tuple, CausalLMOutput]:
|
||||
outputs = self.model(
|
||||
input_ids=input_boxes,
|
||||
input_boxes_counts=input_boxes_counts,
|
||||
cache_position=cache_position,
|
||||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
use_cache=use_cache,
|
||||
output_hidden_states=True,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
return TextEncoderOutput(
|
||||
hidden_states=outputs.last_hidden_state,
|
||||
)
|
||||
135
surya/model/table_rec/encoderdecoder.py
Normal file
@ -0,0 +1,135 @@
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from transformers import PreTrainedModel, VisionEncoderDecoderConfig, PretrainedConfig
|
||||
from transformers.modeling_outputs import Seq2SeqLMOutput, BaseModelOutput
|
||||
from transformers.models.vision_encoder_decoder.modeling_vision_encoder_decoder import shift_tokens_right
|
||||
from surya.model.table_rec.decoder import SuryaTableRecTextEncoder, SuryaTableRecDecoder
|
||||
from surya.model.recognition.encoder import DonutSwinModel
|
||||
import torch.nn.functional as F
|
||||
from transformers.utils import ModelOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
class TableRecOutput(ModelOutput):
|
||||
row_logits: torch.FloatTensor = None
|
||||
col_logits: torch.FloatTensor = None
|
||||
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
class TableRecEncoderDecoderModel(PreTrainedModel):
|
||||
config_class = VisionEncoderDecoderConfig
|
||||
base_model_prefix = "vision_encoder_decoder"
|
||||
main_input_name = "pixel_values"
|
||||
supports_gradient_checkpointing = True
|
||||
_supports_param_buffer_assignment = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Optional[PretrainedConfig] = None,
|
||||
encoder: Optional[PreTrainedModel] = None,
|
||||
text_encoder: Optional[PreTrainedModel] = None,
|
||||
decoder: Optional[PreTrainedModel] = None,
|
||||
):
|
||||
# initialize with config
|
||||
# make sure input & output embeddings is not tied
|
||||
config.tie_word_embeddings = False
|
||||
config.decoder.tie_word_embeddings = False
|
||||
super().__init__(config)
|
||||
|
||||
if encoder is None:
|
||||
encoder = DonutSwinModel(config.encoder)
|
||||
|
||||
if text_encoder is None:
|
||||
text_encoder = SuryaTableRecTextEncoder(config.text_encoder, attn_implementation=config._attn_implementation)
|
||||
|
||||
if decoder is None:
|
||||
decoder = SuryaTableRecDecoder(config.decoder, attn_implementation=config._attn_implementation)
|
||||
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
self.text_encoder = text_encoder
|
||||
|
||||
# make sure that the individual model's config refers to the shared config
|
||||
# so that the updates to the config will be synced
|
||||
self.encoder.config = self.config.encoder
|
||||
self.decoder.config = self.config.decoder
|
||||
self.text_encoder.config = self.config.text_encoder
|
||||
|
||||
def get_encoder(self):
|
||||
return self.encoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.decoder
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.decoder.get_output_embeddings()
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
return self.decoder.set_output_embeddings(new_embeddings)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
decoder_input_ids: torch.LongTensor = None,
|
||||
decoder_cache_position: Optional[torch.LongTensor] = None,
|
||||
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||
encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> Union[Tuple[torch.FloatTensor], TableRecOutput]:
|
||||
kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")}
|
||||
|
||||
kwargs_decoder = {
|
||||
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
|
||||
}
|
||||
|
||||
# Decode
|
||||
decoder_outputs = self.decoder(
|
||||
input_labels=decoder_input_ids,
|
||||
input_boxes_counts=None,
|
||||
cache_position=decoder_cache_position,
|
||||
attention_mask=decoder_attention_mask,
|
||||
encoder_hidden_states=encoder_outputs,
|
||||
encoder_attention_mask=None,
|
||||
use_cache=use_cache,
|
||||
**kwargs_decoder,
|
||||
)
|
||||
|
||||
return TableRecOutput(
|
||||
row_logits=decoder_outputs.row_logits,
|
||||
col_logits=decoder_outputs.col_logits,
|
||||
decoder_hidden_states=decoder_outputs.hidden_states,
|
||||
)
|
||||
|
||||
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
|
||||
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
|
||||
):
|
||||
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past_key_values=past_key_values)
|
||||
decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None
|
||||
input_dict = {
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
"decoder_input_ids": decoder_inputs["input_ids"],
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"past_key_values": decoder_inputs["past_key_values"],
|
||||
"use_cache": use_cache,
|
||||
}
|
||||
return input_dict
|
||||
|
||||
def resize_token_embeddings(self, *args, **kwargs):
|
||||
raise NotImplementedError(
|
||||
"Resizing the embedding layers via the VisionEncoderDecoderModel directly is not supported.Please use the"
|
||||
" respective methods of the wrapped decoder object (model.decoder.resize_token_embeddings(...))"
|
||||
)
|
||||
|
||||
def _reorder_cache(self, past_key_values, beam_idx):
|
||||
# apply decoder cache reordering here
|
||||
return self.decoder._reorder_cache(past_key_values, beam_idx)
|
||||
34
surya/model/table_rec/model.py
Normal file
@ -0,0 +1,34 @@
|
||||
from surya.model.recognition.encoder import DonutSwinModel
|
||||
from surya.model.table_rec.config import SuryaTableRecConfig, SuryaTableRecDecoderConfig, DonutSwinTableRecConfig, \
|
||||
SuryaTableRecTextEncoderConfig
|
||||
from surya.model.table_rec.decoder import SuryaTableRecDecoder, SuryaTableRecTextEncoder
|
||||
from surya.model.table_rec.encoderdecoder import TableRecEncoderDecoderModel
|
||||
from surya.settings import settings
|
||||
|
||||
|
||||
def load_model(checkpoint=settings.TABLE_REC_MODEL_CHECKPOINT, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE):
|
||||
|
||||
config = SuryaTableRecConfig.from_pretrained(checkpoint)
|
||||
decoder_config = config.decoder
|
||||
decoder = SuryaTableRecDecoderConfig(**decoder_config)
|
||||
config.decoder = decoder
|
||||
|
||||
encoder_config = config.encoder
|
||||
encoder = DonutSwinTableRecConfig(**encoder_config)
|
||||
config.encoder = encoder
|
||||
|
||||
text_encoder_config = config.text_encoder
|
||||
text_encoder = SuryaTableRecTextEncoderConfig(**text_encoder_config)
|
||||
config.text_encoder = text_encoder
|
||||
|
||||
model = TableRecEncoderDecoderModel.from_pretrained(checkpoint, config=config, torch_dtype=dtype)
|
||||
|
||||
assert isinstance(model.decoder, SuryaTableRecDecoder)
|
||||
assert isinstance(model.encoder, DonutSwinModel)
|
||||
assert isinstance(model.text_encoder, SuryaTableRecTextEncoder)
|
||||
|
||||
model = model.to(device)
|
||||
model = model.eval()
|
||||
|
||||
print(f"Loaded recognition model {checkpoint} on device {device} with dtype {dtype}")
|
||||
return model
|
||||
248
surya/model/table_rec/processor.py
Normal file
@ -0,0 +1,248 @@
|
||||
import math
|
||||
from typing import Dict, Union, Optional, List, Iterable
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
from torch import TensorType
|
||||
from transformers import DonutImageProcessor, DonutProcessor
|
||||
from transformers.image_processing_utils import BatchFeature
|
||||
from transformers.image_transforms import pad, normalize
|
||||
from transformers.image_utils import PILImageResampling, ImageInput, ChannelDimension, make_list_of_images, get_image_size
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import PIL
|
||||
from surya.model.recognition.tokenizer import Byt5LangTokenizer
|
||||
from surya.settings import settings
|
||||
from surya.model.table_rec.config import BOX_DIM, SPECIAL_TOKENS
|
||||
|
||||
|
||||
def load_processor():
|
||||
processor = SuryaProcessor()
|
||||
processor.image_processor.train = False
|
||||
processor.image_processor.max_size = settings.TABLE_REC_IMAGE_SIZE
|
||||
|
||||
processor.token_pad_id = 0
|
||||
processor.token_eos_id = 1
|
||||
processor.token_bos_id = 2
|
||||
processor.token_row_id = 3
|
||||
processor.token_unused_id = 4
|
||||
processor.box_size = (BOX_DIM, BOX_DIM)
|
||||
processor.special_token_count = SPECIAL_TOKENS
|
||||
return processor
|
||||
|
||||
|
||||
class SuryaImageProcessor(DonutImageProcessor):
|
||||
def __init__(self, *args, max_size=None, train=False, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.patch_size = kwargs.get("patch_size", (4, 4))
|
||||
self.max_size = max_size
|
||||
self.train = train
|
||||
|
||||
@classmethod
|
||||
def numpy_resize(cls, image: np.ndarray, size, interpolation=cv2.INTER_LANCZOS4):
|
||||
max_width, max_height = size["width"], size["height"]
|
||||
|
||||
resized_image = cv2.resize(image, (max_width, max_height), interpolation=interpolation)
|
||||
resized_image = resized_image.transpose(2, 0, 1)
|
||||
|
||||
return resized_image
|
||||
|
||||
def process_inner(self, images: List[np.ndarray]):
|
||||
assert images[0].shape[2] == 3 # RGB input images, channel dim last
|
||||
|
||||
# This also applies the right channel dim format, to channel x height x width
|
||||
images = [SuryaImageProcessor.numpy_resize(img, self.max_size, self.resample) for img in images]
|
||||
assert images[0].shape[0] == 3 # RGB input images, channel dim first
|
||||
|
||||
# Convert to float32 for rescale/normalize
|
||||
images = [img.astype(np.float32) for img in images]
|
||||
|
||||
# Pads with 255 (whitespace)
|
||||
# Pad to max size to improve performance
|
||||
max_size = self.max_size
|
||||
images = [
|
||||
SuryaImageProcessor.pad_image(
|
||||
image=image,
|
||||
size=max_size,
|
||||
input_data_format=ChannelDimension.FIRST,
|
||||
pad_value=settings.RECOGNITION_PAD_VALUE
|
||||
)
|
||||
for image in images
|
||||
]
|
||||
# Rescale and normalize
|
||||
for idx in range(len(images)):
|
||||
images[idx] = images[idx] * self.rescale_factor
|
||||
images = [
|
||||
SuryaImageProcessor.normalize(img, mean=self.image_mean, std=self.image_std, input_data_format=ChannelDimension.FIRST)
|
||||
for img in images
|
||||
]
|
||||
|
||||
return images
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
images: ImageInput,
|
||||
do_resize: bool = None,
|
||||
size: Dict[str, int] = None,
|
||||
resample: PILImageResampling = None,
|
||||
do_thumbnail: bool = None,
|
||||
do_align_long_axis: bool = None,
|
||||
do_pad: bool = None,
|
||||
random_padding: bool = False,
|
||||
do_rescale: bool = None,
|
||||
rescale_factor: float = None,
|
||||
do_normalize: bool = None,
|
||||
image_mean: Optional[Union[float, List[float]]] = None,
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> PIL.Image.Image:
|
||||
images = make_list_of_images(images)
|
||||
|
||||
# Convert to numpy for later processing steps
|
||||
images = [np.array(img) for img in images]
|
||||
images = self.process_inner(images)
|
||||
|
||||
data = {"pixel_values": images}
|
||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
||||
@classmethod
|
||||
def pad_image(
|
||||
cls,
|
||||
image: np.ndarray,
|
||||
size: Dict[str, int],
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
pad_value: float = 0.0,
|
||||
) -> np.ndarray:
|
||||
output_height, output_width = size["height"], size["width"]
|
||||
input_height, input_width = get_image_size(image, channel_dim=input_data_format)
|
||||
|
||||
delta_width = output_width - input_width
|
||||
delta_height = output_height - input_height
|
||||
|
||||
assert delta_width >= 0 and delta_height >= 0
|
||||
|
||||
pad_top = delta_height // 2
|
||||
pad_left = delta_width // 2
|
||||
|
||||
pad_bottom = delta_height - pad_top
|
||||
pad_right = delta_width - pad_left
|
||||
|
||||
padding = ((pad_top, pad_bottom), (pad_left, pad_right))
|
||||
return pad(image, padding, data_format=data_format, input_data_format=input_data_format, constant_values=pad_value)
|
||||
|
||||
@classmethod
|
||||
def align_long_axis(
|
||||
cls,
|
||||
image: np.ndarray,
|
||||
size: Dict[str, int],
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> np.ndarray:
|
||||
input_height, input_width = image.shape[:2]
|
||||
output_height, output_width = size["height"], size["width"]
|
||||
|
||||
if (output_width < output_height and input_width > input_height) or (
|
||||
output_width > output_height and input_width < input_height
|
||||
):
|
||||
image = np.rot90(image, 3)
|
||||
|
||||
return image
|
||||
|
||||
@classmethod
|
||||
def normalize(
|
||||
cls,
|
||||
image: np.ndarray,
|
||||
mean: Union[float, Iterable[float]],
|
||||
std: Union[float, Iterable[float]],
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
return normalize(
|
||||
image, mean=mean, std=std, data_format=data_format, input_data_format=input_data_format, **kwargs
|
||||
)
|
||||
|
||||
|
||||
class SuryaProcessor(DonutProcessor):
|
||||
def __init__(self, image_processor=None, tokenizer=None, train=False, **kwargs):
|
||||
image_processor = SuryaImageProcessor.from_pretrained(settings.RECOGNITION_MODEL_CHECKPOINT)
|
||||
if image_processor is None:
|
||||
raise ValueError("You need to specify an `image_processor`.")
|
||||
|
||||
tokenizer = Byt5LangTokenizer()
|
||||
super().__init__(image_processor, tokenizer)
|
||||
self.current_processor = self.image_processor
|
||||
self._in_target_context_manager = False
|
||||
self.max_input_boxes = kwargs.get("max_input_boxes", 256)
|
||||
self.extra_input_boxes = kwargs.get("extra_input_boxes", 32)
|
||||
|
||||
def resize_boxes(self, img, boxes):
|
||||
width, height = img.size
|
||||
box_width, box_height = self.box_size
|
||||
for box in boxes:
|
||||
# Rescale to 0-1024
|
||||
box[0] = box[0] / width * box_width
|
||||
box[1] = box[1] / height * box_height
|
||||
box[2] = box[2] / width * box_width
|
||||
box[3] = box[3] / height * box_height
|
||||
|
||||
if box[0] < 0:
|
||||
box[0] = 0
|
||||
if box[1] < 0:
|
||||
box[1] = 0
|
||||
if box[2] > box_width:
|
||||
box[2] = box_width
|
||||
if box[3] > box_height:
|
||||
box[3] = box_height
|
||||
|
||||
return boxes
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
images = kwargs.pop("images", [])
|
||||
boxes = kwargs.pop("boxes", [])
|
||||
assert len(images) == len(boxes)
|
||||
|
||||
if len(args) > 0:
|
||||
images = args[0]
|
||||
args = args[1:]
|
||||
|
||||
for i in range(len(boxes)):
|
||||
if len(boxes[i]) > self.max_input_boxes:
|
||||
downsample_ratio = math.ceil(len(boxes[i]) / self.max_input_boxes)
|
||||
boxes[i] = boxes[i][::downsample_ratio]
|
||||
|
||||
new_boxes = []
|
||||
max_len = self.max_input_boxes + self.extra_input_boxes
|
||||
box_masks = []
|
||||
box_ends = []
|
||||
for i in range(len(boxes)):
|
||||
nb = self.resize_boxes(images[i], boxes[i])
|
||||
nb = [[b + self.special_token_count for b in box] for box in nb] # shift up
|
||||
nb = nb[:self.max_input_boxes - 1]
|
||||
|
||||
nb.insert(0, [self.token_row_id] * 4) # Insert special token for max rows/cols
|
||||
for _ in range(self.extra_input_boxes):
|
||||
nb.append([self.token_unused_id] * 4)
|
||||
|
||||
pad_length = max_len - len(nb)
|
||||
box_mask = [1] * len(nb) + [1] * (pad_length)
|
||||
box_ends.append(len(nb))
|
||||
nb = nb + [[self.token_unused_id] * 4] * pad_length
|
||||
|
||||
new_boxes.append(nb)
|
||||
box_masks.append(box_mask)
|
||||
|
||||
box_ends = torch.tensor(box_ends, dtype=torch.long)
|
||||
box_starts = torch.tensor([0] * len(boxes), dtype=torch.long)
|
||||
box_ranges = torch.stack([box_starts, box_ends], dim=1)
|
||||
|
||||
inputs = self.image_processor(images, *args, **kwargs)
|
||||
inputs["input_boxes"] = torch.tensor(new_boxes, dtype=torch.long)
|
||||
inputs["input_boxes_mask"] = torch.tensor(box_masks, dtype=torch.long)
|
||||
inputs["input_boxes_counts"] = box_ranges
|
||||
return inputs
|
||||
13
surya/ocr.py
@ -60,17 +60,24 @@ def run_recognition(images: List[Image.Image], langs: List[List[str] | None], re
|
||||
return predictions_by_image
|
||||
|
||||
|
||||
def run_ocr(images: List[Image.Image], langs: List[List[str] | None], det_model, det_processor, rec_model, rec_processor, batch_size=None) -> List[OCRResult]:
|
||||
def run_ocr(images: List[Image.Image], langs: List[List[str] | None], det_model, det_processor, rec_model, rec_processor, batch_size=None, highres_images: List[Image.Image] | None = None) -> List[OCRResult]:
|
||||
images = convert_if_not_rgb(images)
|
||||
highres_images = convert_if_not_rgb(highres_images) if highres_images is not None else [None] * len(images)
|
||||
det_predictions = batch_text_detection(images, det_model, det_processor)
|
||||
|
||||
all_slices = []
|
||||
slice_map = []
|
||||
all_langs = []
|
||||
|
||||
for idx, (det_pred, image, lang) in enumerate(zip(det_predictions, images, langs)):
|
||||
for idx, (det_pred, image, highres_image, lang) in enumerate(zip(det_predictions, images, highres_images, langs)):
|
||||
polygons = [p.polygon for p in det_pred.bboxes]
|
||||
slices = slice_polys_from_image(image, polygons)
|
||||
if highres_image:
|
||||
width_scaler = highres_image.size[0] / image.size[0]
|
||||
height_scaler = highres_image.size[1] / image.size[1]
|
||||
scaled_polygons = [[[int(p[0] * width_scaler), int(p[1] * height_scaler)] for p in polygon] for polygon in polygons]
|
||||
slices = slice_polys_from_image(highres_image, scaled_polygons)
|
||||
else:
|
||||
slices = slice_polys_from_image(image, polygons)
|
||||
slice_map.append(len(slices))
|
||||
all_langs.extend([lang] * len(slices))
|
||||
all_slices.extend(slices)
|
||||
|
||||
@ -82,7 +82,6 @@ def detect_boxes(linemap, text_threshold, low_text):
|
||||
det = []
|
||||
confidences = []
|
||||
max_confidence = 0
|
||||
segmap = np.zeros_like(labels, dtype=np.uint8)
|
||||
|
||||
for k in range(1, label_count):
|
||||
# size filtering
|
||||
@ -90,39 +89,37 @@ def detect_boxes(linemap, text_threshold, low_text):
|
||||
if size < 10:
|
||||
continue
|
||||
|
||||
mask = labels == k
|
||||
selected_linemap = linemap[mask]
|
||||
|
||||
# thresholding
|
||||
if np.max(selected_linemap) < text_threshold:
|
||||
continue
|
||||
|
||||
# make segmentation map
|
||||
x, y, w, h = stats[k, [cv2.CC_STAT_LEFT, cv2.CC_STAT_TOP, cv2.CC_STAT_WIDTH, cv2.CC_STAT_HEIGHT]]
|
||||
|
||||
try:
|
||||
niter = int(np.sqrt(min(w, h)) * 2)
|
||||
niter = int(np.sqrt(min(w, h)))
|
||||
except ValueError:
|
||||
# Overflow in sqrt term
|
||||
niter = 0
|
||||
|
||||
buffer = 1
|
||||
sx, sy = max(0, x - niter), max(0, y - niter)
|
||||
sx, sy = max(0, x - niter - buffer), max(0, y - niter - buffer)
|
||||
ex, ey = min(img_w, x + w + niter + buffer), min(img_h, y + h + niter + buffer)
|
||||
|
||||
segmap.fill(0)
|
||||
segmap[mask] = 1
|
||||
mask = (labels[sy:ey, sx:ex] == k)
|
||||
selected_linemap = linemap[sy:ey, sx:ex][mask]
|
||||
line_max = np.max(selected_linemap)
|
||||
|
||||
# thresholding
|
||||
if line_max < text_threshold:
|
||||
continue
|
||||
|
||||
segmap = mask.astype(np.uint8)
|
||||
|
||||
ksize = buffer + niter
|
||||
kernel = cv2.getStructuringElement(cv2.MORPH_RECT,(ksize, ksize))
|
||||
|
||||
# Doesn't work well without the zero start (ie, you can't trim the map tightly around the detected region)
|
||||
selected_segmap = segmap[0:ey, 0:ex]
|
||||
selected_segmap[sy:ey, sx:ex] = cv2.dilate(segmap[sy:ey, sx:ex], kernel)
|
||||
selected_segmap = cv2.dilate(segmap, kernel)
|
||||
|
||||
# make box
|
||||
indices = np.nonzero(selected_segmap)
|
||||
np_contours = np.column_stack((indices[1], indices[0]))
|
||||
x_inds = indices[1] + sx
|
||||
y_inds = indices[0] + sy
|
||||
np_contours = np.column_stack((x_inds, y_inds))
|
||||
rectangle = cv2.minAreaRect(np_contours)
|
||||
box = cv2.boxPoints(rectangle)
|
||||
|
||||
@ -139,8 +136,8 @@ def detect_boxes(linemap, text_threshold, low_text):
|
||||
box = np.roll(box, 4-startidx, 0)
|
||||
box = np.array(box)
|
||||
|
||||
confidence = np.mean(selected_linemap[selected_linemap > low_text])
|
||||
max_confidence = max(max_confidence, confidence)
|
||||
confidence = line_max
|
||||
max_confidence = max(max_confidence, line_max)
|
||||
|
||||
confidences.append(confidence)
|
||||
det.append(box)
|
||||
@ -175,16 +172,23 @@ def get_and_clean_boxes(textmap, processor_size, image_size, text_threshold=None
|
||||
return bboxes
|
||||
|
||||
|
||||
def draw_bboxes_on_image(bboxes, image, labels=None):
|
||||
draw = ImageDraw.Draw(image)
|
||||
|
||||
for bbox in bboxes:
|
||||
draw.rectangle(bbox, outline="red", width=1)
|
||||
def draw_bboxes_on_image(bboxes, image, labels=None, label_font_size=10, color: str | list='red'):
|
||||
polys = []
|
||||
for bb in bboxes:
|
||||
# Clockwise polygon
|
||||
poly = [
|
||||
[bb[0], bb[1]],
|
||||
[bb[2], bb[1]],
|
||||
[bb[2], bb[3]],
|
||||
[bb[0], bb[3]]
|
||||
]
|
||||
polys.append(poly)
|
||||
|
||||
return image
|
||||
return draw_polys_on_image(polys, image, labels, label_font_size=label_font_size, color=color)
|
||||
|
||||
|
||||
def draw_polys_on_image(corners, image, labels=None, box_padding=-1, label_offset=1, label_font_size=10):
|
||||
def draw_polys_on_image(corners, image, labels=None, box_padding=-1, label_offset=1, label_font_size=10, color: str | list='red'):
|
||||
draw = ImageDraw.Draw(image)
|
||||
font_path = get_font_path()
|
||||
label_font = ImageFont.truetype(font_path, label_font_size)
|
||||
@ -192,7 +196,7 @@ def draw_polys_on_image(corners, image, labels=None, box_padding=-1, label_offse
|
||||
for i in range(len(corners)):
|
||||
poly = corners[i]
|
||||
poly = [(int(p[0]), int(p[1])) for p in poly]
|
||||
draw.polygon(poly, outline='red', width=1)
|
||||
draw.polygon(poly, outline=color[i] if isinstance(color, list) else color, width=1)
|
||||
|
||||
if labels is not None:
|
||||
label = labels[i]
|
||||
@ -211,7 +215,7 @@ def draw_polys_on_image(corners, image, labels=None, box_padding=-1, label_offse
|
||||
draw.text(
|
||||
text_position,
|
||||
label,
|
||||
fill="red",
|
||||
fill=color[i] if isinstance(color, list) else color,
|
||||
font=label_font
|
||||
)
|
||||
|
||||
|
||||
@ -10,12 +10,12 @@ from surya.settings import settings
|
||||
from surya.postprocessing.math.latex import is_latex
|
||||
|
||||
|
||||
def sort_text_lines(lines: List[TextLine], tolerance=1.25):
|
||||
def sort_text_lines(lines: List[TextLine] | List[dict], tolerance=1.25):
|
||||
# Sorts in reading order. Not 100% accurate, this should only
|
||||
# be used as a starting point for more advanced sorting.
|
||||
vertical_groups = {}
|
||||
for line in lines:
|
||||
group_key = round(line.bbox[1] / tolerance) * tolerance
|
||||
group_key = round(line.bbox[1] if isinstance(line, TextLine) else line["bbox"][1] / tolerance) * tolerance
|
||||
if group_key not in vertical_groups:
|
||||
vertical_groups[group_key] = []
|
||||
vertical_groups[group_key].append(line)
|
||||
@ -23,7 +23,7 @@ def sort_text_lines(lines: List[TextLine], tolerance=1.25):
|
||||
# Sort each group horizontally and flatten the groups into a single list
|
||||
sorted_lines = []
|
||||
for _, group in sorted(vertical_groups.items()):
|
||||
sorted_group = sorted(group, key=lambda x: x.bbox[0])
|
||||
sorted_group = sorted(group, key=lambda x: x.bbox[0] if isinstance(x, TextLine) else x["bbox"][0])
|
||||
sorted_lines.extend(sorted_group)
|
||||
|
||||
return sorted_lines
|
||||
|
||||
@ -26,6 +26,10 @@ def rescale_bbox(bbox, processor_size, image_size):
|
||||
return new_bbox
|
||||
|
||||
|
||||
def rescale_bboxes(bboxes, orig_size, new_size):
|
||||
return [rescale_bbox(bbox, orig_size, new_size) for bbox in bboxes]
|
||||
|
||||
|
||||
def rescale_point(point, processor_size, image_size):
|
||||
# Point is in x, y format
|
||||
page_width, page_height = processor_size
|
||||
|
||||
@ -71,19 +71,23 @@ class PolygonBox(BaseModel):
|
||||
y2 = max(self.bbox[3], other.bbox[3])
|
||||
self.polygon = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]]
|
||||
|
||||
def intersection_area(self, other, margin=0):
|
||||
x_overlap = max(0, min(self.bbox[2], other.bbox[2] - margin) - max(self.bbox[0], other.bbox[0] + margin))
|
||||
y_overlap = max(0, min(self.bbox[3], other.bbox[3] - margin) - max(self.bbox[1], other.bbox[1] + margin))
|
||||
def intersection_area(self, other, x_margin=0, y_margin=0):
|
||||
x_overlap = max(0, min(self.bbox[2] + x_margin, other.bbox[2] + x_margin) - max(self.bbox[0] - x_margin, other.bbox[0] - x_margin))
|
||||
y_overlap = max(0, min(self.bbox[3] + y_margin, other.bbox[3] + y_margin) - max(self.bbox[1] - y_margin, other.bbox[1] - y_margin))
|
||||
return x_overlap * y_overlap
|
||||
|
||||
def intersection_pct(self, other, margin=0):
|
||||
assert 0 <= margin <= 1
|
||||
def intersection_pct(self, other, x_margin=0, y_margin=0):
|
||||
assert 0 <= x_margin <= 1
|
||||
assert 0 <= y_margin <= 1
|
||||
if self.area == 0:
|
||||
return 0
|
||||
|
||||
if margin:
|
||||
margin = int(min(self.width, other.width) * margin)
|
||||
intersection = self.intersection_area(other, margin)
|
||||
if x_margin:
|
||||
x_margin = int(min(self.width, other.width) * x_margin)
|
||||
if y_margin:
|
||||
y_margin = int(min(self.height, other.height) * y_margin)
|
||||
|
||||
intersection = self.intersection_area(other, x_margin, y_margin)
|
||||
return intersection / self.area
|
||||
|
||||
|
||||
@ -119,6 +123,18 @@ class Bbox(BaseModel):
|
||||
def polygon(self):
|
||||
return [[self.bbox[0], self.bbox[1]], [self.bbox[2], self.bbox[1]], [self.bbox[2], self.bbox[3]], [self.bbox[0], self.bbox[3]]]
|
||||
|
||||
@property
|
||||
def center(self):
|
||||
return [(self.bbox[0] + self.bbox[2]) / 2, (self.bbox[1] + self.bbox[3]) / 2]
|
||||
|
||||
def intersection_pct(self, other):
|
||||
if self.area == 0:
|
||||
return 0
|
||||
|
||||
x_overlap = max(0, min(self.bbox[2], other.bbox[2]) - max(self.bbox[0], other.bbox[0]))
|
||||
y_overlap = max(0, min(self.bbox[3], other.bbox[3]) - max(self.bbox[1], other.bbox[1]))
|
||||
intersection = x_overlap * y_overlap
|
||||
return intersection / self.area
|
||||
|
||||
class LayoutBox(PolygonBox):
|
||||
label: str
|
||||
@ -161,3 +177,16 @@ class LayoutResult(BaseModel):
|
||||
class OrderResult(BaseModel):
|
||||
bboxes: List[OrderBox]
|
||||
image_bbox: List[float]
|
||||
|
||||
|
||||
class TableCell(Bbox):
|
||||
row_id: int | None = None
|
||||
col_id: int | None = None
|
||||
text: str | None = None
|
||||
|
||||
|
||||
class TableResult(BaseModel):
|
||||
cells: List[TableCell]
|
||||
rows: List[TableCell]
|
||||
cols: List[TableCell]
|
||||
image_bbox: List[float]
|
||||
|
||||
@ -10,7 +10,8 @@ import os
|
||||
class Settings(BaseSettings):
|
||||
# General
|
||||
TORCH_DEVICE: Optional[str] = None
|
||||
IMAGE_DPI: int = 192
|
||||
IMAGE_DPI: int = 96 # Used for detection, layout, reading order
|
||||
IMAGE_DPI_HIGHRES: int = 192 # Used for OCR, table rec
|
||||
IN_STREAMLIT: bool = False # Whether we're running in streamlit
|
||||
ENABLE_EFFICIENT_ATTENTION: bool = True # Usually keep True, but if you get CUDA errors, setting to False can help
|
||||
|
||||
@ -71,6 +72,14 @@ class Settings(BaseSettings):
|
||||
ORDER_BATCH_SIZE: Optional[int] = None # Defaults to 4 for CPU/MPS, 32 otherwise
|
||||
ORDER_BENCH_DATASET_NAME: str = "vikp/order_bench"
|
||||
|
||||
# Table Rec
|
||||
TABLE_REC_MODEL_CHECKPOINT: str = "vikp/surya_tablerec"
|
||||
TABLE_REC_IMAGE_SIZE: Dict = {"height": 640, "width": 640}
|
||||
TABLE_REC_MAX_BOXES: int = 512
|
||||
TABLE_REC_MAX_ROWS: int = 384
|
||||
TABLE_REC_BATCH_SIZE: Optional[int] = None
|
||||
TABLE_REC_BENCH_DATASET_NAME: str = "vikp/fintabnet_bench"
|
||||
|
||||
# Tesseract (for benchmarks only)
|
||||
TESSDATA_PREFIX: Optional[str] = None
|
||||
|
||||
|
||||
259
surya/tables.py
Normal file
@ -0,0 +1,259 @@
|
||||
from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
from typing import List, Dict
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from surya.model.ordering.encoderdecoder import OrderVisionEncoderDecoderModel
|
||||
from surya.schema import TableResult, TableCell, Bbox
|
||||
from surya.settings import settings
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
from surya.model.table_rec.config import SPECIAL_TOKENS
|
||||
|
||||
|
||||
def get_batch_size():
|
||||
batch_size = settings.TABLE_REC_BATCH_SIZE
|
||||
if batch_size is None:
|
||||
batch_size = 8
|
||||
if settings.TORCH_DEVICE_MODEL == "mps":
|
||||
batch_size = 8
|
||||
if settings.TORCH_DEVICE_MODEL == "cuda":
|
||||
batch_size = 64
|
||||
return batch_size
|
||||
|
||||
|
||||
def sort_bboxes(bboxes, tolerance=1):
|
||||
vertical_groups = {}
|
||||
for block in bboxes:
|
||||
group_key = round(block["bbox"][1] / tolerance) * tolerance
|
||||
if group_key not in vertical_groups:
|
||||
vertical_groups[group_key] = []
|
||||
vertical_groups[group_key].append(block)
|
||||
|
||||
# Sort each group horizontally and flatten the groups into a single list
|
||||
sorted_page_blocks = []
|
||||
for _, group in sorted(vertical_groups.items()):
|
||||
sorted_group = sorted(group, key=lambda x: x["bbox"][0])
|
||||
sorted_page_blocks.extend(sorted_group)
|
||||
|
||||
return sorted_page_blocks
|
||||
|
||||
|
||||
def is_rotated(rows, cols):
|
||||
# Determine if the table is rotated by looking at row and column width / height ratios
|
||||
# Rows should have a >1 ratio, cols <1
|
||||
widths = sum([r.width for r in rows])
|
||||
heights = sum([c.height for c in rows]) + 1
|
||||
r_ratio = widths / heights
|
||||
|
||||
widths = sum([c.width for c in cols])
|
||||
heights = sum([r.height for r in cols]) + 1
|
||||
c_ratio = widths / heights
|
||||
|
||||
return r_ratio * 2 < c_ratio
|
||||
|
||||
def batch_table_recognition(images: List, table_cells: List[List[Dict]], model: OrderVisionEncoderDecoderModel, processor, batch_size=None) -> List[TableResult]:
|
||||
assert all([isinstance(image, Image.Image) for image in images])
|
||||
assert len(images) == len(table_cells)
|
||||
if batch_size is None:
|
||||
batch_size = get_batch_size()
|
||||
|
||||
output_order = []
|
||||
for i in tqdm(range(0, len(images), batch_size), desc="Recognizing tables"):
|
||||
batch_table_cells = deepcopy(table_cells[i:i+batch_size])
|
||||
batch_table_cells = [sort_bboxes(page_bboxes) for page_bboxes in batch_table_cells] # Sort bboxes before passing in
|
||||
batch_list_bboxes = [[block["bbox"] for block in page] for page in batch_table_cells]
|
||||
|
||||
batch_images = images[i:i+batch_size]
|
||||
batch_images = [image.convert("RGB") for image in batch_images] # also copies the images
|
||||
current_batch_size = len(batch_images)
|
||||
|
||||
orig_sizes = [image.size for image in batch_images]
|
||||
model_inputs = processor(images=batch_images, boxes=deepcopy(batch_list_bboxes))
|
||||
|
||||
batch_pixel_values = model_inputs["pixel_values"]
|
||||
batch_bboxes = model_inputs["input_boxes"]
|
||||
batch_bbox_mask = model_inputs["input_boxes_mask"]
|
||||
batch_bbox_counts = model_inputs["input_boxes_counts"]
|
||||
|
||||
batch_bboxes = torch.from_numpy(np.array(batch_bboxes, dtype=np.int32)).to(model.device)
|
||||
batch_bbox_mask = torch.from_numpy(np.array(batch_bbox_mask, dtype=np.int32)).to(model.device)
|
||||
batch_pixel_values = torch.tensor(np.array(batch_pixel_values), dtype=model.dtype).to(model.device)
|
||||
batch_bbox_counts = torch.tensor(np.array(batch_bbox_counts), dtype=torch.long).to(model.device)
|
||||
|
||||
# Setup inputs for the decoder
|
||||
batch_decoder_input = [[[model.config.decoder.bos_token_id] * 5] for _ in range(current_batch_size)]
|
||||
batch_decoder_input = torch.tensor(np.stack(batch_decoder_input, axis=0), dtype=torch.long, device=model.device)
|
||||
inference_token_count = batch_decoder_input.shape[1]
|
||||
|
||||
max_tokens = min(batch_bbox_counts[:, 1].max().item(), settings.TABLE_REC_MAX_BOXES)
|
||||
decoder_position_ids = torch.ones_like(batch_decoder_input[0, :, 0], dtype=torch.int64, device=model.device).cumsum(0) - 1
|
||||
model.decoder.model._setup_cache(model.config, batch_size, model.device, model.dtype)
|
||||
model.text_encoder.model._setup_cache(model.config, batch_size, model.device, model.dtype)
|
||||
|
||||
batch_predictions = [[] for _ in range(current_batch_size)]
|
||||
|
||||
with torch.inference_mode():
|
||||
encoder_hidden_states = model.encoder(pixel_values=batch_pixel_values).last_hidden_state
|
||||
text_encoder_hidden_states = model.text_encoder(
|
||||
input_boxes=batch_bboxes,
|
||||
input_boxes_counts=batch_bbox_counts,
|
||||
cache_position=None,
|
||||
attention_mask=batch_bbox_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=None,
|
||||
use_cache=False
|
||||
).hidden_states
|
||||
|
||||
token_count = 0
|
||||
all_done = torch.zeros(current_batch_size, dtype=torch.bool, device=model.device)
|
||||
|
||||
while token_count < max_tokens:
|
||||
is_prefill = token_count == 0
|
||||
return_dict = model.decoder(
|
||||
input_ids=batch_decoder_input,
|
||||
encoder_hidden_states=text_encoder_hidden_states,
|
||||
cache_position=decoder_position_ids,
|
||||
use_cache=True,
|
||||
prefill=is_prefill
|
||||
)
|
||||
|
||||
decoder_position_ids = decoder_position_ids[-1:] + 1
|
||||
box_logits = return_dict["bbox_logits"][:, -1, :].detach()
|
||||
rowcol_logits = return_dict["class_logits"][:, -1, :].detach()
|
||||
|
||||
rowcol_preds = torch.argmax(rowcol_logits, dim=-1)
|
||||
box_preds = torch.argmax(box_logits, dim=-1)
|
||||
|
||||
done = (rowcol_preds == processor.tokenizer.eos_id) | (rowcol_preds == processor.tokenizer.pad_id)
|
||||
done = done
|
||||
all_done = all_done | done
|
||||
|
||||
if all_done.all():
|
||||
break
|
||||
|
||||
batch_decoder_input = torch.cat([box_preds.unsqueeze(1), rowcol_preds.unsqueeze(1).unsqueeze(1)], dim=-1)
|
||||
|
||||
for j, (pred, status) in enumerate(zip(batch_decoder_input, all_done)):
|
||||
if not status:
|
||||
batch_predictions[j].append(pred[0].tolist())
|
||||
|
||||
token_count += inference_token_count
|
||||
inference_token_count = batch_decoder_input.shape[1]
|
||||
|
||||
for j, (preds, input_cells, orig_size) in enumerate(zip(batch_predictions, batch_table_cells, orig_sizes)):
|
||||
img_w, img_h = orig_size
|
||||
width_scaler = img_w / model.config.decoder.out_box_size
|
||||
height_scaler = img_h / model.config.decoder.out_box_size
|
||||
|
||||
# cx, cy to corners
|
||||
for i, pred in enumerate(preds):
|
||||
w = pred[2] / 2
|
||||
h = pred[3] / 2
|
||||
x1 = pred[0] - w
|
||||
y1 = pred[1] - h
|
||||
x2 = pred[0] + w
|
||||
y2 = pred[1] + h
|
||||
class_ = int(pred[4] - SPECIAL_TOKENS)
|
||||
|
||||
preds[i] = [x1 * width_scaler, y1 * height_scaler, x2 * width_scaler, y2 * height_scaler, class_]
|
||||
|
||||
# Get rows and columns
|
||||
bb_rows = [p[:4] for p in preds if p[4] == 0]
|
||||
bb_cols = [p[:4] for p in preds if p[4] == 1]
|
||||
|
||||
rows = []
|
||||
cols = []
|
||||
for row_idx, row in enumerate(bb_rows):
|
||||
cell = TableCell(
|
||||
bbox=row,
|
||||
row_id=row_idx
|
||||
)
|
||||
rows.append(cell)
|
||||
|
||||
for col_idx, col in enumerate(bb_cols):
|
||||
cell = TableCell(
|
||||
bbox=col,
|
||||
col_id=col_idx,
|
||||
)
|
||||
cols.append(cell)
|
||||
|
||||
# Assign cells to rows/columns
|
||||
cells = []
|
||||
for cell in input_cells:
|
||||
max_intersection = 0
|
||||
row_pred = None
|
||||
for row_idx, row in enumerate(rows):
|
||||
intersection_pct = Bbox(bbox=cell["bbox"]).intersection_pct(row)
|
||||
if intersection_pct > max_intersection:
|
||||
max_intersection = intersection_pct
|
||||
row_pred = row_idx
|
||||
|
||||
max_intersection = 0
|
||||
col_pred = None
|
||||
for col_idx, col in enumerate(cols):
|
||||
intersection_pct = Bbox(bbox=cell["bbox"]).intersection_pct(col)
|
||||
if intersection_pct > max_intersection:
|
||||
max_intersection = intersection_pct
|
||||
col_pred = col_idx
|
||||
|
||||
cells.append(
|
||||
TableCell(
|
||||
bbox=cell["bbox"],
|
||||
text=cell.get("text"),
|
||||
row_id=row_pred,
|
||||
col_id=col_pred
|
||||
)
|
||||
)
|
||||
|
||||
rotated = is_rotated(rows, cols)
|
||||
for cell in cells:
|
||||
if cell.row_id is None:
|
||||
closest_row = None
|
||||
closest_row_dist = None
|
||||
for cell2 in cells:
|
||||
if cell2.row_id is None:
|
||||
continue
|
||||
if rotated:
|
||||
cell_y_center = cell.center[0]
|
||||
cell2_y_center = cell2.center[0]
|
||||
else:
|
||||
cell_y_center = cell.center[1]
|
||||
cell2_y_center = cell2.center[1]
|
||||
y_dist = abs(cell_y_center - cell2_y_center)
|
||||
if closest_row_dist is None or y_dist < closest_row_dist:
|
||||
closest_row = cell2.row_id
|
||||
closest_row_dist = y_dist
|
||||
cell.row_id = closest_row
|
||||
|
||||
if cell.col_id is None:
|
||||
closest_col = None
|
||||
closest_col_dist = None
|
||||
for cell2 in cells:
|
||||
if cell2.col_id is None:
|
||||
continue
|
||||
if rotated:
|
||||
cell_x_center = cell.center[1]
|
||||
cell2_x_center = cell2.center[1]
|
||||
else:
|
||||
cell_x_center = cell.center[0]
|
||||
cell2_x_center = cell2.center[0]
|
||||
|
||||
x_dist = abs(cell2_x_center - cell_x_center)
|
||||
if closest_col_dist is None or x_dist < closest_col_dist:
|
||||
closest_col = cell2.col_id
|
||||
closest_col_dist = x_dist
|
||||
|
||||
cell.col_id = closest_col
|
||||
|
||||
result = TableResult(
|
||||
cells=cells,
|
||||
rows=rows,
|
||||
cols=cols,
|
||||
image_bbox=[0, 0, img_w, img_h],
|
||||
)
|
||||
|
||||
output_order.append(result)
|
||||
|
||||
return output_order
|
||||
146
table_recognition.py
Normal file
@ -0,0 +1,146 @@
|
||||
import pypdfium2 as pdfium # Needs to be on top to avoid warning
|
||||
import os
|
||||
import argparse
|
||||
import copy
|
||||
import json
|
||||
from collections import defaultdict
|
||||
|
||||
from surya.detection import batch_text_detection
|
||||
from surya.input.load import load_from_folder, load_from_file
|
||||
from surya.input.pdflines import get_table_blocks
|
||||
from surya.layout import batch_layout_detection
|
||||
from surya.model.detection.model import load_model as load_det_model, load_processor as load_det_processor
|
||||
from surya.model.table_rec.model import load_model as load_model
|
||||
from surya.model.table_rec.processor import load_processor
|
||||
from surya.tables import batch_table_recognition
|
||||
from surya.postprocessing.heatmap import draw_bboxes_on_image
|
||||
from surya.settings import settings
|
||||
from surya.postprocessing.util import rescale_bboxes, rescale_bbox
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Find reading order of an input file or folder (PDFs or image).")
|
||||
parser.add_argument("input_path", type=str, help="Path to pdf or image file or folder to find reading order in.")
|
||||
parser.add_argument("--results_dir", type=str, help="Path to JSON file with layout results.", default=os.path.join(settings.RESULT_DIR, "surya"))
|
||||
parser.add_argument("--max", type=int, help="Maximum number of pages to process.", default=None)
|
||||
parser.add_argument("--images", action="store_true", help="Save images of detected layout bboxes.", default=False)
|
||||
parser.add_argument("--detect_boxes", action="store_true", help="Detect table boxes.", default=False)
|
||||
parser.add_argument("--skip_table_detection", action="store_true", help="Tables are already cropped, so don't re-detect tables.", default=False)
|
||||
args = parser.parse_args()
|
||||
|
||||
model = load_model()
|
||||
processor = load_processor()
|
||||
|
||||
layout_model = load_det_model(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)
|
||||
layout_processor = load_det_processor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)
|
||||
|
||||
det_model = load_det_model()
|
||||
det_processor = load_det_processor()
|
||||
|
||||
if os.path.isdir(args.input_path):
|
||||
images, _, _ = load_from_folder(args.input_path, args.max)
|
||||
highres_images, names, text_lines = load_from_folder(args.input_path, args.max, dpi=settings.IMAGE_DPI_HIGHRES, load_text_lines=True)
|
||||
folder_name = os.path.basename(args.input_path)
|
||||
else:
|
||||
images, _, _ = load_from_file(args.input_path, args.max)
|
||||
highres_images, names, text_lines = load_from_file(args.input_path, args.max, dpi=settings.IMAGE_DPI_HIGHRES, load_text_lines=True)
|
||||
folder_name = os.path.basename(args.input_path).split(".")[0]
|
||||
|
||||
pnums = []
|
||||
prev_name = None
|
||||
for i, name in enumerate(names):
|
||||
if prev_name is None or prev_name != name:
|
||||
pnums.append(0)
|
||||
else:
|
||||
pnums.append(pnums[-1] + 1)
|
||||
|
||||
prev_name = name
|
||||
|
||||
line_predictions = batch_text_detection(images, det_model, det_processor)
|
||||
layout_predictions = batch_layout_detection(images, layout_model, layout_processor, line_predictions)
|
||||
table_cells = []
|
||||
|
||||
table_imgs = []
|
||||
table_counts = []
|
||||
|
||||
for layout_pred, text_line, img, highres_img in zip(layout_predictions, text_lines, images, highres_images):
|
||||
# The table may already be cropped
|
||||
if args.skip_table_detection:
|
||||
table_imgs.append(highres_img)
|
||||
table_counts.append(1)
|
||||
page_table_imgs = [highres_img]
|
||||
highres_bbox = [[0, 0, highres_img.size[0], highres_img.size[1]]]
|
||||
else:
|
||||
# The bbox for the entire table
|
||||
bbox = [l.bbox for l in layout_pred.bboxes if l.label == "Table"]
|
||||
# Number of tables per page
|
||||
table_counts.append(len(bbox))
|
||||
|
||||
if len(bbox) == 0:
|
||||
continue
|
||||
|
||||
page_table_imgs = []
|
||||
highres_bbox = []
|
||||
for bb in bbox:
|
||||
highres_bb = rescale_bbox(bb, img.size, highres_img.size)
|
||||
page_table_imgs.append(highres_img.crop(highres_bb))
|
||||
highres_bbox.append(highres_bb)
|
||||
|
||||
table_imgs.extend(page_table_imgs)
|
||||
|
||||
# The text cells inside each table
|
||||
table_blocks = get_table_blocks(highres_bbox, text_line, highres_img.size) if text_line is not None else None
|
||||
if text_line is None or args.detect_boxes or any(len(tb) == 0 for tb in table_blocks):
|
||||
det_results = batch_text_detection(page_table_imgs, det_model, det_processor,)
|
||||
cell_bboxes = [[{"bbox": tb.bbox, "text": None} for tb in det_result.bboxes] for det_result in det_results]
|
||||
table_cells.extend(cell_bboxes)
|
||||
else:
|
||||
table_cells.extend(table_blocks)
|
||||
|
||||
table_preds = batch_table_recognition(table_imgs, table_cells, model, processor)
|
||||
result_path = os.path.join(args.results_dir, folder_name)
|
||||
os.makedirs(result_path, exist_ok=True)
|
||||
|
||||
img_idx = 0
|
||||
prev_count = 0
|
||||
table_predictions = defaultdict(list)
|
||||
for i in range(sum(table_counts)):
|
||||
while i >= prev_count + table_counts[img_idx]:
|
||||
prev_count += table_counts[img_idx]
|
||||
img_idx += 1
|
||||
|
||||
pred = table_preds[i]
|
||||
orig_name = names[img_idx]
|
||||
pnum = pnums[img_idx]
|
||||
table_img = table_imgs[i]
|
||||
|
||||
out_pred = pred.model_dump()
|
||||
out_pred["page"] = pnum + 1
|
||||
table_idx = i - prev_count
|
||||
out_pred["table_idx"] = table_idx
|
||||
table_predictions[orig_name].append(out_pred)
|
||||
|
||||
if args.images:
|
||||
boxes = [l.bbox for l in pred.cells]
|
||||
labels = [f"{l.row_id}/{l.col_id}" for l in pred.cells]
|
||||
bbox_image = draw_bboxes_on_image(boxes, copy.deepcopy(table_img), labels=labels, label_font_size=20)
|
||||
bbox_image.save(os.path.join(result_path, f"{name}_page{pnum + 1}_table{table_idx}_cells.png"))
|
||||
|
||||
rows = [l.bbox for l in pred.rows]
|
||||
cols = [l.bbox for l in pred.cols]
|
||||
row_labels = [f"Row {l.row_id}" for l in pred.rows]
|
||||
col_labels = [f"Col {l.col_id}" for l in pred.cols]
|
||||
|
||||
rc_image = copy.deepcopy(table_img)
|
||||
rc_image = draw_bboxes_on_image(rows, rc_image, labels=row_labels, label_font_size=20, color="blue")
|
||||
rc_image = draw_bboxes_on_image(cols, rc_image, labels=col_labels, label_font_size=20, color="red")
|
||||
rc_image.save(os.path.join(result_path, f"{name}_page{pnum + 1}_table{table_idx}_rc.png"))
|
||||
|
||||
with open(os.path.join(result_path, "results.json"), "w+", encoding="utf-8") as f:
|
||||
json.dump(table_predictions, f, ensure_ascii=False)
|
||||
|
||||
print(f"Wrote results to {result_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||