Heavy cleanups

This commit is contained in:
Vik Paruchuri 2026-05-14 15:50:11 -04:00
parent 8e40566be0
commit 05874ae9ad
19 changed files with 79 additions and 1078 deletions

View File

@ -1 +0,0 @@
{"sessionId":"4129570b-108b-41b8-87b0-725c0b77135e","pid":5040,"procStart":"Tue May 5 13:42:45 2026","acquiredAt":1778008965041}

View File

@ -6,6 +6,6 @@ authors:
given-names: Vikas
- name: Datalab Team
date-released: 2025-05-13
url: https://github.com/VikParuchuri/surya
version: 0.14.0
repository-code: https://github.com/VikParuchuri/surya
url: https://github.com/datalab-to/surya
version: 1.0.0
repository-code: https://github.com/datalab-to/surya

View File

@ -1,11 +1,11 @@
[project]
name = "surya-ocr"
version = "2.0.0a0"
description = "OCR, layout, reading order, and table recognition via VLM (vllm + llama.cpp)."
version = "1.0.0a0"
description = "OCR, layout, reading order, and table recognition in 90+ languages."
readme = "README.md"
license = { text = "Apache-2.0" }
authors = [
{ name = "Vik Paruchuri", email = "vik.paruchuri@gmail.com" },
{ name = "Vik Paruchuri", email = "vik@datalab.to" },
]
requires-python = ">=3.10,<4"
keywords = ["ocr", "pdf", "text detection", "text recognition", "tables"]
@ -30,7 +30,7 @@ dependencies = [
]
[project.urls]
Repository = "https://github.com/VikParuchuri/surya"
Repository = "https://github.com/datalab-to/surya"
[project.scripts]
surya_detect = "surya.scripts.detect_text:detect_text_cli"

View File

@ -80,23 +80,6 @@ class PolygonBox(BaseModel):
corner[1] = max(min(corner[1], bounds[3]), bounds[1])
self.polygon = new_corners
def merge(self, other):
x1 = min(self.bbox[0], other.bbox[0])
y1 = min(self.bbox[1], other.bbox[1])
x2 = max(self.bbox[2], other.bbox[2])
y2 = max(self.bbox[3], other.bbox[3])
self.polygon = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]]
def merge_left(self, other):
x1 = min(self.bbox[0], other.bbox[0])
self.polygon[0][0] = x1
self.polygon[3][0] = x1
def merge_right(self, other):
x2 = max(self.bbox[2], other.bbox[2])
self.polygon[1][0] = x2
self.polygon[2][0] = x2
def expand(self, x_margin: float, y_margin: float):
new_polygon = []
x_margin = x_margin * self.width
@ -112,33 +95,6 @@ class PolygonBox(BaseModel):
new_polygon.append([int(poly[0] - x_margin), int(poly[1] + y_margin)])
self.polygon = new_polygon
def intersection_polygon(self, other) -> List[List[float]]:
new_poly = []
for i in range(4):
if i == 0:
new_corner = [
max(self.polygon[0][0], other.polygon[0][0]),
max(self.polygon[0][1], other.polygon[0][1]),
]
elif i == 1:
new_corner = [
min(self.polygon[1][0], other.polygon[1][0]),
max(self.polygon[1][1], other.polygon[1][1]),
]
elif i == 2:
new_corner = [
min(self.polygon[2][0], other.polygon[2][0]),
min(self.polygon[2][1], other.polygon[2][1]),
]
elif i == 3:
new_corner = [
max(self.polygon[3][0], other.polygon[3][0]),
min(self.polygon[3][1], other.polygon[3][1]),
]
new_poly.append(new_corner)
return new_poly
def intersection_area(self, other, x_margin=0, y_margin=0):
x_overlap = self.x_overlap(other, x_margin)
y_overlap = self.y_overlap(other, y_margin)
@ -158,44 +114,9 @@ class PolygonBox(BaseModel):
- max(self.bbox[1] - y_margin, other.bbox[1] - y_margin),
)
def intersection_pct(self, other, x_margin=0, y_margin=0):
assert 0 <= x_margin <= 1
assert 0 <= y_margin <= 1
if self.area == 0:
return 0
if x_margin:
x_margin = int(min(self.width, other.width) * x_margin)
if y_margin:
y_margin = int(min(self.height, other.height) * y_margin)
intersection = self.intersection_area(other, x_margin, y_margin)
return intersection / self.area
def shift(self, x_shift: float | None = None, y_shift: float | None = None):
if x_shift is not None:
for corner in self.polygon:
corner[0] += x_shift
if y_shift is not None:
for corner in self.polygon:
corner[1] += y_shift
def clamp(self, bbox: List[float]):
for corner in self.polygon:
corner[0] = max(min(corner[0], bbox[2]), bbox[0])
corner[1] = max(min(corner[1], bbox[3]), bbox[1])
@property
def center(self):
return [(self.bbox[0] + self.bbox[2]) / 2, (self.bbox[1] + self.bbox[3]) / 2]
def distance(self, other):
center = self.center
other_center = other.center
return (
(center[0] - other_center[0]) ** 2 + (center[1] - other_center[1]) ** 2
) ** 0.5
def __hash__(self):
return hash(tuple(self.bbox))

View File

@ -1,9 +1,4 @@
import copy
from typing import List
import torch
from functools import lru_cache
import torch.nn.functional as F
from surya.common.polygon import PolygonBox
@ -38,21 +33,6 @@ def clean_boxes(boxes: List[PolygonBox]) -> List[PolygonBox]:
return new_boxes
def rescale_bbox(bbox, processor_size, image_size):
page_width, page_height = processor_size
img_width, img_height = image_size
width_scaler = img_width / page_width
height_scaler = img_height / page_height
new_bbox = copy.deepcopy(bbox)
new_bbox[0] = int(new_bbox[0] * width_scaler)
new_bbox[1] = int(new_bbox[1] * height_scaler)
new_bbox[2] = int(new_bbox[2] * width_scaler)
new_bbox[3] = int(new_bbox[3] * height_scaler)
return new_bbox
def expand_bbox(bbox, expansion_factor=0.01):
expansion_low = 1 - expansion_factor
expansion_high = 1 + expansion_factor
@ -62,210 +42,3 @@ def expand_bbox(bbox, expansion_factor=0.01):
bbox[2] * expansion_high,
bbox[3] * expansion_high,
]
SCRIPT_TOKEN_MAPPING = {
"latin": "<SCRIPT-LATIN>",
"punctuation": "<SCRIPT-PUNCTUATION>",
"cyrillic": "<SCRIPT-CYRILLIC>",
"arabic": "<SCRIPT-ARABIC>",
"chinese": "<SCRIPT-CHINESE>",
"japanese": "<SCRIPT-JAPANESE>",
"korean": "<SCRIPT-KOREAN>",
"symbols": "<SCRIPT-SYMBOLS>",
"greek": "<SCRIPT-GREEK>",
"armenian": "<SCRIPT-ARMENIAN>",
"hebrew": "<SCRIPT-HEBREW>",
"devanagari": "<SCRIPT-DEVANAGARI>",
"bengali": "<SCRIPT-BENGALI>",
"gurmukhi": "<SCRIPT-GURMUKHI>",
"gujarati": "<SCRIPT-GUJARATI>",
"oriya": "<SCRIPT-ORIYA>",
"tamil": "<SCRIPT-TAMIL>",
"telugu": "<SCRIPT-TELUGU>",
"kannada": "<SCRIPT-KANNADA>",
"malayalam": "<SCRIPT-MALAYALAM>",
"sinhala": "<SCRIPT-SINHALA>",
"thai": "<SCRIPT-THAI>",
"lao": "<SCRIPT-LAO>",
"myanmar": "<SCRIPT-MYANMAR>",
"georgian": "<SCRIPT-GEORGIAN>",
"ethiopic": "<SCRIPT-ETHIOPIC>",
"khmer": "<SCRIPT-KHMER>",
"mongolian": "<SCRIPT-MONGOLIAN>",
"math": "<SCRIPT-MATH>",
}
@lru_cache(maxsize=1)
def script_ranges():
script_categories = {
# Latin-based scripts (used by English, French, German, etc.)
"latin": [
(0x0041, 0x005A), # Latin uppercase A-Z
(0x0061, 0x007A), # Latin lowercase a-z
(0x0080, 0x00FF), # Latin-1 Supplement
(0x0100, 0x017F), # Latin Extended-A
(0x0180, 0x024F), # Latin Extended-B
(0x0250, 0x02AF), # IPA Extensions
(0x02B0, 0x02FF), # Spacing Modifier Letters
(0x0300, 0x036F), # Combining Diacritical Marks
(0x1E00, 0x1EFF), # Latin Extended Additional
(0x2C60, 0x2C7F), # Latin Extended-C
(0xA720, 0xA7FF), # Latin Extended-D
],
# Punctuation, universal characters, and general symbols
"punctuation": [
(0x0020, 0x0020), # Space
(0x0021, 0x002F), # Basic punctuation and symbols
(0x0030, 0x0039), # Digits 0-9
(0x003A, 0x0040), # More punctuation and symbols
(0x005B, 0x0060), # More punctuation and symbols
(0x007B, 0x007F), # More punctuation and symbols
(0x2000, 0x206F), # General Punctuation
],
# Cyrillic scripts (used by Russian, Ukrainian, etc.)
"cyrillic": [
(0x0400, 0x04FF), # Cyrillic
(0x0500, 0x052F), # Cyrillic Supplement
],
# Arabic scripts
"arabic": [
(0x0600, 0x06FF), # Arabic
(0x0750, 0x077F), # Arabic Supplement
(0x08A0, 0x08FF), # Arabic Extended-A
],
# Chinese characters
"chinese": [
(0x4E00, 0x9FFF), # Common CJK Unified Ideographs
(0x3400, 0x4DBF), # CJK Extension A
(0x20000, 0x2A6DF), # CJK Extension B
],
# Japanese-specific scripts (excluding shared CJK)
"japanese": [
(0x3040, 0x30FF), # Hiragana and Katakana
],
# Korean-specific scripts
"korean": [
(0x1100, 0x11FF), # Hangul Jamo
(0x3130, 0x318F), # Hangul Compatibility Jamo
(0xAC00, 0xD7AF), # Hangul Syllables
],
# Various mathematical and technical symbols
"symbols": [
(0x2070, 0x209F), # Superscripts and Subscripts
(0x20A0, 0x20CF), # Currency Symbols
(0x2100, 0x214F), # Letterlike Symbols
(0x2150, 0x218F), # Number Forms
(0x2190, 0x21FF), # Arrows
(0x2200, 0x22FF), # Mathematical Operators
(0x2300, 0x23FF), # Miscellaneous Technical
(0x2500, 0x257F), # Box Drawing
(0x2580, 0x259F), # Block Elements
(0x25A0, 0x25FF), # Geometric Shapes
(0x2600, 0x26FF), # Miscellaneous Symbols
(0x2700, 0x27BF), # Dingbats
(0x27C0, 0x27EF), # Miscellaneous Mathematical Symbols-A
(0x2980, 0x29FF), # Miscellaneous Mathematical Symbols-B
(0x2A00, 0x2AFF), # Supplemental Mathematical Operators
(0x1D400, 0x1D7FF), # Mathematical Alphanumeric Symbols
],
# Individual scripts for languages with unique writing systems
"greek": [(0x0370, 0x03FF)], # Greek and Coptic
"armenian": [(0x0530, 0x058F)], # Armenian
"hebrew": [(0x0590, 0x05FF)], # Hebrew
"devanagari": [(0x0900, 0x097F)], # Devanagari (Hindi, Sanskrit)
"bengali": [(0x0980, 0x09FF)], # Bengali
"gurmukhi": [(0x0A00, 0x0A7F)], # Gurmukhi (Punjabi)
"gujarati": [(0x0A80, 0x0AFF)], # Gujarati
"oriya": [(0x0B00, 0x0B7F)], # Oriya
"tamil": [(0x0B80, 0x0BFF)], # Tamil
"telugu": [(0x0C00, 0x0C7F)], # Telugu
"kannada": [(0x0C80, 0x0CFF)], # Kannada
"malayalam": [(0x0D00, 0x0D7F)], # Malayalam
"sinhala": [(0x0D80, 0x0DFF)], # Sinhala
"thai": [(0x0E00, 0x0E7F)], # Thai
"lao": [(0x0E80, 0x0EFF)], # Lao
"myanmar": [(0x1000, 0x109F)], # Myanmar
"georgian": [(0x10A0, 0x10FF)], # Georgian
"ethiopic": [(0x1200, 0x137F)], # Ethiopic
"khmer": [(0x1780, 0x17FF)], # Khmer
"mongolian": [(0x1800, 0x18AF)], # Mongolian
}
# Convert to a flat structure with character ranges
flat_ranges = {}
for category, ranges in script_categories.items():
# Create a set of all characters in this category
char_set = set()
for start, end in ranges:
char_set.update(range(start, end + 1))
# Store the set in flat_ranges
flat_ranges[category] = char_set
return script_categories, flat_ranges
def get_top_scripts(text: str, max_scripts: int = 5):
script_categories, flat_ranges = script_ranges()
char_count = {category: 0 for category in script_categories.keys()}
for char in text:
for category, char_set in flat_ranges.items():
if ord(char) in char_set:
char_count[category] += 1
break
top_scripts = sorted(char_count.items(), key=lambda x: x[1], reverse=True)
top_scripts = [ts[0] for ts in top_scripts if ts[1] > 0]
if "<math" in text:
top_scripts.insert(0, "math")
return top_scripts[:max_scripts]
def is_flash_attn_2_supported(device: str | torch.device) -> bool:
if not torch.cuda.is_available():
return False
if "cuda" not in str(device):
return False
# Check CUDA version >= 12.0
cuda_version_str = torch.version.cuda
if cuda_version_str is None:
return False
cuda_version = tuple(map(int, cuda_version_str.split(".")))
if cuda_version < (12, 0):
return False
# Check GPU compute capability (Ampere, Ada, Hopper GPUs)
major, minor = torch.cuda.get_device_capability()
compute_capability = major + minor / 10
if compute_capability < 8.0:
return False
return True
def pad_to_batch_size_repeat(tensor: torch.Tensor, batch_size: int):
current_batch_size = tensor.shape[0]
if current_batch_size >= batch_size:
return tensor
pad_size = batch_size - current_batch_size
if pad_size < 0:
return tensor
# Repeat the last row pad_size times
last_row = tensor[-1:].repeat(pad_size, 1, 1)
# Concatenate original tensor with repeated last rows
return torch.cat([tensor, last_row], dim=0)
def pad_to_batch_size(tensor: torch.Tensor, batch_size: int):
current_batch_size = tensor.shape[0]
if current_batch_size >= batch_size:
return tensor
pad_size = batch_size - current_batch_size
padding = (0, 0) * (tensor.dim() - 1) + (0, pad_size)
return F.pad(tensor, padding, mode="constant", value=0)

View File

@ -1,24 +1,4 @@
import re
from io import BytesIO
from typing import List, Tuple
from PIL import Image, ImageDraw, ImageFont
from surya.debug.fonts import get_font_path
from surya.debug.render_html import render_text_as_html
try:
from playwright.sync_api import sync_playwright
has_playwright = True
except ImportError:
has_playwright = False
def strip_html_tags(html_text):
pattern = re.compile(r"<[\w/][^>]*>")
text_only = pattern.sub("", html_text)
return text_only
from PIL import Image, ImageDraw
def get_text_size(text, font):
@ -26,75 +6,3 @@ def get_text_size(text, font):
draw = ImageDraw.Draw(im)
_, _, width, height = draw.textbbox((0, 0), text=text, font=font)
return width, height
def render_text(draw, text, s_bbox, bbox_width, bbox_height, font_path, box_font_size):
font = ImageFont.truetype(font_path, box_font_size)
text_width, text_height = get_text_size(text, font)
while (text_width > bbox_width or text_height > bbox_height) and box_font_size > 6:
box_font_size = box_font_size - 1
font = ImageFont.truetype(font_path, box_font_size)
text_width, text_height = get_text_size(text, font)
# Calculate text position (centered in bbox)
text_width, text_height = get_text_size(text, font)
x = s_bbox[0]
y = s_bbox[1] + (bbox_height - text_height) / 2
draw.text((x, y), text, fill="black", font=font)
def draw_text_with_playwright(
bboxes, texts: List[str], image_size: Tuple[int, int]
) -> Image.Image:
html_content, image_size = render_text_as_html(bboxes, texts, image_size)
if not has_playwright:
raise ImportError(
"Playwright is not installed. Please install it using `pip install playwright`"
)
with sync_playwright() as p:
browser = p.chromium.launch(headless=True)
page = browser.new_page(
viewport={"width": image_size[0], "height": image_size[1]}
)
page.set_content(html_content)
page.wait_for_timeout(1000)
body = page.query_selector("body")
image = body.screenshot()
browser.close()
pil_img = Image.open(BytesIO(image))
return pil_img
def draw_text_on_image(
bboxes,
texts,
image_size: Tuple[int, int],
font_path=None,
max_font_size=60,
res_upscale=2,
) -> Image.Image:
if has_playwright:
return draw_text_with_playwright(bboxes, texts, image_size)
texts = [strip_html_tags(text) for text in texts]
if font_path is None:
font_path = get_font_path()
new_image_size = (image_size[0] * res_upscale, image_size[1] * res_upscale)
image = Image.new("RGB", new_image_size, color="white")
draw = ImageDraw.Draw(image)
for bbox, text in zip(bboxes, texts):
s_bbox = [int(coord * res_upscale) for coord in bbox]
bbox_width = s_bbox[2] - s_bbox[0]
bbox_height = s_bbox[3] - s_bbox[1]
# Shrink the text to fit in the bbox if needed
box_font_size = max(6, min(int(0.75 * bbox_height), max_font_size))
render_text(
draw, text, s_bbox, bbox_width, bbox_height, font_path, box_font_size
)
return image

View File

@ -19,7 +19,11 @@ from typing import Any, Dict, List, Optional, Union
import numpy as np
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from transformers.image_processing_utils import (
BaseImageProcessor,
BatchFeature,
get_size_dict,
)
from transformers.image_transforms import to_channel_dimension_format
from transformers.image_utils import (
IMAGENET_DEFAULT_MEAN,
@ -34,7 +38,6 @@ from transformers.utils import TensorType
import PIL.Image
import torch
from surya.common.s3 import S3DownloaderMixin
@ -107,7 +110,9 @@ class SegformerImageProcessor(S3DownloaderMixin, BaseImageProcessor):
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
self.image_mean = (
image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
)
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
self.do_reduce_labels = do_reduce_labels
self._valid_processor_keys = [
@ -152,12 +157,18 @@ class SegformerImageProcessor(S3DownloaderMixin, BaseImageProcessor):
image_std: Optional[Union[float, List[float]]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
):
if do_rescale:
image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
image = self.rescale(
image=image, scale=rescale_factor, input_data_format=input_data_format
)
if do_normalize:
image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
image = self.normalize(
image=image,
mean=image_mean,
std=image_std,
input_data_format=input_data_format,
)
return image
@ -193,7 +204,9 @@ class SegformerImageProcessor(S3DownloaderMixin, BaseImageProcessor):
input_data_format=input_data_format,
)
if data_format is not None:
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
image = to_channel_dimension_format(
image, data_format, input_channel_dim=input_data_format
)
return image
def __call__(self, images, segmentation_maps=None, **kwargs):
@ -276,7 +289,9 @@ class SegformerImageProcessor(S3DownloaderMixin, BaseImageProcessor):
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
resample = resample if resample is not None else self.resample
size = size if size is not None else self.size
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
rescale_factor = (
rescale_factor if rescale_factor is not None else self.rescale_factor
)
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
@ -299,4 +314,4 @@ class SegformerImageProcessor(S3DownloaderMixin, BaseImageProcessor):
]
data = {"pixel_values": images}
return BatchFeature(data=data, tensor_type=return_tensors)
return BatchFeature(data=data, tensor_type=return_tensors)

View File

@ -3,7 +3,7 @@
One process owns one SuryaInferenceManager. The manager wraps a single backend
(vllm | llamacpp) which speaks OpenAI-compatible chat completions.
Predictors take the manager via explicit injection (see surya/models.py).
Predictors take the manager via explicit injection at construction time.
"""
from __future__ import annotations

View File

@ -1,4 +1,4 @@
"""vllm backend: spawns the vllm/vllm-openai docker image with MTP=3."""
"""vllm backend: spawns the vllm/vllm-openai docker image with MTP=2."""
from __future__ import annotations
@ -80,7 +80,6 @@ class VllmBackend(Backend):
def __init__(self):
self.handle: Optional[ServerHandle] = None
self._client: Optional[OpenAI] = None
self._container_name: Optional[str] = None
def start(self) -> ServerHandle:
if self.handle is not None:
@ -113,7 +112,6 @@ class VllmBackend(Backend):
def spawn_fn(port: int) -> SpawnHandle:
container_name = f"surya-vllm-{port}"
self._container_name = container_name
hf_cache = os.path.expanduser(settings.DOCKER_HF_CACHE_PATH)
cmd = [
docker,

View File

@ -1,401 +0,0 @@
"""Pipelined layout + block-OCR orchestrator.
Standard flow runs all layouts (batch) then all blocks (batch). The latter
can't start until every page has finished layout, leaving server slots
idle while layouts trickle in.
This orchestrator overlaps the two phases:
* Submit every layout request into one ThreadPoolExecutor.
* As each layout future completes, immediately submit that page's block
OCR requests, **reverse-sorted by count** (LPT scheduling) so big
blocks start their decode earliest.
* If a layout fails to parse, fall back to HIGH_ACCURACY_BBOX_PROMPT on
the full OCR-DPI page its HTML output yields layout + content in one
shot (no per-block phase needed for fallback pages).
* Drain remaining block futures + fallback futures.
Returns the same shape as the serial pair: (List[LayoutResult], List[PageOCRResult]).
"""
from __future__ import annotations
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import List, Optional, Tuple
from PIL import Image
from surya.inference import SuryaInferenceManager
from surya.inference.backends.openai_client import _generate_one
from surya.inference.parsers import (
clean_block_html,
denorm_bbox,
parse_full_page_html,
parse_layout,
)
from surya.inference.prompts import LAYOUT_JSON_SCHEMA
from surya.inference.schema import (
BatchInputItem,
PROMPT_TYPE_BLOCK,
PROMPT_TYPE_HIGH_ACCURACY_BBOX,
PROMPT_TYPE_LAYOUT,
)
from surya.inference.util import image_token_budget
from surya.layout.label import LAYOUT_PRED_RELABEL
from surya.layout.schema import LayoutBox, LayoutResult
from surya.logging import get_logger
from surya.recognition import SKIP_CANON_LABELS, _crop_block
from surya.recognition.schema import BlockOCRResult, PageOCRResult
from surya.settings import settings
logger = get_logger()
def _make_layout_result(
raw: str,
error: bool,
mean_token_prob: Optional[float],
target_size: Tuple[int, int],
) -> LayoutResult:
w, h = target_size
page_bbox = [0, 0, float(w), float(h)]
if error or not raw:
return LayoutResult(bboxes=[], image_bbox=page_bbox, raw=raw, error=True)
try:
parsed = parse_layout(raw)
except Exception as e:
logger.warning(f"Layout parse failed: {e}; raw[:200]={raw[:200]!r}")
return LayoutResult(bboxes=[], image_bbox=page_bbox, raw=raw, error=True)
confidence = mean_token_prob if mean_token_prob is not None else 1.0
boxes: List[LayoutBox] = []
for idx, blk in enumerate(parsed):
pixel_bbox = denorm_bbox(blk.bbox, w, h, scale=settings.BBOX_SCALE)
canon = LAYOUT_PRED_RELABEL.get(blk.label, blk.label)
boxes.append(
LayoutBox(
polygon=list(pixel_bbox),
label=canon,
raw_label=blk.label,
position=idx,
count=blk.count,
confidence=confidence,
)
)
return LayoutResult(bboxes=boxes, image_bbox=page_bbox, raw=raw, error=False)
def _make_block_result(
page_box: LayoutBox,
raw: str,
error: bool,
mean_token_prob: Optional[float],
skipped: bool,
) -> BlockOCRResult:
if skipped:
return BlockOCRResult(
polygon=page_box.polygon,
label=page_box.label,
raw_label=page_box.raw_label,
reading_order=page_box.position,
html="",
skipped=True,
confidence=1.0,
)
if error:
return BlockOCRResult(
polygon=page_box.polygon,
label=page_box.label,
raw_label=page_box.raw_label,
reading_order=page_box.position,
html="",
skipped=False,
error=True,
confidence=0.0,
)
html = clean_block_html(raw)
conf = mean_token_prob if mean_token_prob is not None else 1.0
return BlockOCRResult(
polygon=page_box.polygon,
label=page_box.label,
raw_label=page_box.raw_label,
reading_order=page_box.position,
html=html,
skipped=False,
error=False,
confidence=conf,
)
def _full_page_to_results(
raw: str,
mean_token_prob: Optional[float],
target_size: Tuple[int, int],
) -> Tuple[LayoutResult, PageOCRResult]:
"""Parse HIGH_ACCURACY_BBOX_PROMPT output into a LayoutResult + PageOCRResult
pair (one block per top-level <div>)."""
w, h = target_size
page_bbox = [0, 0, float(w), float(h)]
parsed = parse_full_page_html(raw)
confidence = mean_token_prob if mean_token_prob is not None else 1.0
boxes: List[LayoutBox] = []
blocks: List[BlockOCRResult] = []
for idx, item in enumerate(parsed):
pixel_bbox = denorm_bbox(item.bbox, w, h, scale=settings.BBOX_SCALE)
canon = LAYOUT_PRED_RELABEL.get(item.label, item.label)
polygon = [
[pixel_bbox[0], pixel_bbox[1]],
[pixel_bbox[2], pixel_bbox[1]],
[pixel_bbox[2], pixel_bbox[3]],
[pixel_bbox[0], pixel_bbox[3]],
]
boxes.append(
LayoutBox(
polygon=polygon,
label=canon,
raw_label=item.label,
position=idx,
count=0,
confidence=confidence,
)
)
skipped = canon in SKIP_CANON_LABELS
blocks.append(
BlockOCRResult(
polygon=polygon,
label=canon,
raw_label=item.label,
reading_order=idx,
html="" if skipped else item.html,
skipped=skipped,
error=False,
confidence=confidence,
)
)
layout = LayoutResult(bboxes=boxes, image_bbox=page_bbox, raw=raw, error=False)
page = PageOCRResult(blocks=blocks, image_bbox=page_bbox)
return layout, page
def layout_then_blocks(
manager: SuryaInferenceManager,
ocr_images: List[Image.Image],
layout_images: Optional[List[Image.Image]] = None,
max_workers: Optional[int] = None,
) -> Tuple[List[LayoutResult], List[PageOCRResult]]:
"""Run layout for all pages with block OCR pipelined per page.
layout_images: optional low-DPI renders for layout. If None, ocr_images are
used for both. Bbox coords are returned in ocr_images coord space either way.
On layout failure, falls back to HIGH_ACCURACY_BBOX_PROMPT on the full OCR
image (when settings.SURYA_LAYOUT_FALLBACK_FULL_PAGE is True).
"""
if not ocr_images:
return [], []
if layout_images is None:
layout_images = ocr_images
manager.start()
backend = manager.backend
client = backend._client
model_name = backend.handle.model_name
timeout = settings.SURYA_INFERENCE_TIMEOUT_SECONDS
request_logprobs = settings.SURYA_INFERENCE_LOGPROBS
n_workers = max_workers or settings.SURYA_INFERENCE_PARALLEL
common_kwargs = dict(
client=client,
model_name=model_name,
max_tokens_default=settings.SURYA_MAX_TOKENS_LAYOUT,
temperature=0.0,
top_p=0.1,
timeout=timeout,
request_logprobs_default=request_logprobs,
)
n_pages = len(ocr_images)
layout_results: List[Optional[LayoutResult]] = [None] * n_pages
page_results: List[Optional[PageOCRResult]] = [None] * n_pages
block_results: dict = {} # (page_idx, block_idx) -> (raw, error, mean_p, skipped)
fallback_enabled = settings.SURYA_LAYOUT_FALLBACK_FULL_PAGE
with ThreadPoolExecutor(max_workers=n_workers) as executor:
# ---- Phase 1: submit all layout requests ----
guided = LAYOUT_JSON_SCHEMA if settings.SURYA_GUIDED_LAYOUT else None
layout_futures = {}
for page_idx, lo_img in enumerate(layout_images):
item = BatchInputItem(
image=lo_img,
prompt_type=PROMPT_TYPE_LAYOUT,
max_tokens=settings.SURYA_MAX_TOKENS_LAYOUT,
guided_json=guided,
metadata={"page_idx": page_idx},
)
fut = executor.submit(_generate_one, item, **common_kwargs)
layout_futures[fut] = page_idx
# ---- Phase 2: as layouts land, submit blocks (LPT sort) or fallback ----
block_futures = {}
fallback_futures = {}
def _submit_fallback(page_idx: int):
"""Schedule a HIGH_ACCURACY_BBOX_PROMPT request for this page."""
item = BatchInputItem(
image=ocr_images[page_idx],
prompt_type=PROMPT_TYPE_HIGH_ACCURACY_BBOX,
max_tokens=settings.SURYA_MAX_TOKENS_FULL_PAGE,
metadata={"page_idx": page_idx},
)
fb_fut = executor.submit(_generate_one, item, **common_kwargs)
fallback_futures[fb_fut] = page_idx
for fut in as_completed(layout_futures):
page_idx = layout_futures[fut]
try:
gen = fut.result()
except Exception as e:
logger.warning(f"Layout request failed for page {page_idx}: {e}")
if fallback_enabled:
_submit_fallback(page_idx)
else:
w, h = ocr_images[page_idx].size
layout_results[page_idx] = LayoutResult(
bboxes=[],
image_bbox=[0, 0, float(w), float(h)],
raw=None,
error=True,
)
continue
target_size = ocr_images[page_idx].size
layout_result = _make_layout_result(
raw=gen.raw,
error=gen.error,
mean_token_prob=gen.mean_token_prob,
target_size=target_size,
)
if layout_result.error:
if fallback_enabled:
logger.info(
f"Layout failed for page {page_idx}, falling back to full-page"
)
_submit_fallback(page_idx)
continue
layout_results[page_idx] = layout_result
continue
layout_results[page_idx] = layout_result
# Per-page LPT: reverse-sort blocks by count
blocks_sorted = sorted(
enumerate(layout_result.bboxes),
key=lambda kv: -kv[1].count,
)
for block_idx, blk in blocks_sorted:
if blk.label in SKIP_CANON_LABELS:
block_results[(page_idx, block_idx)] = ("", False, None, True)
continue
crop = _crop_block(ocr_images[page_idx], blk.polygon)
max_tokens = image_token_budget(
blk.count, ceiling=settings.SURYA_MAX_TOKENS_BLOCK_CEILING
)
block_item = BatchInputItem(
image=crop,
prompt_type=PROMPT_TYPE_BLOCK,
max_tokens=max_tokens,
metadata={"page_idx": page_idx, "block_idx": block_idx},
)
bfut = executor.submit(_generate_one, block_item, **common_kwargs)
block_futures[bfut] = (page_idx, block_idx)
# ---- Phase 3: drain block futures ----
for fut in as_completed(block_futures):
key = block_futures[fut]
try:
gen = fut.result()
except Exception as e:
logger.warning(f"Block request failed for {key}: {e}")
block_results[key] = ("", True, None, False)
continue
block_results[key] = (gen.raw, gen.error, gen.mean_token_prob, False)
# ---- Phase 4: drain fallback futures ----
for fut in as_completed(fallback_futures):
page_idx = fallback_futures[fut]
target_size = ocr_images[page_idx].size
page_bbox = [0, 0, float(target_size[0]), float(target_size[1])]
try:
gen = fut.result()
except Exception as e:
logger.warning(f"Fallback request failed for page {page_idx}: {e}")
layout_results[page_idx] = LayoutResult(
bboxes=[],
image_bbox=page_bbox,
raw=None,
error=True,
)
page_results[page_idx] = PageOCRResult(blocks=[], image_bbox=page_bbox)
continue
if gen.error or not gen.raw:
layout_results[page_idx] = LayoutResult(
bboxes=[],
image_bbox=page_bbox,
raw=gen.raw,
error=True,
)
page_results[page_idx] = PageOCRResult(blocks=[], image_bbox=page_bbox)
continue
try:
layout, page = _full_page_to_results(
gen.raw, gen.mean_token_prob, target_size
)
layout_results[page_idx] = layout
page_results[page_idx] = page
except Exception as e:
logger.warning(f"Fallback parse failed for page {page_idx}: {e}")
layout_results[page_idx] = LayoutResult(
bboxes=[],
image_bbox=page_bbox,
raw=gen.raw,
error=True,
)
page_results[page_idx] = PageOCRResult(blocks=[], image_bbox=page_bbox)
# ---- Assemble PageOCRResult for non-fallback pages ----
for page_idx, layout_result in enumerate(layout_results):
if page_results[page_idx] is not None:
# Already populated by fallback path
continue
if layout_result is None:
w, h = ocr_images[page_idx].size
page_results[page_idx] = PageOCRResult(
blocks=[], image_bbox=[0, 0, float(w), float(h)]
)
continue
blocks: List[BlockOCRResult] = []
for block_idx, blk in enumerate(layout_result.bboxes):
entry = block_results.get((page_idx, block_idx))
if entry is None:
blocks.append(
BlockOCRResult(
polygon=blk.polygon,
label=blk.label,
raw_label=blk.raw_label,
reading_order=blk.position,
html="",
error=True,
confidence=0.0,
)
)
continue
raw, error, mean_p, skipped = entry
blocks.append(_make_block_result(blk, raw, error, mean_p, skipped))
w, h = ocr_images[page_idx].size
page_results[page_idx] = PageOCRResult(
blocks=blocks, image_bbox=[0, 0, float(w), float(h)]
)
return list(layout_results), list(page_results) # type: ignore[arg-type]

View File

@ -67,40 +67,6 @@ ALLOWED_ATTRIBUTES = [
# Block labels we don't run OCR on.
SKIP_OCR_LABELS = {"Figure", "Image", "Diagram", "Blank-Page"}
LAYOUT_LABELS = """- Caption
- Footnote
- Equation-Block
- List-Group
- Page-Header
- Page-Footer
- Image
- Section-Header
- Table
- Text
- Complex-Block
- Code-Block
- Form
- Table-Of-Contents
- Figure
- Chemical-Block
- Diagram
- Bibliography
- Blank-Page"""
GUIDELINES = f"""Only use these tags {ALLOWED_TAGS}, and these attributes {ALLOWED_ATTRIBUTES}.
Guidelines:
* Inline math: Surround math with <math>...</math> tags. Math expressions should be rendered in KaTeX-compatible LaTeX. Use display for block math.
* Tables: Use colspan and rowspan attributes to match table structure.
* Formatting: Maintain consistent formatting with the image, including spacing, indentation, subscripts/superscripts, and special characters.
* Images: Include a description of any images in the alt attribute of an <img> tag. Do not fill out the src property. Describe in detail inside the div tag. Also convert charts to high fidelity data, and convert diagrams to mermaid.
* Forms: Mark checkboxes and radio buttons properly.
* Text: join lines together properly into paragraphs using <p>...</p> tags. Use <br> tags for line breaks within paragraphs, but only when absolutely necessary to maintain meaning.
* Chemistry: Use <chem>...</chem> tags for chemical formulas with reactive SMILES.
* Lists: Preserve indents and proper list markers.
* Use the simplest possible HTML structure that accurately represents the content of the block.
* Make sure the text is accurate and easy for a human to read and interpret. Reading order should be correct and natural."""
LAYOUT_PROMPT = (
"Output the layout of this image as JSON. Each entry is a dict with "
'"label", "bbox", and "count" fields. Bbox is x0 y0 x1 y1, normalized 0-1000.'

View File

@ -7,7 +7,6 @@ from surya.settings import settings
import os
import filetype
from PIL import Image
import json
logger = get_logger()
@ -76,9 +75,3 @@ def load_from_folder(
logger.warning(f"Could not load image {path}")
continue
return images, names
def load_lang_file(lang_path, names):
with open(lang_path, "r") as f:
lang_dict = json.load(f)
return [lang_dict[name].copy() for name in names]

View File

@ -1,24 +1,9 @@
from typing import List
import cv2
import numpy as np
import pypdfium2
from PIL import Image
from surya.logging import get_logger
from surya.settings import settings
logger = get_logger()
def convert_if_not_rgb(images: List[Image.Image]) -> List[Image.Image]:
new_images = []
for image in images:
if image.mode != "RGB":
image = image.convert("RGB")
new_images.append(image)
return new_images
def open_pdf(pdf_filepath):
return pypdfium2.PdfDocument(pdf_filepath)
@ -30,72 +15,3 @@ def get_page_images(doc, indices: List, dpi=settings.IMAGE_DPI):
]
images = [image.convert("RGB") for image in images]
return images
def slice_bboxes_from_image(image: np.ndarray, bboxes):
lines = []
for bbox in bboxes:
bbox = np.array(bbox, dtype=np.int32)
bbox = np.clip(bbox, 0, None) # Ensure no negative indices
# Ensure bbox is within the image bounds
if bbox[3] <= bbox[1]:
bbox[3] = bbox[1] + 1
if bbox[2] <= bbox[0]:
bbox[2] = bbox[0] + 1
bbox[2] = min(bbox[2], image.shape[1])
bbox[3] = min(bbox[3], image.shape[0])
line = image[bbox[1] : bbox[3], bbox[0] : bbox[2]].copy()
if line.size == 0:
logger.warning(f"Warning: found an empty line with bbox {bbox}")
lines.append(line)
return lines
def slice_polys_from_image(image: np.ndarray, polys):
lines = []
for idx, poly in enumerate(polys):
lines.append(slice_and_pad_poly(image, poly))
return lines
def slice_and_pad_poly(image_array: np.array, coordinates):
# Draw polygon onto mask
coordinates = [(corner[0], corner[1]) for corner in coordinates]
bbox = [
min([x[0] for x in coordinates]),
min([x[1] for x in coordinates]),
max([x[0] for x in coordinates]),
max([x[1] for x in coordinates]),
]
# We mask out anything not in the polygon
cropped_polygon = image_array[bbox[1] : bbox[3], bbox[0] : bbox[2]].copy()
height, width = cropped_polygon.shape[:2]
coordinates = [(x - bbox[0], y - bbox[1]) for x, y in coordinates]
# Validate the cropped area
if any(
[
bbox[3] <= bbox[1] or bbox[2] <= bbox[0],
len(coordinates) < 3,
height == 0,
width == 0,
]
):
return cropped_polygon
# Pad the area outside the polygon with the pad value
try:
mask = np.zeros(cropped_polygon.shape[:2], dtype=np.uint8)
cv2.fillPoly(mask, [np.int32(coordinates)], 1)
mask = np.stack([mask] * 3, axis=-1)
cropped_polygon[mask == 0] = 255 # white pad
except cv2.error as e:
logger.warning(f"Warning: issue while processing polygon: {e}")
return cropped_polygon

View File

@ -1,29 +1,6 @@
"""Surya2 layout labels emitted by the model + canonicalization to surya's
public label vocabulary."""
# Labels the model emits via LAYOUT_PROMPT (see surya/inference/prompts.py).
LAYOUT_LABELS = (
"Caption",
"Footnote",
"Equation-Block",
"List-Group",
"Page-Header",
"Page-Footer",
"Image",
"Section-Header",
"Table",
"Text",
"Complex-Block",
"Code-Block",
"Form",
"Table-Of-Contents",
"Figure",
"Chemical-Block",
"Diagram",
"Bibliography",
"Blank-Page",
)
# Canonicalize raw model labels to public surya label names. Marker and other
# downstream consumers depend on these names.
LAYOUT_PRED_RELABEL = {

View File

@ -1,35 +0,0 @@
from typing import Dict, Optional
import torch
from surya.detection import DetectionPredictor
from surya.inference import SuryaInferenceManager
from surya.layout import LayoutPredictor
from surya.logging import configure_logging
from surya.ocr_error import OCRErrorPredictor
from surya.recognition import RecognitionPredictor
from surya.table_rec import TableRecPredictor
configure_logging()
def load_predictors(
device: str | torch.device | None = None,
dtype: torch.dtype | str | None = None,
manager: Optional[SuryaInferenceManager] = None,
) -> Dict[str, object]:
"""Build the standard surya predictor set.
The VLM-backed predictors (layout, recognition, table_rec) share a single
SuryaInferenceManager. Detection and OCR error keep their own torch models.
"""
if manager is None:
manager = SuryaInferenceManager(lazy=True)
return {
"layout": LayoutPredictor(manager),
"recognition": RecognitionPredictor(manager),
"table_rec": TableRecPredictor(manager),
"detection": DetectionPredictor(device=device, dtype=dtype),
"ocr_error": OCRErrorPredictor(device=device, dtype=dtype),
"manager": manager,
}

View File

@ -1,18 +1,20 @@
import collections
import os
import json
import unicodedata
from typing import List, Optional, Tuple
from tokenizers import normalizers
from transformers.tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
from transformers.tokenization_utils import (
PreTrainedTokenizer,
_is_control,
_is_punctuation,
_is_whitespace,
)
from surya.common.s3 import S3DownloaderMixin
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
# Copied from transformers.models.bert.tokenization_bert.load_vocab
def load_vocab(vocab_file):
"""Loads a vocabulary file into a dictionary."""
@ -101,7 +103,9 @@ class DistilBertTokenizer(S3DownloaderMixin, PreTrainedTokenizer):
" model use `tokenizer = DistilBertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
)
self.vocab = load_vocab(vocab_file)
self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
self.ids_to_tokens = collections.OrderedDict(
[(ids, tok) for tok, ids in self.vocab.items()]
)
self.do_basic_tokenize = do_basic_tokenize
if do_basic_tokenize:
self.basic_tokenizer = BasicTokenizer(
@ -110,7 +114,9 @@ class DistilBertTokenizer(S3DownloaderMixin, PreTrainedTokenizer):
tokenize_chinese_chars=tokenize_chinese_chars,
strip_accents=strip_accents,
)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=str(unk_token))
self.wordpiece_tokenizer = WordpieceTokenizer(
vocab=self.vocab, unk_token=str(unk_token)
)
super().__init__(
do_lower_case=do_lower_case,
@ -145,7 +151,10 @@ class DistilBertTokenizer(S3DownloaderMixin, PreTrainedTokenizer):
split_tokens = []
if self.do_basic_tokenize:
for token in self.basic_tokenizer.tokenize(
text, never_split=self.all_special_tokens if not split_special_tokens else None
text,
never_split=self.all_special_tokens
if not split_special_tokens
else None,
):
# If the token is part of the never_split set
if token in self.basic_tokenizer.never_split:
@ -200,7 +209,10 @@ class DistilBertTokenizer(S3DownloaderMixin, PreTrainedTokenizer):
# Copied from transformers.models.bert.tokenization_bert.BertTokenizer.get_special_tokens_mask
def get_special_tokens_mask(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
self,
token_ids_0: List[int],
token_ids_1: Optional[List[int]] = None,
already_has_special_tokens: bool = False,
) -> List[int]:
"""
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
@ -220,7 +232,9 @@ class DistilBertTokenizer(S3DownloaderMixin, PreTrainedTokenizer):
if already_has_special_tokens:
return super().get_special_tokens_mask(
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
token_ids_0=token_ids_0,
token_ids_1=token_ids_1,
already_has_special_tokens=True,
)
if token_ids_1 is not None:
@ -258,14 +272,20 @@ class DistilBertTokenizer(S3DownloaderMixin, PreTrainedTokenizer):
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
# Copied from transformers.models.bert.tokenization_bert.BertTokenizer.save_vocabulary
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
def save_vocabulary(
self, save_directory: str, filename_prefix: Optional[str] = None
) -> Tuple[str]:
index = 0
if os.path.isdir(save_directory):
vocab_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
save_directory,
(filename_prefix + "-" if filename_prefix else "")
+ VOCAB_FILES_NAMES["vocab_file"],
)
else:
vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory
vocab_file = (
filename_prefix + "-" if filename_prefix else ""
) + save_directory
with open(vocab_file, "w", encoding="utf-8") as writer:
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
if index != token_index:
@ -329,7 +349,11 @@ class BasicTokenizer(object):
[`PreTrainedTokenizer.tokenize`]) List of token not to split.
"""
# union() returns a new set by concatenating the two sets.
never_split = self.never_split.union(set(never_split)) if never_split else self.never_split
never_split = (
self.never_split.union(set(never_split))
if never_split
else self.never_split
)
text = self._clean_text(text)
# This was added on November 1st, 2018 for the multilingual and Chinese
@ -370,7 +394,9 @@ class BasicTokenizer(object):
def _run_split_on_punc(self, text, never_split=None):
"""Splits punctuation on a piece of text."""
if not self.do_split_on_punc or (never_split is not None and text in never_split):
if not self.do_split_on_punc or (
never_split is not None and text in never_split
):
return [text]
chars = list(text)
i = 0
@ -496,4 +522,4 @@ class WordpieceTokenizer(object):
output_tokens.append(self.unk_token)
else:
output_tokens.extend(sub_tokens)
return output_tokens
return output_tokens

View File

@ -24,9 +24,7 @@ from surya.layout.schema import LayoutResult
from surya.logging import get_logger
from surya.recognition.schema import (
BlockOCRResult,
OCRResult,
PageOCRResult,
TextLine,
)
from surya.settings import settings
@ -251,23 +249,3 @@ class RecognitionPredictor:
)
results.append(PageOCRResult(blocks=blocks, image_bbox=page_bbox))
return results
def to_legacy_ocr_results(
self, page_results: List[PageOCRResult]
) -> List[OCRResult]:
"""Compatibility shim: map BlockOCRResult → OCRResult.text_lines for old
downstream code that hasn't migrated yet. One TextLine per block, no chars."""
out: List[OCRResult] = []
for page in page_results:
lines: List[TextLine] = []
for blk in page.blocks:
lines.append(
TextLine(
polygon=blk.polygon,
text=blk.html,
chars=[],
confidence=blk.confidence,
)
)
out.append(OCRResult(text_lines=lines, image_bbox=page.image_bbox))
return out

View File

@ -1,4 +1,4 @@
from typing import Any, List, Optional
from typing import List
from pydantic import BaseModel
@ -12,30 +12,8 @@ class BlockOCRResult(PolygonBox):
html: str = "" # block HTML (BLOCK_PROMPT output, "" if skipped)
skipped: bool = False # True if label was in SKIP_OCR_LABELS
error: bool = False
char_confidences: Optional[List[float]] = None # phase 2
raw_logprobs: Optional[List[Any]] = None # phase 2 debugging
class PageOCRResult(BaseModel):
blocks: List[BlockOCRResult]
image_bbox: List[float]
# ---- Back-compat shims for code paths that still expect text_lines ----
# These are intentionally minimal; downstream consumers should migrate to
# BlockOCRResult / PageOCRResult.
class TextChar(BaseModel):
text: str
confidence: float = 0.0
class TextLine(PolygonBox):
text: str = ""
chars: List[TextChar] = []
class OCRResult(BaseModel):
text_lines: List[TextLine]
image_bbox: List[float]

View File

@ -15,7 +15,6 @@ class Settings(BaseSettings):
IMAGE_DPI: int = 96 # used for layout, recognition, and table rec
IMAGE_DPI_HIGHRES: int = 192
IN_STREAMLIT: bool = False
FLATTEN_PDF: bool = True
DISABLE_TQDM: bool = False
S3_BASE_URL: str = "https://models.datalab.to"
PARALLEL_DOWNLOAD_WORKERS: int = 10
@ -23,7 +22,6 @@ class Settings(BaseSettings):
LOGLEVEL: str = "INFO"
# Paths
DATA_DIR: str = "data"
RESULT_DIR: str = "results"
BASE_DIR: str = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
FONT_DIR: str = os.path.join(BASE_DIR, "static", "fonts")
@ -39,7 +37,7 @@ class Settings(BaseSettings):
return "cpu"
# ---- Surya2 inference (VLM-backed: vllm | llamacpp) ---------------------
SURYA_MODEL_CHECKPOINT: str = "datalab-to/surya-2.1.2"
SURYA_MODEL_CHECKPOINT: str = "datalab-to/surya-2.1.2-mtp"
SURYA_GGUF_REPO: str = "datalab-to/surya-ocr-2-gguf"
SURYA_GGUF_MODEL_FILE: str = "surya-2.gguf"
SURYA_GGUF_MMPROJ_FILE: str = "surya-2-mmproj.gguf"
@ -90,7 +88,7 @@ class Settings(BaseSettings):
VLLM_MAX_MODEL_LEN: int = 18000
VLLM_GPU_MEMORY_UTILIZATION: float = 0.85
VLLM_ENABLE_MTP: bool = True
VLLM_MTP_TOKENS: int = 3
VLLM_MTP_TOKENS: int = 2
VLLM_EXTRA_ARGS: Optional[str] = None
DOCKER_HF_CACHE_PATH: str = "~/.cache/huggingface"
@ -103,7 +101,6 @@ class Settings(BaseSettings):
# ---- Detection (kept) ---------------------------------------------------
DETECTOR_BATCH_SIZE: Optional[int] = None
DETECTOR_MODEL_CHECKPOINT: str = "s3://text_detection/2025_05_07"
DETECTOR_BENCH_DATASET_NAME: str = "vikp/doclaynet_bench"
DETECTOR_IMAGE_CHUNK_HEIGHT: int = 1400
DETECTOR_TEXT_THRESHOLD: float = 0.6
DETECTOR_BLANK_THRESHOLD: float = 0.35
@ -115,11 +112,6 @@ class Settings(BaseSettings):
OCR_ERROR_MODEL_CHECKPOINT: str = "s3://ocr_error_detection/2025_02_18"
OCR_ERROR_BATCH_SIZE: Optional[int] = None
# ---- Layout / Recognition / Table-rec batch sizes (CLI scripts) -------
LAYOUT_BATCH_SIZE: Optional[int] = None
RECOGNITION_BATCH_SIZE: Optional[int] = None
TABLE_REC_BATCH_SIZE: Optional[int] = None
# ---- Debug / draw fonts (label rendering on annotated images) ----------
RECOGNITION_RENDER_FONTS: Dict[str, str] = {
"all": os.path.join(FONT_DIR, "GoNotoCurrent-Regular.ttf"),
@ -131,9 +123,6 @@ class Settings(BaseSettings):
"https://github.com/satbyy/go-noto-universal/releases/download/v7.0"
)
# ---- Tesseract ---------------------------------------------------------
TESSDATA_PREFIX: Optional[str] = None
@computed_field
def MODEL_DTYPE(self) -> torch.dtype:
if self.TORCH_DEVICE_MODEL == "cpu":