Transformers 4.46 compat
Some checks failed
Integration test / build (push) Has been cancelled

This commit is contained in:
Vik Paruchuri 2024-10-24 14:27:30 -04:00
parent 907692714d
commit 4978be7066
5 changed files with 188 additions and 20 deletions

View File

@ -128,7 +128,6 @@ def parallel_get_lines(preds, orig_sizes):
return result
def batch_text_detection(images: List, model, processor, batch_size=None) -> List[TextDetectionResult]:
detection_generator = batch_detection(images, model, processor, batch_size=batch_size)

View File

@ -1,8 +1,131 @@
from transformers import MBartConfig, DonutSwinConfig
from transformers import PretrainedConfig
class MBartOrderConfig(MBartConfig):
pass
class SuryaOrderConfig(PretrainedConfig):
model_type = "vision-encoder-decoder"
is_composition = True
class VariableDonutSwinConfig(DonutSwinConfig):
pass
def __init__(self, **kwargs):
super().__init__(**kwargs)
encoder_config = kwargs.pop("encoder")
decoder_config = kwargs.pop("decoder")
self.encoder = encoder_config
self.decoder = decoder_config
self.is_encoder_decoder = True
class MBartOrderConfig(PretrainedConfig):
model_type = "surya_order"
keys_to_ignore_at_inference = ["past_key_values"]
attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
def __init__(
self,
vocab_size=50265,
max_position_embeddings=1024,
encoder_layers=12,
encoder_ffn_dim=4096,
encoder_attention_heads=16,
decoder_layers=12,
decoder_ffn_dim=4096,
decoder_attention_heads=16,
encoder_layerdrop=0.0,
decoder_layerdrop=0.0,
use_cache=True,
is_encoder_decoder=True,
activation_function="gelu",
d_model=1024,
dropout=0.1,
attention_dropout=0.0,
activation_dropout=0.0,
init_std=0.02,
classifier_dropout=0.0,
scale_embedding=False,
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
forced_eos_token_id=2,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.d_model = d_model
self.encoder_ffn_dim = encoder_ffn_dim
self.encoder_layers = encoder_layers
self.encoder_attention_heads = encoder_attention_heads
self.decoder_ffn_dim = decoder_ffn_dim
self.decoder_layers = decoder_layers
self.decoder_attention_heads = decoder_attention_heads
self.dropout = dropout
self.attention_dropout = attention_dropout
self.activation_dropout = activation_dropout
self.activation_function = activation_function
self.init_std = init_std
self.encoder_layerdrop = encoder_layerdrop
self.decoder_layerdrop = decoder_layerdrop
self.classifier_dropout = classifier_dropout
self.use_cache = use_cache
self.num_hidden_layers = encoder_layers
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
is_encoder_decoder=is_encoder_decoder,
forced_eos_token_id=forced_eos_token_id,
**kwargs,
)
class VariableDonutSwinConfig(PretrainedConfig):
model_type = "donut-swin"
attribute_map = {
"num_attention_heads": "num_heads",
"num_hidden_layers": "num_layers",
}
def __init__(
self,
image_size=224,
patch_size=4,
num_channels=3,
embed_dim=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_size=7,
mlp_ratio=4.0,
qkv_bias=True,
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
drop_path_rate=0.1,
hidden_act="gelu",
use_absolute_embeddings=False,
initializer_range=0.02,
layer_norm_eps=1e-5,
**kwargs,
):
super().__init__(**kwargs)
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.embed_dim = embed_dim
self.depths = depths
self.num_layers = len(depths)
self.num_heads = num_heads
self.window_size = window_size
self.mlp_ratio = mlp_ratio
self.qkv_bias = qkv_bias
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.drop_path_rate = drop_path_rate
self.hidden_act = hidden_act
self.use_absolute_embeddings = use_absolute_embeddings
self.layer_norm_eps = layer_norm_eps
self.initializer_range = initializer_range
# we set the hidden_size attribute in order to make Swin work with VisionEncoderDecoderModel
# this indicates the channel dimension after the last stage of the model
self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))

View File

@ -195,18 +195,12 @@ class MBartGQAttention(nn.Module):
return attn_output, attn_weights_reshaped, past_key_value
MBART_ATTENTION_CLASSES = {
"eager": MBartGQAttention,
"flash_attention_2": None
}
class MBartOrderDecoderLayer(MBartDecoderLayer):
def __init__(self, config: MBartConfig):
nn.Module.__init__(self)
self.embed_dim = config.d_model
self.self_attn = MBART_ATTENTION_CLASSES[config._attn_implementation](
self.self_attn = MBartGQAttention(
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
num_kv_heads=config.kv_heads,
@ -220,7 +214,7 @@ class MBartOrderDecoderLayer(MBartDecoderLayer):
self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.encoder_attn = MBART_ATTENTION_CLASSES[config._attn_implementation](
self.encoder_attn = MBartGQAttention(
self.embed_dim,
config.decoder_attention_heads,
num_kv_heads=config.kv_heads,

View File

@ -1,11 +1,53 @@
from typing import Optional, Union, Tuple, List
import torch
from transformers import VisionEncoderDecoderModel, GenerationMixin
from transformers import VisionEncoderDecoderModel, GenerationMixin, PreTrainedModel, PretrainedConfig
from transformers.modeling_outputs import Seq2SeqLMOutput, BaseModelOutput
from surya.model.ordering.decoder import MBartOrder
from surya.model.ordering.encoder import VariableDonutSwinModel
class OrderVisionEncoderDecoderModel(PreTrainedModel, GenerationMixin):
def __init__(
self,
config: Optional[PretrainedConfig] = None,
encoder: Optional[PreTrainedModel] = None,
decoder: Optional[PreTrainedModel] = None,
text_encoder: Optional[PreTrainedModel] = None,
):
# initialize with config
# make sure input & output embeddings is not tied
config.tie_word_embeddings = False
config.decoder.tie_word_embeddings = False
super().__init__(config)
if encoder is None:
encoder = VariableDonutSwinModel(config.encoder)
if decoder is None:
decoder = MBartOrder(config.decoder, attn_implementation=config._attn_implementation)
self.encoder = encoder
self.decoder = decoder
# make sure that the individual model's config refers to the shared config
# so that the updates to the config will be synced
self.encoder.config = self.config.encoder
self.decoder.config = self.config.decoder
def get_encoder(self):
return self.encoder
def get_decoder(self):
return self.decoder
def get_output_embeddings(self):
return self.decoder.get_output_embeddings()
def set_output_embeddings(self, new_embeddings):
return self.decoder.set_output_embeddings(new_embeddings)
class OrderVisionEncoderDecoderModel(VisionEncoderDecoderModel, GenerationMixin):
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
@ -88,3 +130,13 @@ class OrderVisionEncoderDecoderModel(VisionEncoderDecoderModel, GenerationMixin)
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)
def resize_token_embeddings(self, *args, **kwargs):
raise NotImplementedError(
"Resizing the embedding layers via the VisionEncoderDecoderModel directly is not supported.Please use the"
" respective methods of the wrapped decoder object (model.decoder.resize_token_embeddings(...))"
)
def _reorder_cache(self, past_key_values, beam_idx):
# apply decoder cache reordering here
return self.decoder._reorder_cache(past_key_values, beam_idx)

View File

@ -1,6 +1,6 @@
from transformers import DetrConfig, BeitConfig, DetrImageProcessor, VisionEncoderDecoderConfig, AutoModelForCausalLM, \
AutoModel
from surya.model.ordering.config import MBartOrderConfig, VariableDonutSwinConfig
from surya.model.ordering.config import MBartOrderConfig, VariableDonutSwinConfig, SuryaOrderConfig
from surya.model.ordering.decoder import MBartOrder
from surya.model.ordering.encoder import VariableDonutSwinModel
from surya.model.ordering.encoderdecoder import OrderVisionEncoderDecoderModel
@ -9,13 +9,13 @@ from surya.settings import settings
def load_model(checkpoint=settings.ORDER_MODEL_CHECKPOINT, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE):
config = VisionEncoderDecoderConfig.from_pretrained(checkpoint)
config = SuryaOrderConfig.from_pretrained(checkpoint)
decoder_config = vars(config.decoder)
decoder_config = config.decoder
decoder = MBartOrderConfig(**decoder_config)
config.decoder = decoder
encoder_config = vars(config.encoder)
encoder_config = config.encoder
encoder = VariableDonutSwinConfig(**encoder_config)
config.encoder = encoder