Add in GQA

This commit is contained in:
Vik Paruchuri 2024-01-26 13:36:22 -08:00
parent e1376dba6e
commit f3d9de9400
8 changed files with 295 additions and 23 deletions

View File

@ -1,5 +1,4 @@
import argparse
import copy
import json
from collections import defaultdict
@ -8,10 +7,10 @@ from surya.input.processing import slice_polys_from_image
from surya.model.detection.segformer import load_model as load_detection_model, load_processor as load_detection_processor
from surya.model.recognition.model import load_model as load_recognition_model
from surya.model.recognition.processor import load_processor as load_recognition_processor
from surya.model.recognition.tokenizer import _tokenize
from surya.postprocessing.text import draw_text_on_image
from surya.detection import batch_detection
from surya.recognition import batch_recognition
from surya.postprocessing.affinity import draw_lines_on_image
from surya.postprocessing.heatmap import draw_polys_on_image
from surya.settings import settings
import os
@ -23,13 +22,12 @@ def main():
parser.add_argument("--max", type=int, help="Maximum number of pages to process.", default=None)
parser.add_argument("--start_page", type=int, help="Page to start processing at.", default=0)
parser.add_argument("--images", action="store_true", help="Save images of detected bboxes.", default=False)
parser.add_argument("--debug", action="store_true", help="Run in debug mode.", default=False)
parser.add_argument("--lang", type=str, help="Language to use for OCR. Comma separate for multiple.", default="en")
args = parser.parse_args()
langs = args.lang.split(",")
detection_model = load_detection_model()
detection_processor = load_detection_processor()
detection_model = load_detection_model()
if os.path.isdir(args.input_path):
images, names = load_from_folder(args.input_path, args.max, args.start_page)
@ -38,27 +36,52 @@ def main():
images, names = load_from_file(args.input_path, args.max, args.start_page)
folder_name = os.path.basename(args.input_path).split(".")[0]
predictions = batch_detection(images, detection_model, detection_processor)
det_predictions = batch_detection(images, detection_model, detection_processor)
result_path = os.path.join(args.results_dir, folder_name)
os.makedirs(result_path, exist_ok=True)
del detection_processor
del detection_model
recognition_model = load_recognition_model()
_, lang_tokens = _tokenize("", langs)
recognition_model = load_recognition_model(langs=lang_tokens) # Prune model moes to only include languages we need
recognition_processor = load_recognition_processor()
slice_map = []
all_slices = []
all_langs = []
for idx, (image, pred, name) in enumerate(zip(images, predictions, names)):
for idx, (image, pred, name) in enumerate(zip(images, det_predictions, names)):
slices = slice_polys_from_image(image, pred["polygons"])
slice_map.append(len(slices))
all_slices.extend(slices)
all_langs.extend([langs] * len(slices))
predictions = batch_recognition(all_slices, all_langs, recognition_model, recognition_processor)
print(predictions)
rec_predictions = batch_recognition(all_slices, all_langs, recognition_model, recognition_processor)
predictions_by_page = defaultdict(list)
slice_start = 0
for idx, (image, det_pred, name) in enumerate(zip(images, det_predictions, names)):
slice_end = slice_start + slice_map[idx]
image_lines = rec_predictions[slice_start:slice_end]
slice_start = slice_end
assert len(image_lines) == len(det_pred["polygons"]) == len(det_pred["bboxes"])
predictions_by_page[name].append({
"lines": image_lines,
"polys": det_pred["polygons"],
"bboxes": det_pred["bboxes"],
"name": name,
"page_number": len(predictions_by_page[name]) + 1
})
if args.images:
page_image = draw_text_on_image(det_pred["bboxes"], image_lines, image.size)
page_image.save(os.path.join(result_path, f"{name}_{idx}_text.png"))
with open(os.path.join(result_path, "results.json"), "w+") as f:
json.dump(predictions_by_page, f)
print(f"Wrote results to {result_path}")
if __name__ == "__main__":

2
static/fonts/.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
*
!.gitignore

View File

@ -8,7 +8,7 @@ from transformers.activations import ACT2FN
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_attention_mask
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions, BaseModelOutputWithPastAndCrossAttentions
from transformers.models.mbart.modeling_mbart import MBartPreTrainedModel, MBartDecoder, \
MBartLearnedPositionalEmbedding, MBART_ATTENTION_CLASSES
MBartLearnedPositionalEmbedding
from surya.model.recognition.config import MBartMoEConfig
import torch
import math
@ -84,6 +84,195 @@ class MBartExpertLayer(nn.Module):
return final_hidden_states
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
From llama
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
class MBartGQAttention(nn.Module):
def __init__(
self,
embed_dim: int,
num_heads: int,
num_kv_heads: int,
dropout: float = 0.0,
is_decoder: bool = False,
bias: bool = True,
is_causal: bool = False,
config: Optional[MBartConfig] = None,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.num_kv_groups = self.num_heads // self.num_kv_heads
assert self.num_heads % self.num_kv_heads == 0, f"num_heads ({self.num_heads}) must be divisible by num_kv_heads ({self.num_kv_heads})"
assert embed_dim % self.num_kv_heads == 0, f"embed_dim ({self.embed_dim}) must be divisible by num_kv_heads ({self.num_kv_heads})"
self.dropout = dropout
self.head_dim = embed_dim // num_heads
self.config = config
if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {num_heads})."
)
self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
self.is_causal = is_causal
self.k_proj = nn.Linear(embed_dim, self.num_kv_heads * self.head_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, self.num_kv_heads * self.head_dim, bias=bias)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def _shape_key_value(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, _ = hidden_states.size()
# get query proj
query_states = self.q_proj(hidden_states) * self.scaling
# get key, value proj
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
# is checking that the `sequence_length` of the `past_key_value` is the same as
# the provided `key_value_states` to support prefix tuning
if (
is_cross_attention
and past_key_value is not None
and past_key_value[0].shape[2] == key_value_states.shape[1]
):
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
key_states = self._shape_key_value(self.k_proj(key_value_states), -1, bsz)
value_states = self._shape_key_value(self.v_proj(key_value_states), -1, bsz)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self._shape_key_value(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape_key_value(self.v_proj(hidden_states), -1, bsz)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
key_states = self._shape_key_value(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape_key_value(self.v_proj(hidden_states), -1, bsz)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
# Expand kv heads, then match query shape
key_states = repeat_kv(key_states, self.num_kv_groups)
value_states = repeat_kv(value_states, self.num_kv_groups)
key_states = key_states.reshape(*proj_shape)
value_states = value_states.reshape(*proj_shape)
src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,):
raise ValueError(
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
f" {layer_head_mask.size()}"
)
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if output_attentions:
# this operation is a bit awkward, but it's required to
# make sure that attn_weights keeps its gradient.
# In order to do so, attn_weights have to be reshaped
# twice and have to be reused in the following
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
else:
attn_weights_reshaped = None
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.bmm(attn_probs, value_states)
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned across GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights_reshaped, past_key_value
MBART_ATTENTION_CLASSES = {
"eager": MBartGQAttention,
"flash_attention_2": None
}
class MBartMoEDecoderLayer(nn.Module):
def __init__(self, config: MBartConfig, has_moe=False):
super().__init__()
@ -92,6 +281,7 @@ class MBartMoEDecoderLayer(nn.Module):
self.self_attn = MBART_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
num_kv_heads=config.kv_heads,
dropout=config.attention_dropout,
is_decoder=True,
is_causal=True,
@ -105,6 +295,7 @@ class MBartMoEDecoderLayer(nn.Module):
self.encoder_attn = MBART_ATTENTION_CLASSES[config._attn_implementation](
self.embed_dim,
config.decoder_attention_heads,
num_kv_heads=config.kv_heads,
dropout=config.attention_dropout,
is_decoder=True,
config=config,
@ -498,4 +689,15 @@ class MBartMoE(MBartForCausalLM):
"past_key_values": past_key_values,
"use_cache": use_cache,
"langs": langs
}
}
def prune_moe_experts(self, keep_keys: List[int]):
keep_keys = [str(key) for key in keep_keys]
for layer in self.model.decoder.layers:
if not layer.has_moe:
continue
lang_keys = list(layer.moe.experts.keys())
for lang in lang_keys:
if lang not in keep_keys:
layer.moe.experts.pop(lang)

View File

@ -1,3 +1,4 @@
from typing import List, Optional
from transformers import VisionEncoderDecoderModel, VisionEncoderDecoderConfig, AutoModel, AutoModelForCausalLM
from surya.model.recognition.config import MBartMoEConfig, VariableDonutSwinConfig
from surya.model.recognition.encoder import VariableDonutSwinModel
@ -5,7 +6,7 @@ from surya.model.recognition.decoder import MBartMoE
from surya.settings import settings
def load_model(checkpoint=settings.RECOGNITION_MODEL_CHECKPOINT, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE):
def load_model(checkpoint=settings.RECOGNITION_MODEL_CHECKPOINT, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE, langs: Optional[List[int]] = None):
config = VisionEncoderDecoderConfig.from_pretrained(checkpoint)
decoder_config = vars(config.decoder)
@ -25,6 +26,10 @@ def load_model(checkpoint=settings.RECOGNITION_MODEL_CHECKPOINT, device=settings
assert isinstance(model.decoder, MBartMoE)
assert isinstance(model.encoder, VariableDonutSwinModel)
# Prune moe experts that are not needed
if langs is not None:
model.decoder.prune_moe_experts(langs)
model = model.to(device)
model = model.eval()
print(f"Loading recognition model {checkpoint} on device {device} with dtype {dtype}")

View File

@ -16,7 +16,7 @@ from surya.settings import settings
def load_processor():
processor = SuryaProcessor()
processor.image_processor.train = False
processor.image_processor.max_size = {"height": 180, "width": 900}
processor.image_processor.max_size = settings.RECOGNITION_IMAGE_SIZE
processor.tokenizer.model_max_length = settings.RECOGNITION_MAX_TOKENS
return processor

View File

@ -0,0 +1,35 @@
from PIL import Image, ImageDraw, ImageFont
from surya.settings import settings
def get_text_size(text, font):
im = Image.new(mode="P", size=(0, 0))
draw = ImageDraw.Draw(im)
_, _, width, height = draw.textbbox((0, 0), text=text, font=font)
return width, height
def draw_text_on_image(bboxes, texts, image_size=(1024, 1024), font_path=settings.RECOGNITION_RENDER_FONT, font_size=14):
image = Image.new('RGB', image_size, color='white')
draw = ImageDraw.Draw(image)
for bbox, text in zip(bboxes, texts):
bbox_width = bbox[2] - bbox[0]
bbox_height = bbox[3] - bbox[1]
# Shrink the text to fit in the bbox if needed
font = ImageFont.truetype(font_path, font_size)
text_width, text_height = get_text_size(text, font)
while (text_width > bbox_width or text_height > bbox_height) and font_size > 6:
font_size -= 1
font = ImageFont.truetype(font_path, font_size)
text_width, text_height = get_text_size(text, font)
# Calculate text position (centered in bbox)
text_width, text_height = get_text_size(text, font)
x = bbox[0] + (bbox_width - text_width) / 2
y = bbox[1] + (bbox_height - text_height) / 2
draw.text((x, y), text, fill="black", font=font)
return image

View File

@ -3,6 +3,7 @@ import torch
from PIL import Image
from surya.settings import settings
from tqdm import tqdm
import numpy as np
def batch_recognition(images: List, languages: List[List[str]], model, processor):
@ -17,9 +18,9 @@ def batch_recognition(images: List, languages: List[List[str]], model, processor
batch_pixel_values = model_inputs["pixel_values"][i:i+settings.RECOGNITION_BATCH_SIZE]
batch_decoder_input = [[model.config.decoder_start_token_id] + lang for lang in batch_langs]
batch_langs = torch.tensor(batch_langs, dtype=torch.long).to(model.device)
batch_pixel_values = torch.tensor(batch_pixel_values, dtype=model.dtype).to(model.device)
batch_decoder_input = torch.tensor(batch_decoder_input, dtype=torch.long).to(model.device)
batch_langs = torch.from_numpy(np.array(batch_langs, dtype=np.int64)).to(model.device)
batch_pixel_values = torch.tensor(np.array(batch_pixel_values), dtype=model.dtype).to(model.device)
batch_decoder_input = torch.from_numpy(np.array(batch_decoder_input, dtype=np.int64)).to(model.device)
with torch.inference_mode():
generated_ids = model.generate(

View File

@ -12,6 +12,12 @@ class Settings(BaseSettings):
TORCH_DEVICE: Optional[str] = None
IMAGE_DPI: int = 96
# Paths
DATA_DIR: str = "data"
RESULT_DIR: str = "results"
BASE_DIR: str = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
FONT_DIR: str = os.path.join(BASE_DIR, "static", "fonts")
@computed_field
@property
def TORCH_DEVICE_MODEL(self) -> str:
@ -47,13 +53,11 @@ class Settings(BaseSettings):
DETECTOR_NMS_THRESHOLD: float = 0.35 # Threshold for non-maximum suppression
# Text recognition
RECOGNITION_MODEL_CHECKPOINT: str = "vikp/rec_test"
RECOGNITION_MODEL_CHECKPOINT: str = "vikp/rec_test_gqa"
RECOGNITION_MAX_TOKENS: int = 512
RECOGNITION_BATCH_SIZE: int = 128 if TORCH_DEVICE_MODEL == "cuda" else 8
# Paths
DATA_DIR: str = "data"
RESULT_DIR: str = "results"
RECOGNITION_BATCH_SIZE: int = 8 if TORCH_DEVICE_MODEL in ["cpu", "mps"] else 128
RECOGNITION_IMAGE_SIZE: Dict = {"height": 196, "width": 896}
RECOGNITION_RENDER_FONT: str = os.path.join(FONT_DIR, "GoNotoKurrent-Regular.ttf")
@computed_field
@property