Merge pull request #196 from VikParuchuri/dev

Table recognition, better layout
This commit is contained in:
Vik Paruchuri 2024-10-08 09:03:02 -07:00 committed by GitHub
commit a87dede439
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
39 changed files with 2789 additions and 151 deletions

View File

@ -36,10 +36,11 @@ jobs:
run: |
poetry run python benchmark/layout.py --max 5
poetry run python scripts/verify_benchmark_scores.py results/benchmark/layout_bench/results.json --bench_type layout
- name: Run ordering benchmark text
- name: Run ordering benchmark
run: |
poetry run python benchmark/ordering.py --max 5
poetry run python scripts/verify_benchmark_scores.py results/benchmark/order_bench/results.json --bench_type ordering
- name: Run table recognition benchmark
run: |
poetry run python benchmark/table_recognition.py --max 5
poetry run python scripts/verify_benchmark_scores.py results/benchmark/table_rec_bench/results.json --bench_type table_recognition

108
README.md
View File

@ -6,16 +6,23 @@ Surya is a document OCR toolkit that does:
- Line-level text detection in any language
- Layout analysis (table, image, header, etc detection)
- Reading order detection
- Table recognition (detecting rows/columns)
It works on a range of documents (see [usage](#usage) and [benchmarks](#benchmarks) for more details).
| Detection | OCR |
|:----------------------------------------------------------------:|:-----------------------------------------------------------------------:|
| ![New York Times Article Detection](static/images/excerpt.png) | ![New York Times Article Recognition](static/images/excerpt_text.png) |
| <img src="static/images/excerpt.png" width="500px"/> | <img src="static/images/excerpt_text.png" width="500px"/> |
| Layout | Reading Order |
|:------------------------------------------------------------------:|:--------------------------------------------------------------------------:|
| ![New York Times Article Layout](static/images/excerpt_layout.png) | ![New York Times Article Reading Order](static/images/excerpt_reading.jpg) |
| <img src="static/images/excerpt_layout.png" width="500px"/> | <img src="static/images/excerpt_reading.jpg" width="500px"/> |
| Table Recognition | |
|:-----------------------------------------:|:----------------:|
| <img src="static/images/table_rec.png" width="500px"/> | <img width="500px"/> |
Surya is named for the [Hindu sun god](https://en.wikipedia.org/wiki/Surya), who has universal vision.
@ -25,19 +32,19 @@ Surya is named for the [Hindu sun god](https://en.wikipedia.org/wiki/Surya), who
## Examples
| Name | Detection | OCR | Layout | Order |
|------------------|:-----------------------------------:|-----------------------------------------:|-------------------------------------------:|--------------------------------------------:|
| Japanese | [Image](static/images/japanese.jpg) | [Image](static/images/japanese_text.jpg) | [Image](static/images/japanese_layout.jpg) | [Image](static/images/japanese_reading.jpg) |
| Chinese | [Image](static/images/chinese.jpg) | [Image](static/images/chinese_text.jpg) | [Image](static/images/chinese_layout.jpg) | [Image](static/images/chinese_reading.jpg) |
| Hindi | [Image](static/images/hindi.jpg) | [Image](static/images/hindi_text.jpg) | [Image](static/images/hindi_layout.jpg) | [Image](static/images/hindi_reading.jpg) |
| Arabic | [Image](static/images/arabic.jpg) | [Image](static/images/arabic_text.jpg) | [Image](static/images/arabic_layout.jpg) | [Image](static/images/arabic_reading.jpg) |
| Chinese + Hindi | [Image](static/images/chi_hind.jpg) | [Image](static/images/chi_hind_text.jpg) | [Image](static/images/chi_hind_layout.jpg) | [Image](static/images/chi_hind_reading.jpg) |
| Presentation | [Image](static/images/pres.png) | [Image](static/images/pres_text.jpg) | [Image](static/images/pres_layout.jpg) | [Image](static/images/pres_reading.jpg) |
| Scientific Paper | [Image](static/images/paper.jpg) | [Image](static/images/paper_text.jpg) | [Image](static/images/paper_layout.jpg) | [Image](static/images/paper_reading.jpg) |
| Scanned Document | [Image](static/images/scanned.png) | [Image](static/images/scanned_text.jpg) | [Image](static/images/scanned_layout.jpg) | [Image](static/images/scanned_reading.jpg) |
| New York Times | [Image](static/images/nyt.jpg) | [Image](static/images/nyt_text.jpg) | [Image](static/images/nyt_layout.jpg) | [Image](static/images/nyt_order.jpg) |
| Scanned Form | [Image](static/images/funsd.png) | [Image](static/images/funsd_text.jpg) | [Image](static/images/funsd_layout.jpg) | [Image](static/images/funsd_reading.jpg) |
| Textbook | [Image](static/images/textbook.jpg) | [Image](static/images/textbook_text.jpg) | [Image](static/images/textbook_layout.jpg) | [Image](static/images/textbook_order.jpg) |
| Name | Detection | OCR | Layout | Order | Table Rec |
|------------------|:-----------------------------------:|-----------------------------------------:|-------------------------------------------:|--------------------------------------------:|---------------------------------------------:|
| Japanese | [Image](static/images/japanese.jpg) | [Image](static/images/japanese_text.jpg) | [Image](static/images/japanese_layout.jpg) | [Image](static/images/japanese_reading.jpg) | [Image](static/images/japanese_tablerec.png) |
| Chinese | [Image](static/images/chinese.jpg) | [Image](static/images/chinese_text.jpg) | [Image](static/images/chinese_layout.jpg) | [Image](static/images/chinese_reading.jpg) | |
| Hindi | [Image](static/images/hindi.jpg) | [Image](static/images/hindi_text.jpg) | [Image](static/images/hindi_layout.jpg) | [Image](static/images/hindi_reading.jpg) | |
| Arabic | [Image](static/images/arabic.jpg) | [Image](static/images/arabic_text.jpg) | [Image](static/images/arabic_layout.jpg) | [Image](static/images/arabic_reading.jpg) | |
| Chinese + Hindi | [Image](static/images/chi_hind.jpg) | [Image](static/images/chi_hind_text.jpg) | [Image](static/images/chi_hind_layout.jpg) | [Image](static/images/chi_hind_reading.jpg) | |
| Presentation | [Image](static/images/pres.png) | [Image](static/images/pres_text.jpg) | [Image](static/images/pres_layout.jpg) | [Image](static/images/pres_reading.jpg) | [Image](static/images/pres_tablerec.png) |
| Scientific Paper | [Image](static/images/paper.jpg) | [Image](static/images/paper_text.jpg) | [Image](static/images/paper_layout.jpg) | [Image](static/images/paper_reading.jpg) | [Image](static/images/paper_tablerec.png) |
| Scanned Document | [Image](static/images/scanned.png) | [Image](static/images/scanned_text.jpg) | [Image](static/images/scanned_layout.jpg) | [Image](static/images/scanned_reading.jpg) | [Image](static/images/scanned_tablerec.png) |
| New York Times | [Image](static/images/nyt.jpg) | [Image](static/images/nyt_text.jpg) | [Image](static/images/nyt_layout.jpg) | [Image](static/images/nyt_order.jpg) | |
| Scanned Form | [Image](static/images/funsd.png) | [Image](static/images/funsd_text.jpg) | [Image](static/images/funsd_layout.jpg) | [Image](static/images/funsd_reading.jpg) | [Image](static/images/scanned_tablerec2.png) |
| Textbook | [Image](static/images/textbook.jpg) | [Image](static/images/textbook_text.jpg) | [Image](static/images/textbook_layout.jpg) | [Image](static/images/textbook_order.jpg) | |
# Hosted API
@ -272,6 +279,43 @@ processor = load_processor()
order_predictions = batch_ordering([image], [bboxes], model, processor)
```
## Table Recognition
This command will write out a json file with the detected table cells and row/column ids, along with row/column bounding boxes.
```shell
surya_table DATA_PATH
```
- `DATA_PATH` can be an image, pdf, or folder of images/pdfs
- `--images` will save images of the pages and detected table cells + rows and columns (optional)
- `--max` specifies the maximum number of pages to process if you don't want to process everything
- `--results_dir` specifies the directory to save results to instead of the default
- `--detect_boxes` specifies if cells should be detected. By default, they're pulled out of the PDF, but this is not always possible.
- `--skip_table_detection` tells table recognition not to detect tables first. Use this if your image is already cropped to a table.
The `results.json` file will contain a json dictionary where the keys are the input filenames without extensions. Each value will be a list of dictionaries, one per page of the input document. Each page dictionary contains:
- `cells` - detected table cells
- `bbox` - the axis-aligned rectangle for the text line in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner.
- `row_id` - the id of the row this cell belongs to.
- `col_id` - the id of the column this cell belongs to.
- `text` - if text could be pulled out of the pdf, the text of this cell.
- `rows` - detected table rows
- `bbox` - the bounding box of the table row
- `row_id` - the id of the row
- `cols` - detected table columns
- `bbox` - the bounding box of the table column
- `col_id`- the id of the column
- `page` - the page number in the file
- `table_idx` - the index of the table on the page (sorted in vertical order)
- `image_bbox` - the bbox for the image in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner. All line bboxes will be contained within this bbox.
**Performance tips**
Setting the `TABLE_REC_BATCH_SIZE` env var properly will make a big difference when using a GPU. Each batch item will use `150MB` of VRAM, so very high batch sizes are possible. The default is a batch size `64`, which will use about 10GB of VRAM. Depending on your CPU core count, it might help, too - the default CPU batch size is `8`.
# Limitations
- This is specialized for document OCR. It will likely not work on photos or other images.
@ -381,10 +425,23 @@ I benchmarked the layout analysis on [Publaynet](https://github.com/ibm-aur-nlp/
**Methodology**
I benchmarked the layout analysis on the layout dataset from [here](https://www.icst.pku.edu.cn/cpdp/sjzy/), which was not in the training data. Unfortunately, this dataset is fairly noisy, and not all the labels are correct. It was very hard to find a dataset annotated with reading order and also layout information. I wanted to avoid using a cloud service for the ground truth.
I benchmarked the reading order on the layout dataset from [here](https://www.icst.pku.edu.cn/cpdp/sjzy/), which was not in the training data. Unfortunately, this dataset is fairly noisy, and not all the labels are correct. It was very hard to find a dataset annotated with reading order and also layout information. I wanted to avoid using a cloud service for the ground truth.
The accuracy is computed by finding if each pair of layout boxes is in the correct order, then taking the % that are correct.
## Table Recognition
| Model | Row Intersection | Col Intersection | Time Per Image |
|-------------------|------------------|------------------|------------------|
| Surya | 0.97 | 0.93 | 0.03 |
| Table transformer | 0.72 | 0.84 | 0.02 |
Higher is better for intersection, which the percentage of the actual row/column overlapped by the predictions.
**Methodology**
The benchmark uses a subset of [Fintabnet](https://developer.ibm.com/exchanges/data/all/fintabnet/) from IBM. It has labeled rows and columns. After table recognition is run, the predicted rows and columns are compared to the ground truth. There is an additional penalty for predicting too many or too few rows/columns.
## Running your own benchmarks
You can benchmark the performance of surya on your machine.
@ -396,7 +453,7 @@ You can benchmark the performance of surya on your machine.
This will evaluate tesseract and surya for text line detection across a randomly sampled set of images from [doclaynet](https://huggingface.co/datasets/vikp/doclaynet_bench).
```
```shell
python benchmark/detection.py --max 256
```
@ -409,7 +466,7 @@ python benchmark/detection.py --max 256
This will evaluate surya and optionally tesseract on multilingual pdfs from common crawl (with synthetic data for missing languages).
```
```shell
python benchmark/recognition.py --tesseract
```
@ -425,7 +482,7 @@ python benchmark/recognition.py --tesseract
This will evaluate surya on the publaynet dataset.
```
```shell
python benchmark/layout.py
```
@ -435,7 +492,7 @@ python benchmark/layout.py
**Reading Order**
```
```shell
python benchmark/ordering.py
```
@ -443,6 +500,17 @@ python benchmark/ordering.py
- `--debug` will render images with detected text
- `--results_dir` will let you specify a directory to save results to instead of the default one
**Table Recognition**
```shell
python benchmark/table_recognition.py --max 1024 --tatr
```
- `--max` controls how many images to process for the benchmark
- `--debug` will render images with detected text
- `--results_dir` will let you specify a directory to save results to instead of the default one
- `--tatr` specifies whether to also run table transformer
# Training
Text detection was trained on 4x A6000s for 3 days. It used a diverse set of images as training data. It was trained from scratch using a modified efficientvit architecture for semantic segmentation.

View File

@ -0,0 +1,143 @@
import argparse
import collections
import copy
import json
from tabulate import tabulate
from surya.input.processing import convert_if_not_rgb
from surya.model.table_rec.model import load_model
from surya.model.table_rec.processor import load_processor
from surya.tables import batch_table_recognition, get_batch_size
from surya.settings import settings
from surya.benchmark.metrics import rank_accuracy, penalized_iou_score
from surya.benchmark.tatr import load_tatr, batch_inference_tatr
import os
import time
import datasets
def main():
parser = argparse.ArgumentParser(description="Benchmark surya table recognition model.")
parser.add_argument("--results_dir", type=str, help="Path to JSON file with benchmark results.", default=os.path.join(settings.RESULT_DIR, "benchmark"))
parser.add_argument("--max", type=int, help="Maximum number of images to run benchmark on.", default=None)
parser.add_argument("--tatr", action="store_true", help="Run table transformer.", default=False)
args = parser.parse_args()
model = load_model()
processor = load_processor()
pathname = "table_rec_bench"
# These have already been shuffled randomly, so sampling from the start is fine
split = "train"
if args.max is not None:
split = f"train[:{args.max}]"
dataset = datasets.load_dataset(settings.TABLE_REC_BENCH_DATASET_NAME, split=split)
images = list(dataset["image"])
images = convert_if_not_rgb(images)
bboxes = list(dataset["bboxes"])
start = time.time()
bboxes = [[{"bbox": b, "text": None} for b in bb] for bb in bboxes]
table_rec_predictions = batch_table_recognition(images, bboxes, model, processor)
surya_time = time.time() - start
folder_name = os.path.basename(pathname).split(".")[0]
result_path = os.path.join(args.results_dir, folder_name)
os.makedirs(result_path, exist_ok=True)
page_metrics = collections.OrderedDict()
mean_col_iou = 0
mean_row_iou = 0
for idx, pred in enumerate(table_rec_predictions):
row = dataset[idx]
pred_row_boxes = [p.bbox for p in pred.rows]
pred_col_bboxes = [p.bbox for p in pred.cols]
actual_row_bboxes = row["rows"]
actual_col_bboxes = row["cols"]
row_score = penalized_iou_score(pred_row_boxes, actual_row_bboxes)
col_score = penalized_iou_score(pred_col_bboxes, actual_col_bboxes)
page_results = {
"row_score": row_score,
"col_score": col_score,
"row_count": len(actual_row_bboxes),
"col_count": len(actual_col_bboxes)
}
mean_col_iou += col_score
mean_row_iou += row_score
page_metrics[idx] = page_results
mean_col_iou /= len(table_rec_predictions)
mean_row_iou /= len(table_rec_predictions)
out_data = {"surya": {
"time": surya_time,
"mean_row_iou": mean_row_iou,
"mean_col_iou": mean_col_iou,
"page_metrics": page_metrics
}}
if args.tatr:
tatr_model = load_tatr()
start = time.time()
tatr_predictions = batch_inference_tatr(tatr_model, images, 1)
tatr_time = time.time() - start
page_metrics = collections.OrderedDict()
mean_col_iou = 0
mean_row_iou = 0
for idx, pred in enumerate(tatr_predictions):
row = dataset[idx]
pred_row_boxes = [p["bbox"] for p in pred["rows"]]
pred_col_bboxes = [p["bbox"] for p in pred["cols"]]
actual_row_bboxes = row["rows"]
actual_col_bboxes = row["cols"]
row_score = penalized_iou_score(pred_row_boxes, actual_row_bboxes)
col_score = penalized_iou_score(pred_col_bboxes, actual_col_bboxes)
page_results = {
"row_score": row_score,
"col_score": col_score,
"row_count": len(actual_row_bboxes),
"col_count": len(actual_col_bboxes)
}
mean_col_iou += col_score
mean_row_iou += row_score
page_metrics[idx] = page_results
mean_col_iou /= len(tatr_predictions)
mean_row_iou /= len(tatr_predictions)
out_data["tatr"] = {
"time": tatr_time,
"mean_row_iou": mean_row_iou,
"mean_col_iou": mean_col_iou,
"page_metrics": page_metrics
}
with open(os.path.join(result_path, "results.json"), "w+") as f:
json.dump(out_data, f, indent=4)
table = [
["Model", "Row Intersection", "Col Intersection", "Time Per Image"],
["Surya", f"{out_data['surya']['mean_row_iou']:.2f}", f"{out_data['surya']['mean_col_iou']:.2f}",
f"{surya_time / len(images):.2f}"],
]
if args.tatr:
table.append(["Table transformer", f"{out_data['tatr']['mean_row_iou']:.2f}", f"{out_data['tatr']['mean_col_iou']:.2f}",
f"{tatr_time / len(images):.2f}"])
print(tabulate(table, headers="firstrow", tablefmt="github"))
print("Intersection is the average of the intersection % between each actual row/column, and the predictions. With penalties for too many/few predictions.")
print("Note that table transformers is unbatched, since the example code in the repo is unbatched.")
print(f"Wrote results to {result_path}")
if __name__ == "__main__":
main()

View File

@ -1,3 +1,4 @@
import pypdfium2 # Causes a warning if not the top import
import argparse
import copy
import json
@ -27,10 +28,10 @@ def main():
det_processor = load_processor()
if os.path.isdir(args.input_path):
images, names = load_from_folder(args.input_path, args.max)
images, names, _ = load_from_folder(args.input_path, args.max)
folder_name = os.path.basename(args.input_path)
else:
images, names = load_from_file(args.input_path, args.max)
images, names, _ = load_from_file(args.input_path, args.max)
folder_name = os.path.basename(args.input_path).split(".")[0]
line_predictions = batch_text_detection(images, det_model, det_processor)

View File

@ -28,10 +28,10 @@ def main():
processor = load_processor(checkpoint=checkpoint)
if os.path.isdir(args.input_path):
images, names = load_from_folder(args.input_path, args.max)
images, names, _ = load_from_folder(args.input_path, args.max)
folder_name = os.path.basename(args.input_path)
else:
images, names = load_from_file(args.input_path, args.max)
images, names, _ = load_from_file(args.input_path, args.max)
folder_name = os.path.basename(args.input_path).split(".")[0]
start = time.time()

View File

@ -4,21 +4,27 @@ from typing import List
import pypdfium2
import streamlit as st
from surya.detection import batch_text_detection
from surya.input.pdflines import get_page_text_lines, get_table_blocks
from surya.layout import batch_layout_detection
from surya.model.detection.model import load_model, load_processor
from surya.model.recognition.model import load_model as load_rec_model
from surya.model.recognition.processor import load_processor as load_rec_processor
from surya.model.ordering.processor import load_processor as load_order_processor
from surya.model.ordering.model import load_model as load_order_model
from surya.model.table_rec.model import load_model as load_table_model
from surya.model.table_rec.processor import load_processor as load_table_processor
from surya.ordering import batch_ordering
from surya.postprocessing.heatmap import draw_polys_on_image
from surya.postprocessing.heatmap import draw_polys_on_image, draw_bboxes_on_image
from surya.ocr import run_ocr
from surya.postprocessing.text import draw_text_on_image
from PIL import Image
from surya.languages import CODE_TO_LANGUAGE
from surya.input.langs import replace_lang_with_code
from surya.schema import OCRResult, TextDetectionResult, LayoutResult, OrderResult
from surya.schema import OCRResult, TextDetectionResult, LayoutResult, OrderResult, TableResult
from surya.settings import settings
from surya.tables import batch_table_recognition
from surya.postprocessing.util import rescale_bboxes, rescale_bbox
@st.cache_resource()
def load_det_cached():
@ -40,6 +46,11 @@ def load_order_cached():
return load_order_model(), load_order_processor()
@st.cache_resource()
def load_table_cached():
return load_table_model(), load_table_processor()
def text_detection(img) -> (Image.Image, TextDetectionResult):
pred = batch_text_detection([img], det_model, det_processor)[0]
polygons = [p.polygon for p in pred.bboxes]
@ -52,7 +63,7 @@ def layout_detection(img) -> (Image.Image, LayoutResult):
pred = batch_layout_detection([img], layout_model, layout_processor, [det_pred])[0]
polygons = [p.polygon for p in pred.bboxes]
labels = [p.label for p in pred.bboxes]
layout_img = draw_polys_on_image(polygons, img.copy(), labels=labels)
layout_img = draw_polys_on_image(polygons, img.copy(), labels=labels, label_font_size=18)
return layout_img, pred
@ -62,14 +73,56 @@ def order_detection(img) -> (Image.Image, OrderResult):
pred = batch_ordering([img], [bboxes], order_model, order_processor)[0]
polys = [l.polygon for l in pred.bboxes]
positions = [str(l.position) for l in pred.bboxes]
order_img = draw_polys_on_image(polys, img.copy(), labels=positions, label_font_size=20)
order_img = draw_polys_on_image(polys, img.copy(), labels=positions, label_font_size=18)
return order_img, pred
def table_recognition(img, highres_img, filepath, page_idx: int, use_pdf_boxes: bool, skip_table_detection: bool) -> (Image.Image, List[TableResult]):
if skip_table_detection:
layout_tables = [(0, 0, highres_img.size[0], highres_img.size[1])]
table_imgs = [highres_img]
else:
_, layout_pred = layout_detection(img)
layout_tables_lowres = [l.bbox for l in layout_pred.bboxes if l.label == "Table"]
table_imgs = []
layout_tables = []
for tb in layout_tables_lowres:
highres_bbox = rescale_bbox(tb, img.size, highres_img.size)
table_imgs.append(
highres_img.crop(highres_bbox)
)
layout_tables.append(highres_bbox)
page_text = get_page_text_lines(filepath, [page_idx], [highres_img.size])[0]
table_bboxes = get_table_blocks(layout_tables, page_text, highres_img.size)
if not use_pdf_boxes or any(len(tb) == 0 for tb in table_bboxes):
det_results = batch_text_detection(table_imgs, det_model, det_processor)
table_bboxes = [[{"bbox": tb.bbox, "text": None} for tb in det_result.bboxes] for det_result in det_results]
table_preds = batch_table_recognition(table_imgs, table_bboxes, table_model, table_processor)
table_img = img.copy()
for results, table_bbox in zip(table_preds, layout_tables):
adjusted_bboxes = []
labels = []
for item in results.cells:
adjusted_bboxes.append([
(item.bbox[0] + table_bbox[0]),
(item.bbox[1] + table_bbox[1]),
(item.bbox[2] + table_bbox[0]),
(item.bbox[3] + table_bbox[1])
])
labels.append(f"{item.row_id} / {item.col_id}")
table_img = draw_bboxes_on_image(adjusted_bboxes, highres_img, labels=labels, label_font_size=18)
return table_img, table_preds
# Function for OCR
def ocr(img, langs: List[str]) -> (Image.Image, OCRResult):
def ocr(img, highres_img, langs: List[str]) -> (Image.Image, OCRResult):
replace_lang_with_code(langs)
img_pred = run_ocr([img], [langs], det_model, det_processor, rec_model, rec_processor)[0]
img_pred = run_ocr([img], [langs], det_model, det_processor, rec_model, rec_processor, highres_images=[highres_img])[0]
bboxes = [l.bbox for l in img_pred.text_lines]
text = [l.text for l in img_pred.text_lines]
@ -83,7 +136,7 @@ def open_pdf(pdf_file):
@st.cache_data()
def get_page_image(pdf_file, page_num, dpi=96):
def get_page_image(pdf_file, page_num, dpi=settings.IMAGE_DPI):
doc = open_pdf(pdf_file)
renderer = doc.render(
pypdfium2.PdfBitmap.to_pil,
@ -108,6 +161,7 @@ det_model, det_processor = load_det_cached()
rec_model, rec_processor = load_rec_cached()
layout_model, layout_processor = load_layout_cached()
order_model, order_processor = load_order_cached()
table_model, table_processor = load_table_cached()
st.markdown("""
@ -136,14 +190,20 @@ if "pdf" in filetype:
page_count = page_count(in_file)
page_number = st.sidebar.number_input(f"Page number out of {page_count}:", min_value=1, value=1, max_value=page_count)
pil_image = get_page_image(in_file, page_number)
pil_image = get_page_image(in_file, page_number, settings.IMAGE_DPI)
pil_image_highres = get_page_image(in_file, page_number, dpi=settings.IMAGE_DPI_HIGHRES)
else:
pil_image = Image.open(in_file).convert("RGB")
pil_image_highres = pil_image
page_number = None
text_det = st.sidebar.button("Run Text Detection")
text_rec = st.sidebar.button("Run OCR")
layout_det = st.sidebar.button("Run Layout Analysis")
order_det = st.sidebar.button("Run Reading Order")
table_rec = st.sidebar.button("Run Table Rec")
use_pdf_boxes = st.sidebar.checkbox("PDF table boxes", value=True, help="Table recognition only: Use the bounding boxes from the PDF file vs text detection model.")
skip_table_detection = st.sidebar.checkbox("Skip table detection", value=False, help="Table recognition only: Skip table detection and treat the whole image/page as a table.")
if pil_image is None:
st.stop()
@ -165,7 +225,7 @@ if layout_det:
# Run OCR
if text_rec:
rec_img, pred = ocr(pil_image, languages)
rec_img, pred = ocr(pil_image, pil_image_highres, languages)
with col1:
st.image(rec_img, caption="OCR Result", use_column_width=True)
json_tab, text_tab = st.tabs(["JSON", "Text Lines (for debugging)"])
@ -180,5 +240,12 @@ if order_det:
st.image(order_img, caption="Reading Order", use_column_width=True)
st.json(pred.model_dump(), expanded=True)
if table_rec:
table_img, pred = table_recognition(pil_image, pil_image_highres, in_file, page_number - 1 if page_number else None, use_pdf_boxes, skip_table_detection)
with col1:
st.image(table_img, caption="Table Recognition", use_column_width=True)
st.json([p.model_dump() for p in pred], expanded=True)
with col2:
st.image(pil_image, caption="Uploaded Image", use_column_width=True)

View File

@ -30,10 +30,12 @@ def main():
args = parser.parse_args()
if os.path.isdir(args.input_path):
images, names = load_from_folder(args.input_path, args.max, args.start_page)
images, names, _ = load_from_folder(args.input_path, args.max, args.start_page)
highres_images, _, _ = load_from_folder(args.input_path, args.max, args.start_page, settings.IMAGE_DPI_HIGHRES)
folder_name = os.path.basename(args.input_path)
else:
images, names = load_from_file(args.input_path, args.max, args.start_page)
images, names, _ = load_from_file(args.input_path, args.max, args.start_page)
highres_images, _, _ = load_from_file(args.input_path, args.max, args.start_page, settings.IMAGE_DPI_HIGHRES)
folder_name = os.path.basename(args.input_path).split(".")[0]
if args.lang_file:
@ -60,7 +62,7 @@ def main():
os.makedirs(result_path, exist_ok=True)
start = time.time()
predictions_by_image = run_ocr(images, image_langs, det_model, det_processor, rec_model, rec_processor)
predictions_by_image = run_ocr(images, image_langs, det_model, det_processor, rec_model, rec_processor, highres_images=highres_images)
if args.debug:
print(f"OCR took {time.time() - start:.2f} seconds")
max_chars = max([len(l.text) for p in predictions_by_image for l in p.text_lines])
@ -70,7 +72,7 @@ def main():
for idx, (name, image, pred, langs) in enumerate(zip(names, images, predictions_by_image, image_langs)):
bboxes = [l.bbox for l in pred.text_lines]
pred_text = [l.text for l in pred.text_lines]
page_image = draw_text_on_image(bboxes, pred_text, image.size, langs, has_math="_math" in langs)
page_image = draw_text_on_image(bboxes, pred_text, image.size, langs, has_math="_math" in langs if langs else False)
page_image.save(os.path.join(result_path, f"{name}_{idx}_text.png"))
out_preds = defaultdict(list)

123
poetry.lock generated
View File

@ -605,6 +605,23 @@ files = [
{file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
]
[[package]]
name = "coloredlogs"
version = "15.0.1"
description = "Colored terminal output for Python's logging module"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
files = [
{file = "coloredlogs-15.0.1-py2.py3-none-any.whl", hash = "sha256:612ee75c546f53e92e70049c9dbfcc18c935a2b9a53b66085ce9ef6a6e5c0934"},
{file = "coloredlogs-15.0.1.tar.gz", hash = "sha256:7c991aa71a4577af2f82600d8f8f3a89f936baeaf9b50a9c197da014e5bf16b0"},
]
[package.dependencies]
humanfriendly = ">=9.1"
[package.extras]
cron = ["capturer (>=2.4)"]
[[package]]
name = "comm"
version = "0.2.2"
@ -803,6 +820,17 @@ files = [
{file = "filetype-1.2.0.tar.gz", hash = "sha256:66b56cd6474bf41d8c54660347d37afcc3f7d1970648de365c102ef77548aadb"},
]
[[package]]
name = "flatbuffers"
version = "24.3.25"
description = "The FlatBuffers serialization format for Python"
optional = false
python-versions = "*"
files = [
{file = "flatbuffers-24.3.25-py2.py3-none-any.whl", hash = "sha256:8dbdec58f935f3765e4f7f3cf635ac3a77f83568138d6a2311f524ec96364812"},
{file = "flatbuffers-24.3.25.tar.gz", hash = "sha256:de2ec5b203f21441716617f38443e0a8ebf3d25bf0d9c0bb0ce68fa00ad546a4"},
]
[[package]]
name = "fqdn"
version = "1.5.1"
@ -1148,6 +1176,20 @@ testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gr
torch = ["safetensors", "torch"]
typing = ["types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)"]
[[package]]
name = "humanfriendly"
version = "10.0"
description = "Human friendly output for text interfaces using Python"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
files = [
{file = "humanfriendly-10.0-py2.py3-none-any.whl", hash = "sha256:1697e1a8a8f550fd43c2865cd84542fc175a61dcb779b6fee18cf6b6ccba1477"},
{file = "humanfriendly-10.0.tar.gz", hash = "sha256:6b0b831ce8f15f7300721aa49829fc4e83921a9a301cc7f606be6686a2288ddc"},
]
[package.dependencies]
pyreadline3 = {version = "*", markers = "sys_platform == \"win32\" and python_version >= \"3.8\""}
[[package]]
name = "idna"
version = "3.7"
@ -2296,6 +2338,48 @@ files = [
{file = "nvidia_nvtx_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:65f4d98982b31b60026e0e6de73fbdfc09d08a96f4656dd3665ca616a11e1e82"},
]
[[package]]
name = "onnxruntime"
version = "1.19.2"
description = "ONNX Runtime is a runtime accelerator for Machine Learning models"
optional = false
python-versions = "*"
files = [
{file = "onnxruntime-1.19.2-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:84fa57369c06cadd3c2a538ae2a26d76d583e7c34bdecd5769d71ca5c0fc750e"},
{file = "onnxruntime-1.19.2-cp310-cp310-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:bdc471a66df0c1cdef774accef69e9f2ca168c851ab5e4f2f3341512c7ef4666"},
{file = "onnxruntime-1.19.2-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e3a4ce906105d99ebbe817f536d50a91ed8a4d1592553f49b3c23c4be2560ae6"},
{file = "onnxruntime-1.19.2-cp310-cp310-win32.whl", hash = "sha256:4b3d723cc154c8ddeb9f6d0a8c0d6243774c6b5930847cc83170bfe4678fafb3"},
{file = "onnxruntime-1.19.2-cp310-cp310-win_amd64.whl", hash = "sha256:17ed7382d2c58d4b7354fb2b301ff30b9bf308a1c7eac9546449cd122d21cae5"},
{file = "onnxruntime-1.19.2-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:d863e8acdc7232d705d49e41087e10b274c42f09e259016a46f32c34e06dc4fd"},
{file = "onnxruntime-1.19.2-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c1dfe4f660a71b31caa81fc298a25f9612815215a47b286236e61d540350d7b6"},
{file = "onnxruntime-1.19.2-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a36511dc07c5c964b916697e42e366fa43c48cdb3d3503578d78cef30417cb84"},
{file = "onnxruntime-1.19.2-cp311-cp311-win32.whl", hash = "sha256:50cbb8dc69d6befad4746a69760e5b00cc3ff0a59c6c3fb27f8afa20e2cab7e7"},
{file = "onnxruntime-1.19.2-cp311-cp311-win_amd64.whl", hash = "sha256:1c3e5d415b78337fa0b1b75291e9ea9fb2a4c1f148eb5811e7212fed02cfffa8"},
{file = "onnxruntime-1.19.2-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:68e7051bef9cfefcbb858d2d2646536829894d72a4130c24019219442b1dd2ed"},
{file = "onnxruntime-1.19.2-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d2d366fbcc205ce68a8a3bde2185fd15c604d9645888703785b61ef174265168"},
{file = "onnxruntime-1.19.2-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:477b93df4db467e9cbf34051662a4b27c18e131fa1836e05974eae0d6e4cf29b"},
{file = "onnxruntime-1.19.2-cp312-cp312-win32.whl", hash = "sha256:9a174073dc5608fad05f7cf7f320b52e8035e73d80b0a23c80f840e5a97c0147"},
{file = "onnxruntime-1.19.2-cp312-cp312-win_amd64.whl", hash = "sha256:190103273ea4507638ffc31d66a980594b237874b65379e273125150eb044857"},
{file = "onnxruntime-1.19.2-cp38-cp38-macosx_11_0_universal2.whl", hash = "sha256:636bc1d4cc051d40bc52e1f9da87fbb9c57d9d47164695dfb1c41646ea51ea66"},
{file = "onnxruntime-1.19.2-cp38-cp38-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5bd8b875757ea941cbcfe01582970cc299893d1b65bd56731e326a8333f638a3"},
{file = "onnxruntime-1.19.2-cp38-cp38-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b2046fc9560f97947bbc1acbe4c6d48585ef0f12742744307d3364b131ac5778"},
{file = "onnxruntime-1.19.2-cp38-cp38-win32.whl", hash = "sha256:31c12840b1cde4ac1f7d27d540c44e13e34f2345cf3642762d2a3333621abb6a"},
{file = "onnxruntime-1.19.2-cp38-cp38-win_amd64.whl", hash = "sha256:016229660adea180e9a32ce218b95f8f84860a200f0f13b50070d7d90e92956c"},
{file = "onnxruntime-1.19.2-cp39-cp39-macosx_11_0_universal2.whl", hash = "sha256:006c8d326835c017a9e9f74c9c77ebb570a71174a1e89fe078b29a557d9c3848"},
{file = "onnxruntime-1.19.2-cp39-cp39-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:df2a94179a42d530b936f154615b54748239c2908ee44f0d722cb4df10670f68"},
{file = "onnxruntime-1.19.2-cp39-cp39-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fae4b4de45894b9ce7ae418c5484cbf0341db6813effec01bb2216091c52f7fb"},
{file = "onnxruntime-1.19.2-cp39-cp39-win32.whl", hash = "sha256:dc5430f473e8706fff837ae01323be9dcfddd3ea471c900a91fa7c9b807ec5d3"},
{file = "onnxruntime-1.19.2-cp39-cp39-win_amd64.whl", hash = "sha256:38475e29a95c5f6c62c2c603d69fc7d4c6ccbf4df602bd567b86ae1138881c49"},
]
[package.dependencies]
coloredlogs = "*"
flatbuffers = "*"
numpy = ">=1.21.6"
packaging = "*"
protobuf = "*"
sympy = "*"
[[package]]
name = "opencv-python"
version = "4.10.0.84"
@ -2315,11 +2399,11 @@ files = [
[package.dependencies]
numpy = [
{version = ">=1.21.0", markers = "python_version == \"3.9\" and platform_system == \"Darwin\" and platform_machine == \"arm64\""},
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
{version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""},
{version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""},
{version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""},
{version = ">=1.19.3", markers = "platform_system == \"Linux\" and platform_machine == \"aarch64\" and python_version >= \"3.8\" and python_version < \"3.10\" or python_version > \"3.9\" and python_version < \"3.10\" or python_version >= \"3.9\" and platform_system != \"Darwin\" and python_version < \"3.10\" or python_version >= \"3.9\" and platform_machine != \"arm64\" and python_version < \"3.10\""},
{version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""},
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
]
[[package]]
@ -2385,8 +2469,8 @@ files = [
[package.dependencies]
numpy = [
{version = ">=1.22.4", markers = "python_version < \"3.11\""},
{version = ">=1.23.2", markers = "python_version == \"3.11\""},
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
{version = ">=1.23.2", markers = "python_version == \"3.11\""},
]
python-dateutil = ">=2.8.2"
pytz = ">=2020.1"
@ -2443,6 +2527,23 @@ files = [
qa = ["flake8 (==5.0.4)", "mypy (==0.971)", "types-setuptools (==67.2.0.1)"]
testing = ["docopt", "pytest"]
[[package]]
name = "pdftext"
version = "0.3.12"
description = "Extract structured text from pdfs quickly"
optional = false
python-versions = "!=2.7.*,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,!=3.7.*,!=3.8.*,>=3.9"
files = [
{file = "pdftext-0.3.12-py3-none-any.whl", hash = "sha256:4903349d5be23984dcb417d6d37611fec9595cf65b535fc54c609f8bc702b847"},
{file = "pdftext-0.3.12.tar.gz", hash = "sha256:a82cccff535c97f806f8f32708bf336e19d1005b7b4a8e08cc44a37522a5cd62"},
]
[package.dependencies]
onnxruntime = ">=1.19.2,<2.0.0"
pydantic = ">=2.7.1,<3.0.0"
pydantic-settings = ">=2.2.1,<3.0.0"
pypdfium2 = ">=4.29.0,<5.0.0"
[[package]]
name = "pexpect"
version = "4.9.0"
@ -3016,6 +3117,20 @@ files = [
{file = "pypdfium2-4.30.0.tar.gz", hash = "sha256:48b5b7e5566665bc1015b9d69c1ebabe21f6aee468b509531c3c8318eeee2e16"},
]
[[package]]
name = "pyreadline3"
version = "3.5.4"
description = "A python implementation of GNU readline."
optional = false
python-versions = ">=3.8"
files = [
{file = "pyreadline3-3.5.4-py3-none-any.whl", hash = "sha256:eaf8e6cc3c49bcccf145fc6067ba8643d1df34d604a1ec0eccbf7a18e6d3fae6"},
{file = "pyreadline3-3.5.4.tar.gz", hash = "sha256:8d57d53039a1c75adba8e50dd3d992b28143480816187ea5efbd5c78e6c885b7"},
]
[package.extras]
dev = ["build", "flake8", "mypy", "pytest", "twine"]
[[package]]
name = "pytesseract"
version = "0.3.10"
@ -4819,4 +4934,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools",
[metadata]
lock-version = "2.0"
python-versions = ">=3.9,<3.13,!=3.9.7"
content-hash = "d250e5223075069c0561f95e970624731feb7ddc20f1bc7b8ef6dd826a8f3085"
content-hash = "58776d15cd2b20d1a735106b587c2595e18fd521dc91ef732192610e8e93d8ff"

View File

@ -1,12 +1,12 @@
[tool.poetry]
name = "surya-ocr"
version = "0.5.0"
description = "OCR, layout, reading order, and line detection in 90+ languages"
version = "0.6.0"
description = "OCR, layout, reading order, and table recognition in 90+ languages"
authors = ["Vik Paruchuri <vik.paruchuri@gmail.com>"]
readme = "README.md"
license = "GPL-3.0-or-later"
repository = "https://github.com/VikParuchuri/surya"
keywords = ["ocr", "pdf", "text detection", "text recognition"]
keywords = ["ocr", "pdf", "text detection", "text recognition", "tables"]
packages = [
{include = "surya"}
]
@ -17,6 +17,7 @@ include = [
"run_ocr_app.py",
"detect_layout.py",
"reading_order.py",
"table_recognition.py"
]
[tool.poetry.dependencies]
@ -32,6 +33,7 @@ opencv-python = "^4.9.0.80"
tabulate = "^0.9.0"
filetype = "^1.2.0"
ftfy = "^6.1.3"
pdftext = "^0.3.12"
[tool.poetry.group.dev.dependencies]
jupyter = "^1.0.0"
@ -50,6 +52,7 @@ surya_ocr = "ocr_text:main"
surya_layout = "detect_layout:main"
surya_gui = "run_ocr_app:run_app"
surya_order = "reading_order:main"
surya_table = "table_recognition:main"
[build-system]
requires = ["poetry-core"]

View File

@ -33,10 +33,10 @@ def main():
det_processor = load_det_processor()
if os.path.isdir(args.input_path):
images, names = load_from_folder(args.input_path, args.max)
images, names, _ = load_from_folder(args.input_path, args.max)
folder_name = os.path.basename(args.input_path)
else:
images, names = load_from_file(args.input_path, args.max)
images, names, _ = load_from_file(args.input_path, args.max)
folder_name = os.path.basename(args.input_path).split(".")[0]
line_predictions = batch_text_detection(images, det_model, det_processor)

View File

@ -27,6 +27,14 @@ def verify_order(data):
raise ValueError("Scores do not meet the required threshold")
def verify_table_rec(data):
row_score = data["surya"]["mean_row_iou"]
col_score = data["surya"]["mean_col_iou"]
if row_score < 0.75 or col_score < 0.75:
raise ValueError("Scores do not meet the required threshold")
def verify_scores(file_path, bench_type):
with open(file_path, 'r') as file:
data = json.load(file)
@ -39,6 +47,8 @@ def verify_scores(file_path, bench_type):
verify_layout(data)
elif bench_type == "ordering":
verify_order(data)
elif bench_type == "table_recognition":
verify_table_rec(data)
else:
raise ValueError("Invalid benchmark type")

Binary file not shown.

After

Width:  |  Height:  |  Size: 25 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 22 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 351 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.1 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 934 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 711 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.6 MiB

BIN
static/images/table_rec.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.0 MiB

View File

@ -4,6 +4,7 @@ from itertools import repeat
import numpy as np
from concurrent.futures import ProcessPoolExecutor
def intersection_area(box1, box2):
x_left = max(box1[0], box2[0])
y_top = max(box1[1], box2[1])
@ -15,6 +16,59 @@ def intersection_area(box1, box2):
return (x_right - x_left) * (y_bottom - y_top)
def box_area(box):
return (box[2] - box[0]) * (box[3] - box[1])
def calculate_iou(box1, box2, box1_only=False):
intersection = intersection_area(box1, box2)
union = box_area(box1)
if not box1_only:
union += box_area(box2) - intersection
if union == 0:
return 0
return intersection / union
def match_boxes(preds, references):
num_actual = len(references)
num_predicted = len(preds)
iou_matrix = np.zeros((num_actual, num_predicted))
for i, actual in enumerate(references):
for j, pred in enumerate(preds):
iou_matrix[i, j] = calculate_iou(actual, pred, box1_only=True)
sorted_indices = np.argsort(iou_matrix, axis=None)[::-1]
sorted_ious = iou_matrix.flatten()[sorted_indices]
actual_indices, predicted_indices = np.unravel_index(sorted_indices, iou_matrix.shape)
assigned_actual = set()
assigned_pred = set()
matches = []
for idx, iou in zip(zip(actual_indices, predicted_indices), sorted_ious):
i, j = idx
if i not in assigned_actual and j not in assigned_pred:
iou_val = iou_matrix[i, j]
if iou_val > .95: # Account for rounding on box edges
iou_val = 1.0
matches.append((i, j, iou_val))
assigned_actual.add(i)
assigned_pred.add(j)
unassigned_actual = set(range(num_actual)) - assigned_actual
unassigned_pred = set(range(num_predicted)) - assigned_pred
matches.extend([(i, None, -1.0) for i in unassigned_actual])
matches.extend([(None, j, 0.0) for j in unassigned_pred])
return matches
def penalized_iou_score(preds, references):
matches = match_boxes(preds, references)
iou = sum([match[2] for match in matches]) / len(matches)
return iou
def intersection_pixels(box1, box2):
x_left = max(box1[0], box2[0])

117
surya/benchmark/tatr.py Normal file
View File

@ -0,0 +1,117 @@
import torch
from transformers import DetrFeatureExtractor, AutoModelForObjectDetection
from surya.settings import settings
from PIL import Image
import numpy as np
class MaxResize(object):
def __init__(self, max_size=800):
self.max_size = max_size
def __call__(self, image):
width, height = image.size
current_max_size = max(width, height)
scale = self.max_size / current_max_size
resized_image = image.resize((int(round(scale * width)), int(round(scale * height))))
return resized_image
def to_tensor(image):
# Convert PIL Image to NumPy array
np_image = np.array(image).astype(np.float32)
# Rearrange dimensions to [C, H, W] format
np_image = np_image.transpose((2, 0, 1))
# Normalize to [0.0, 1.0]
np_image /= 255.0
return torch.from_numpy(np_image)
def normalize(tensor, mean, std):
for t, m, s in zip(tensor, mean, std):
t.sub_(m).div_(s)
return tensor
def structure_transform(image):
image = MaxResize(1000)(image)
tensor = to_tensor(image)
normalized_tensor = normalize(tensor, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
return normalized_tensor
def box_cxcywh_to_xyxy(x):
x_c, y_c, w, h = x.unbind(-1)
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
return torch.stack(b, dim=1)
def rescale_bboxes(out_bbox, size):
width, height = size
boxes = box_cxcywh_to_xyxy(out_bbox)
boxes = boxes * torch.tensor([width, height, width, height], dtype=torch.float32)
return boxes
def outputs_to_objects(outputs, img_sizes, id2label):
m = outputs.logits.softmax(-1).max(-1)
batch_labels = list(m.indices.detach().cpu().numpy())
batch_scores = list(m.values.detach().cpu().numpy())
batch_bboxes = outputs['pred_boxes'].detach().cpu()
batch_objects = []
for i in range(len(img_sizes)):
pred_bboxes = [elem.tolist() for elem in rescale_bboxes(batch_bboxes[i], img_sizes[i])]
pred_scores = batch_scores[i]
pred_labels = batch_labels[i]
objects = []
for label, score, bbox in zip(pred_labels, pred_scores, pred_bboxes):
class_label = id2label[int(label)]
if not class_label == 'no object':
objects.append({
'label': class_label,
'score': float(score),
'bbox': [float(elem) for elem in bbox]}
)
rows = []
cols = []
for i, cell in enumerate(objects):
if cell["label"] == "table column":
cols.append(cell)
if cell["label"] == "table row":
rows.append(cell)
batch_objects.append({
"rows": rows,
"cols": cols
})
return batch_objects
def load_tatr():
return AutoModelForObjectDetection.from_pretrained("microsoft/table-transformer-structure-recognition-v1.1-all").to(settings.TORCH_DEVICE_MODEL)
def batch_inference_tatr(model, images, batch_size):
device = model.device
rows_cols = []
for i in range(0, len(images), batch_size):
batch_images = images[i:i + batch_size]
pixel_values = torch.stack([structure_transform(img) for img in batch_images], dim=0).to(device)
# forward pass
with torch.no_grad():
outputs = model(pixel_values)
id2label = model.config.id2label
id2label[len(model.config.id2label)] = "no object"
rows_cols.extend(outputs_to_objects(outputs, [img.size for img in batch_images], id2label))
return rows_cols

View File

@ -1,4 +1,4 @@
from typing import List, Tuple
from typing import List, Tuple, Generator
import torch
import numpy as np
@ -26,14 +26,17 @@ def get_batch_size():
return batch_size
def batch_detection(images: List, model: EfficientViTForSemanticSegmentation, processor, batch_size=None) -> Tuple[List[List[np.ndarray]], List[Tuple[int, int]]]:
def batch_detection(
images: List,
model: EfficientViTForSemanticSegmentation,
processor,
batch_size=None
) -> Generator[Tuple[List[List[np.ndarray]], List[Tuple[int, int]]], None, None]:
assert all([isinstance(image, Image.Image) for image in images])
if batch_size is None:
batch_size = get_batch_size()
heatmap_count = model.config.num_labels
images = [image.convert("RGB") for image in images] # also copies the images
orig_sizes = [image.size for image in images]
splits_per_image = [get_total_splits(size, processor) for size in orig_sizes]
@ -52,10 +55,9 @@ def batch_detection(images: List, model: EfficientViTForSemanticSegmentation, pr
if len(current_batch) > 0:
batches.append(current_batch)
all_preds = []
for batch_idx in tqdm(range(len(batches)), desc="Detecting bboxes"):
batch_image_idxs = batches[batch_idx]
batch_images = convert_if_not_rgb([images[j] for j in batch_image_idxs])
batch_images = [images[j].convert("RGB") for j in batch_image_idxs]
split_index = []
split_heights = []
@ -98,11 +100,7 @@ def batch_detection(images: List, model: EfficientViTForSemanticSegmentation, pr
heatmaps[k] = np.vstack([heatmaps[k], pred_heatmaps[k]])
preds[idx] = heatmaps
all_preds.extend(preds)
assert len(all_preds) == len(images)
assert all([len(pred) == heatmap_count for pred in all_preds])
return all_preds, orig_sizes
yield preds, [orig_sizes[j] for j in batch_image_idxs]
def parallel_get_lines(preds, orig_sizes):
@ -125,16 +123,21 @@ def parallel_get_lines(preds, orig_sizes):
def batch_text_detection(images: List, model, processor, batch_size=None) -> List[TextDetectionResult]:
preds, orig_sizes = batch_detection(images, model, processor, batch_size=batch_size)
detection_generator = batch_detection(images, model, processor, batch_size=batch_size)
results = []
if settings.IN_STREAMLIT or len(images) < settings.DETECTOR_MIN_PARALLEL_THRESH: # Ensures we don't parallelize with streamlit, or with very few images
for i in range(len(images)):
result = parallel_get_lines(preds[i], orig_sizes[i])
results.append(result)
else:
max_workers = min(settings.DETECTOR_POSTPROCESSING_CPU_WORKERS, len(images))
max_workers = min(settings.DETECTOR_POSTPROCESSING_CPU_WORKERS, len(images))
parallelize = not settings.IN_STREAMLIT and len(images) >= settings.DETECTOR_MIN_PARALLEL_THRESH
if parallelize:
with ProcessPoolExecutor(max_workers=max_workers) as executor:
results = list(executor.map(parallel_get_lines, preds, orig_sizes))
for preds, orig_sizes in detection_generator:
batch_results = list(executor.map(parallel_get_lines, preds, orig_sizes))
results.extend(batch_results)
else:
for preds, orig_sizes in detection_generator:
for pred, orig_size in zip(preds, orig_sizes):
results.append(parallel_get_lines(pred, orig_size))
return results

View File

@ -1,17 +1,19 @@
import PIL
from surya.input.processing import open_pdf, get_page_images
from surya.settings import settings
import os
import filetype
from PIL import Image
import json
def get_name_from_path(path):
return os.path.basename(path).split(".")[0]
def load_pdf(pdf_path, max_pages=None, start_page=None):
def load_pdf(pdf_path, max_pages=None, start_page=None, dpi=settings.IMAGE_DPI, load_text_lines=False):
doc = open_pdf(pdf_path)
last_page = len(doc)
@ -25,47 +27,58 @@ def load_pdf(pdf_path, max_pages=None, start_page=None):
last_page = min(start_page + max_pages, last_page)
page_indices = list(range(start_page, last_page))
images = get_page_images(doc, page_indices)
images = get_page_images(doc, page_indices, dpi=dpi)
text_lines = None
if load_text_lines:
from surya.input.pdflines import get_page_text_lines # Putting import here because pypdfium2 causes warnings if its not the top import
text_lines = get_page_text_lines(
pdf_path,
page_indices,
[i.size for i in images]
)
doc.close()
names = [get_name_from_path(pdf_path) for _ in page_indices]
return images, names
return images, names, text_lines
def load_image(image_path):
image = Image.open(image_path).convert("RGB")
name = get_name_from_path(image_path)
return [image], [name]
return [image], [name], [None]
def load_from_file(input_path, max_pages=None, start_page=None):
def load_from_file(input_path, max_pages=None, start_page=None, dpi=settings.IMAGE_DPI, load_text_lines=False):
input_type = filetype.guess(input_path)
if input_type.extension == "pdf":
return load_pdf(input_path, max_pages, start_page)
return load_pdf(input_path, max_pages, start_page, dpi=dpi, load_text_lines=load_text_lines)
else:
return load_image(input_path)
def load_from_folder(folder_path, max_pages=None, start_page=None):
def load_from_folder(folder_path, max_pages=None, start_page=None, dpi=settings.IMAGE_DPI, load_text_lines=False):
image_paths = [os.path.join(folder_path, image_name) for image_name in os.listdir(folder_path) if not image_name.startswith(".")]
image_paths = [ip for ip in image_paths if not os.path.isdir(ip)]
images = []
names = []
text_lines = []
for path in image_paths:
extension = filetype.guess(path)
if extension and extension.extension == "pdf":
image, name = load_pdf(path, max_pages, start_page)
image, name, text_line = load_pdf(path, max_pages, start_page, dpi=dpi, load_text_lines=load_text_lines)
images.extend(image)
names.extend(name)
text_lines.extend(text_line)
else:
try:
image, name = load_image(path)
image, name, text_line = load_image(path)
images.extend(image)
names.extend(name)
text_lines.extend(text_line)
except PIL.UnidentifiedImageError:
print(f"Could not load image {path}")
continue
return images, names
return images, names, text_lines
def load_lang_file(lang_path, names):

86
surya/input/pdflines.py Normal file
View File

@ -0,0 +1,86 @@
from pdftext.extraction import dictionary_output
from surya.postprocessing.text import sort_text_lines
from surya.schema import PolygonBox
def get_page_text_lines(filepath: str, page_idxs: list, out_sizes: list) -> list:
assert len(page_idxs) == len(out_sizes)
pages_text = dictionary_output(filepath, sort=False, page_range=page_idxs, keep_chars=True)
for full_text, out_size in zip(pages_text, out_sizes):
width = full_text["width"]
height = full_text["height"]
text_w_scale = out_size[0] / width
text_h_scale = out_size[1] / height
for block in full_text["blocks"]:
for line in block["lines"]:
line["bbox"] = [line["bbox"][0] * text_w_scale, line["bbox"][1] * text_h_scale,
line["bbox"][2] * text_w_scale, line["bbox"][3] * text_h_scale]
for span in line["spans"]:
for char in span["chars"]:
char["bbox"] = [char["bbox"][0] * text_w_scale, char["bbox"][1] * text_h_scale,
char["bbox"][2] * text_w_scale, char["bbox"][3] * text_h_scale]
return pages_text
def get_table_blocks(tables: list, full_text: dict, img_size: list, table_thresh=.8):
# Returns coordinates relative to input table, not full image
table_texts = []
for table in tables:
table_poly = PolygonBox(polygon=[
[table[0], table[1]],
[table[2], table[1]],
[table[2], table[3]],
[table[0], table[3]]
])
table_text = []
rotation = full_text["rotation"]
for block in full_text["blocks"]:
for line in block["lines"]:
line_poly = PolygonBox(polygon=[
[line["bbox"][0], line["bbox"][1]],
[line["bbox"][2], line["bbox"][1]],
[line["bbox"][2], line["bbox"][3]],
[line["bbox"][0], line["bbox"][3]]
])
if line_poly.intersection_pct(table_poly) < table_thresh:
continue
curr_span = None
curr_box = None
for span in line["spans"]:
for char in span["chars"]:
same_span = False
if curr_span:
if rotation == 90:
same_span = (char["bbox"][0] - curr_box[0]) / img_size[0] < 0.01 and abs(char["bbox"][1] - curr_box[3]) / img_size[1] < 0.01
elif rotation == 180:
same_span = (char["bbox"][2] - curr_box[0]) / img_size[0] < 0.01 and (char["bbox"][1] - curr_box[1]) / img_size[1] < 0.01
elif rotation == 270:
same_span = (char["bbox"][0] - curr_box[0]) / img_size[0] < 0.01 and abs(char["bbox"][3] - curr_box[1]) / img_size[1] < 0.01
else:
same_span = (char["bbox"][0] - curr_box[2]) / img_size[0] < 0.01 and (char["bbox"][1] - curr_box[1]) / img_size[1] < 0.01
if curr_span is None:
curr_span = char["char"]
curr_box = char["bbox"]
elif same_span:
curr_span += char["char"]
curr_box = [min(curr_box[0], char["bbox"][0]), min(curr_box[1], char["bbox"][1]),
max(curr_box[2], char["bbox"][2]), max(curr_box[3], char["bbox"][3])]
else:
table_text.append({"text": curr_span, "bbox": curr_box})
curr_span = char["char"]
curr_box = char["bbox"]
if curr_span is not None:
table_text.append({"text": curr_span, "bbox": curr_box})
# Adjust to be relative to input table
for item in table_text:
item["bbox"] = [
item["bbox"][0] - table[0],
item["bbox"][1] - table[1],
item["bbox"][2] - table[0],
item["bbox"][3] - table[1]
]
table_text = sort_text_lines(table_text)
table_texts.append(table_text)
return table_texts

View File

@ -115,4 +115,4 @@ def slice_and_pad_poly(image_array: np.array, coordinates):
cropped_polygon[mask == 0] = settings.RECOGNITION_PAD_VALUE
rectangle_image = Image.fromarray(cropped_polygon)
return rectangle_image
return rectangle_image

View File

@ -12,7 +12,7 @@ from surya.settings import settings
def get_regions_from_detection_result(detection_result: TextDetectionResult, heatmaps: List[np.ndarray], orig_size, id2label, segment_assignment, vertical_line_width=20) -> List[LayoutBox]:
logits = np.stack(heatmaps, axis=0)
vertical_line_bboxes = [line for line in detection_result.vertical_lines]
vertical_line_bboxes = detection_result.vertical_lines
line_bboxes = detection_result.bboxes
# Scale back to processor size
@ -38,6 +38,8 @@ def get_regions_from_detection_result(detection_result: TextDetectionResult, hea
detected_boxes = []
for heatmap_idx in range(1, len(id2label)): # Skip the blank class
heatmap = logits[heatmap_idx]
if np.max(heatmap) < settings.DETECTOR_BLANK_THRESHOLD:
continue
bboxes = get_detected_boxes(heatmap)
bboxes = [bbox for bbox in bboxes if bbox.area > 25]
for bb in bboxes:
@ -89,14 +91,12 @@ def get_regions_from_detection_result(detection_result: TextDetectionResult, hea
max_x = max(max_x, max_x_box)
max_y = max(max_y, max_y_box)
bbox.polygon[0][0] = min_x
bbox.polygon[0][1] = min_y
bbox.polygon[1][0] = max_x
bbox.polygon[1][1] = min_y
bbox.polygon[2][0] = max_x
bbox.polygon[2][1] = max_y
bbox.polygon[3][0] = min_x
bbox.polygon[3][1] = max_y
bbox.polygon = [
[min_x, min_y],
[max_x, min_y],
[max_x, max_y],
[min_x, max_y]
]
if bbox_idx in box_lines and bbox.label in ["Picture"]:
bbox.label = "Figure"
@ -104,17 +104,18 @@ def get_regions_from_detection_result(detection_result: TextDetectionResult, hea
new_boxes.append(bbox)
# Merge tables together (sometimes one column is detected as a separate table)
for i in range(5): # Up to 5 rounds of merging
mergeable_types = ["Table", "Picture", "Figure"]
for ftype in mergeable_types:
to_remove = set()
for bbox_idx, bbox in enumerate(new_boxes):
if bbox.label != "Table" or bbox_idx in to_remove:
if bbox.label != ftype or bbox_idx in to_remove:
continue
for bbox_idx2, bbox2 in enumerate(new_boxes):
if bbox2.label != "Table" or bbox_idx2 in to_remove or bbox_idx == bbox_idx2:
if bbox2.label != ftype or bbox_idx2 in to_remove or bbox_idx == bbox_idx2:
continue
if bbox.intersection_pct(bbox2) > 0:
if bbox.intersection_pct(bbox2, x_margin=.25) > .1:
bbox.merge(bbox2)
to_remove.add(bbox_idx2)
@ -151,10 +152,14 @@ def get_regions(heatmaps: List[np.ndarray], orig_size, id2label, segment_assignm
heatmap = heatmaps[i]
assert heatmap.shape == segment_assignment.shape
heatmap[segment_assignment != i] = 0 # zero out where another segment is
# Skip processing empty labels
if np.max(heatmap) < settings.DETECTOR_BLANK_THRESHOLD:
continue
bbox = get_and_clean_boxes(heatmap, list(reversed(heatmap.shape)), orig_size)
for bb in bbox:
bboxes.append(LayoutBox(polygon=bb.polygon, label=id2label[i]))
heatmaps.append(heatmap)
bboxes = keep_largest_boxes(bboxes)
return bboxes
@ -182,23 +187,43 @@ def parallel_get_regions(heatmaps: List[np.ndarray], orig_size, id2label, detect
def batch_layout_detection(images: List, model, processor, detection_results: Optional[List[TextDetectionResult]] = None, batch_size=None) -> List[LayoutResult]:
preds, orig_sizes = batch_detection(images, model, processor, batch_size=batch_size)
layout_generator = batch_detection(images, model, processor, batch_size=batch_size)
id2label = model.config.id2label
results = []
if settings.IN_STREAMLIT or len(images) < settings.DETECTOR_MIN_PARALLEL_THRESH: # Ensures we don't parallelize with streamlit or too few images
for i in range(len(images)):
result = parallel_get_regions(preds[i], orig_sizes[i], id2label, detection_results[i] if detection_results else None)
results.append(result)
else:
futures = []
max_workers = min(settings.DETECTOR_POSTPROCESSING_CPU_WORKERS, len(images))
with ProcessPoolExecutor(max_workers=max_workers) as executor:
for i in range(len(images)):
future = executor.submit(parallel_get_regions, preds[i], orig_sizes[i], id2label, detection_results[i] if detection_results else None)
futures.append(future)
max_workers = min(settings.DETECTOR_POSTPROCESSING_CPU_WORKERS, len(images))
parallelize = not settings.IN_STREAMLIT and len(images) >= settings.DETECTOR_MIN_PARALLEL_THRESH
for future in futures:
results.append(future.result())
if parallelize:
with ProcessPoolExecutor(max_workers=max_workers) as executor:
img_idx = 0
for preds, orig_sizes in layout_generator:
futures = []
for pred, orig_size in zip(preds, orig_sizes):
future = executor.submit(
parallel_get_regions,
pred,
orig_size,
id2label,
detection_results[img_idx] if detection_results else None
)
futures.append(future)
img_idx += 1
for future in futures:
results.append(future.result())
else:
img_idx = 0
for preds, orig_sizes in layout_generator:
for pred, orig_size in zip(preds, orig_sizes):
results.append(parallel_get_regions(
pred,
orig_size,
id2label,
detection_results[img_idx] if detection_results else None
))
img_idx += 1
return results

View File

@ -0,0 +1,260 @@
from transformers import PretrainedConfig
from surya.settings import settings
BOX_DIM = 1024
SPECIAL_TOKENS = 7
MAX_ROWS = 384
class SuryaTableRecConfig(PretrainedConfig):
model_type = "vision-encoder-decoder"
is_composition = True
def __init__(self, **kwargs):
super().__init__(**kwargs)
encoder_config = kwargs.pop("encoder")
decoder_config = kwargs.pop("decoder")
text_enc_config = kwargs.pop("text_encoder")
self.encoder = encoder_config
self.decoder = decoder_config
self.text_encoder = text_enc_config
self.is_encoder_decoder = True
if isinstance(decoder_config, dict):
self.decoder_start_token_id = decoder_config["bos_token_id"]
self.pad_token_id = decoder_config["pad_token_id"]
self.eos_token_id = decoder_config["eos_token_id"]
else:
self.decoder_start_token_id = decoder_config.bos_token_id
self.pad_token_id = decoder_config.pad_token_id
self.eos_token_id = decoder_config.eos_token_id
class DonutSwinTableRecConfig(PretrainedConfig):
model_type = "donut-swin"
attribute_map = {
"num_attention_heads": "num_heads",
"num_hidden_layers": "num_layers",
}
def __init__(
self,
image_size=(settings.TABLE_REC_IMAGE_SIZE["width"], settings.TABLE_REC_IMAGE_SIZE["height"]),
patch_size=4,
num_channels=3,
embed_dim=128,
depths=[2, 2, 14, 2],
num_heads=[4, 8, 16, 32],
num_kv_heads=[4, 8, 16, 32],
window_size=8,
mlp_ratio=4.0,
qkv_bias=True,
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
drop_path_rate=0.1,
hidden_act="gelu",
use_absolute_embeddings=True,
initializer_range=0.02,
layer_norm_eps=1e-5,
encoder_length=1024,
**kwargs,
):
super().__init__(**kwargs)
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.embed_dim = embed_dim
self.depths = depths
self.num_layers = len(depths)
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.window_size = window_size
self.mlp_ratio = mlp_ratio
self.qkv_bias = qkv_bias
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.drop_path_rate = drop_path_rate
self.hidden_act = hidden_act
self.use_absolute_embeddings = use_absolute_embeddings
self.layer_norm_eps = layer_norm_eps
self.initializer_range = initializer_range
# we set the hidden_size attribute in order to make Swin work with VisionEncoderDecoderModel
# this indicates the channel dimension after the last stage of the model
self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))
self.encoder_length = encoder_length
class SuryaTableRecDecoderConfig(PretrainedConfig):
model_type = "surya_tablerec"
def __init__(
self,
num_hidden_layers=3,
vocab_size=settings.TABLE_REC_MAX_ROWS + SPECIAL_TOKENS,
hidden_size=512,
intermediate_size=4 * 512,
encoder_hidden_size=1024,
num_attention_heads=8,
lru_width=None,
attention_window_size=16,
conv1d_width=4,
logits_soft_cap=30.0,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=0,
eos_token_id=1,
bos_token_id=2,
hidden_activation="gelu_pytorch_tanh",
rope_theta=10000.0,
block_types=("attention",),
cross_attn_layers=(0, 1, 2, 3),
encoder_cross_attn_layers=(0, 1, 2, 3),
self_attn_layers=(0, 1, 2, 3),
global_attn_layers=(0, 1, 2, 3),
attention_dropout=0.0,
num_key_value_heads=4,
attention_bias=False,
w_init_variance_scale=0.01,
init_std=0.02,
tie_word_embeddings=False,
aux_heads=0, # How many n-token-ahead heads to add
causal=True,
max_classes=2 + SPECIAL_TOKENS,
max_width=1024 + SPECIAL_TOKENS,
max_height=1024 + SPECIAL_TOKENS,
out_box_size=1024,
**kwargs,
):
self.num_hidden_layers = num_hidden_layers
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_attention_heads = num_attention_heads
self.lru_width = lru_width if lru_width is not None else hidden_size
self.attention_window_size = attention_window_size
self.conv1d_width = conv1d_width
self.logits_soft_cap = logits_soft_cap
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.block_types = list(block_types)
self.hidden_activation = hidden_activation
self.head_dim = self.hidden_size // self.num_attention_heads
self.num_key_value_heads = num_key_value_heads if num_key_value_heads is not None else num_attention_heads
if self.num_key_value_heads > self.num_attention_heads:
raise ValueError("The number of `num_key_value_heads` must be smaller than `num_attention_heads`")
self.cross_attn_layers = cross_attn_layers
self.self_attn_layers = self_attn_layers
self.global_attn_layers = global_attn_layers
self.attention_dropout = attention_dropout
self.attention_bias = attention_bias
self.w_init_variance_scale = w_init_variance_scale
self.final_w_init_variance_scale = 2.0 / self.num_hidden_layers
self.init_std = init_std
self.tie_word_embeddings = tie_word_embeddings
self.aux_heads = aux_heads
self.encoder_hidden_size=encoder_hidden_size
self.causal = causal
self.encoder_cross_attn_layers = encoder_cross_attn_layers
self.max_classes = max_classes
self.max_width = max_width
self.max_height = max_height
self.out_box_size = out_box_size
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
**kwargs,
)
@property
def layers_block_type(self):
return (self.block_types * 100)[: self.num_hidden_layers]
class SuryaTableRecTextEncoderConfig(PretrainedConfig):
model_type = "surya_tablerec"
def __init__(
self,
num_hidden_layers=4,
vocab_size=settings.TABLE_REC_MAX_ROWS + SPECIAL_TOKENS,
hidden_size=1024,
intermediate_size=4 * 1024,
encoder_hidden_size=1024,
num_attention_heads=16,
lru_width=None,
attention_window_size=16,
conv1d_width=4,
logits_soft_cap=30.0,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=0,
eos_token_id=1,
bos_token_id=2,
hidden_activation="gelu_pytorch_tanh",
rope_theta=10000.0,
block_types=("attention",),
cross_attn_layers=(0, 1, 2, 3, 4, 5),
self_attn_layers=(0, 1, 2, 3, 4, 5),
global_attn_layers=(0, 1, 2, 3, 4, 5),
attention_dropout=0.0,
num_key_value_heads=16,
attention_bias=False,
w_init_variance_scale=0.01,
init_std=0.02,
tie_word_embeddings=False,
causal=False,
max_width=BOX_DIM + SPECIAL_TOKENS,
max_height=BOX_DIM + SPECIAL_TOKENS,
max_position_embeddings=1024,
**kwargs,
):
self.num_hidden_layers = num_hidden_layers
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_attention_heads = num_attention_heads
self.lru_width = lru_width if lru_width is not None else hidden_size
self.attention_window_size = attention_window_size
self.conv1d_width = conv1d_width
self.logits_soft_cap = logits_soft_cap
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.block_types = list(block_types)
self.hidden_activation = hidden_activation
self.head_dim = self.hidden_size // self.num_attention_heads
self.num_key_value_heads = num_key_value_heads if num_key_value_heads is not None else num_attention_heads
if self.num_key_value_heads > self.num_attention_heads:
raise ValueError("The number of `num_key_value_heads` must be smaller than `num_attention_heads`")
self.cross_attn_layers = cross_attn_layers
self.self_attn_layers = self_attn_layers
self.global_attn_layers = global_attn_layers
self.attention_dropout = attention_dropout
self.attention_bias = attention_bias
self.w_init_variance_scale = w_init_variance_scale
self.final_w_init_variance_scale = 2.0 / self.num_hidden_layers
self.init_std = init_std
self.tie_word_embeddings = tie_word_embeddings
self.encoder_hidden_size = encoder_hidden_size
self.causal = causal
self.max_width = max_width
self.max_height = max_height
self.max_position_embeddings = max_position_embeddings
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
**kwargs,
)
@property
def layers_block_type(self):
return (self.block_types * 100)[: self.num_hidden_layers]

View File

@ -0,0 +1,795 @@
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from transformers.utils import ModelOutput
from surya.model.table_rec.config import SuryaTableRecDecoderConfig, SuryaTableRecTextEncoderConfig
from transformers import PreTrainedModel
from transformers.activations import ACT2FN
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from transformers.modeling_outputs import BaseModelOutputWithNoAttention, CausalLMOutput
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from surya.settings import settings
_MAX_SQRT_GRADIENT = 1000.0
@dataclass
class TableRecModelOutput(ModelOutput):
bbox_logits: torch.Tensor
class_logits: torch.Tensor | None = None
hidden_states: torch.Tensor | None = None
class SuryaTableRecDecoderRMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.zeros(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float())
# Llama does x.to(float16) * w whilst SuryaTableRecDecoder is (x * w).to(float16)
# See https://github.com/huggingface/transformers/pull/29402
output = output * (1.0 + self.weight.float())
return output.type_as(x)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.eps}"
ALL_LAYERNORM_LAYERS.append(SuryaTableRecDecoderRMSNorm)
class SuryaTableRecDecoderRotaryEmbedding(nn.Module):
def __init__(self, dim, base=10000, device=None):
super().__init__()
self.dim = dim
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
@torch.no_grad()
# Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding.forward with Gemma->SuryaTableRecDecoder
def forward(self, x, position_ids, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
self.inv_freq.to(x.device)
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
# Copied from transformers.models.llama.modeling_llama.repeat_kv
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
class SuryaTableRecDecoderSdpaCrossAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper
Modified for GQA
"""
def __init__(self, config: SuryaTableRecDecoderConfig):
super().__init__()
self.config = config
self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size
self.num_attention_heads = config.num_attention_heads
self.head_dim = config.head_dim
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads
self.q_proj = nn.Linear(self.hidden_size, self.num_attention_heads * self.head_dim, bias=config.attention_bias)
self.k_proj = nn.Linear(self.config.encoder_hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.v_proj = nn.Linear(self.config.encoder_hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.o_proj = nn.Linear(self.num_attention_heads * self.head_dim, self.hidden_size, bias=True)
self.rotary_emb = SuryaTableRecDecoderRotaryEmbedding(
self.head_dim,
base=config.rope_theta,
)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# Encoder attention mask currently ignored
bsz, q_len, _ = hidden_states.size()
_, v_len, _ = encoder_hidden_states.size()
query_states = self.q_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_attention_heads, self.head_dim).transpose(1, 2)
if self.key_states is None:
key_states = self.k_proj(encoder_hidden_states)
value_states = self.v_proj(encoder_hidden_states)
key_states = key_states.view(bsz, v_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, v_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
if use_cache:
self._update_cache(key_states, value_states)
else:
key_states = self.key_states
value_states = self.value_states
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states.contiguous(),
key_states.contiguous(),
value_states.contiguous(),
attn_mask=None,
dropout_p=self.attention_dropout if self.training else 0.0,
scale=self.head_dim**-0.5,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output
def _setup_cache(self, batch_size, device, dtype=None):
# Setup initial caches
self.value_states = None
self.key_states = None
@torch.no_grad()
def _update_cache(self, key_states, value_states, **cache_kwargs):
self.value_states = value_states
self.key_states = key_states
class SuryaTableRecDecoderSdpaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: SuryaTableRecDecoderConfig):
super().__init__()
self.config = config
self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size
self.num_attention_heads = config.num_attention_heads
self.head_dim = config.head_dim
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads
self.q_proj = nn.Linear(self.hidden_size, self.num_attention_heads * self.head_dim, bias=config.attention_bias)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.o_proj = nn.Linear(self.num_attention_heads * self.head_dim, self.hidden_size, bias=True)
self.rotary_emb = SuryaTableRecDecoderRotaryEmbedding(
self.head_dim,
base=config.rope_theta,
)
def forward(
self,
hidden_states: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cache_position: Optional[torch.LongTensor] = None,
use_cache: bool = False,
window_attn: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# Final is bsz, num_attention_heads, seq_len, head_dim
query_states = query_states.view(bsz, q_len, self.num_attention_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if use_cache and hasattr(self, "key_states"):
cache_kwargs = {"cache_position": cache_position, "window_attn": window_attn}
key_states, value_states = self._update_cache(key_states, value_states, **cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
causal_mask = attention_mask
if attention_mask is not None:
# Mask is batch, head, seq_len, kv_len
causal_mask = causal_mask[:, :, :, :key_states.shape[-2]]
current_cache_position = cache_position[-1].item() if cache_position is not None else None
if current_cache_position and settings.RECOGNITION_STATIC_CACHE:
# Mask out future cache positions
position_mask = torch.ones_like(causal_mask, dtype=torch.bool, device=causal_mask.device)
position_mask[:, :, :, :current_cache_position + 1] = False
causal_mask = torch.where(position_mask, torch.finfo(causal_mask.dtype).min, causal_mask)
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states.contiguous(),
key_states.contiguous(),
value_states.contiguous(),
attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
scale=self.head_dim**-0.5,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output
def _setup_cache(self, batch_size, device, dtype=None):
if dtype is None and self.config.torch_dtype is not None:
dtype = self.config.torch_dtype
dtype = dtype if dtype is not None else torch.float32
# Setup initial caches
self.value_states = None
self.key_states = None
if settings.RECOGNITION_STATIC_CACHE:
cache_shape = (batch_size, self.num_key_value_heads, settings.RECOGNITION_MAX_TOKENS, self.head_dim)
self.value_states = torch.zeros(cache_shape, dtype=dtype, device=device)
self.key_states = torch.zeros(cache_shape, dtype=dtype, device=device)
def _update_static_cache(self, key_states, value_states, **cache_kwargs):
cache_position = cache_kwargs.get("cache_position")
k_out, v_out = self.key_states.to(key_states.device), self.value_states.to(value_states.device)
k_out[:, :, cache_position] = key_states.to(k_out.dtype)
v_out[:, :, cache_position] = value_states.to(v_out.dtype)
self.key_states, self.value_states = k_out, v_out
return k_out, v_out
def _update_dynamic_cache(self, key_states, value_states, **cache_kwargs):
k_out = key_states
if self.key_states is not None:
k_out = torch.cat([self.key_states, key_states], dim=2)
v_out = value_states
if self.value_states is not None:
v_out = torch.cat([self.value_states, value_states], dim=2)
self.key_states, self.value_states = k_out, v_out
return k_out, v_out
@torch.no_grad()
def _update_cache(self, key_states, value_states, **cache_kwargs):
if settings.RECOGNITION_STATIC_CACHE:
return self._update_static_cache(key_states, value_states, **cache_kwargs)
return self._update_dynamic_cache(key_states, value_states, **cache_kwargs)
class SuryaTableRecDecoderMlp(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
if config.hidden_activation is None:
config.hidden_activation = "gelu_pytorch_tanh"
hidden_activation = config.hidden_activation
self.act_fn = ACT2FN[hidden_activation]
def forward(self, x):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
class SuryaTableRecDecoderLayer(nn.Module):
def __init__(self, config, layer_idx):
super().__init__()
super().__init__()
self.cross_pre_norm = SuryaTableRecDecoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.temporal_pre_norm = SuryaTableRecDecoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.temporal_block = None
if layer_idx in config.self_attn_layers:
self.temporal_block = SuryaTableRecDecoderSdpaAttention(config)
self.cross_attn_block = None
if layer_idx in config.cross_attn_layers:
self.cross_attn_block = SuryaTableRecDecoderSdpaCrossAttention(config)
self.window_attn = layer_idx not in config.global_attn_layers
self.channel_pre_norm = SuryaTableRecDecoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.mlp_block = SuryaTableRecDecoderMlp(config)
def forward(
self,
activations: torch.Tensor,
position_ids: torch.Tensor,
attention_mask: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
encoder_attention_mask: torch.Tensor = None,
cache_position: torch.Tensor = None,
use_cache: bool = None,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
raw_activations = activations
if self.cross_attn_block is not None:
# Do cross-attention on encoder outputs
cross_attn_inputs = self.cross_pre_norm(activations)
cross_attn_path = self.cross_attn_block(
cross_attn_inputs, encoder_hidden_states, attention_mask, encoder_attention_mask, use_cache=use_cache
)
cross_attn_output = cross_attn_path + raw_activations
else:
cross_attn_output = raw_activations
if self.temporal_block is not None:
inputs_normalized = self.temporal_pre_norm(cross_attn_output) # RMSNorm introduces slight slight differences
hidden_states = self.temporal_block(
inputs_normalized, position_ids, attention_mask, cache_position=cache_position, use_cache=use_cache, window_attn=self.window_attn
)
residual = hidden_states + raw_activations
else:
residual = cross_attn_output
hidden_states = self.channel_pre_norm(residual)
hidden_states = self.mlp_block(hidden_states)
hidden_states = hidden_states + residual
return hidden_states
class SuryaTableRecDecoderPreTrainedModel(PreTrainedModel):
config_class = SuryaTableRecDecoderConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["SuryaTableRecDecoderLayer"]
_skip_keys_device_placement = ["cache"]
_supports_flash_attn_2 = False
_supports_sdpa = False # we can't compare with eager for now
_supports_cache_class = True
_supports_quantized_cache = True
def _init_weights(self, module):
if isinstance(module, SuryaTableRecDecoderSdpaAttention):
torch.nn.init.normal_(module.q_proj.weight, mean=0.0, std=self.config.init_std)
torch.nn.init.normal_(module.k_proj.weight, mean=0.0, std=self.config.init_std)
torch.nn.init.normal_(module.v_proj.weight, mean=0.0, std=self.config.init_std)
torch.nn.init.normal_(module.o_proj.weight, mean=0.0, std=self.config.init_std)
elif isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=self.config.init_std)
if getattr(module, "bias", None) is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.init_std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _setup_cache(self, config, batch, device, dtype):
layers = getattr(self, "model", self).layers
for layer in layers:
if layer.temporal_block:
layer.temporal_block._setup_cache(batch, device, dtype)
if layer.cross_attn_block:
layer.cross_attn_block._setup_cache(batch, device, dtype)
def reset_cache(self, batch, device, dtype):
pass
def _tie_weights(self):
pass
def tie_weights(self):
pass
class LabelEmbedding(nn.Module):
def __init__(self, config):
super().__init__()
self.vocab_size = config.vocab_size
self.x1_embed = nn.Embedding(config.max_width, config.hidden_size)
self.y1_embed = nn.Embedding(config.max_height, config.hidden_size)
self.x2_embed = nn.Embedding(config.max_width, config.hidden_size)
self.y2_embed = nn.Embedding(config.max_height, config.hidden_size)
self.w_embed = nn.Embedding(config.max_width, config.hidden_size)
self.h_embed = nn.Embedding(config.max_height, config.hidden_size)
self.cx_embed = nn.Embedding(config.max_width, config.hidden_size)
self.cy_embed = nn.Embedding(config.max_height, config.hidden_size)
self.class_embed = nn.Embedding(config.max_classes, config.hidden_size)
self.max_width = config.max_width
self.max_height = config.max_height
self.max_classes = config.max_classes
def forward(self, labels: torch.LongTensor, input_box_counts: torch.LongTensor):
cx, cy, w, h, class_ = labels.to(torch.long).unbind(dim=-1)
# Shape is (batch_size, num_boxes/seq len, d_model)
x1 = (cx - w // 2).long()
y1 = (cy - h // 2).long()
x2 = (cx + w // 2).long()
y2 = (cy + h // 2).long()
x1 = torch.clamp(x1, 0, self.max_width - 1)
y1 = torch.clamp(y1, 0, self.max_height - 1)
x2 = torch.clamp(x2, 0, self.max_width - 1)
y2 = torch.clamp(y2, 0, self.max_height - 1)
class_ = torch.clamp(class_, 0, self.max_classes - 1).long()
w = torch.clamp(w, 0, self.max_width - 1).long()
h = torch.clamp(h, 0, self.max_height - 1).long()
cx = torch.clamp(cx, 0, self.max_width - 1).long()
cy = torch.clamp(cy, 0, self.max_height - 1).long()
coord_embeds = self.x1_embed(x1) + self.y1_embed(y1) + self.x2_embed(x2) + self.y2_embed(y2)
class_embeds = self.class_embed(class_)
embedded = coord_embeds + self.w_embed(w) + self.h_embed(h) + self.cx_embed(cx) + self.cy_embed(cy) + class_embeds
return embedded
class BboxEmbedding(nn.Module):
def __init__(self, config, embed_positions=False):
super().__init__()
self.x1_embed = nn.Embedding(config.max_width, config.hidden_size)
self.y1_embed = nn.Embedding(config.max_height, config.hidden_size)
self.x2_embed = nn.Embedding(config.max_width, config.hidden_size)
self.y2_embed = nn.Embedding(config.max_height, config.hidden_size)
self.w_embed = nn.Embedding(config.max_width, config.hidden_size)
self.h_embed = nn.Embedding(config.max_height, config.hidden_size)
self.cx_embed = nn.Embedding(config.max_width, config.hidden_size)
self.cy_embed = nn.Embedding(config.max_height, config.hidden_size)
self.box_pos_embed = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.max_width = config.max_width
self.max_height = config.max_height
self.embed_positions = embed_positions
def forward(self, boxes: torch.LongTensor, input_box_counts: torch.LongTensor):
x1, y1, x2, y2 = boxes.unbind(dim=-1)
x1 = torch.clamp(x1, 0, self.max_width - 1).long()
y1 = torch.clamp(y1, 0, self.max_height - 1).long()
x2 = torch.clamp(x2, 0, self.max_width - 1).long()
y2 = torch.clamp(y2, 0, self.max_height - 1).long()
# Shape is (batch_size, num_boxes/seq len, d_model)
w = x2 - x1
h = y2 - y1
# Center x and y in torch long tensors
cx = (x1 + x2) / 2
cy = (y1 + y2) / 2
cx = cx.long()
cy = cy.long()
w = torch.clamp(w, 0, self.max_width - 1).long()
h = torch.clamp(h, 0, self.max_height - 1).long()
cx = torch.clamp(cx, 0, self.max_width - 1).long()
cy = torch.clamp(cy, 0, self.max_height - 1).long()
coord_embeds = self.x1_embed(x1) + self.y1_embed(y1) + self.x2_embed(x2) + self.y2_embed(y2)
embedded = coord_embeds + self.w_embed(w) + self.h_embed(h) + self.cx_embed(cx) + self.cy_embed(cy)
# Add in positional embeddings for the boxes and labels
if self.embed_positions:
for j in range(embedded.shape[0]):
box_start = input_box_counts[j, 0]
box_end = input_box_counts[j, 1] - 1 # Skip the sep token
box_count = box_end - box_start
embedded[j, box_start:box_end] = embedded[j, box_start:box_end] + self.box_pos_embed.weight[:box_count]
return embedded
class SuryaTableRecDecoderModel(SuryaTableRecDecoderPreTrainedModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`SuryaTableRecDecoderDecoderLayer`]
Args:
config: SuryaTableRecDecoderConfig
"""
def __init__(self, config: SuryaTableRecDecoderConfig, embed_labels=False, embed_positions=True):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.causal = config.causal
if embed_labels:
self.embed_tokens = LabelEmbedding(config)
else:
self.embed_tokens = BboxEmbedding(config, embed_positions=embed_positions)
self.layers = nn.ModuleList(
[SuryaTableRecDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.final_norm = SuryaTableRecDecoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
self.register_buffer(
"normalizer", torch.tensor(self.config.hidden_size**0.5, dtype=torch.float32), persistent=False
)
# Initialize weights and apply final processing
self.post_init()
# Copied from transformers.models.llama.modeling_llama.LlamaModel.get_input_embeddings
def get_input_embeddings(self):
return self.embed_tokens
# Copied from transformers.models.llama.modeling_llama.LlamaModel.set_input_embeddings
def set_input_embeddings(self, value):
self.embed_tokens = value
def forward(
self,
input_ids: torch.LongTensor = None,
input_boxes_counts: torch.LongTensor = None,
position_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
prefill: bool = False
) -> Union[Tuple, BaseModelOutputWithNoAttention]:
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if self.gradient_checkpointing and self.training and use_cache:
use_cache = False
inputs_embeds = self.embed_tokens(input_ids, input_boxes_counts)
hidden_states = inputs_embeds
if use_cache and prefill:
self._setup_cache(self.config, hidden_states.shape[0], hidden_states.device, hidden_states.dtype)
if cache_position is None:
cache_position = torch.arange(hidden_states.shape[1], device=hidden_states.device)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
all_hidden_states = () if output_hidden_states else None
for i, residual_block in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
hidden_states = self._gradient_checkpointing_func(
residual_block.__call__, hidden_states, position_ids, causal_mask, encoder_hidden_states, encoder_attention_mask, cache_position, use_cache
)
else:
hidden_states = residual_block(hidden_states, position_ids, causal_mask, encoder_hidden_states, encoder_attention_mask, cache_position, use_cache)
hidden_states = self.final_norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)
return BaseModelOutputWithNoAttention(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
)
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
# Ignore copy
def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
if not self.causal:
return None
dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
target_length = max(settings.TABLE_REC_MAX_BOXES, sequence_length)
diagonal = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
causal_mask = diagonal
if sequence_length != 1:
# Select the upper triangular part of the matrix, but unmask current token (the diagonal)
# triu will be the min_dtype, everything else is 0 (attended to)
causal_mask = torch.triu(diagonal, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
if attention_mask.dim() == 2:
# Mask positions in the causal mask that are masked in the attention mask
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
if attention_mask is not None and attention_mask.device.type == "cuda":
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
class SuryaTableRecDecoder(SuryaTableRecDecoderPreTrainedModel):
_tied_weights_keys = None
def __init__(self, config, **kwargs):
super().__init__(config)
self.model = SuryaTableRecDecoderModel(config, embed_labels=True, embed_positions=False)
self.vocab_size = config.vocab_size
self.bbox_head = nn.Linear(config.hidden_size, config.max_width * 4, bias=False)
self.class_head = nn.Linear(config.hidden_size, config.max_classes, bias=False)
self.max_width = config.max_width
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
# Ignore copy
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
prefill: bool = False,
**kwargs
) -> Union[Tuple, TableRecModelOutput]:
outputs = self.model(
input_ids=input_ids,
cache_position=cache_position,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_hidden_states=True,
return_dict=True,
prefill=prefill,
)
hidden_states = outputs[0]
bbox_logits = self.bbox_head(hidden_states)
class_logits = self.class_head(hidden_states)
bsz, seq_len = class_logits.shape[:2]
bbox_logits = bbox_logits.view(bsz, seq_len, 4, self.max_width)
return TableRecModelOutput(
bbox_logits=bbox_logits,
class_logits=class_logits,
hidden_states=hidden_states,
)
@dataclass
class TextEncoderOutput(CausalLMOutput):
hidden_states: torch.FloatTensor = None
class SuryaTableRecTextEncoder(SuryaTableRecDecoderPreTrainedModel):
_tied_weights_keys = None
config_class = SuryaTableRecTextEncoderConfig
def __init__(self, config, **kwargs):
super().__init__(config)
self.model = SuryaTableRecDecoderModel(config, embed_labels=False, embed_positions=True)
self.vocab_size = config.vocab_size
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
# Ignore copy
def forward(
self,
input_boxes: Optional[torch.LongTensor] = None,
input_boxes_counts: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
**kwargs
) -> Union[Tuple, CausalLMOutput]:
outputs = self.model(
input_ids=input_boxes,
input_boxes_counts=input_boxes_counts,
cache_position=cache_position,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_hidden_states=True,
return_dict=True,
)
return TextEncoderOutput(
hidden_states=outputs.last_hidden_state,
)

View File

@ -0,0 +1,135 @@
import random
from dataclasses import dataclass
from typing import Optional, Union, Tuple
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers import PreTrainedModel, VisionEncoderDecoderConfig, PretrainedConfig
from transformers.modeling_outputs import Seq2SeqLMOutput, BaseModelOutput
from transformers.models.vision_encoder_decoder.modeling_vision_encoder_decoder import shift_tokens_right
from surya.model.table_rec.decoder import SuryaTableRecTextEncoder, SuryaTableRecDecoder
from surya.model.recognition.encoder import DonutSwinModel
import torch.nn.functional as F
from transformers.utils import ModelOutput
@dataclass
class TableRecOutput(ModelOutput):
row_logits: torch.FloatTensor = None
col_logits: torch.FloatTensor = None
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
class TableRecEncoderDecoderModel(PreTrainedModel):
config_class = VisionEncoderDecoderConfig
base_model_prefix = "vision_encoder_decoder"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
_supports_param_buffer_assignment = False
def __init__(
self,
config: Optional[PretrainedConfig] = None,
encoder: Optional[PreTrainedModel] = None,
text_encoder: Optional[PreTrainedModel] = None,
decoder: Optional[PreTrainedModel] = None,
):
# initialize with config
# make sure input & output embeddings is not tied
config.tie_word_embeddings = False
config.decoder.tie_word_embeddings = False
super().__init__(config)
if encoder is None:
encoder = DonutSwinModel(config.encoder)
if text_encoder is None:
text_encoder = SuryaTableRecTextEncoder(config.text_encoder, attn_implementation=config._attn_implementation)
if decoder is None:
decoder = SuryaTableRecDecoder(config.decoder, attn_implementation=config._attn_implementation)
self.encoder = encoder
self.decoder = decoder
self.text_encoder = text_encoder
# make sure that the individual model's config refers to the shared config
# so that the updates to the config will be synced
self.encoder.config = self.config.encoder
self.decoder.config = self.config.decoder
self.text_encoder.config = self.config.text_encoder
def get_encoder(self):
return self.encoder
def get_decoder(self):
return self.decoder
def get_output_embeddings(self):
return self.decoder.get_output_embeddings()
def set_output_embeddings(self, new_embeddings):
return self.decoder.set_output_embeddings(new_embeddings)
def forward(
self,
decoder_input_ids: torch.LongTensor = None,
decoder_cache_position: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None,
encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
use_cache: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple[torch.FloatTensor], TableRecOutput]:
kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")}
kwargs_decoder = {
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
}
# Decode
decoder_outputs = self.decoder(
input_labels=decoder_input_ids,
input_boxes_counts=None,
cache_position=decoder_cache_position,
attention_mask=decoder_attention_mask,
encoder_hidden_states=encoder_outputs,
encoder_attention_mask=None,
use_cache=use_cache,
**kwargs_decoder,
)
return TableRecOutput(
row_logits=decoder_outputs.row_logits,
col_logits=decoder_outputs.col_logits,
decoder_hidden_states=decoder_outputs.hidden_states,
)
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
):
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past_key_values=past_key_values)
decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None
input_dict = {
"attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask,
"decoder_input_ids": decoder_inputs["input_ids"],
"encoder_outputs": encoder_outputs,
"past_key_values": decoder_inputs["past_key_values"],
"use_cache": use_cache,
}
return input_dict
def resize_token_embeddings(self, *args, **kwargs):
raise NotImplementedError(
"Resizing the embedding layers via the VisionEncoderDecoderModel directly is not supported.Please use the"
" respective methods of the wrapped decoder object (model.decoder.resize_token_embeddings(...))"
)
def _reorder_cache(self, past_key_values, beam_idx):
# apply decoder cache reordering here
return self.decoder._reorder_cache(past_key_values, beam_idx)

View File

@ -0,0 +1,34 @@
from surya.model.recognition.encoder import DonutSwinModel
from surya.model.table_rec.config import SuryaTableRecConfig, SuryaTableRecDecoderConfig, DonutSwinTableRecConfig, \
SuryaTableRecTextEncoderConfig
from surya.model.table_rec.decoder import SuryaTableRecDecoder, SuryaTableRecTextEncoder
from surya.model.table_rec.encoderdecoder import TableRecEncoderDecoderModel
from surya.settings import settings
def load_model(checkpoint=settings.TABLE_REC_MODEL_CHECKPOINT, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE):
config = SuryaTableRecConfig.from_pretrained(checkpoint)
decoder_config = config.decoder
decoder = SuryaTableRecDecoderConfig(**decoder_config)
config.decoder = decoder
encoder_config = config.encoder
encoder = DonutSwinTableRecConfig(**encoder_config)
config.encoder = encoder
text_encoder_config = config.text_encoder
text_encoder = SuryaTableRecTextEncoderConfig(**text_encoder_config)
config.text_encoder = text_encoder
model = TableRecEncoderDecoderModel.from_pretrained(checkpoint, config=config, torch_dtype=dtype)
assert isinstance(model.decoder, SuryaTableRecDecoder)
assert isinstance(model.encoder, DonutSwinModel)
assert isinstance(model.text_encoder, SuryaTableRecTextEncoder)
model = model.to(device)
model = model.eval()
print(f"Loaded recognition model {checkpoint} on device {device} with dtype {dtype}")
return model

View File

@ -0,0 +1,248 @@
import math
from typing import Dict, Union, Optional, List, Iterable
import cv2
import torch
from torch import TensorType
from transformers import DonutImageProcessor, DonutProcessor
from transformers.image_processing_utils import BatchFeature
from transformers.image_transforms import pad, normalize
from transformers.image_utils import PILImageResampling, ImageInput, ChannelDimension, make_list_of_images, get_image_size
import numpy as np
from PIL import Image
import PIL
from surya.model.recognition.tokenizer import Byt5LangTokenizer
from surya.settings import settings
from surya.model.table_rec.config import BOX_DIM, SPECIAL_TOKENS
def load_processor():
processor = SuryaProcessor()
processor.image_processor.train = False
processor.image_processor.max_size = settings.TABLE_REC_IMAGE_SIZE
processor.token_pad_id = 0
processor.token_eos_id = 1
processor.token_bos_id = 2
processor.token_row_id = 3
processor.token_unused_id = 4
processor.box_size = (BOX_DIM, BOX_DIM)
processor.special_token_count = SPECIAL_TOKENS
return processor
class SuryaImageProcessor(DonutImageProcessor):
def __init__(self, *args, max_size=None, train=False, **kwargs):
super().__init__(*args, **kwargs)
self.patch_size = kwargs.get("patch_size", (4, 4))
self.max_size = max_size
self.train = train
@classmethod
def numpy_resize(cls, image: np.ndarray, size, interpolation=cv2.INTER_LANCZOS4):
max_width, max_height = size["width"], size["height"]
resized_image = cv2.resize(image, (max_width, max_height), interpolation=interpolation)
resized_image = resized_image.transpose(2, 0, 1)
return resized_image
def process_inner(self, images: List[np.ndarray]):
assert images[0].shape[2] == 3 # RGB input images, channel dim last
# This also applies the right channel dim format, to channel x height x width
images = [SuryaImageProcessor.numpy_resize(img, self.max_size, self.resample) for img in images]
assert images[0].shape[0] == 3 # RGB input images, channel dim first
# Convert to float32 for rescale/normalize
images = [img.astype(np.float32) for img in images]
# Pads with 255 (whitespace)
# Pad to max size to improve performance
max_size = self.max_size
images = [
SuryaImageProcessor.pad_image(
image=image,
size=max_size,
input_data_format=ChannelDimension.FIRST,
pad_value=settings.RECOGNITION_PAD_VALUE
)
for image in images
]
# Rescale and normalize
for idx in range(len(images)):
images[idx] = images[idx] * self.rescale_factor
images = [
SuryaImageProcessor.normalize(img, mean=self.image_mean, std=self.image_std, input_data_format=ChannelDimension.FIRST)
for img in images
]
return images
def preprocess(
self,
images: ImageInput,
do_resize: bool = None,
size: Dict[str, int] = None,
resample: PILImageResampling = None,
do_thumbnail: bool = None,
do_align_long_axis: bool = None,
do_pad: bool = None,
random_padding: bool = False,
do_rescale: bool = None,
rescale_factor: float = None,
do_normalize: bool = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
) -> PIL.Image.Image:
images = make_list_of_images(images)
# Convert to numpy for later processing steps
images = [np.array(img) for img in images]
images = self.process_inner(images)
data = {"pixel_values": images}
return BatchFeature(data=data, tensor_type=return_tensors)
@classmethod
def pad_image(
cls,
image: np.ndarray,
size: Dict[str, int],
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
pad_value: float = 0.0,
) -> np.ndarray:
output_height, output_width = size["height"], size["width"]
input_height, input_width = get_image_size(image, channel_dim=input_data_format)
delta_width = output_width - input_width
delta_height = output_height - input_height
assert delta_width >= 0 and delta_height >= 0
pad_top = delta_height // 2
pad_left = delta_width // 2
pad_bottom = delta_height - pad_top
pad_right = delta_width - pad_left
padding = ((pad_top, pad_bottom), (pad_left, pad_right))
return pad(image, padding, data_format=data_format, input_data_format=input_data_format, constant_values=pad_value)
@classmethod
def align_long_axis(
cls,
image: np.ndarray,
size: Dict[str, int],
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray:
input_height, input_width = image.shape[:2]
output_height, output_width = size["height"], size["width"]
if (output_width < output_height and input_width > input_height) or (
output_width > output_height and input_width < input_height
):
image = np.rot90(image, 3)
return image
@classmethod
def normalize(
cls,
image: np.ndarray,
mean: Union[float, Iterable[float]],
std: Union[float, Iterable[float]],
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
) -> np.ndarray:
return normalize(
image, mean=mean, std=std, data_format=data_format, input_data_format=input_data_format, **kwargs
)
class SuryaProcessor(DonutProcessor):
def __init__(self, image_processor=None, tokenizer=None, train=False, **kwargs):
image_processor = SuryaImageProcessor.from_pretrained(settings.RECOGNITION_MODEL_CHECKPOINT)
if image_processor is None:
raise ValueError("You need to specify an `image_processor`.")
tokenizer = Byt5LangTokenizer()
super().__init__(image_processor, tokenizer)
self.current_processor = self.image_processor
self._in_target_context_manager = False
self.max_input_boxes = kwargs.get("max_input_boxes", 256)
self.extra_input_boxes = kwargs.get("extra_input_boxes", 32)
def resize_boxes(self, img, boxes):
width, height = img.size
box_width, box_height = self.box_size
for box in boxes:
# Rescale to 0-1024
box[0] = box[0] / width * box_width
box[1] = box[1] / height * box_height
box[2] = box[2] / width * box_width
box[3] = box[3] / height * box_height
if box[0] < 0:
box[0] = 0
if box[1] < 0:
box[1] = 0
if box[2] > box_width:
box[2] = box_width
if box[3] > box_height:
box[3] = box_height
return boxes
def __call__(self, *args, **kwargs):
images = kwargs.pop("images", [])
boxes = kwargs.pop("boxes", [])
assert len(images) == len(boxes)
if len(args) > 0:
images = args[0]
args = args[1:]
for i in range(len(boxes)):
if len(boxes[i]) > self.max_input_boxes:
downsample_ratio = math.ceil(len(boxes[i]) / self.max_input_boxes)
boxes[i] = boxes[i][::downsample_ratio]
new_boxes = []
max_len = self.max_input_boxes + self.extra_input_boxes
box_masks = []
box_ends = []
for i in range(len(boxes)):
nb = self.resize_boxes(images[i], boxes[i])
nb = [[b + self.special_token_count for b in box] for box in nb] # shift up
nb = nb[:self.max_input_boxes - 1]
nb.insert(0, [self.token_row_id] * 4) # Insert special token for max rows/cols
for _ in range(self.extra_input_boxes):
nb.append([self.token_unused_id] * 4)
pad_length = max_len - len(nb)
box_mask = [1] * len(nb) + [1] * (pad_length)
box_ends.append(len(nb))
nb = nb + [[self.token_unused_id] * 4] * pad_length
new_boxes.append(nb)
box_masks.append(box_mask)
box_ends = torch.tensor(box_ends, dtype=torch.long)
box_starts = torch.tensor([0] * len(boxes), dtype=torch.long)
box_ranges = torch.stack([box_starts, box_ends], dim=1)
inputs = self.image_processor(images, *args, **kwargs)
inputs["input_boxes"] = torch.tensor(new_boxes, dtype=torch.long)
inputs["input_boxes_mask"] = torch.tensor(box_masks, dtype=torch.long)
inputs["input_boxes_counts"] = box_ranges
return inputs

View File

@ -60,17 +60,24 @@ def run_recognition(images: List[Image.Image], langs: List[List[str] | None], re
return predictions_by_image
def run_ocr(images: List[Image.Image], langs: List[List[str] | None], det_model, det_processor, rec_model, rec_processor, batch_size=None) -> List[OCRResult]:
def run_ocr(images: List[Image.Image], langs: List[List[str] | None], det_model, det_processor, rec_model, rec_processor, batch_size=None, highres_images: List[Image.Image] | None = None) -> List[OCRResult]:
images = convert_if_not_rgb(images)
highres_images = convert_if_not_rgb(highres_images) if highres_images is not None else [None] * len(images)
det_predictions = batch_text_detection(images, det_model, det_processor)
all_slices = []
slice_map = []
all_langs = []
for idx, (det_pred, image, lang) in enumerate(zip(det_predictions, images, langs)):
for idx, (det_pred, image, highres_image, lang) in enumerate(zip(det_predictions, images, highres_images, langs)):
polygons = [p.polygon for p in det_pred.bboxes]
slices = slice_polys_from_image(image, polygons)
if highres_image:
width_scaler = highres_image.size[0] / image.size[0]
height_scaler = highres_image.size[1] / image.size[1]
scaled_polygons = [[[int(p[0] * width_scaler), int(p[1] * height_scaler)] for p in polygon] for polygon in polygons]
slices = slice_polys_from_image(highres_image, scaled_polygons)
else:
slices = slice_polys_from_image(image, polygons)
slice_map.append(len(slices))
all_langs.extend([lang] * len(slices))
all_slices.extend(slices)

View File

@ -82,7 +82,6 @@ def detect_boxes(linemap, text_threshold, low_text):
det = []
confidences = []
max_confidence = 0
segmap = np.zeros_like(labels, dtype=np.uint8)
for k in range(1, label_count):
# size filtering
@ -90,39 +89,37 @@ def detect_boxes(linemap, text_threshold, low_text):
if size < 10:
continue
mask = labels == k
selected_linemap = linemap[mask]
# thresholding
if np.max(selected_linemap) < text_threshold:
continue
# make segmentation map
x, y, w, h = stats[k, [cv2.CC_STAT_LEFT, cv2.CC_STAT_TOP, cv2.CC_STAT_WIDTH, cv2.CC_STAT_HEIGHT]]
try:
niter = int(np.sqrt(min(w, h)) * 2)
niter = int(np.sqrt(min(w, h)))
except ValueError:
# Overflow in sqrt term
niter = 0
buffer = 1
sx, sy = max(0, x - niter), max(0, y - niter)
sx, sy = max(0, x - niter - buffer), max(0, y - niter - buffer)
ex, ey = min(img_w, x + w + niter + buffer), min(img_h, y + h + niter + buffer)
segmap.fill(0)
segmap[mask] = 1
mask = (labels[sy:ey, sx:ex] == k)
selected_linemap = linemap[sy:ey, sx:ex][mask]
line_max = np.max(selected_linemap)
# thresholding
if line_max < text_threshold:
continue
segmap = mask.astype(np.uint8)
ksize = buffer + niter
kernel = cv2.getStructuringElement(cv2.MORPH_RECT,(ksize, ksize))
# Doesn't work well without the zero start (ie, you can't trim the map tightly around the detected region)
selected_segmap = segmap[0:ey, 0:ex]
selected_segmap[sy:ey, sx:ex] = cv2.dilate(segmap[sy:ey, sx:ex], kernel)
selected_segmap = cv2.dilate(segmap, kernel)
# make box
indices = np.nonzero(selected_segmap)
np_contours = np.column_stack((indices[1], indices[0]))
x_inds = indices[1] + sx
y_inds = indices[0] + sy
np_contours = np.column_stack((x_inds, y_inds))
rectangle = cv2.minAreaRect(np_contours)
box = cv2.boxPoints(rectangle)
@ -139,8 +136,8 @@ def detect_boxes(linemap, text_threshold, low_text):
box = np.roll(box, 4-startidx, 0)
box = np.array(box)
confidence = np.mean(selected_linemap[selected_linemap > low_text])
max_confidence = max(max_confidence, confidence)
confidence = line_max
max_confidence = max(max_confidence, line_max)
confidences.append(confidence)
det.append(box)
@ -175,16 +172,23 @@ def get_and_clean_boxes(textmap, processor_size, image_size, text_threshold=None
return bboxes
def draw_bboxes_on_image(bboxes, image, labels=None):
draw = ImageDraw.Draw(image)
for bbox in bboxes:
draw.rectangle(bbox, outline="red", width=1)
def draw_bboxes_on_image(bboxes, image, labels=None, label_font_size=10, color: str | list='red'):
polys = []
for bb in bboxes:
# Clockwise polygon
poly = [
[bb[0], bb[1]],
[bb[2], bb[1]],
[bb[2], bb[3]],
[bb[0], bb[3]]
]
polys.append(poly)
return image
return draw_polys_on_image(polys, image, labels, label_font_size=label_font_size, color=color)
def draw_polys_on_image(corners, image, labels=None, box_padding=-1, label_offset=1, label_font_size=10):
def draw_polys_on_image(corners, image, labels=None, box_padding=-1, label_offset=1, label_font_size=10, color: str | list='red'):
draw = ImageDraw.Draw(image)
font_path = get_font_path()
label_font = ImageFont.truetype(font_path, label_font_size)
@ -192,7 +196,7 @@ def draw_polys_on_image(corners, image, labels=None, box_padding=-1, label_offse
for i in range(len(corners)):
poly = corners[i]
poly = [(int(p[0]), int(p[1])) for p in poly]
draw.polygon(poly, outline='red', width=1)
draw.polygon(poly, outline=color[i] if isinstance(color, list) else color, width=1)
if labels is not None:
label = labels[i]
@ -211,7 +215,7 @@ def draw_polys_on_image(corners, image, labels=None, box_padding=-1, label_offse
draw.text(
text_position,
label,
fill="red",
fill=color[i] if isinstance(color, list) else color,
font=label_font
)

View File

@ -10,12 +10,12 @@ from surya.settings import settings
from surya.postprocessing.math.latex import is_latex
def sort_text_lines(lines: List[TextLine], tolerance=1.25):
def sort_text_lines(lines: List[TextLine] | List[dict], tolerance=1.25):
# Sorts in reading order. Not 100% accurate, this should only
# be used as a starting point for more advanced sorting.
vertical_groups = {}
for line in lines:
group_key = round(line.bbox[1] / tolerance) * tolerance
group_key = round(line.bbox[1] if isinstance(line, TextLine) else line["bbox"][1] / tolerance) * tolerance
if group_key not in vertical_groups:
vertical_groups[group_key] = []
vertical_groups[group_key].append(line)
@ -23,7 +23,7 @@ def sort_text_lines(lines: List[TextLine], tolerance=1.25):
# Sort each group horizontally and flatten the groups into a single list
sorted_lines = []
for _, group in sorted(vertical_groups.items()):
sorted_group = sorted(group, key=lambda x: x.bbox[0])
sorted_group = sorted(group, key=lambda x: x.bbox[0] if isinstance(x, TextLine) else x["bbox"][0])
sorted_lines.extend(sorted_group)
return sorted_lines

View File

@ -26,6 +26,10 @@ def rescale_bbox(bbox, processor_size, image_size):
return new_bbox
def rescale_bboxes(bboxes, orig_size, new_size):
return [rescale_bbox(bbox, orig_size, new_size) for bbox in bboxes]
def rescale_point(point, processor_size, image_size):
# Point is in x, y format
page_width, page_height = processor_size

View File

@ -71,19 +71,23 @@ class PolygonBox(BaseModel):
y2 = max(self.bbox[3], other.bbox[3])
self.polygon = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]]
def intersection_area(self, other, margin=0):
x_overlap = max(0, min(self.bbox[2], other.bbox[2] - margin) - max(self.bbox[0], other.bbox[0] + margin))
y_overlap = max(0, min(self.bbox[3], other.bbox[3] - margin) - max(self.bbox[1], other.bbox[1] + margin))
def intersection_area(self, other, x_margin=0, y_margin=0):
x_overlap = max(0, min(self.bbox[2] + x_margin, other.bbox[2] + x_margin) - max(self.bbox[0] - x_margin, other.bbox[0] - x_margin))
y_overlap = max(0, min(self.bbox[3] + y_margin, other.bbox[3] + y_margin) - max(self.bbox[1] - y_margin, other.bbox[1] - y_margin))
return x_overlap * y_overlap
def intersection_pct(self, other, margin=0):
assert 0 <= margin <= 1
def intersection_pct(self, other, x_margin=0, y_margin=0):
assert 0 <= x_margin <= 1
assert 0 <= y_margin <= 1
if self.area == 0:
return 0
if margin:
margin = int(min(self.width, other.width) * margin)
intersection = self.intersection_area(other, margin)
if x_margin:
x_margin = int(min(self.width, other.width) * x_margin)
if y_margin:
y_margin = int(min(self.height, other.height) * y_margin)
intersection = self.intersection_area(other, x_margin, y_margin)
return intersection / self.area
@ -119,6 +123,18 @@ class Bbox(BaseModel):
def polygon(self):
return [[self.bbox[0], self.bbox[1]], [self.bbox[2], self.bbox[1]], [self.bbox[2], self.bbox[3]], [self.bbox[0], self.bbox[3]]]
@property
def center(self):
return [(self.bbox[0] + self.bbox[2]) / 2, (self.bbox[1] + self.bbox[3]) / 2]
def intersection_pct(self, other):
if self.area == 0:
return 0
x_overlap = max(0, min(self.bbox[2], other.bbox[2]) - max(self.bbox[0], other.bbox[0]))
y_overlap = max(0, min(self.bbox[3], other.bbox[3]) - max(self.bbox[1], other.bbox[1]))
intersection = x_overlap * y_overlap
return intersection / self.area
class LayoutBox(PolygonBox):
label: str
@ -161,3 +177,16 @@ class LayoutResult(BaseModel):
class OrderResult(BaseModel):
bboxes: List[OrderBox]
image_bbox: List[float]
class TableCell(Bbox):
row_id: int | None = None
col_id: int | None = None
text: str | None = None
class TableResult(BaseModel):
cells: List[TableCell]
rows: List[TableCell]
cols: List[TableCell]
image_bbox: List[float]

View File

@ -10,7 +10,8 @@ import os
class Settings(BaseSettings):
# General
TORCH_DEVICE: Optional[str] = None
IMAGE_DPI: int = 192
IMAGE_DPI: int = 96 # Used for detection, layout, reading order
IMAGE_DPI_HIGHRES: int = 192 # Used for OCR, table rec
IN_STREAMLIT: bool = False # Whether we're running in streamlit
ENABLE_EFFICIENT_ATTENTION: bool = True # Usually keep True, but if you get CUDA errors, setting to False can help
@ -71,6 +72,14 @@ class Settings(BaseSettings):
ORDER_BATCH_SIZE: Optional[int] = None # Defaults to 4 for CPU/MPS, 32 otherwise
ORDER_BENCH_DATASET_NAME: str = "vikp/order_bench"
# Table Rec
TABLE_REC_MODEL_CHECKPOINT: str = "vikp/surya_tablerec"
TABLE_REC_IMAGE_SIZE: Dict = {"height": 640, "width": 640}
TABLE_REC_MAX_BOXES: int = 512
TABLE_REC_MAX_ROWS: int = 384
TABLE_REC_BATCH_SIZE: Optional[int] = None
TABLE_REC_BENCH_DATASET_NAME: str = "vikp/fintabnet_bench"
# Tesseract (for benchmarks only)
TESSDATA_PREFIX: Optional[str] = None

259
surya/tables.py Normal file
View File

@ -0,0 +1,259 @@
from collections import defaultdict
from copy import deepcopy
from typing import List, Dict
import torch
from PIL import Image
from surya.model.ordering.encoderdecoder import OrderVisionEncoderDecoderModel
from surya.schema import TableResult, TableCell, Bbox
from surya.settings import settings
from tqdm import tqdm
import numpy as np
from surya.model.table_rec.config import SPECIAL_TOKENS
def get_batch_size():
batch_size = settings.TABLE_REC_BATCH_SIZE
if batch_size is None:
batch_size = 8
if settings.TORCH_DEVICE_MODEL == "mps":
batch_size = 8
if settings.TORCH_DEVICE_MODEL == "cuda":
batch_size = 64
return batch_size
def sort_bboxes(bboxes, tolerance=1):
vertical_groups = {}
for block in bboxes:
group_key = round(block["bbox"][1] / tolerance) * tolerance
if group_key not in vertical_groups:
vertical_groups[group_key] = []
vertical_groups[group_key].append(block)
# Sort each group horizontally and flatten the groups into a single list
sorted_page_blocks = []
for _, group in sorted(vertical_groups.items()):
sorted_group = sorted(group, key=lambda x: x["bbox"][0])
sorted_page_blocks.extend(sorted_group)
return sorted_page_blocks
def is_rotated(rows, cols):
# Determine if the table is rotated by looking at row and column width / height ratios
# Rows should have a >1 ratio, cols <1
widths = sum([r.width for r in rows])
heights = sum([c.height for c in rows]) + 1
r_ratio = widths / heights
widths = sum([c.width for c in cols])
heights = sum([r.height for r in cols]) + 1
c_ratio = widths / heights
return r_ratio * 2 < c_ratio
def batch_table_recognition(images: List, table_cells: List[List[Dict]], model: OrderVisionEncoderDecoderModel, processor, batch_size=None) -> List[TableResult]:
assert all([isinstance(image, Image.Image) for image in images])
assert len(images) == len(table_cells)
if batch_size is None:
batch_size = get_batch_size()
output_order = []
for i in tqdm(range(0, len(images), batch_size), desc="Recognizing tables"):
batch_table_cells = deepcopy(table_cells[i:i+batch_size])
batch_table_cells = [sort_bboxes(page_bboxes) for page_bboxes in batch_table_cells] # Sort bboxes before passing in
batch_list_bboxes = [[block["bbox"] for block in page] for page in batch_table_cells]
batch_images = images[i:i+batch_size]
batch_images = [image.convert("RGB") for image in batch_images] # also copies the images
current_batch_size = len(batch_images)
orig_sizes = [image.size for image in batch_images]
model_inputs = processor(images=batch_images, boxes=deepcopy(batch_list_bboxes))
batch_pixel_values = model_inputs["pixel_values"]
batch_bboxes = model_inputs["input_boxes"]
batch_bbox_mask = model_inputs["input_boxes_mask"]
batch_bbox_counts = model_inputs["input_boxes_counts"]
batch_bboxes = torch.from_numpy(np.array(batch_bboxes, dtype=np.int32)).to(model.device)
batch_bbox_mask = torch.from_numpy(np.array(batch_bbox_mask, dtype=np.int32)).to(model.device)
batch_pixel_values = torch.tensor(np.array(batch_pixel_values), dtype=model.dtype).to(model.device)
batch_bbox_counts = torch.tensor(np.array(batch_bbox_counts), dtype=torch.long).to(model.device)
# Setup inputs for the decoder
batch_decoder_input = [[[model.config.decoder.bos_token_id] * 5] for _ in range(current_batch_size)]
batch_decoder_input = torch.tensor(np.stack(batch_decoder_input, axis=0), dtype=torch.long, device=model.device)
inference_token_count = batch_decoder_input.shape[1]
max_tokens = min(batch_bbox_counts[:, 1].max().item(), settings.TABLE_REC_MAX_BOXES)
decoder_position_ids = torch.ones_like(batch_decoder_input[0, :, 0], dtype=torch.int64, device=model.device).cumsum(0) - 1
model.decoder.model._setup_cache(model.config, batch_size, model.device, model.dtype)
model.text_encoder.model._setup_cache(model.config, batch_size, model.device, model.dtype)
batch_predictions = [[] for _ in range(current_batch_size)]
with torch.inference_mode():
encoder_hidden_states = model.encoder(pixel_values=batch_pixel_values).last_hidden_state
text_encoder_hidden_states = model.text_encoder(
input_boxes=batch_bboxes,
input_boxes_counts=batch_bbox_counts,
cache_position=None,
attention_mask=batch_bbox_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=None,
use_cache=False
).hidden_states
token_count = 0
all_done = torch.zeros(current_batch_size, dtype=torch.bool, device=model.device)
while token_count < max_tokens:
is_prefill = token_count == 0
return_dict = model.decoder(
input_ids=batch_decoder_input,
encoder_hidden_states=text_encoder_hidden_states,
cache_position=decoder_position_ids,
use_cache=True,
prefill=is_prefill
)
decoder_position_ids = decoder_position_ids[-1:] + 1
box_logits = return_dict["bbox_logits"][:, -1, :].detach()
rowcol_logits = return_dict["class_logits"][:, -1, :].detach()
rowcol_preds = torch.argmax(rowcol_logits, dim=-1)
box_preds = torch.argmax(box_logits, dim=-1)
done = (rowcol_preds == processor.tokenizer.eos_id) | (rowcol_preds == processor.tokenizer.pad_id)
done = done
all_done = all_done | done
if all_done.all():
break
batch_decoder_input = torch.cat([box_preds.unsqueeze(1), rowcol_preds.unsqueeze(1).unsqueeze(1)], dim=-1)
for j, (pred, status) in enumerate(zip(batch_decoder_input, all_done)):
if not status:
batch_predictions[j].append(pred[0].tolist())
token_count += inference_token_count
inference_token_count = batch_decoder_input.shape[1]
for j, (preds, input_cells, orig_size) in enumerate(zip(batch_predictions, batch_table_cells, orig_sizes)):
img_w, img_h = orig_size
width_scaler = img_w / model.config.decoder.out_box_size
height_scaler = img_h / model.config.decoder.out_box_size
# cx, cy to corners
for i, pred in enumerate(preds):
w = pred[2] / 2
h = pred[3] / 2
x1 = pred[0] - w
y1 = pred[1] - h
x2 = pred[0] + w
y2 = pred[1] + h
class_ = int(pred[4] - SPECIAL_TOKENS)
preds[i] = [x1 * width_scaler, y1 * height_scaler, x2 * width_scaler, y2 * height_scaler, class_]
# Get rows and columns
bb_rows = [p[:4] for p in preds if p[4] == 0]
bb_cols = [p[:4] for p in preds if p[4] == 1]
rows = []
cols = []
for row_idx, row in enumerate(bb_rows):
cell = TableCell(
bbox=row,
row_id=row_idx
)
rows.append(cell)
for col_idx, col in enumerate(bb_cols):
cell = TableCell(
bbox=col,
col_id=col_idx,
)
cols.append(cell)
# Assign cells to rows/columns
cells = []
for cell in input_cells:
max_intersection = 0
row_pred = None
for row_idx, row in enumerate(rows):
intersection_pct = Bbox(bbox=cell["bbox"]).intersection_pct(row)
if intersection_pct > max_intersection:
max_intersection = intersection_pct
row_pred = row_idx
max_intersection = 0
col_pred = None
for col_idx, col in enumerate(cols):
intersection_pct = Bbox(bbox=cell["bbox"]).intersection_pct(col)
if intersection_pct > max_intersection:
max_intersection = intersection_pct
col_pred = col_idx
cells.append(
TableCell(
bbox=cell["bbox"],
text=cell.get("text"),
row_id=row_pred,
col_id=col_pred
)
)
rotated = is_rotated(rows, cols)
for cell in cells:
if cell.row_id is None:
closest_row = None
closest_row_dist = None
for cell2 in cells:
if cell2.row_id is None:
continue
if rotated:
cell_y_center = cell.center[0]
cell2_y_center = cell2.center[0]
else:
cell_y_center = cell.center[1]
cell2_y_center = cell2.center[1]
y_dist = abs(cell_y_center - cell2_y_center)
if closest_row_dist is None or y_dist < closest_row_dist:
closest_row = cell2.row_id
closest_row_dist = y_dist
cell.row_id = closest_row
if cell.col_id is None:
closest_col = None
closest_col_dist = None
for cell2 in cells:
if cell2.col_id is None:
continue
if rotated:
cell_x_center = cell.center[1]
cell2_x_center = cell2.center[1]
else:
cell_x_center = cell.center[0]
cell2_x_center = cell2.center[0]
x_dist = abs(cell2_x_center - cell_x_center)
if closest_col_dist is None or x_dist < closest_col_dist:
closest_col = cell2.col_id
closest_col_dist = x_dist
cell.col_id = closest_col
result = TableResult(
cells=cells,
rows=rows,
cols=cols,
image_bbox=[0, 0, img_w, img_h],
)
output_order.append(result)
return output_order

146
table_recognition.py Normal file
View File

@ -0,0 +1,146 @@
import pypdfium2 as pdfium # Needs to be on top to avoid warning
import os
import argparse
import copy
import json
from collections import defaultdict
from surya.detection import batch_text_detection
from surya.input.load import load_from_folder, load_from_file
from surya.input.pdflines import get_table_blocks
from surya.layout import batch_layout_detection
from surya.model.detection.model import load_model as load_det_model, load_processor as load_det_processor
from surya.model.table_rec.model import load_model as load_model
from surya.model.table_rec.processor import load_processor
from surya.tables import batch_table_recognition
from surya.postprocessing.heatmap import draw_bboxes_on_image
from surya.settings import settings
from surya.postprocessing.util import rescale_bboxes, rescale_bbox
def main():
parser = argparse.ArgumentParser(description="Find reading order of an input file or folder (PDFs or image).")
parser.add_argument("input_path", type=str, help="Path to pdf or image file or folder to find reading order in.")
parser.add_argument("--results_dir", type=str, help="Path to JSON file with layout results.", default=os.path.join(settings.RESULT_DIR, "surya"))
parser.add_argument("--max", type=int, help="Maximum number of pages to process.", default=None)
parser.add_argument("--images", action="store_true", help="Save images of detected layout bboxes.", default=False)
parser.add_argument("--detect_boxes", action="store_true", help="Detect table boxes.", default=False)
parser.add_argument("--skip_table_detection", action="store_true", help="Tables are already cropped, so don't re-detect tables.", default=False)
args = parser.parse_args()
model = load_model()
processor = load_processor()
layout_model = load_det_model(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)
layout_processor = load_det_processor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)
det_model = load_det_model()
det_processor = load_det_processor()
if os.path.isdir(args.input_path):
images, _, _ = load_from_folder(args.input_path, args.max)
highres_images, names, text_lines = load_from_folder(args.input_path, args.max, dpi=settings.IMAGE_DPI_HIGHRES, load_text_lines=True)
folder_name = os.path.basename(args.input_path)
else:
images, _, _ = load_from_file(args.input_path, args.max)
highres_images, names, text_lines = load_from_file(args.input_path, args.max, dpi=settings.IMAGE_DPI_HIGHRES, load_text_lines=True)
folder_name = os.path.basename(args.input_path).split(".")[0]
pnums = []
prev_name = None
for i, name in enumerate(names):
if prev_name is None or prev_name != name:
pnums.append(0)
else:
pnums.append(pnums[-1] + 1)
prev_name = name
line_predictions = batch_text_detection(images, det_model, det_processor)
layout_predictions = batch_layout_detection(images, layout_model, layout_processor, line_predictions)
table_cells = []
table_imgs = []
table_counts = []
for layout_pred, text_line, img, highres_img in zip(layout_predictions, text_lines, images, highres_images):
# The table may already be cropped
if args.skip_table_detection:
table_imgs.append(highres_img)
table_counts.append(1)
page_table_imgs = [highres_img]
highres_bbox = [[0, 0, highres_img.size[0], highres_img.size[1]]]
else:
# The bbox for the entire table
bbox = [l.bbox for l in layout_pred.bboxes if l.label == "Table"]
# Number of tables per page
table_counts.append(len(bbox))
if len(bbox) == 0:
continue
page_table_imgs = []
highres_bbox = []
for bb in bbox:
highres_bb = rescale_bbox(bb, img.size, highres_img.size)
page_table_imgs.append(highres_img.crop(highres_bb))
highres_bbox.append(highres_bb)
table_imgs.extend(page_table_imgs)
# The text cells inside each table
table_blocks = get_table_blocks(highres_bbox, text_line, highres_img.size) if text_line is not None else None
if text_line is None or args.detect_boxes or any(len(tb) == 0 for tb in table_blocks):
det_results = batch_text_detection(page_table_imgs, det_model, det_processor,)
cell_bboxes = [[{"bbox": tb.bbox, "text": None} for tb in det_result.bboxes] for det_result in det_results]
table_cells.extend(cell_bboxes)
else:
table_cells.extend(table_blocks)
table_preds = batch_table_recognition(table_imgs, table_cells, model, processor)
result_path = os.path.join(args.results_dir, folder_name)
os.makedirs(result_path, exist_ok=True)
img_idx = 0
prev_count = 0
table_predictions = defaultdict(list)
for i in range(sum(table_counts)):
while i >= prev_count + table_counts[img_idx]:
prev_count += table_counts[img_idx]
img_idx += 1
pred = table_preds[i]
orig_name = names[img_idx]
pnum = pnums[img_idx]
table_img = table_imgs[i]
out_pred = pred.model_dump()
out_pred["page"] = pnum + 1
table_idx = i - prev_count
out_pred["table_idx"] = table_idx
table_predictions[orig_name].append(out_pred)
if args.images:
boxes = [l.bbox for l in pred.cells]
labels = [f"{l.row_id}/{l.col_id}" for l in pred.cells]
bbox_image = draw_bboxes_on_image(boxes, copy.deepcopy(table_img), labels=labels, label_font_size=20)
bbox_image.save(os.path.join(result_path, f"{name}_page{pnum + 1}_table{table_idx}_cells.png"))
rows = [l.bbox for l in pred.rows]
cols = [l.bbox for l in pred.cols]
row_labels = [f"Row {l.row_id}" for l in pred.rows]
col_labels = [f"Col {l.col_id}" for l in pred.cols]
rc_image = copy.deepcopy(table_img)
rc_image = draw_bboxes_on_image(rows, rc_image, labels=row_labels, label_font_size=20, color="blue")
rc_image = draw_bboxes_on_image(cols, rc_image, labels=col_labels, label_font_size=20, color="red")
rc_image.save(os.path.join(result_path, f"{name}_page{pnum + 1}_table{table_idx}_rc.png"))
with open(os.path.join(result_path, "results.json"), "w+", encoding="utf-8") as f:
json.dump(table_predictions, f, ensure_ascii=False)
print(f"Wrote results to {result_path}")
if __name__ == "__main__":
main()