Better calculation of max image token count

This commit is contained in:
Tarun Menta 2025-08-01 18:39:06 -04:00
parent 3212707c49
commit fe8545cfc8
No known key found for this signature in database

View File

@ -7,6 +7,7 @@ from collections import deque
import cv2
import numpy as np
import torch
import math
from PIL import Image
from tqdm import tqdm
import torch.nn.functional as F
@ -453,9 +454,42 @@ class FoundationPredictor(BasePredictor):
return new_input, processed_outputs, idxs_to_merge
def get_max_image_token_count(self, images: list[np.ndarray]) -> int:
# Extra 1 to account for rotation token when present
return 1 + self.processor.num_register_tokens + 2048
def get_max_image_token_count(self, images: list[np.ndarray], tasks: List[TaskNames]) -> int:
def compute_scaled_size(H: int, W: int, max_size: Tuple[int, int]) -> Tuple[int, int]:
max_W, max_H = max_size
min_W, min_H = (168, 168)
current_pixels = H * W
max_pixels = max_H * max_W
min_pixels = min_H * min_W
if current_pixels > max_pixels:
scale = (max_pixels / current_pixels) ** 0.5
return math.floor(H * scale), math.floor(W * scale)
elif current_pixels < min_pixels:
scale = (min_pixels / current_pixels) ** 0.5
return math.ceil(H * scale), math.ceil(W * scale)
return H, W
def get_tile_count(H: int, W: int, factor: int) -> int:
H_bar = math.ceil(H / factor) * factor
W_bar = math.ceil(W / factor) * factor
grid_h = H_bar / self.processor.patch_size
grid_w = W_bar // self.processor.patch_size
return grid_h * grid_w
max_tokens = 0
factor = self.processor.patch_size * self.processor.merge_size
for image, task in zip(images, tasks):
H, W = image.shape[:2]
max_size = self.tasks[task]["img_size"]
scaled_H, scaled_W = compute_scaled_size(H, W, max_size)
token_count = get_tile_count(scaled_H, scaled_W, factor) / (self.processor.merge_size ** 2)
max_tokens = max(max_tokens, token_count)
# Extra 10 to account for EOS/BOS/Rotation token etc.
return 10 + self.processor.num_register_tokens + int(max_tokens)
def prediction_loop(
self,
@ -485,7 +519,7 @@ class FoundationPredictor(BasePredictor):
batch_size = min(len(images), batch_size)
current_inputs = None
max_image_tokens = self.get_max_image_token_count(images)
max_image_tokens = self.get_max_image_token_count(images, task_names)
if max_sliding_window is None:
max_sliding_window = self.model.config.sliding_window
self.setup_cache(batch_size, max_cache_len=max_image_tokens + max_sliding_window, max_sliding_window=max_sliding_window)