surya/table_recognition.py
2025-01-10 09:10:23 -05:00

107 lines
3.9 KiB
Python

import os
import click
import copy
import json
from collections import defaultdict
from surya.common.cli.config import CLILoader
from surya.layout import LayoutPredictor
from surya.table_rec import TableRecPredictor
from surya.debug.draw import draw_bboxes_on_image
from surya.common.util import rescale_bbox, expand_bbox
@click.command(help="Detect layout of an input file or folder (PDFs or image).")
@CLILoader.common_options
@click.option("--detect_boxes", is_flag=True, help="Detect table boxes.", default=False)
@click.option("--skip_table_detection", is_flag=True, help="Tables are already cropped, so don't re-detect tables.", default=False)
def main(input_path: str, detect_boxes: bool, skip_table_detection: bool, **kwargs):
loader = CLILoader(input_path, kwargs, highres=True)
table_rec_predictor = TableRecPredictor()
layout_predictor = LayoutPredictor()
pnums = []
prev_name = None
for i, name in enumerate(loader.names):
if prev_name is None or prev_name != name:
pnums.append(0)
else:
pnums.append(pnums[-1] + 1)
prev_name = name
layout_predictions = layout_predictor(loader.images)
table_imgs = []
table_counts = []
for layout_pred, img, highres_img in zip(layout_predictions, loader.images, loader.highres_images):
# The table may already be cropped
if skip_table_detection:
table_imgs.append(highres_img)
table_counts.append(1)
else:
# The bbox for the entire table
bbox = [l.bbox for l in layout_pred.bboxes if l.label == "Table"]
# Number of tables per page
table_counts.append(len(bbox))
if len(bbox) == 0:
continue
page_table_imgs = []
highres_bbox = []
for bb in bbox:
highres_bb = rescale_bbox(bb, img.size, highres_img.size)
highres_bb = expand_bbox(highres_bb, expansion_factor=.01)
page_table_imgs.append(highres_img.crop(highres_bb))
highres_bbox.append(highres_bb)
table_imgs.extend(page_table_imgs)
table_preds = table_rec_predictor(table_imgs)
img_idx = 0
prev_count = 0
table_predictions = defaultdict(list)
for i in range(sum(table_counts)):
while i >= prev_count + table_counts[img_idx]:
prev_count += table_counts[img_idx]
img_idx += 1
pred = table_preds[i]
orig_name = loader.names[img_idx]
pnum = pnums[img_idx]
table_img = table_imgs[i]
out_pred = pred.model_dump()
out_pred["page"] = pnum + 1
table_idx = i - prev_count
out_pred["table_idx"] = table_idx
table_predictions[orig_name].append(out_pred)
if loader.images:
rows = [l.bbox for l in pred.rows]
cols = [l.bbox for l in pred.cols]
row_labels = [f"Row {l.row_id}" for l in pred.rows]
col_labels = [f"Col {l.col_id}" for l in pred.cols]
cells = [l.bbox for l in pred.cells]
rc_image = copy.deepcopy(table_img)
rc_image = draw_bboxes_on_image(rows, rc_image, labels=row_labels, label_font_size=20, color="blue")
rc_image = draw_bboxes_on_image(cols, rc_image, labels=col_labels, label_font_size=20, color="red")
rc_image.save(os.path.join(loader.result_path, f"{name}_page{pnum + 1}_table{table_idx}_rc.png"))
cell_image = copy.deepcopy(table_img)
cell_image = draw_bboxes_on_image(cells, cell_image, color="green")
cell_image.save(os.path.join(loader.result_path, f"{name}_page{pnum + 1}_table{table_idx}_cells.png"))
with open(os.path.join(loader.result_path, "results.json"), "w+", encoding="utf-8") as f:
json.dump(table_predictions, f, ensure_ascii=False)
print(f"Wrote results to {loader.result_path}")
if __name__ == "__main__":
main()