From fe8545cfc8b4d784de6204061cfb14003e1f169d Mon Sep 17 00:00:00 2001 From: Tarun Menta Date: Fri, 1 Aug 2025 18:39:06 -0400 Subject: [PATCH] Better calculation of max image token count --- surya/foundation/__init__.py | 42 ++++++++++++++++++++++++++++++++---- 1 file changed, 38 insertions(+), 4 deletions(-) diff --git a/surya/foundation/__init__.py b/surya/foundation/__init__.py index 47bc036..5cfea2a 100644 --- a/surya/foundation/__init__.py +++ b/surya/foundation/__init__.py @@ -7,6 +7,7 @@ from collections import deque import cv2 import numpy as np import torch +import math from PIL import Image from tqdm import tqdm import torch.nn.functional as F @@ -453,9 +454,42 @@ class FoundationPredictor(BasePredictor): return new_input, processed_outputs, idxs_to_merge - def get_max_image_token_count(self, images: list[np.ndarray]) -> int: - # Extra 1 to account for rotation token when present - return 1 + self.processor.num_register_tokens + 2048 + def get_max_image_token_count(self, images: list[np.ndarray], tasks: List[TaskNames]) -> int: + def compute_scaled_size(H: int, W: int, max_size: Tuple[int, int]) -> Tuple[int, int]: + max_W, max_H = max_size + min_W, min_H = (168, 168) + + current_pixels = H * W + max_pixels = max_H * max_W + min_pixels = min_H * min_W + + if current_pixels > max_pixels: + scale = (max_pixels / current_pixels) ** 0.5 + return math.floor(H * scale), math.floor(W * scale) + elif current_pixels < min_pixels: + scale = (min_pixels / current_pixels) ** 0.5 + return math.ceil(H * scale), math.ceil(W * scale) + return H, W + + def get_tile_count(H: int, W: int, factor: int) -> int: + H_bar = math.ceil(H / factor) * factor + W_bar = math.ceil(W / factor) * factor + grid_h = H_bar / self.processor.patch_size + grid_w = W_bar // self.processor.patch_size + return grid_h * grid_w + + max_tokens = 0 + factor = self.processor.patch_size * self.processor.merge_size + + for image, task in zip(images, tasks): + H, W = image.shape[:2] + max_size = self.tasks[task]["img_size"] + scaled_H, scaled_W = compute_scaled_size(H, W, max_size) + token_count = get_tile_count(scaled_H, scaled_W, factor) / (self.processor.merge_size ** 2) + max_tokens = max(max_tokens, token_count) + + # Extra 10 to account for EOS/BOS/Rotation token etc. + return 10 + self.processor.num_register_tokens + int(max_tokens) def prediction_loop( self, @@ -485,7 +519,7 @@ class FoundationPredictor(BasePredictor): batch_size = min(len(images), batch_size) current_inputs = None - max_image_tokens = self.get_max_image_token_count(images) + max_image_tokens = self.get_max_image_token_count(images, task_names) if max_sliding_window is None: max_sliding_window = self.model.config.sliding_window self.setup_cache(batch_size, max_cache_len=max_image_tokens + max_sliding_window, max_sliding_window=max_sliding_window)