mirror of
https://github.com/VikParuchuri/surya.git
synced 2026-06-04 21:03:53 +08:00
Heavy cleanups
This commit is contained in:
parent
8e40566be0
commit
05874ae9ad
@ -1 +0,0 @@
|
||||
{"sessionId":"4129570b-108b-41b8-87b0-725c0b77135e","pid":5040,"procStart":"Tue May 5 13:42:45 2026","acquiredAt":1778008965041}
|
||||
@ -6,6 +6,6 @@ authors:
|
||||
given-names: Vikas
|
||||
- name: Datalab Team
|
||||
date-released: 2025-05-13
|
||||
url: https://github.com/VikParuchuri/surya
|
||||
version: 0.14.0
|
||||
repository-code: https://github.com/VikParuchuri/surya
|
||||
url: https://github.com/datalab-to/surya
|
||||
version: 1.0.0
|
||||
repository-code: https://github.com/datalab-to/surya
|
||||
@ -1,11 +1,11 @@
|
||||
[project]
|
||||
name = "surya-ocr"
|
||||
version = "2.0.0a0"
|
||||
description = "OCR, layout, reading order, and table recognition via VLM (vllm + llama.cpp)."
|
||||
version = "1.0.0a0"
|
||||
description = "OCR, layout, reading order, and table recognition in 90+ languages."
|
||||
readme = "README.md"
|
||||
license = { text = "Apache-2.0" }
|
||||
authors = [
|
||||
{ name = "Vik Paruchuri", email = "vik.paruchuri@gmail.com" },
|
||||
{ name = "Vik Paruchuri", email = "vik@datalab.to" },
|
||||
]
|
||||
requires-python = ">=3.10,<4"
|
||||
keywords = ["ocr", "pdf", "text detection", "text recognition", "tables"]
|
||||
@ -30,7 +30,7 @@ dependencies = [
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
Repository = "https://github.com/VikParuchuri/surya"
|
||||
Repository = "https://github.com/datalab-to/surya"
|
||||
|
||||
[project.scripts]
|
||||
surya_detect = "surya.scripts.detect_text:detect_text_cli"
|
||||
|
||||
@ -80,23 +80,6 @@ class PolygonBox(BaseModel):
|
||||
corner[1] = max(min(corner[1], bounds[3]), bounds[1])
|
||||
self.polygon = new_corners
|
||||
|
||||
def merge(self, other):
|
||||
x1 = min(self.bbox[0], other.bbox[0])
|
||||
y1 = min(self.bbox[1], other.bbox[1])
|
||||
x2 = max(self.bbox[2], other.bbox[2])
|
||||
y2 = max(self.bbox[3], other.bbox[3])
|
||||
self.polygon = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]]
|
||||
|
||||
def merge_left(self, other):
|
||||
x1 = min(self.bbox[0], other.bbox[0])
|
||||
self.polygon[0][0] = x1
|
||||
self.polygon[3][0] = x1
|
||||
|
||||
def merge_right(self, other):
|
||||
x2 = max(self.bbox[2], other.bbox[2])
|
||||
self.polygon[1][0] = x2
|
||||
self.polygon[2][0] = x2
|
||||
|
||||
def expand(self, x_margin: float, y_margin: float):
|
||||
new_polygon = []
|
||||
x_margin = x_margin * self.width
|
||||
@ -112,33 +95,6 @@ class PolygonBox(BaseModel):
|
||||
new_polygon.append([int(poly[0] - x_margin), int(poly[1] + y_margin)])
|
||||
self.polygon = new_polygon
|
||||
|
||||
def intersection_polygon(self, other) -> List[List[float]]:
|
||||
new_poly = []
|
||||
for i in range(4):
|
||||
if i == 0:
|
||||
new_corner = [
|
||||
max(self.polygon[0][0], other.polygon[0][0]),
|
||||
max(self.polygon[0][1], other.polygon[0][1]),
|
||||
]
|
||||
elif i == 1:
|
||||
new_corner = [
|
||||
min(self.polygon[1][0], other.polygon[1][0]),
|
||||
max(self.polygon[1][1], other.polygon[1][1]),
|
||||
]
|
||||
elif i == 2:
|
||||
new_corner = [
|
||||
min(self.polygon[2][0], other.polygon[2][0]),
|
||||
min(self.polygon[2][1], other.polygon[2][1]),
|
||||
]
|
||||
elif i == 3:
|
||||
new_corner = [
|
||||
max(self.polygon[3][0], other.polygon[3][0]),
|
||||
min(self.polygon[3][1], other.polygon[3][1]),
|
||||
]
|
||||
new_poly.append(new_corner)
|
||||
|
||||
return new_poly
|
||||
|
||||
def intersection_area(self, other, x_margin=0, y_margin=0):
|
||||
x_overlap = self.x_overlap(other, x_margin)
|
||||
y_overlap = self.y_overlap(other, y_margin)
|
||||
@ -158,44 +114,9 @@ class PolygonBox(BaseModel):
|
||||
- max(self.bbox[1] - y_margin, other.bbox[1] - y_margin),
|
||||
)
|
||||
|
||||
def intersection_pct(self, other, x_margin=0, y_margin=0):
|
||||
assert 0 <= x_margin <= 1
|
||||
assert 0 <= y_margin <= 1
|
||||
if self.area == 0:
|
||||
return 0
|
||||
|
||||
if x_margin:
|
||||
x_margin = int(min(self.width, other.width) * x_margin)
|
||||
if y_margin:
|
||||
y_margin = int(min(self.height, other.height) * y_margin)
|
||||
|
||||
intersection = self.intersection_area(other, x_margin, y_margin)
|
||||
return intersection / self.area
|
||||
|
||||
def shift(self, x_shift: float | None = None, y_shift: float | None = None):
|
||||
if x_shift is not None:
|
||||
for corner in self.polygon:
|
||||
corner[0] += x_shift
|
||||
if y_shift is not None:
|
||||
for corner in self.polygon:
|
||||
corner[1] += y_shift
|
||||
|
||||
def clamp(self, bbox: List[float]):
|
||||
for corner in self.polygon:
|
||||
corner[0] = max(min(corner[0], bbox[2]), bbox[0])
|
||||
corner[1] = max(min(corner[1], bbox[3]), bbox[1])
|
||||
|
||||
@property
|
||||
def center(self):
|
||||
return [(self.bbox[0] + self.bbox[2]) / 2, (self.bbox[1] + self.bbox[3]) / 2]
|
||||
|
||||
def distance(self, other):
|
||||
center = self.center
|
||||
other_center = other.center
|
||||
|
||||
return (
|
||||
(center[0] - other_center[0]) ** 2 + (center[1] - other_center[1]) ** 2
|
||||
) ** 0.5
|
||||
|
||||
def __hash__(self):
|
||||
return hash(tuple(self.bbox))
|
||||
|
||||
@ -1,9 +1,4 @@
|
||||
import copy
|
||||
from typing import List
|
||||
import torch
|
||||
from functools import lru_cache
|
||||
|
||||
import torch.nn.functional as F
|
||||
|
||||
from surya.common.polygon import PolygonBox
|
||||
|
||||
@ -38,21 +33,6 @@ def clean_boxes(boxes: List[PolygonBox]) -> List[PolygonBox]:
|
||||
return new_boxes
|
||||
|
||||
|
||||
def rescale_bbox(bbox, processor_size, image_size):
|
||||
page_width, page_height = processor_size
|
||||
|
||||
img_width, img_height = image_size
|
||||
width_scaler = img_width / page_width
|
||||
height_scaler = img_height / page_height
|
||||
|
||||
new_bbox = copy.deepcopy(bbox)
|
||||
new_bbox[0] = int(new_bbox[0] * width_scaler)
|
||||
new_bbox[1] = int(new_bbox[1] * height_scaler)
|
||||
new_bbox[2] = int(new_bbox[2] * width_scaler)
|
||||
new_bbox[3] = int(new_bbox[3] * height_scaler)
|
||||
return new_bbox
|
||||
|
||||
|
||||
def expand_bbox(bbox, expansion_factor=0.01):
|
||||
expansion_low = 1 - expansion_factor
|
||||
expansion_high = 1 + expansion_factor
|
||||
@ -62,210 +42,3 @@ def expand_bbox(bbox, expansion_factor=0.01):
|
||||
bbox[2] * expansion_high,
|
||||
bbox[3] * expansion_high,
|
||||
]
|
||||
|
||||
SCRIPT_TOKEN_MAPPING = {
|
||||
"latin": "<SCRIPT-LATIN>",
|
||||
"punctuation": "<SCRIPT-PUNCTUATION>",
|
||||
"cyrillic": "<SCRIPT-CYRILLIC>",
|
||||
"arabic": "<SCRIPT-ARABIC>",
|
||||
"chinese": "<SCRIPT-CHINESE>",
|
||||
"japanese": "<SCRIPT-JAPANESE>",
|
||||
"korean": "<SCRIPT-KOREAN>",
|
||||
"symbols": "<SCRIPT-SYMBOLS>",
|
||||
"greek": "<SCRIPT-GREEK>",
|
||||
"armenian": "<SCRIPT-ARMENIAN>",
|
||||
"hebrew": "<SCRIPT-HEBREW>",
|
||||
"devanagari": "<SCRIPT-DEVANAGARI>",
|
||||
"bengali": "<SCRIPT-BENGALI>",
|
||||
"gurmukhi": "<SCRIPT-GURMUKHI>",
|
||||
"gujarati": "<SCRIPT-GUJARATI>",
|
||||
"oriya": "<SCRIPT-ORIYA>",
|
||||
"tamil": "<SCRIPT-TAMIL>",
|
||||
"telugu": "<SCRIPT-TELUGU>",
|
||||
"kannada": "<SCRIPT-KANNADA>",
|
||||
"malayalam": "<SCRIPT-MALAYALAM>",
|
||||
"sinhala": "<SCRIPT-SINHALA>",
|
||||
"thai": "<SCRIPT-THAI>",
|
||||
"lao": "<SCRIPT-LAO>",
|
||||
"myanmar": "<SCRIPT-MYANMAR>",
|
||||
"georgian": "<SCRIPT-GEORGIAN>",
|
||||
"ethiopic": "<SCRIPT-ETHIOPIC>",
|
||||
"khmer": "<SCRIPT-KHMER>",
|
||||
"mongolian": "<SCRIPT-MONGOLIAN>",
|
||||
"math": "<SCRIPT-MATH>",
|
||||
}
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def script_ranges():
|
||||
script_categories = {
|
||||
# Latin-based scripts (used by English, French, German, etc.)
|
||||
"latin": [
|
||||
(0x0041, 0x005A), # Latin uppercase A-Z
|
||||
(0x0061, 0x007A), # Latin lowercase a-z
|
||||
(0x0080, 0x00FF), # Latin-1 Supplement
|
||||
(0x0100, 0x017F), # Latin Extended-A
|
||||
(0x0180, 0x024F), # Latin Extended-B
|
||||
(0x0250, 0x02AF), # IPA Extensions
|
||||
(0x02B0, 0x02FF), # Spacing Modifier Letters
|
||||
(0x0300, 0x036F), # Combining Diacritical Marks
|
||||
(0x1E00, 0x1EFF), # Latin Extended Additional
|
||||
(0x2C60, 0x2C7F), # Latin Extended-C
|
||||
(0xA720, 0xA7FF), # Latin Extended-D
|
||||
],
|
||||
# Punctuation, universal characters, and general symbols
|
||||
"punctuation": [
|
||||
(0x0020, 0x0020), # Space
|
||||
(0x0021, 0x002F), # Basic punctuation and symbols
|
||||
(0x0030, 0x0039), # Digits 0-9
|
||||
(0x003A, 0x0040), # More punctuation and symbols
|
||||
(0x005B, 0x0060), # More punctuation and symbols
|
||||
(0x007B, 0x007F), # More punctuation and symbols
|
||||
(0x2000, 0x206F), # General Punctuation
|
||||
],
|
||||
# Cyrillic scripts (used by Russian, Ukrainian, etc.)
|
||||
"cyrillic": [
|
||||
(0x0400, 0x04FF), # Cyrillic
|
||||
(0x0500, 0x052F), # Cyrillic Supplement
|
||||
],
|
||||
# Arabic scripts
|
||||
"arabic": [
|
||||
(0x0600, 0x06FF), # Arabic
|
||||
(0x0750, 0x077F), # Arabic Supplement
|
||||
(0x08A0, 0x08FF), # Arabic Extended-A
|
||||
],
|
||||
# Chinese characters
|
||||
"chinese": [
|
||||
(0x4E00, 0x9FFF), # Common CJK Unified Ideographs
|
||||
(0x3400, 0x4DBF), # CJK Extension A
|
||||
(0x20000, 0x2A6DF), # CJK Extension B
|
||||
],
|
||||
# Japanese-specific scripts (excluding shared CJK)
|
||||
"japanese": [
|
||||
(0x3040, 0x30FF), # Hiragana and Katakana
|
||||
],
|
||||
# Korean-specific scripts
|
||||
"korean": [
|
||||
(0x1100, 0x11FF), # Hangul Jamo
|
||||
(0x3130, 0x318F), # Hangul Compatibility Jamo
|
||||
(0xAC00, 0xD7AF), # Hangul Syllables
|
||||
],
|
||||
# Various mathematical and technical symbols
|
||||
"symbols": [
|
||||
(0x2070, 0x209F), # Superscripts and Subscripts
|
||||
(0x20A0, 0x20CF), # Currency Symbols
|
||||
(0x2100, 0x214F), # Letterlike Symbols
|
||||
(0x2150, 0x218F), # Number Forms
|
||||
(0x2190, 0x21FF), # Arrows
|
||||
(0x2200, 0x22FF), # Mathematical Operators
|
||||
(0x2300, 0x23FF), # Miscellaneous Technical
|
||||
(0x2500, 0x257F), # Box Drawing
|
||||
(0x2580, 0x259F), # Block Elements
|
||||
(0x25A0, 0x25FF), # Geometric Shapes
|
||||
(0x2600, 0x26FF), # Miscellaneous Symbols
|
||||
(0x2700, 0x27BF), # Dingbats
|
||||
(0x27C0, 0x27EF), # Miscellaneous Mathematical Symbols-A
|
||||
(0x2980, 0x29FF), # Miscellaneous Mathematical Symbols-B
|
||||
(0x2A00, 0x2AFF), # Supplemental Mathematical Operators
|
||||
(0x1D400, 0x1D7FF), # Mathematical Alphanumeric Symbols
|
||||
],
|
||||
# Individual scripts for languages with unique writing systems
|
||||
"greek": [(0x0370, 0x03FF)], # Greek and Coptic
|
||||
"armenian": [(0x0530, 0x058F)], # Armenian
|
||||
"hebrew": [(0x0590, 0x05FF)], # Hebrew
|
||||
"devanagari": [(0x0900, 0x097F)], # Devanagari (Hindi, Sanskrit)
|
||||
"bengali": [(0x0980, 0x09FF)], # Bengali
|
||||
"gurmukhi": [(0x0A00, 0x0A7F)], # Gurmukhi (Punjabi)
|
||||
"gujarati": [(0x0A80, 0x0AFF)], # Gujarati
|
||||
"oriya": [(0x0B00, 0x0B7F)], # Oriya
|
||||
"tamil": [(0x0B80, 0x0BFF)], # Tamil
|
||||
"telugu": [(0x0C00, 0x0C7F)], # Telugu
|
||||
"kannada": [(0x0C80, 0x0CFF)], # Kannada
|
||||
"malayalam": [(0x0D00, 0x0D7F)], # Malayalam
|
||||
"sinhala": [(0x0D80, 0x0DFF)], # Sinhala
|
||||
"thai": [(0x0E00, 0x0E7F)], # Thai
|
||||
"lao": [(0x0E80, 0x0EFF)], # Lao
|
||||
"myanmar": [(0x1000, 0x109F)], # Myanmar
|
||||
"georgian": [(0x10A0, 0x10FF)], # Georgian
|
||||
"ethiopic": [(0x1200, 0x137F)], # Ethiopic
|
||||
"khmer": [(0x1780, 0x17FF)], # Khmer
|
||||
"mongolian": [(0x1800, 0x18AF)], # Mongolian
|
||||
}
|
||||
|
||||
# Convert to a flat structure with character ranges
|
||||
flat_ranges = {}
|
||||
for category, ranges in script_categories.items():
|
||||
# Create a set of all characters in this category
|
||||
char_set = set()
|
||||
for start, end in ranges:
|
||||
char_set.update(range(start, end + 1))
|
||||
|
||||
# Store the set in flat_ranges
|
||||
flat_ranges[category] = char_set
|
||||
|
||||
return script_categories, flat_ranges
|
||||
|
||||
def get_top_scripts(text: str, max_scripts: int = 5):
|
||||
script_categories, flat_ranges = script_ranges()
|
||||
char_count = {category: 0 for category in script_categories.keys()}
|
||||
for char in text:
|
||||
for category, char_set in flat_ranges.items():
|
||||
if ord(char) in char_set:
|
||||
char_count[category] += 1
|
||||
break
|
||||
|
||||
top_scripts = sorted(char_count.items(), key=lambda x: x[1], reverse=True)
|
||||
top_scripts = [ts[0] for ts in top_scripts if ts[1] > 0]
|
||||
if "<math" in text:
|
||||
top_scripts.insert(0, "math")
|
||||
|
||||
return top_scripts[:max_scripts]
|
||||
|
||||
def is_flash_attn_2_supported(device: str | torch.device) -> bool:
|
||||
if not torch.cuda.is_available():
|
||||
return False
|
||||
|
||||
if "cuda" not in str(device):
|
||||
return False
|
||||
|
||||
# Check CUDA version >= 12.0
|
||||
cuda_version_str = torch.version.cuda
|
||||
if cuda_version_str is None:
|
||||
return False
|
||||
cuda_version = tuple(map(int, cuda_version_str.split(".")))
|
||||
if cuda_version < (12, 0):
|
||||
return False
|
||||
|
||||
# Check GPU compute capability (Ampere, Ada, Hopper GPUs)
|
||||
major, minor = torch.cuda.get_device_capability()
|
||||
compute_capability = major + minor / 10
|
||||
if compute_capability < 8.0:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def pad_to_batch_size_repeat(tensor: torch.Tensor, batch_size: int):
|
||||
current_batch_size = tensor.shape[0]
|
||||
if current_batch_size >= batch_size:
|
||||
return tensor
|
||||
|
||||
pad_size = batch_size - current_batch_size
|
||||
if pad_size < 0:
|
||||
return tensor
|
||||
|
||||
# Repeat the last row pad_size times
|
||||
last_row = tensor[-1:].repeat(pad_size, 1, 1)
|
||||
|
||||
# Concatenate original tensor with repeated last rows
|
||||
return torch.cat([tensor, last_row], dim=0)
|
||||
|
||||
|
||||
def pad_to_batch_size(tensor: torch.Tensor, batch_size: int):
|
||||
current_batch_size = tensor.shape[0]
|
||||
if current_batch_size >= batch_size:
|
||||
return tensor
|
||||
|
||||
pad_size = batch_size - current_batch_size
|
||||
padding = (0, 0) * (tensor.dim() - 1) + (0, pad_size)
|
||||
|
||||
return F.pad(tensor, padding, mode="constant", value=0)
|
||||
|
||||
@ -1,24 +1,4 @@
|
||||
import re
|
||||
from io import BytesIO
|
||||
from typing import List, Tuple
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
from surya.debug.fonts import get_font_path
|
||||
from surya.debug.render_html import render_text_as_html
|
||||
|
||||
try:
|
||||
from playwright.sync_api import sync_playwright
|
||||
|
||||
has_playwright = True
|
||||
except ImportError:
|
||||
has_playwright = False
|
||||
|
||||
|
||||
def strip_html_tags(html_text):
|
||||
pattern = re.compile(r"<[\w/][^>]*>")
|
||||
text_only = pattern.sub("", html_text)
|
||||
|
||||
return text_only
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
|
||||
def get_text_size(text, font):
|
||||
@ -26,75 +6,3 @@ def get_text_size(text, font):
|
||||
draw = ImageDraw.Draw(im)
|
||||
_, _, width, height = draw.textbbox((0, 0), text=text, font=font)
|
||||
return width, height
|
||||
|
||||
|
||||
def render_text(draw, text, s_bbox, bbox_width, bbox_height, font_path, box_font_size):
|
||||
font = ImageFont.truetype(font_path, box_font_size)
|
||||
text_width, text_height = get_text_size(text, font)
|
||||
while (text_width > bbox_width or text_height > bbox_height) and box_font_size > 6:
|
||||
box_font_size = box_font_size - 1
|
||||
font = ImageFont.truetype(font_path, box_font_size)
|
||||
text_width, text_height = get_text_size(text, font)
|
||||
|
||||
# Calculate text position (centered in bbox)
|
||||
text_width, text_height = get_text_size(text, font)
|
||||
x = s_bbox[0]
|
||||
y = s_bbox[1] + (bbox_height - text_height) / 2
|
||||
|
||||
draw.text((x, y), text, fill="black", font=font)
|
||||
|
||||
|
||||
def draw_text_with_playwright(
|
||||
bboxes, texts: List[str], image_size: Tuple[int, int]
|
||||
) -> Image.Image:
|
||||
html_content, image_size = render_text_as_html(bboxes, texts, image_size)
|
||||
if not has_playwright:
|
||||
raise ImportError(
|
||||
"Playwright is not installed. Please install it using `pip install playwright`"
|
||||
)
|
||||
|
||||
with sync_playwright() as p:
|
||||
browser = p.chromium.launch(headless=True)
|
||||
page = browser.new_page(
|
||||
viewport={"width": image_size[0], "height": image_size[1]}
|
||||
)
|
||||
page.set_content(html_content)
|
||||
page.wait_for_timeout(1000)
|
||||
body = page.query_selector("body")
|
||||
image = body.screenshot()
|
||||
browser.close()
|
||||
|
||||
pil_img = Image.open(BytesIO(image))
|
||||
return pil_img
|
||||
|
||||
|
||||
def draw_text_on_image(
|
||||
bboxes,
|
||||
texts,
|
||||
image_size: Tuple[int, int],
|
||||
font_path=None,
|
||||
max_font_size=60,
|
||||
res_upscale=2,
|
||||
) -> Image.Image:
|
||||
if has_playwright:
|
||||
return draw_text_with_playwright(bboxes, texts, image_size)
|
||||
|
||||
texts = [strip_html_tags(text) for text in texts]
|
||||
if font_path is None:
|
||||
font_path = get_font_path()
|
||||
new_image_size = (image_size[0] * res_upscale, image_size[1] * res_upscale)
|
||||
image = Image.new("RGB", new_image_size, color="white")
|
||||
draw = ImageDraw.Draw(image)
|
||||
|
||||
for bbox, text in zip(bboxes, texts):
|
||||
s_bbox = [int(coord * res_upscale) for coord in bbox]
|
||||
bbox_width = s_bbox[2] - s_bbox[0]
|
||||
bbox_height = s_bbox[3] - s_bbox[1]
|
||||
|
||||
# Shrink the text to fit in the bbox if needed
|
||||
box_font_size = max(6, min(int(0.75 * bbox_height), max_font_size))
|
||||
render_text(
|
||||
draw, text, s_bbox, bbox_width, bbox_height, font_path, box_font_size
|
||||
)
|
||||
|
||||
return image
|
||||
|
||||
@ -19,7 +19,11 @@ from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
||||
from transformers.image_processing_utils import (
|
||||
BaseImageProcessor,
|
||||
BatchFeature,
|
||||
get_size_dict,
|
||||
)
|
||||
from transformers.image_transforms import to_channel_dimension_format
|
||||
from transformers.image_utils import (
|
||||
IMAGENET_DEFAULT_MEAN,
|
||||
@ -34,7 +38,6 @@ from transformers.utils import TensorType
|
||||
|
||||
|
||||
import PIL.Image
|
||||
import torch
|
||||
|
||||
from surya.common.s3 import S3DownloaderMixin
|
||||
|
||||
@ -107,7 +110,9 @@ class SegformerImageProcessor(S3DownloaderMixin, BaseImageProcessor):
|
||||
self.do_rescale = do_rescale
|
||||
self.rescale_factor = rescale_factor
|
||||
self.do_normalize = do_normalize
|
||||
self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
|
||||
self.image_mean = (
|
||||
image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
|
||||
)
|
||||
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
|
||||
self.do_reduce_labels = do_reduce_labels
|
||||
self._valid_processor_keys = [
|
||||
@ -152,12 +157,18 @@ class SegformerImageProcessor(S3DownloaderMixin, BaseImageProcessor):
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
):
|
||||
|
||||
if do_rescale:
|
||||
image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
|
||||
image = self.rescale(
|
||||
image=image, scale=rescale_factor, input_data_format=input_data_format
|
||||
)
|
||||
|
||||
if do_normalize:
|
||||
image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
|
||||
image = self.normalize(
|
||||
image=image,
|
||||
mean=image_mean,
|
||||
std=image_std,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
|
||||
return image
|
||||
|
||||
@ -193,7 +204,9 @@ class SegformerImageProcessor(S3DownloaderMixin, BaseImageProcessor):
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
if data_format is not None:
|
||||
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
|
||||
image = to_channel_dimension_format(
|
||||
image, data_format, input_channel_dim=input_data_format
|
||||
)
|
||||
return image
|
||||
|
||||
def __call__(self, images, segmentation_maps=None, **kwargs):
|
||||
@ -276,7 +289,9 @@ class SegformerImageProcessor(S3DownloaderMixin, BaseImageProcessor):
|
||||
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
||||
resample = resample if resample is not None else self.resample
|
||||
size = size if size is not None else self.size
|
||||
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
||||
rescale_factor = (
|
||||
rescale_factor if rescale_factor is not None else self.rescale_factor
|
||||
)
|
||||
image_mean = image_mean if image_mean is not None else self.image_mean
|
||||
image_std = image_std if image_std is not None else self.image_std
|
||||
|
||||
@ -299,4 +314,4 @@ class SegformerImageProcessor(S3DownloaderMixin, BaseImageProcessor):
|
||||
]
|
||||
|
||||
data = {"pixel_values": images}
|
||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
One process owns one SuryaInferenceManager. The manager wraps a single backend
|
||||
(vllm | llamacpp) which speaks OpenAI-compatible chat completions.
|
||||
|
||||
Predictors take the manager via explicit injection (see surya/models.py).
|
||||
Predictors take the manager via explicit injection at construction time.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
"""vllm backend: spawns the vllm/vllm-openai docker image with MTP=3."""
|
||||
"""vllm backend: spawns the vllm/vllm-openai docker image with MTP=2."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@ -80,7 +80,6 @@ class VllmBackend(Backend):
|
||||
def __init__(self):
|
||||
self.handle: Optional[ServerHandle] = None
|
||||
self._client: Optional[OpenAI] = None
|
||||
self._container_name: Optional[str] = None
|
||||
|
||||
def start(self) -> ServerHandle:
|
||||
if self.handle is not None:
|
||||
@ -113,7 +112,6 @@ class VllmBackend(Backend):
|
||||
|
||||
def spawn_fn(port: int) -> SpawnHandle:
|
||||
container_name = f"surya-vllm-{port}"
|
||||
self._container_name = container_name
|
||||
hf_cache = os.path.expanduser(settings.DOCKER_HF_CACHE_PATH)
|
||||
cmd = [
|
||||
docker,
|
||||
|
||||
@ -1,401 +0,0 @@
|
||||
"""Pipelined layout + block-OCR orchestrator.
|
||||
|
||||
Standard flow runs all layouts (batch) then all blocks (batch). The latter
|
||||
can't start until every page has finished layout, leaving server slots
|
||||
idle while layouts trickle in.
|
||||
|
||||
This orchestrator overlaps the two phases:
|
||||
* Submit every layout request into one ThreadPoolExecutor.
|
||||
* As each layout future completes, immediately submit that page's block
|
||||
OCR requests, **reverse-sorted by count** (LPT scheduling) so big
|
||||
blocks start their decode earliest.
|
||||
* If a layout fails to parse, fall back to HIGH_ACCURACY_BBOX_PROMPT on
|
||||
the full OCR-DPI page — its HTML output yields layout + content in one
|
||||
shot (no per-block phase needed for fallback pages).
|
||||
* Drain remaining block futures + fallback futures.
|
||||
|
||||
Returns the same shape as the serial pair: (List[LayoutResult], List[PageOCRResult]).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from surya.inference import SuryaInferenceManager
|
||||
from surya.inference.backends.openai_client import _generate_one
|
||||
from surya.inference.parsers import (
|
||||
clean_block_html,
|
||||
denorm_bbox,
|
||||
parse_full_page_html,
|
||||
parse_layout,
|
||||
)
|
||||
from surya.inference.prompts import LAYOUT_JSON_SCHEMA
|
||||
from surya.inference.schema import (
|
||||
BatchInputItem,
|
||||
PROMPT_TYPE_BLOCK,
|
||||
PROMPT_TYPE_HIGH_ACCURACY_BBOX,
|
||||
PROMPT_TYPE_LAYOUT,
|
||||
)
|
||||
from surya.inference.util import image_token_budget
|
||||
from surya.layout.label import LAYOUT_PRED_RELABEL
|
||||
from surya.layout.schema import LayoutBox, LayoutResult
|
||||
from surya.logging import get_logger
|
||||
from surya.recognition import SKIP_CANON_LABELS, _crop_block
|
||||
from surya.recognition.schema import BlockOCRResult, PageOCRResult
|
||||
from surya.settings import settings
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def _make_layout_result(
|
||||
raw: str,
|
||||
error: bool,
|
||||
mean_token_prob: Optional[float],
|
||||
target_size: Tuple[int, int],
|
||||
) -> LayoutResult:
|
||||
w, h = target_size
|
||||
page_bbox = [0, 0, float(w), float(h)]
|
||||
if error or not raw:
|
||||
return LayoutResult(bboxes=[], image_bbox=page_bbox, raw=raw, error=True)
|
||||
try:
|
||||
parsed = parse_layout(raw)
|
||||
except Exception as e:
|
||||
logger.warning(f"Layout parse failed: {e}; raw[:200]={raw[:200]!r}")
|
||||
return LayoutResult(bboxes=[], image_bbox=page_bbox, raw=raw, error=True)
|
||||
|
||||
confidence = mean_token_prob if mean_token_prob is not None else 1.0
|
||||
boxes: List[LayoutBox] = []
|
||||
for idx, blk in enumerate(parsed):
|
||||
pixel_bbox = denorm_bbox(blk.bbox, w, h, scale=settings.BBOX_SCALE)
|
||||
canon = LAYOUT_PRED_RELABEL.get(blk.label, blk.label)
|
||||
boxes.append(
|
||||
LayoutBox(
|
||||
polygon=list(pixel_bbox),
|
||||
label=canon,
|
||||
raw_label=blk.label,
|
||||
position=idx,
|
||||
count=blk.count,
|
||||
confidence=confidence,
|
||||
)
|
||||
)
|
||||
return LayoutResult(bboxes=boxes, image_bbox=page_bbox, raw=raw, error=False)
|
||||
|
||||
|
||||
def _make_block_result(
|
||||
page_box: LayoutBox,
|
||||
raw: str,
|
||||
error: bool,
|
||||
mean_token_prob: Optional[float],
|
||||
skipped: bool,
|
||||
) -> BlockOCRResult:
|
||||
if skipped:
|
||||
return BlockOCRResult(
|
||||
polygon=page_box.polygon,
|
||||
label=page_box.label,
|
||||
raw_label=page_box.raw_label,
|
||||
reading_order=page_box.position,
|
||||
html="",
|
||||
skipped=True,
|
||||
confidence=1.0,
|
||||
)
|
||||
if error:
|
||||
return BlockOCRResult(
|
||||
polygon=page_box.polygon,
|
||||
label=page_box.label,
|
||||
raw_label=page_box.raw_label,
|
||||
reading_order=page_box.position,
|
||||
html="",
|
||||
skipped=False,
|
||||
error=True,
|
||||
confidence=0.0,
|
||||
)
|
||||
html = clean_block_html(raw)
|
||||
conf = mean_token_prob if mean_token_prob is not None else 1.0
|
||||
return BlockOCRResult(
|
||||
polygon=page_box.polygon,
|
||||
label=page_box.label,
|
||||
raw_label=page_box.raw_label,
|
||||
reading_order=page_box.position,
|
||||
html=html,
|
||||
skipped=False,
|
||||
error=False,
|
||||
confidence=conf,
|
||||
)
|
||||
|
||||
|
||||
def _full_page_to_results(
|
||||
raw: str,
|
||||
mean_token_prob: Optional[float],
|
||||
target_size: Tuple[int, int],
|
||||
) -> Tuple[LayoutResult, PageOCRResult]:
|
||||
"""Parse HIGH_ACCURACY_BBOX_PROMPT output into a LayoutResult + PageOCRResult
|
||||
pair (one block per top-level <div>)."""
|
||||
w, h = target_size
|
||||
page_bbox = [0, 0, float(w), float(h)]
|
||||
parsed = parse_full_page_html(raw)
|
||||
confidence = mean_token_prob if mean_token_prob is not None else 1.0
|
||||
|
||||
boxes: List[LayoutBox] = []
|
||||
blocks: List[BlockOCRResult] = []
|
||||
for idx, item in enumerate(parsed):
|
||||
pixel_bbox = denorm_bbox(item.bbox, w, h, scale=settings.BBOX_SCALE)
|
||||
canon = LAYOUT_PRED_RELABEL.get(item.label, item.label)
|
||||
polygon = [
|
||||
[pixel_bbox[0], pixel_bbox[1]],
|
||||
[pixel_bbox[2], pixel_bbox[1]],
|
||||
[pixel_bbox[2], pixel_bbox[3]],
|
||||
[pixel_bbox[0], pixel_bbox[3]],
|
||||
]
|
||||
boxes.append(
|
||||
LayoutBox(
|
||||
polygon=polygon,
|
||||
label=canon,
|
||||
raw_label=item.label,
|
||||
position=idx,
|
||||
count=0,
|
||||
confidence=confidence,
|
||||
)
|
||||
)
|
||||
skipped = canon in SKIP_CANON_LABELS
|
||||
blocks.append(
|
||||
BlockOCRResult(
|
||||
polygon=polygon,
|
||||
label=canon,
|
||||
raw_label=item.label,
|
||||
reading_order=idx,
|
||||
html="" if skipped else item.html,
|
||||
skipped=skipped,
|
||||
error=False,
|
||||
confidence=confidence,
|
||||
)
|
||||
)
|
||||
layout = LayoutResult(bboxes=boxes, image_bbox=page_bbox, raw=raw, error=False)
|
||||
page = PageOCRResult(blocks=blocks, image_bbox=page_bbox)
|
||||
return layout, page
|
||||
|
||||
|
||||
def layout_then_blocks(
|
||||
manager: SuryaInferenceManager,
|
||||
ocr_images: List[Image.Image],
|
||||
layout_images: Optional[List[Image.Image]] = None,
|
||||
max_workers: Optional[int] = None,
|
||||
) -> Tuple[List[LayoutResult], List[PageOCRResult]]:
|
||||
"""Run layout for all pages with block OCR pipelined per page.
|
||||
|
||||
layout_images: optional low-DPI renders for layout. If None, ocr_images are
|
||||
used for both. Bbox coords are returned in ocr_images coord space either way.
|
||||
|
||||
On layout failure, falls back to HIGH_ACCURACY_BBOX_PROMPT on the full OCR
|
||||
image (when settings.SURYA_LAYOUT_FALLBACK_FULL_PAGE is True).
|
||||
"""
|
||||
if not ocr_images:
|
||||
return [], []
|
||||
if layout_images is None:
|
||||
layout_images = ocr_images
|
||||
|
||||
manager.start()
|
||||
backend = manager.backend
|
||||
client = backend._client
|
||||
model_name = backend.handle.model_name
|
||||
|
||||
timeout = settings.SURYA_INFERENCE_TIMEOUT_SECONDS
|
||||
request_logprobs = settings.SURYA_INFERENCE_LOGPROBS
|
||||
n_workers = max_workers or settings.SURYA_INFERENCE_PARALLEL
|
||||
|
||||
common_kwargs = dict(
|
||||
client=client,
|
||||
model_name=model_name,
|
||||
max_tokens_default=settings.SURYA_MAX_TOKENS_LAYOUT,
|
||||
temperature=0.0,
|
||||
top_p=0.1,
|
||||
timeout=timeout,
|
||||
request_logprobs_default=request_logprobs,
|
||||
)
|
||||
|
||||
n_pages = len(ocr_images)
|
||||
layout_results: List[Optional[LayoutResult]] = [None] * n_pages
|
||||
page_results: List[Optional[PageOCRResult]] = [None] * n_pages
|
||||
block_results: dict = {} # (page_idx, block_idx) -> (raw, error, mean_p, skipped)
|
||||
|
||||
fallback_enabled = settings.SURYA_LAYOUT_FALLBACK_FULL_PAGE
|
||||
|
||||
with ThreadPoolExecutor(max_workers=n_workers) as executor:
|
||||
# ---- Phase 1: submit all layout requests ----
|
||||
guided = LAYOUT_JSON_SCHEMA if settings.SURYA_GUIDED_LAYOUT else None
|
||||
layout_futures = {}
|
||||
for page_idx, lo_img in enumerate(layout_images):
|
||||
item = BatchInputItem(
|
||||
image=lo_img,
|
||||
prompt_type=PROMPT_TYPE_LAYOUT,
|
||||
max_tokens=settings.SURYA_MAX_TOKENS_LAYOUT,
|
||||
guided_json=guided,
|
||||
metadata={"page_idx": page_idx},
|
||||
)
|
||||
fut = executor.submit(_generate_one, item, **common_kwargs)
|
||||
layout_futures[fut] = page_idx
|
||||
|
||||
# ---- Phase 2: as layouts land, submit blocks (LPT sort) or fallback ----
|
||||
block_futures = {}
|
||||
fallback_futures = {}
|
||||
|
||||
def _submit_fallback(page_idx: int):
|
||||
"""Schedule a HIGH_ACCURACY_BBOX_PROMPT request for this page."""
|
||||
item = BatchInputItem(
|
||||
image=ocr_images[page_idx],
|
||||
prompt_type=PROMPT_TYPE_HIGH_ACCURACY_BBOX,
|
||||
max_tokens=settings.SURYA_MAX_TOKENS_FULL_PAGE,
|
||||
metadata={"page_idx": page_idx},
|
||||
)
|
||||
fb_fut = executor.submit(_generate_one, item, **common_kwargs)
|
||||
fallback_futures[fb_fut] = page_idx
|
||||
|
||||
for fut in as_completed(layout_futures):
|
||||
page_idx = layout_futures[fut]
|
||||
try:
|
||||
gen = fut.result()
|
||||
except Exception as e:
|
||||
logger.warning(f"Layout request failed for page {page_idx}: {e}")
|
||||
if fallback_enabled:
|
||||
_submit_fallback(page_idx)
|
||||
else:
|
||||
w, h = ocr_images[page_idx].size
|
||||
layout_results[page_idx] = LayoutResult(
|
||||
bboxes=[],
|
||||
image_bbox=[0, 0, float(w), float(h)],
|
||||
raw=None,
|
||||
error=True,
|
||||
)
|
||||
continue
|
||||
|
||||
target_size = ocr_images[page_idx].size
|
||||
layout_result = _make_layout_result(
|
||||
raw=gen.raw,
|
||||
error=gen.error,
|
||||
mean_token_prob=gen.mean_token_prob,
|
||||
target_size=target_size,
|
||||
)
|
||||
if layout_result.error:
|
||||
if fallback_enabled:
|
||||
logger.info(
|
||||
f"Layout failed for page {page_idx}, falling back to full-page"
|
||||
)
|
||||
_submit_fallback(page_idx)
|
||||
continue
|
||||
layout_results[page_idx] = layout_result
|
||||
continue
|
||||
|
||||
layout_results[page_idx] = layout_result
|
||||
|
||||
# Per-page LPT: reverse-sort blocks by count
|
||||
blocks_sorted = sorted(
|
||||
enumerate(layout_result.bboxes),
|
||||
key=lambda kv: -kv[1].count,
|
||||
)
|
||||
for block_idx, blk in blocks_sorted:
|
||||
if blk.label in SKIP_CANON_LABELS:
|
||||
block_results[(page_idx, block_idx)] = ("", False, None, True)
|
||||
continue
|
||||
crop = _crop_block(ocr_images[page_idx], blk.polygon)
|
||||
max_tokens = image_token_budget(
|
||||
blk.count, ceiling=settings.SURYA_MAX_TOKENS_BLOCK_CEILING
|
||||
)
|
||||
block_item = BatchInputItem(
|
||||
image=crop,
|
||||
prompt_type=PROMPT_TYPE_BLOCK,
|
||||
max_tokens=max_tokens,
|
||||
metadata={"page_idx": page_idx, "block_idx": block_idx},
|
||||
)
|
||||
bfut = executor.submit(_generate_one, block_item, **common_kwargs)
|
||||
block_futures[bfut] = (page_idx, block_idx)
|
||||
|
||||
# ---- Phase 3: drain block futures ----
|
||||
for fut in as_completed(block_futures):
|
||||
key = block_futures[fut]
|
||||
try:
|
||||
gen = fut.result()
|
||||
except Exception as e:
|
||||
logger.warning(f"Block request failed for {key}: {e}")
|
||||
block_results[key] = ("", True, None, False)
|
||||
continue
|
||||
block_results[key] = (gen.raw, gen.error, gen.mean_token_prob, False)
|
||||
|
||||
# ---- Phase 4: drain fallback futures ----
|
||||
for fut in as_completed(fallback_futures):
|
||||
page_idx = fallback_futures[fut]
|
||||
target_size = ocr_images[page_idx].size
|
||||
page_bbox = [0, 0, float(target_size[0]), float(target_size[1])]
|
||||
try:
|
||||
gen = fut.result()
|
||||
except Exception as e:
|
||||
logger.warning(f"Fallback request failed for page {page_idx}: {e}")
|
||||
layout_results[page_idx] = LayoutResult(
|
||||
bboxes=[],
|
||||
image_bbox=page_bbox,
|
||||
raw=None,
|
||||
error=True,
|
||||
)
|
||||
page_results[page_idx] = PageOCRResult(blocks=[], image_bbox=page_bbox)
|
||||
continue
|
||||
if gen.error or not gen.raw:
|
||||
layout_results[page_idx] = LayoutResult(
|
||||
bboxes=[],
|
||||
image_bbox=page_bbox,
|
||||
raw=gen.raw,
|
||||
error=True,
|
||||
)
|
||||
page_results[page_idx] = PageOCRResult(blocks=[], image_bbox=page_bbox)
|
||||
continue
|
||||
try:
|
||||
layout, page = _full_page_to_results(
|
||||
gen.raw, gen.mean_token_prob, target_size
|
||||
)
|
||||
layout_results[page_idx] = layout
|
||||
page_results[page_idx] = page
|
||||
except Exception as e:
|
||||
logger.warning(f"Fallback parse failed for page {page_idx}: {e}")
|
||||
layout_results[page_idx] = LayoutResult(
|
||||
bboxes=[],
|
||||
image_bbox=page_bbox,
|
||||
raw=gen.raw,
|
||||
error=True,
|
||||
)
|
||||
page_results[page_idx] = PageOCRResult(blocks=[], image_bbox=page_bbox)
|
||||
|
||||
# ---- Assemble PageOCRResult for non-fallback pages ----
|
||||
for page_idx, layout_result in enumerate(layout_results):
|
||||
if page_results[page_idx] is not None:
|
||||
# Already populated by fallback path
|
||||
continue
|
||||
if layout_result is None:
|
||||
w, h = ocr_images[page_idx].size
|
||||
page_results[page_idx] = PageOCRResult(
|
||||
blocks=[], image_bbox=[0, 0, float(w), float(h)]
|
||||
)
|
||||
continue
|
||||
blocks: List[BlockOCRResult] = []
|
||||
for block_idx, blk in enumerate(layout_result.bboxes):
|
||||
entry = block_results.get((page_idx, block_idx))
|
||||
if entry is None:
|
||||
blocks.append(
|
||||
BlockOCRResult(
|
||||
polygon=blk.polygon,
|
||||
label=blk.label,
|
||||
raw_label=blk.raw_label,
|
||||
reading_order=blk.position,
|
||||
html="",
|
||||
error=True,
|
||||
confidence=0.0,
|
||||
)
|
||||
)
|
||||
continue
|
||||
raw, error, mean_p, skipped = entry
|
||||
blocks.append(_make_block_result(blk, raw, error, mean_p, skipped))
|
||||
w, h = ocr_images[page_idx].size
|
||||
page_results[page_idx] = PageOCRResult(
|
||||
blocks=blocks, image_bbox=[0, 0, float(w), float(h)]
|
||||
)
|
||||
|
||||
return list(layout_results), list(page_results) # type: ignore[arg-type]
|
||||
@ -67,40 +67,6 @@ ALLOWED_ATTRIBUTES = [
|
||||
# Block labels we don't run OCR on.
|
||||
SKIP_OCR_LABELS = {"Figure", "Image", "Diagram", "Blank-Page"}
|
||||
|
||||
LAYOUT_LABELS = """- Caption
|
||||
- Footnote
|
||||
- Equation-Block
|
||||
- List-Group
|
||||
- Page-Header
|
||||
- Page-Footer
|
||||
- Image
|
||||
- Section-Header
|
||||
- Table
|
||||
- Text
|
||||
- Complex-Block
|
||||
- Code-Block
|
||||
- Form
|
||||
- Table-Of-Contents
|
||||
- Figure
|
||||
- Chemical-Block
|
||||
- Diagram
|
||||
- Bibliography
|
||||
- Blank-Page"""
|
||||
|
||||
GUIDELINES = f"""Only use these tags {ALLOWED_TAGS}, and these attributes {ALLOWED_ATTRIBUTES}.
|
||||
|
||||
Guidelines:
|
||||
* Inline math: Surround math with <math>...</math> tags. Math expressions should be rendered in KaTeX-compatible LaTeX. Use display for block math.
|
||||
* Tables: Use colspan and rowspan attributes to match table structure.
|
||||
* Formatting: Maintain consistent formatting with the image, including spacing, indentation, subscripts/superscripts, and special characters.
|
||||
* Images: Include a description of any images in the alt attribute of an <img> tag. Do not fill out the src property. Describe in detail inside the div tag. Also convert charts to high fidelity data, and convert diagrams to mermaid.
|
||||
* Forms: Mark checkboxes and radio buttons properly.
|
||||
* Text: join lines together properly into paragraphs using <p>...</p> tags. Use <br> tags for line breaks within paragraphs, but only when absolutely necessary to maintain meaning.
|
||||
* Chemistry: Use <chem>...</chem> tags for chemical formulas with reactive SMILES.
|
||||
* Lists: Preserve indents and proper list markers.
|
||||
* Use the simplest possible HTML structure that accurately represents the content of the block.
|
||||
* Make sure the text is accurate and easy for a human to read and interpret. Reading order should be correct and natural."""
|
||||
|
||||
LAYOUT_PROMPT = (
|
||||
"Output the layout of this image as JSON. Each entry is a dict with "
|
||||
'"label", "bbox", and "count" fields. Bbox is x0 y0 x1 y1, normalized 0-1000.'
|
||||
|
||||
@ -7,7 +7,6 @@ from surya.settings import settings
|
||||
import os
|
||||
import filetype
|
||||
from PIL import Image
|
||||
import json
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
@ -76,9 +75,3 @@ def load_from_folder(
|
||||
logger.warning(f"Could not load image {path}")
|
||||
continue
|
||||
return images, names
|
||||
|
||||
|
||||
def load_lang_file(lang_path, names):
|
||||
with open(lang_path, "r") as f:
|
||||
lang_dict = json.load(f)
|
||||
return [lang_dict[name].copy() for name in names]
|
||||
|
||||
@ -1,24 +1,9 @@
|
||||
from typing import List
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pypdfium2
|
||||
from PIL import Image
|
||||
|
||||
from surya.logging import get_logger
|
||||
from surya.settings import settings
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def convert_if_not_rgb(images: List[Image.Image]) -> List[Image.Image]:
|
||||
new_images = []
|
||||
for image in images:
|
||||
if image.mode != "RGB":
|
||||
image = image.convert("RGB")
|
||||
new_images.append(image)
|
||||
return new_images
|
||||
|
||||
|
||||
def open_pdf(pdf_filepath):
|
||||
return pypdfium2.PdfDocument(pdf_filepath)
|
||||
@ -30,72 +15,3 @@ def get_page_images(doc, indices: List, dpi=settings.IMAGE_DPI):
|
||||
]
|
||||
images = [image.convert("RGB") for image in images]
|
||||
return images
|
||||
|
||||
|
||||
def slice_bboxes_from_image(image: np.ndarray, bboxes):
|
||||
lines = []
|
||||
for bbox in bboxes:
|
||||
bbox = np.array(bbox, dtype=np.int32)
|
||||
bbox = np.clip(bbox, 0, None) # Ensure no negative indices
|
||||
# Ensure bbox is within the image bounds
|
||||
if bbox[3] <= bbox[1]:
|
||||
bbox[3] = bbox[1] + 1
|
||||
|
||||
if bbox[2] <= bbox[0]:
|
||||
bbox[2] = bbox[0] + 1
|
||||
|
||||
bbox[2] = min(bbox[2], image.shape[1])
|
||||
bbox[3] = min(bbox[3], image.shape[0])
|
||||
|
||||
line = image[bbox[1] : bbox[3], bbox[0] : bbox[2]].copy()
|
||||
if line.size == 0:
|
||||
logger.warning(f"Warning: found an empty line with bbox {bbox}")
|
||||
lines.append(line)
|
||||
return lines
|
||||
|
||||
|
||||
def slice_polys_from_image(image: np.ndarray, polys):
|
||||
lines = []
|
||||
for idx, poly in enumerate(polys):
|
||||
lines.append(slice_and_pad_poly(image, poly))
|
||||
return lines
|
||||
|
||||
|
||||
def slice_and_pad_poly(image_array: np.array, coordinates):
|
||||
# Draw polygon onto mask
|
||||
coordinates = [(corner[0], corner[1]) for corner in coordinates]
|
||||
bbox = [
|
||||
min([x[0] for x in coordinates]),
|
||||
min([x[1] for x in coordinates]),
|
||||
max([x[0] for x in coordinates]),
|
||||
max([x[1] for x in coordinates]),
|
||||
]
|
||||
|
||||
# We mask out anything not in the polygon
|
||||
cropped_polygon = image_array[bbox[1] : bbox[3], bbox[0] : bbox[2]].copy()
|
||||
height, width = cropped_polygon.shape[:2]
|
||||
|
||||
coordinates = [(x - bbox[0], y - bbox[1]) for x, y in coordinates]
|
||||
|
||||
# Validate the cropped area
|
||||
if any(
|
||||
[
|
||||
bbox[3] <= bbox[1] or bbox[2] <= bbox[0],
|
||||
len(coordinates) < 3,
|
||||
height == 0,
|
||||
width == 0,
|
||||
]
|
||||
):
|
||||
return cropped_polygon
|
||||
|
||||
# Pad the area outside the polygon with the pad value
|
||||
try:
|
||||
mask = np.zeros(cropped_polygon.shape[:2], dtype=np.uint8)
|
||||
cv2.fillPoly(mask, [np.int32(coordinates)], 1)
|
||||
mask = np.stack([mask] * 3, axis=-1)
|
||||
|
||||
cropped_polygon[mask == 0] = 255 # white pad
|
||||
except cv2.error as e:
|
||||
logger.warning(f"Warning: issue while processing polygon: {e}")
|
||||
|
||||
return cropped_polygon
|
||||
|
||||
@ -1,29 +1,6 @@
|
||||
"""Surya2 layout labels emitted by the model + canonicalization to surya's
|
||||
public label vocabulary."""
|
||||
|
||||
# Labels the model emits via LAYOUT_PROMPT (see surya/inference/prompts.py).
|
||||
LAYOUT_LABELS = (
|
||||
"Caption",
|
||||
"Footnote",
|
||||
"Equation-Block",
|
||||
"List-Group",
|
||||
"Page-Header",
|
||||
"Page-Footer",
|
||||
"Image",
|
||||
"Section-Header",
|
||||
"Table",
|
||||
"Text",
|
||||
"Complex-Block",
|
||||
"Code-Block",
|
||||
"Form",
|
||||
"Table-Of-Contents",
|
||||
"Figure",
|
||||
"Chemical-Block",
|
||||
"Diagram",
|
||||
"Bibliography",
|
||||
"Blank-Page",
|
||||
)
|
||||
|
||||
# Canonicalize raw model labels to public surya label names. Marker and other
|
||||
# downstream consumers depend on these names.
|
||||
LAYOUT_PRED_RELABEL = {
|
||||
|
||||
@ -1,35 +0,0 @@
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from surya.detection import DetectionPredictor
|
||||
from surya.inference import SuryaInferenceManager
|
||||
from surya.layout import LayoutPredictor
|
||||
from surya.logging import configure_logging
|
||||
from surya.ocr_error import OCRErrorPredictor
|
||||
from surya.recognition import RecognitionPredictor
|
||||
from surya.table_rec import TableRecPredictor
|
||||
|
||||
configure_logging()
|
||||
|
||||
|
||||
def load_predictors(
|
||||
device: str | torch.device | None = None,
|
||||
dtype: torch.dtype | str | None = None,
|
||||
manager: Optional[SuryaInferenceManager] = None,
|
||||
) -> Dict[str, object]:
|
||||
"""Build the standard surya predictor set.
|
||||
|
||||
The VLM-backed predictors (layout, recognition, table_rec) share a single
|
||||
SuryaInferenceManager. Detection and OCR error keep their own torch models.
|
||||
"""
|
||||
if manager is None:
|
||||
manager = SuryaInferenceManager(lazy=True)
|
||||
return {
|
||||
"layout": LayoutPredictor(manager),
|
||||
"recognition": RecognitionPredictor(manager),
|
||||
"table_rec": TableRecPredictor(manager),
|
||||
"detection": DetectionPredictor(device=device, dtype=dtype),
|
||||
"ocr_error": OCRErrorPredictor(device=device, dtype=dtype),
|
||||
"manager": manager,
|
||||
}
|
||||
@ -1,18 +1,20 @@
|
||||
import collections
|
||||
import os
|
||||
import json
|
||||
import unicodedata
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from tokenizers import normalizers
|
||||
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace
|
||||
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
|
||||
from transformers.tokenization_utils import (
|
||||
PreTrainedTokenizer,
|
||||
_is_control,
|
||||
_is_punctuation,
|
||||
_is_whitespace,
|
||||
)
|
||||
|
||||
from surya.common.s3 import S3DownloaderMixin
|
||||
|
||||
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.tokenization_bert.load_vocab
|
||||
def load_vocab(vocab_file):
|
||||
"""Loads a vocabulary file into a dictionary."""
|
||||
@ -101,7 +103,9 @@ class DistilBertTokenizer(S3DownloaderMixin, PreTrainedTokenizer):
|
||||
" model use `tokenizer = DistilBertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
|
||||
)
|
||||
self.vocab = load_vocab(vocab_file)
|
||||
self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
|
||||
self.ids_to_tokens = collections.OrderedDict(
|
||||
[(ids, tok) for tok, ids in self.vocab.items()]
|
||||
)
|
||||
self.do_basic_tokenize = do_basic_tokenize
|
||||
if do_basic_tokenize:
|
||||
self.basic_tokenizer = BasicTokenizer(
|
||||
@ -110,7 +114,9 @@ class DistilBertTokenizer(S3DownloaderMixin, PreTrainedTokenizer):
|
||||
tokenize_chinese_chars=tokenize_chinese_chars,
|
||||
strip_accents=strip_accents,
|
||||
)
|
||||
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=str(unk_token))
|
||||
self.wordpiece_tokenizer = WordpieceTokenizer(
|
||||
vocab=self.vocab, unk_token=str(unk_token)
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
do_lower_case=do_lower_case,
|
||||
@ -145,7 +151,10 @@ class DistilBertTokenizer(S3DownloaderMixin, PreTrainedTokenizer):
|
||||
split_tokens = []
|
||||
if self.do_basic_tokenize:
|
||||
for token in self.basic_tokenizer.tokenize(
|
||||
text, never_split=self.all_special_tokens if not split_special_tokens else None
|
||||
text,
|
||||
never_split=self.all_special_tokens
|
||||
if not split_special_tokens
|
||||
else None,
|
||||
):
|
||||
# If the token is part of the never_split set
|
||||
if token in self.basic_tokenizer.never_split:
|
||||
@ -200,7 +209,10 @@ class DistilBertTokenizer(S3DownloaderMixin, PreTrainedTokenizer):
|
||||
|
||||
# Copied from transformers.models.bert.tokenization_bert.BertTokenizer.get_special_tokens_mask
|
||||
def get_special_tokens_mask(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
||||
self,
|
||||
token_ids_0: List[int],
|
||||
token_ids_1: Optional[List[int]] = None,
|
||||
already_has_special_tokens: bool = False,
|
||||
) -> List[int]:
|
||||
"""
|
||||
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
|
||||
@ -220,7 +232,9 @@ class DistilBertTokenizer(S3DownloaderMixin, PreTrainedTokenizer):
|
||||
|
||||
if already_has_special_tokens:
|
||||
return super().get_special_tokens_mask(
|
||||
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
|
||||
token_ids_0=token_ids_0,
|
||||
token_ids_1=token_ids_1,
|
||||
already_has_special_tokens=True,
|
||||
)
|
||||
|
||||
if token_ids_1 is not None:
|
||||
@ -258,14 +272,20 @@ class DistilBertTokenizer(S3DownloaderMixin, PreTrainedTokenizer):
|
||||
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
|
||||
|
||||
# Copied from transformers.models.bert.tokenization_bert.BertTokenizer.save_vocabulary
|
||||
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||
def save_vocabulary(
|
||||
self, save_directory: str, filename_prefix: Optional[str] = None
|
||||
) -> Tuple[str]:
|
||||
index = 0
|
||||
if os.path.isdir(save_directory):
|
||||
vocab_file = os.path.join(
|
||||
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
||||
save_directory,
|
||||
(filename_prefix + "-" if filename_prefix else "")
|
||||
+ VOCAB_FILES_NAMES["vocab_file"],
|
||||
)
|
||||
else:
|
||||
vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory
|
||||
vocab_file = (
|
||||
filename_prefix + "-" if filename_prefix else ""
|
||||
) + save_directory
|
||||
with open(vocab_file, "w", encoding="utf-8") as writer:
|
||||
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
|
||||
if index != token_index:
|
||||
@ -329,7 +349,11 @@ class BasicTokenizer(object):
|
||||
[`PreTrainedTokenizer.tokenize`]) List of token not to split.
|
||||
"""
|
||||
# union() returns a new set by concatenating the two sets.
|
||||
never_split = self.never_split.union(set(never_split)) if never_split else self.never_split
|
||||
never_split = (
|
||||
self.never_split.union(set(never_split))
|
||||
if never_split
|
||||
else self.never_split
|
||||
)
|
||||
text = self._clean_text(text)
|
||||
|
||||
# This was added on November 1st, 2018 for the multilingual and Chinese
|
||||
@ -370,7 +394,9 @@ class BasicTokenizer(object):
|
||||
|
||||
def _run_split_on_punc(self, text, never_split=None):
|
||||
"""Splits punctuation on a piece of text."""
|
||||
if not self.do_split_on_punc or (never_split is not None and text in never_split):
|
||||
if not self.do_split_on_punc or (
|
||||
never_split is not None and text in never_split
|
||||
):
|
||||
return [text]
|
||||
chars = list(text)
|
||||
i = 0
|
||||
@ -496,4 +522,4 @@ class WordpieceTokenizer(object):
|
||||
output_tokens.append(self.unk_token)
|
||||
else:
|
||||
output_tokens.extend(sub_tokens)
|
||||
return output_tokens
|
||||
return output_tokens
|
||||
|
||||
@ -24,9 +24,7 @@ from surya.layout.schema import LayoutResult
|
||||
from surya.logging import get_logger
|
||||
from surya.recognition.schema import (
|
||||
BlockOCRResult,
|
||||
OCRResult,
|
||||
PageOCRResult,
|
||||
TextLine,
|
||||
)
|
||||
from surya.settings import settings
|
||||
|
||||
@ -251,23 +249,3 @@ class RecognitionPredictor:
|
||||
)
|
||||
results.append(PageOCRResult(blocks=blocks, image_bbox=page_bbox))
|
||||
return results
|
||||
|
||||
def to_legacy_ocr_results(
|
||||
self, page_results: List[PageOCRResult]
|
||||
) -> List[OCRResult]:
|
||||
"""Compatibility shim: map BlockOCRResult → OCRResult.text_lines for old
|
||||
downstream code that hasn't migrated yet. One TextLine per block, no chars."""
|
||||
out: List[OCRResult] = []
|
||||
for page in page_results:
|
||||
lines: List[TextLine] = []
|
||||
for blk in page.blocks:
|
||||
lines.append(
|
||||
TextLine(
|
||||
polygon=blk.polygon,
|
||||
text=blk.html,
|
||||
chars=[],
|
||||
confidence=blk.confidence,
|
||||
)
|
||||
)
|
||||
out.append(OCRResult(text_lines=lines, image_bbox=page.image_bbox))
|
||||
return out
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Any, List, Optional
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -12,30 +12,8 @@ class BlockOCRResult(PolygonBox):
|
||||
html: str = "" # block HTML (BLOCK_PROMPT output, "" if skipped)
|
||||
skipped: bool = False # True if label was in SKIP_OCR_LABELS
|
||||
error: bool = False
|
||||
char_confidences: Optional[List[float]] = None # phase 2
|
||||
raw_logprobs: Optional[List[Any]] = None # phase 2 debugging
|
||||
|
||||
|
||||
class PageOCRResult(BaseModel):
|
||||
blocks: List[BlockOCRResult]
|
||||
image_bbox: List[float]
|
||||
|
||||
|
||||
# ---- Back-compat shims for code paths that still expect text_lines ----
|
||||
# These are intentionally minimal; downstream consumers should migrate to
|
||||
# BlockOCRResult / PageOCRResult.
|
||||
|
||||
|
||||
class TextChar(BaseModel):
|
||||
text: str
|
||||
confidence: float = 0.0
|
||||
|
||||
|
||||
class TextLine(PolygonBox):
|
||||
text: str = ""
|
||||
chars: List[TextChar] = []
|
||||
|
||||
|
||||
class OCRResult(BaseModel):
|
||||
text_lines: List[TextLine]
|
||||
image_bbox: List[float]
|
||||
|
||||
@ -15,7 +15,6 @@ class Settings(BaseSettings):
|
||||
IMAGE_DPI: int = 96 # used for layout, recognition, and table rec
|
||||
IMAGE_DPI_HIGHRES: int = 192
|
||||
IN_STREAMLIT: bool = False
|
||||
FLATTEN_PDF: bool = True
|
||||
DISABLE_TQDM: bool = False
|
||||
S3_BASE_URL: str = "https://models.datalab.to"
|
||||
PARALLEL_DOWNLOAD_WORKERS: int = 10
|
||||
@ -23,7 +22,6 @@ class Settings(BaseSettings):
|
||||
LOGLEVEL: str = "INFO"
|
||||
|
||||
# Paths
|
||||
DATA_DIR: str = "data"
|
||||
RESULT_DIR: str = "results"
|
||||
BASE_DIR: str = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
FONT_DIR: str = os.path.join(BASE_DIR, "static", "fonts")
|
||||
@ -39,7 +37,7 @@ class Settings(BaseSettings):
|
||||
return "cpu"
|
||||
|
||||
# ---- Surya2 inference (VLM-backed: vllm | llamacpp) ---------------------
|
||||
SURYA_MODEL_CHECKPOINT: str = "datalab-to/surya-2.1.2"
|
||||
SURYA_MODEL_CHECKPOINT: str = "datalab-to/surya-2.1.2-mtp"
|
||||
SURYA_GGUF_REPO: str = "datalab-to/surya-ocr-2-gguf"
|
||||
SURYA_GGUF_MODEL_FILE: str = "surya-2.gguf"
|
||||
SURYA_GGUF_MMPROJ_FILE: str = "surya-2-mmproj.gguf"
|
||||
@ -90,7 +88,7 @@ class Settings(BaseSettings):
|
||||
VLLM_MAX_MODEL_LEN: int = 18000
|
||||
VLLM_GPU_MEMORY_UTILIZATION: float = 0.85
|
||||
VLLM_ENABLE_MTP: bool = True
|
||||
VLLM_MTP_TOKENS: int = 3
|
||||
VLLM_MTP_TOKENS: int = 2
|
||||
VLLM_EXTRA_ARGS: Optional[str] = None
|
||||
DOCKER_HF_CACHE_PATH: str = "~/.cache/huggingface"
|
||||
|
||||
@ -103,7 +101,6 @@ class Settings(BaseSettings):
|
||||
# ---- Detection (kept) ---------------------------------------------------
|
||||
DETECTOR_BATCH_SIZE: Optional[int] = None
|
||||
DETECTOR_MODEL_CHECKPOINT: str = "s3://text_detection/2025_05_07"
|
||||
DETECTOR_BENCH_DATASET_NAME: str = "vikp/doclaynet_bench"
|
||||
DETECTOR_IMAGE_CHUNK_HEIGHT: int = 1400
|
||||
DETECTOR_TEXT_THRESHOLD: float = 0.6
|
||||
DETECTOR_BLANK_THRESHOLD: float = 0.35
|
||||
@ -115,11 +112,6 @@ class Settings(BaseSettings):
|
||||
OCR_ERROR_MODEL_CHECKPOINT: str = "s3://ocr_error_detection/2025_02_18"
|
||||
OCR_ERROR_BATCH_SIZE: Optional[int] = None
|
||||
|
||||
# ---- Layout / Recognition / Table-rec batch sizes (CLI scripts) -------
|
||||
LAYOUT_BATCH_SIZE: Optional[int] = None
|
||||
RECOGNITION_BATCH_SIZE: Optional[int] = None
|
||||
TABLE_REC_BATCH_SIZE: Optional[int] = None
|
||||
|
||||
# ---- Debug / draw fonts (label rendering on annotated images) ----------
|
||||
RECOGNITION_RENDER_FONTS: Dict[str, str] = {
|
||||
"all": os.path.join(FONT_DIR, "GoNotoCurrent-Regular.ttf"),
|
||||
@ -131,9 +123,6 @@ class Settings(BaseSettings):
|
||||
"https://github.com/satbyy/go-noto-universal/releases/download/v7.0"
|
||||
)
|
||||
|
||||
# ---- Tesseract ---------------------------------------------------------
|
||||
TESSDATA_PREFIX: Optional[str] = None
|
||||
|
||||
@computed_field
|
||||
def MODEL_DTYPE(self) -> torch.dtype:
|
||||
if self.TORCH_DEVICE_MODEL == "cpu":
|
||||
|
||||
Loading…
Reference in New Issue
Block a user