From 79246df8378e300d4e8eceb43a3ed41e61d02425 Mon Sep 17 00:00:00 2001 From: Vik Paruchuri Date: Wed, 27 May 2026 10:36:53 -0400 Subject: [PATCH] Add screenshot app --- pyproject.toml | 2 + surya/inference/backends/spawn.py | 13 +- surya/scripts/config.py | 55 ++- surya/scripts/screenshot_app.py | 226 ++++++++++++ surya/scripts/streamlit_app.py | 66 +++- surya/scripts/templates/surya_screenshot.html | 331 ++++++++++++++++++ surya/settings.py | 4 + uv.lock | 31 ++ 8 files changed, 712 insertions(+), 16 deletions(-) create mode 100644 surya/scripts/screenshot_app.py create mode 100644 surya/scripts/templates/surya_screenshot.html diff --git a/pyproject.toml b/pyproject.toml index 3a45867..8253adf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,7 @@ surya_ocr = "surya.scripts.ocr_text:ocr_text_cli" surya_layout = "surya.scripts.detect_layout:detect_layout_cli" surya_gui = "surya.scripts.run_streamlit_app:streamlit_app_cli" surya_table = "surya.scripts.table_recognition:table_recognition_cli" +surya_screenshot = "surya.scripts.screenshot_app:main" [dependency-groups] dev = [ @@ -48,6 +49,7 @@ dev = [ "pytest>=8.3.4", "pdftext>=0.5.1", "tabulate>=0.9.0", + "flask>=3.0.0", ] [build-system] diff --git a/surya/inference/backends/spawn.py b/surya/inference/backends/spawn.py index 6776d04..0ece794 100644 --- a/surya/inference/backends/spawn.py +++ b/surya/inference/backends/spawn.py @@ -265,7 +265,10 @@ def attach_or_spawn( }, ) - # 5. Register atexit cleanup (only spawner) + # 5. Register atexit cleanup (only spawner). Skipped when keep-alive is + # set so the server outlives this process and later commands attach to + # it via the sentinel. (_cleanup is still callable below on startup + # failure, where we always tear a half-started server down.) def _cleanup(): try: if spawn_handle.cleanup_kind == "docker": @@ -276,7 +279,13 @@ def attach_or_spawn( finally: _delete_sentinel(backend) - atexit.register(_cleanup) + if settings.SURYA_INFERENCE_KEEP_ALIVE: + logger.info( + f"keep-alive: {backend} server on port {port} will stay up " + f"after exit (cleanup_id={spawn_handle.cleanup_id!r})" + ) + else: + atexit.register(_cleanup) # 6. Wait for health health_url = health_url_for(port) diff --git a/surya/scripts/config.py b/surya/scripts/config.py index 4a40d05..921627a 100644 --- a/surya/scripts/config.py +++ b/surya/scripts/config.py @@ -17,15 +17,45 @@ class CLILoader: self.debug = cli_options.get("debug", False) self.output_dir = cli_options.get("output_dir") + # Opt in to leaving the inference server up so later commands reuse it. + if cli_options.get("keep_server"): + settings.SURYA_INFERENCE_KEEP_ALIVE = True + self.load(highres) @staticmethod def common_options(fn): - fn = click.argument("input_path", type=click.Path(exists=True), required=True)(fn) - fn = click.option("--output_dir", type=click.Path(exists=False), required=False, default=os.path.join(settings.RESULT_DIR, "surya"), help="Directory to save output.")(fn) - fn = click.option("--page_range", type=str, default=None, help="Page range to convert, specify comma separated page numbers or ranges. Example: 0,5-10,20")(fn) - fn = click.option("--images", is_flag=True, help="Save images of detected bboxes.", default=False)(fn) - fn = click.option('--debug', '-d', is_flag=True, help='Enable debug mode.', default=False)(fn) + fn = click.argument("input_path", type=click.Path(exists=True), required=True)( + fn + ) + fn = click.option( + "--output_dir", + type=click.Path(exists=False), + required=False, + default=os.path.join(settings.RESULT_DIR, "surya"), + help="Directory to save output.", + )(fn) + fn = click.option( + "--page_range", + type=str, + default=None, + help="Page range to convert, specify comma separated page numbers or ranges. Example: 0,5-10,20", + )(fn) + fn = click.option( + "--images", + is_flag=True, + help="Save images of detected bboxes.", + default=False, + )(fn) + fn = click.option( + "--debug", "-d", is_flag=True, help="Enable debug mode.", default=False + )(fn) + fn = click.option( + "--keep_server", + is_flag=True, + default=False, + help="Keep the inference server (vllm/llama.cpp) running after this command exits so later commands reuse it instead of re-spawning.", + )(fn) return fn def load(self, highres: bool = False): @@ -34,13 +64,16 @@ class CLILoader: images, names = load_from_folder(self.filepath, self.page_range) folder_name = os.path.basename(self.filepath) if highres: - highres_images, _ = load_from_folder(self.filepath, self.page_range, settings.IMAGE_DPI_HIGHRES) + highres_images, _ = load_from_folder( + self.filepath, self.page_range, settings.IMAGE_DPI_HIGHRES + ) else: images, names = load_from_file(self.filepath, self.page_range) folder_name = os.path.basename(self.filepath).split(".")[0] if highres: - highres_images, _ = load_from_file(self.filepath, self.page_range, settings.IMAGE_DPI_HIGHRES) - + highres_images, _ = load_from_file( + self.filepath, self.page_range, settings.IMAGE_DPI_HIGHRES + ) self.images = images self.highres_images = highres_images @@ -59,5 +92,7 @@ class CLILoader: page_lst += list(range(int(start), int(end) + 1)) else: page_lst.append(int(i)) - page_lst = sorted(list(set(page_lst))) # Deduplicate page numbers and sort in order - return page_lst \ No newline at end of file + page_lst = sorted( + list(set(page_lst)) + ) # Deduplicate page numbers and sort in order + return page_lst diff --git a/surya/scripts/screenshot_app.py b/surya/scripts/screenshot_app.py new file mode 100644 index 0000000..34b88ab --- /dev/null +++ b/surya/scripts/screenshot_app.py @@ -0,0 +1,226 @@ +"""Screenshot-friendly Surya viewer. + +Shows a PDF/image page on the left and full-page OCR output on the right, side +by side, for clean screenshots. You can scroll through pages and preview them +before running OCR, then export the side-by-side view as a PNG. + +Run with `surya_screenshot`, then open http://localhost:8504. +""" + +from __future__ import annotations + +import base64 +import io +import os +import tempfile +import uuid +from typing import List, Optional + +import pypdfium2 +from flask import Flask, jsonify, render_template, request +from PIL import Image +from werkzeug.utils import secure_filename + +from surya.inference import SuryaInferenceManager +from surya.logging import configure_logging, get_logger +from surya.recognition import RecognitionPredictor +from surya.recognition.schema import PageOCRResult +from surya.settings import settings + +configure_logging() +logger = get_logger() + +app = Flask(__name__) + +ALLOWED_EXT = {".pdf", ".png", ".jpg", ".jpeg", ".gif", ".webp"} +UPLOAD_DIR = os.path.join(tempfile.gettempdir(), "surya_screenshot") +os.makedirs(UPLOAD_DIR, exist_ok=True) + +_rec: Optional[RecognitionPredictor] = None + + +def get_rec() -> RecognitionPredictor: + """Lazily build the recognition predictor (shared inference manager).""" + global _rec + if _rec is None: + _rec = RecognitionPredictor(SuryaInferenceManager()) + return _rec + + +# Datalab-flavored palette for layout block overlays, keyed by canonical label. +LABEL_COLORS = { + "Text": "#2563eb", + "SectionHeader": "#0ea5e9", + "PageHeader": "#7c3aed", + "PageFooter": "#7c3aed", + "Caption": "#c026d3", + "Footnote": "#64748b", + "Equation": "#9333ea", + "Table": "#f59e0b", + "TableOfContents": "#f59e0b", + "Form": "#ea580c", + "ListGroup": "#10b981", + "Picture": "#db2777", + "Figure": "#db2777", + "Diagram": "#db2777", + "Code": "#0d9488", + "default": "#ef4444", +} + + +def _logo_data_url() -> str: + path = os.path.join(settings.BASE_DIR, "static", "datalab-logo.png") + try: + with open(path, "rb") as f: + return "data:image/png;base64," + base64.b64encode(f.read()).decode() + except Exception: + return "" + + +def _pil_to_data_url(img: Image.Image, fmt: str = "PNG") -> str: + buf = io.BytesIO() + img.save(buf, format=fmt) + return ( + f"data:image/{fmt.lower()};base64," + base64.b64encode(buf.getvalue()).decode() + ) + + +def _is_pdf(path: str) -> bool: + return path.lower().endswith(".pdf") + + +def _page_count(path: str) -> int: + if _is_pdf(path): + doc = pypdfium2.PdfDocument(path) + n = len(doc) + doc.close() + return n + return 1 + + +def _render_page(path: str, page: int, dpi: int) -> Image.Image: + """Render a 0-indexed page of a PDF (or load an image file) as RGB.""" + if _is_pdf(path): + doc = pypdfium2.PdfDocument(path) + try: + pil = doc[page].render(scale=dpi / 72).to_pil().convert("RGB") + finally: + doc.close() + return pil + return Image.open(path).convert("RGB") + + +def _assemble_page_html(page: PageOCRResult) -> str: + """Whole-page HTML from a PageOCRResult (math stays in tags).""" + parts: List[str] = [] + for blk in page.blocks: + if blk.skipped: + continue + x0, y0, x1, y1 = (int(c) for c in blk.bbox) + parts.append( + f'
{blk.html or ""}
' + ) + return "\n".join(parts) + + +@app.route("/") +def index(): + return render_template("surya_screenshot.html", logo=_logo_data_url()) + + +@app.route("/info", methods=["POST"]) +def info(): + path = (request.json or {}).get("file_path", "").strip() + if not path: + return jsonify({"error": "file_path is required"}), 400 + if not os.path.exists(path): + return jsonify({"error": f"File not found: {path}"}), 400 + try: + return jsonify({"page_count": _page_count(path)}) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + +@app.route("/upload", methods=["POST"]) +def upload(): + """Accept a drag/drop (or browsed) file, save to a temp path, return it.""" + f = request.files.get("file") + if f is None or not f.filename: + return jsonify({"error": "no file uploaded"}), 400 + ext = os.path.splitext(f.filename)[1].lower() + if ext not in ALLOWED_EXT: + return jsonify({"error": f"unsupported file type: {ext or '(none)'}"}), 400 + safe = secure_filename(f.filename) or f"upload{ext}" + dest = os.path.join(UPLOAD_DIR, f"{uuid.uuid4().hex}_{safe}") + f.save(dest) + try: + return jsonify( + {"file_path": dest, "page_count": _page_count(dest), "name": f.filename} + ) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + +@app.route("/page", methods=["POST"]) +def page(): + """Render a single page for preview (no OCR).""" + data = request.json or {} + path = data.get("file_path", "").strip() + page_num = int(data.get("page", 0)) + if not path or not os.path.exists(path): + return jsonify({"error": "valid file_path is required"}), 400 + try: + img = _render_page(path, page_num, settings.IMAGE_DPI_HIGHRES) + return jsonify( + { + "image_base64": _pil_to_data_url(img), + "width": img.size[0], + "height": img.size[1], + } + ) + except Exception as e: + return jsonify({"error": str(e)}), 500 + + +@app.route("/process", methods=["POST"]) +def process(): + """Run full-page OCR on one page; return the page image + OCR HTML + blocks.""" + data = request.json or {} + path = data.get("file_path", "").strip() + page_num = int(data.get("page", 0)) + if not path or not os.path.exists(path): + return jsonify({"error": "valid file_path is required"}), 400 + try: + img = _render_page(path, page_num, settings.IMAGE_DPI_HIGHRES) + page_result = get_rec()([img], full_page=True)[0] + blocks = [ + { + "bbox": [int(c) for c in blk.bbox], + "label": blk.label, + "color": LABEL_COLORS.get(blk.label, LABEL_COLORS["default"]), + } + for blk in page_result.blocks + if not blk.skipped + ] + return jsonify( + { + "image_base64": _pil_to_data_url(img), + "width": img.size[0], + "height": img.size[1], + "html": _assemble_page_html(page_result), + "blocks": blocks, + "n_blocks": len(page_result.blocks), + } + ) + except Exception as e: + logger.exception("Full-page OCR failed") + return jsonify({"error": str(e)}), 500 + + +def main(): + app.run(host="0.0.0.0", port=8504) + + +if __name__ == "__main__": + main() diff --git a/surya/scripts/streamlit_app.py b/surya/scripts/streamlit_app.py index 26d0142..24d0cc9 100644 --- a/surya/scripts/streamlit_app.py +++ b/surya/scripts/streamlit_app.py @@ -4,12 +4,14 @@ inference manager. Detection + OCR-error stay in their own torch paths.""" from __future__ import annotations import io +import re import tempfile import time from typing import List import pypdfium2 import streamlit as st +import streamlit.components.v1 as components from PIL import Image, ImageDraw from surya.debug.draw import draw_polys_on_image, draw_bboxes_on_image @@ -24,6 +26,61 @@ from surya.table_rec import TableRecPredictor from surya.table_rec.schema import TableResult +# KaTeX-enabled HTML wrapper. The OCR HTML wraps math in ... +# (KaTeX-compatible LaTeX inside), which a browser would otherwise show as +# raw text. We convert those tags to \( \) / \[ \] delimiters and let KaTeX +# auto-render typeset them inside an iframe component. +_KATEX_HEAD = r""" + + + + + +""" + +_KATEX_TAIL = r""" + +""" + +_MATH_RE = re.compile(r"]*)>(.*?)", re.DOTALL | re.IGNORECASE) + + +def _math_to_katex(html_str: str) -> str: + """Rewrite ... tags into KaTeX \\( \\) / \\[ \\] delimiters.""" + + def repl(m: "re.Match") -> str: + attrs, inner = m.group(1), m.group(2) + if re.search(r"""display\s*=\s*["']block["']""", attrs): + return "\\[" + inner + "\\]" + return "\\(" + inner + "\\)" + + return _MATH_RE.sub(repl, html_str or "") + + +def render_ocr_html(html_str: str, height: int = 400) -> None: + """Render OCR HTML with math typeset by KaTeX (iframe component).""" + components.html( + _KATEX_HEAD + _math_to_katex(html_str) + _KATEX_TAIL, + height=height, + scrolling=True, + ) + + def _assemble_page_html(page: PageOCRResult) -> str: """Reconstruct a div-block whole-page HTML from a PageOCRResult.""" parts: List[str] = [] @@ -334,7 +391,7 @@ if run_block_ocr: ) full_html = _assemble_page_html(page) with st.expander("Full page HTML (rendered)", expanded=False): - st.markdown(full_html, unsafe_allow_html=True) + render_ocr_html(full_html, height=600) with st.expander("Full page HTML (source)", expanded=False): st.code(full_html, language="html") for blk in page.blocks: @@ -366,6 +423,7 @@ if run_block_ocr: elif blk.error: st.error("Block OCR errored") else: + render_ocr_html(blk.html, height=160) st.code(blk.html, language="html") @@ -382,7 +440,7 @@ if run_full_page_ocr: ) full_html = _assemble_page_html(page) with st.expander("Full page HTML (rendered)", expanded=False): - st.markdown(full_html, unsafe_allow_html=True) + render_ocr_html(full_html, height=600) with st.expander("Full page HTML (source)", expanded=False): st.code(full_html, language="html") for blk in page.blocks: @@ -394,7 +452,7 @@ if run_full_page_ocr: elif blk.error: st.error("Block OCR errored") else: - st.markdown(blk.html, unsafe_allow_html=True) + render_ocr_html(blk.html, height=160) st.code(blk.html, language="html") @@ -412,7 +470,7 @@ if run_table_rec: for pred in preds: if pred.mode == "full" and pred.html: with st.expander("Table HTML"): - st.markdown(pred.html, unsafe_allow_html=True) + render_ocr_html(pred.html, height=400) st.code(pred.html, language="html") else: st.json(pred.model_dump(), expanded=False) diff --git a/surya/scripts/templates/surya_screenshot.html b/surya/scripts/templates/surya_screenshot.html new file mode 100644 index 0000000..c472f16 --- /dev/null +++ b/surya/scripts/templates/surya_screenshot.html @@ -0,0 +1,331 @@ + + + + + + Surya · Full-Page OCR + + + + + + +
+
+ {% if logo %}Datalab{% endif %} + Surya + Full-Page OCR +
+
+ + + + +
+ + + +
+ + + + + +
+
+ +
+
+
PDF Page
+
+
+
+
Full-Page OCR
+
Load a file, scroll to a page, then run full-page OCR.
+
+
+ +
Drop a PDF or image to load
+ + + + diff --git a/surya/settings.py b/surya/settings.py index b051eae..f10fbac 100644 --- a/surya/settings.py +++ b/surya/settings.py @@ -49,6 +49,10 @@ class Settings(BaseSettings): SURYA_INFERENCE_BACKEND: Optional[str] = None # "vllm" | "llamacpp" | None (auto) SURYA_INFERENCE_URL: Optional[str] = None # external server, skip spawn SURYA_INFERENCE_AUTOSTART: bool = True + # Leave an auto-spawned server running after the process exits so later + # commands attach to it instead of re-spawning (avoids repeated startup / + # model-load cost). Stop it manually when done — see `surya/inference`. + SURYA_INFERENCE_KEEP_ALIVE: bool = False SURYA_INFERENCE_HOST: str = "127.0.0.1" SURYA_INFERENCE_PORT: Optional[int] = None # None = pick a free port SURYA_INFERENCE_PARALLEL: int = 8 diff --git a/uv.lock b/uv.lock index d19141a..a0ac0b4 100644 --- a/uv.lock +++ b/uv.lock @@ -823,6 +823,23 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/18/79/1b8fa1bb3568781e84c9200f951c735f3f157429f44be0495da55894d620/filetype-1.2.0-py2.py3-none-any.whl", hash = "sha256:7ce71b6880181241cf7ac8697a2f1eb6a8bd9b429f7ad6d27b8db9ba5f1c2d25", size = 19970 }, ] +[[package]] +name = "flask" +version = "3.1.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "blinker" }, + { name = "click" }, + { name = "itsdangerous" }, + { name = "jinja2" }, + { name = "markupsafe" }, + { name = "werkzeug" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/26/00/35d85dcce6c57fdc871f3867d465d780f302a175ea360f62533f12b27e2b/flask-3.1.3.tar.gz", hash = "sha256:0ef0e52b8a9cd932855379197dd8f94047b359ca0a78695144304cb45f87c9eb", size = 759004 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7f/9c/34f6962f9b9e9c71f6e5ed806e0d0ff03c9d1b0b2340088a0cf4bce09b18/flask-3.1.3-py3-none-any.whl", hash = "sha256:f4bcbefc124291925f1a26446da31a5178f9483862233b23c0c96a20701f670c", size = 103424 }, +] + [[package]] name = "fqdn" version = "1.5.1" @@ -3944,6 +3961,7 @@ dependencies = [ [package.dev-dependencies] dev = [ { name = "datasets" }, + { name = "flask" }, { name = "jupyter" }, { name = "pdftext" }, { name = "pre-commit" }, @@ -3976,6 +3994,7 @@ requires-dist = [ [package.metadata.requires-dev] dev = [ { name = "datasets", specifier = ">=2.16.1" }, + { name = "flask", specifier = ">=3.0.0" }, { name = "jupyter", specifier = ">=1.0.0" }, { name = "pdftext", specifier = ">=0.5.1" }, { name = "pre-commit", specifier = ">=4.2.0" }, @@ -4468,6 +4487,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6f/28/258ebab549c2bf3e64d2b0217b973467394a9cea8c42f70418ca2c5d0d2e/websockets-16.0-py3-none-any.whl", hash = "sha256:1637db62fad1dc833276dded54215f2c7fa46912301a24bd94d45d46a011ceec", size = 171598 }, ] +[[package]] +name = "werkzeug" +version = "3.1.8" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "markupsafe" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/dd/b2/381be8cfdee792dd117872481b6e378f85c957dd7c5bca38897b08f765fd/werkzeug-3.1.8.tar.gz", hash = "sha256:9bad61a4268dac112f1c5cd4630a56ede601b6ed420300677a869083d70a4c44", size = 875852 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/93/8c/2e650f2afeb7ee576912636c23ddb621c91ac6a98e66dc8d29c3c69446e1/werkzeug-3.1.8-py3-none-any.whl", hash = "sha256:63a77fb8892bf28ebc3178683445222aa500e48ebad5ec77b0ad80f8726b1f50", size = 226459 }, +] + [[package]] name = "widgetsnbextension" version = "4.0.15"