surya/tests/test_table_rec.py
2026-05-14 15:29:07 -04:00

58 lines
1.6 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from PIL import Image, ImageDraw
def test_table_rec(table_rec_predictor):
data = [
["Name", "Age", "City"],
["Alice", 25, "New York"],
["Bob", 30, "Los Angeles"],
["Charlie", 35, "Chicago"],
]
test_image = draw_table(data)
results = table_rec_predictor([test_image])
assert len(results) == 1
assert results[0].image_bbox == [0, 0, test_image.size[0], test_image.size[1]]
if results[0].error:
return
rows = results[0].rows
cols = results[0].cols
cells = results[0].cells
assert len(rows) >= 1
assert len(cols) >= 1
# Geometric cells = rows × cols
assert len(cells) == len(rows) * len(cols) or len(cells) <= len(rows) * len(cols)
def draw_table(data, cell_width=100, cell_height=40):
rows = len(data)
cols = len(data[0])
width = cols * cell_width
height = rows * cell_height
image = Image.new("RGB", (width, height), "white")
draw = ImageDraw.Draw(image)
for i in range(rows + 1):
y = i * cell_height
draw.line([(0, y), (width, y)], fill="black", width=1)
for i in range(cols + 1):
x = i * cell_width
draw.line([(x, 0), (x, height)], fill="black", width=1)
for i in range(rows):
for j in range(cols):
text = str(data[i][j])
text_bbox = draw.textbbox((0, 0), text)
text_width = text_bbox[2] - text_bbox[0]
text_height = text_bbox[3] - text_bbox[1]
x = j * cell_width + (cell_width - text_width) // 2
y = i * cell_height + (cell_height - text_height) // 2
draw.text((x, y), text, fill="black")
return image