mirror of
https://github.com/VikParuchuri/surya.git
synced 2026-06-04 21:03:53 +08:00
Better calculation of max image token count
This commit is contained in:
parent
3212707c49
commit
fe8545cfc8
@ -7,6 +7,7 @@ from collections import deque
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import math
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
import torch.nn.functional as F
|
||||
@ -453,9 +454,42 @@ class FoundationPredictor(BasePredictor):
|
||||
|
||||
return new_input, processed_outputs, idxs_to_merge
|
||||
|
||||
def get_max_image_token_count(self, images: list[np.ndarray]) -> int:
|
||||
# Extra 1 to account for rotation token when present
|
||||
return 1 + self.processor.num_register_tokens + 2048
|
||||
def get_max_image_token_count(self, images: list[np.ndarray], tasks: List[TaskNames]) -> int:
|
||||
def compute_scaled_size(H: int, W: int, max_size: Tuple[int, int]) -> Tuple[int, int]:
|
||||
max_W, max_H = max_size
|
||||
min_W, min_H = (168, 168)
|
||||
|
||||
current_pixels = H * W
|
||||
max_pixels = max_H * max_W
|
||||
min_pixels = min_H * min_W
|
||||
|
||||
if current_pixels > max_pixels:
|
||||
scale = (max_pixels / current_pixels) ** 0.5
|
||||
return math.floor(H * scale), math.floor(W * scale)
|
||||
elif current_pixels < min_pixels:
|
||||
scale = (min_pixels / current_pixels) ** 0.5
|
||||
return math.ceil(H * scale), math.ceil(W * scale)
|
||||
return H, W
|
||||
|
||||
def get_tile_count(H: int, W: int, factor: int) -> int:
|
||||
H_bar = math.ceil(H / factor) * factor
|
||||
W_bar = math.ceil(W / factor) * factor
|
||||
grid_h = H_bar / self.processor.patch_size
|
||||
grid_w = W_bar // self.processor.patch_size
|
||||
return grid_h * grid_w
|
||||
|
||||
max_tokens = 0
|
||||
factor = self.processor.patch_size * self.processor.merge_size
|
||||
|
||||
for image, task in zip(images, tasks):
|
||||
H, W = image.shape[:2]
|
||||
max_size = self.tasks[task]["img_size"]
|
||||
scaled_H, scaled_W = compute_scaled_size(H, W, max_size)
|
||||
token_count = get_tile_count(scaled_H, scaled_W, factor) / (self.processor.merge_size ** 2)
|
||||
max_tokens = max(max_tokens, token_count)
|
||||
|
||||
# Extra 10 to account for EOS/BOS/Rotation token etc.
|
||||
return 10 + self.processor.num_register_tokens + int(max_tokens)
|
||||
|
||||
def prediction_loop(
|
||||
self,
|
||||
@ -485,7 +519,7 @@ class FoundationPredictor(BasePredictor):
|
||||
batch_size = min(len(images), batch_size)
|
||||
current_inputs = None
|
||||
|
||||
max_image_tokens = self.get_max_image_token_count(images)
|
||||
max_image_tokens = self.get_max_image_token_count(images, task_names)
|
||||
if max_sliding_window is None:
|
||||
max_sliding_window = self.model.config.sliding_window
|
||||
self.setup_cache(batch_size, max_cache_len=max_image_tokens + max_sliding_window, max_sliding_window=max_sliding_window)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user