mirror of
https://github.com/VikParuchuri/surya.git
synced 2026-06-04 21:03:53 +08:00
This commit is contained in:
parent
907692714d
commit
4978be7066
@ -128,7 +128,6 @@ def parallel_get_lines(preds, orig_sizes):
|
||||
return result
|
||||
|
||||
|
||||
|
||||
def batch_text_detection(images: List, model, processor, batch_size=None) -> List[TextDetectionResult]:
|
||||
detection_generator = batch_detection(images, model, processor, batch_size=batch_size)
|
||||
|
||||
|
||||
@ -1,8 +1,131 @@
|
||||
from transformers import MBartConfig, DonutSwinConfig
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
|
||||
class MBartOrderConfig(MBartConfig):
|
||||
pass
|
||||
class SuryaOrderConfig(PretrainedConfig):
|
||||
model_type = "vision-encoder-decoder"
|
||||
is_composition = True
|
||||
|
||||
class VariableDonutSwinConfig(DonutSwinConfig):
|
||||
pass
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
encoder_config = kwargs.pop("encoder")
|
||||
decoder_config = kwargs.pop("decoder")
|
||||
|
||||
self.encoder = encoder_config
|
||||
self.decoder = decoder_config
|
||||
self.is_encoder_decoder = True
|
||||
|
||||
|
||||
class MBartOrderConfig(PretrainedConfig):
|
||||
model_type = "surya_order"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=50265,
|
||||
max_position_embeddings=1024,
|
||||
encoder_layers=12,
|
||||
encoder_ffn_dim=4096,
|
||||
encoder_attention_heads=16,
|
||||
decoder_layers=12,
|
||||
decoder_ffn_dim=4096,
|
||||
decoder_attention_heads=16,
|
||||
encoder_layerdrop=0.0,
|
||||
decoder_layerdrop=0.0,
|
||||
use_cache=True,
|
||||
is_encoder_decoder=True,
|
||||
activation_function="gelu",
|
||||
d_model=1024,
|
||||
dropout=0.1,
|
||||
attention_dropout=0.0,
|
||||
activation_dropout=0.0,
|
||||
init_std=0.02,
|
||||
classifier_dropout=0.0,
|
||||
scale_embedding=False,
|
||||
pad_token_id=1,
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
forced_eos_token_id=2,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.d_model = d_model
|
||||
self.encoder_ffn_dim = encoder_ffn_dim
|
||||
self.encoder_layers = encoder_layers
|
||||
self.encoder_attention_heads = encoder_attention_heads
|
||||
self.decoder_ffn_dim = decoder_ffn_dim
|
||||
self.decoder_layers = decoder_layers
|
||||
self.decoder_attention_heads = decoder_attention_heads
|
||||
self.dropout = dropout
|
||||
self.attention_dropout = attention_dropout
|
||||
self.activation_dropout = activation_dropout
|
||||
self.activation_function = activation_function
|
||||
self.init_std = init_std
|
||||
self.encoder_layerdrop = encoder_layerdrop
|
||||
self.decoder_layerdrop = decoder_layerdrop
|
||||
self.classifier_dropout = classifier_dropout
|
||||
self.use_cache = use_cache
|
||||
self.num_hidden_layers = encoder_layers
|
||||
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
is_encoder_decoder=is_encoder_decoder,
|
||||
forced_eos_token_id=forced_eos_token_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class VariableDonutSwinConfig(PretrainedConfig):
|
||||
model_type = "donut-swin"
|
||||
|
||||
attribute_map = {
|
||||
"num_attention_heads": "num_heads",
|
||||
"num_hidden_layers": "num_layers",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_size=224,
|
||||
patch_size=4,
|
||||
num_channels=3,
|
||||
embed_dim=96,
|
||||
depths=[2, 2, 6, 2],
|
||||
num_heads=[3, 6, 12, 24],
|
||||
window_size=7,
|
||||
mlp_ratio=4.0,
|
||||
qkv_bias=True,
|
||||
hidden_dropout_prob=0.0,
|
||||
attention_probs_dropout_prob=0.0,
|
||||
drop_path_rate=0.1,
|
||||
hidden_act="gelu",
|
||||
use_absolute_embeddings=False,
|
||||
initializer_range=0.02,
|
||||
layer_norm_eps=1e-5,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.image_size = image_size
|
||||
self.patch_size = patch_size
|
||||
self.num_channels = num_channels
|
||||
self.embed_dim = embed_dim
|
||||
self.depths = depths
|
||||
self.num_layers = len(depths)
|
||||
self.num_heads = num_heads
|
||||
self.window_size = window_size
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.qkv_bias = qkv_bias
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.drop_path_rate = drop_path_rate
|
||||
self.hidden_act = hidden_act
|
||||
self.use_absolute_embeddings = use_absolute_embeddings
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.initializer_range = initializer_range
|
||||
# we set the hidden_size attribute in order to make Swin work with VisionEncoderDecoderModel
|
||||
# this indicates the channel dimension after the last stage of the model
|
||||
self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))
|
||||
@ -195,18 +195,12 @@ class MBartGQAttention(nn.Module):
|
||||
return attn_output, attn_weights_reshaped, past_key_value
|
||||
|
||||
|
||||
MBART_ATTENTION_CLASSES = {
|
||||
"eager": MBartGQAttention,
|
||||
"flash_attention_2": None
|
||||
}
|
||||
|
||||
|
||||
class MBartOrderDecoderLayer(MBartDecoderLayer):
|
||||
def __init__(self, config: MBartConfig):
|
||||
nn.Module.__init__(self)
|
||||
self.embed_dim = config.d_model
|
||||
|
||||
self.self_attn = MBART_ATTENTION_CLASSES[config._attn_implementation](
|
||||
self.self_attn = MBartGQAttention(
|
||||
embed_dim=self.embed_dim,
|
||||
num_heads=config.decoder_attention_heads,
|
||||
num_kv_heads=config.kv_heads,
|
||||
@ -220,7 +214,7 @@ class MBartOrderDecoderLayer(MBartDecoderLayer):
|
||||
self.activation_dropout = config.activation_dropout
|
||||
|
||||
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||
self.encoder_attn = MBART_ATTENTION_CLASSES[config._attn_implementation](
|
||||
self.encoder_attn = MBartGQAttention(
|
||||
self.embed_dim,
|
||||
config.decoder_attention_heads,
|
||||
num_kv_heads=config.kv_heads,
|
||||
|
||||
@ -1,11 +1,53 @@
|
||||
from typing import Optional, Union, Tuple, List
|
||||
|
||||
import torch
|
||||
from transformers import VisionEncoderDecoderModel, GenerationMixin
|
||||
from transformers import VisionEncoderDecoderModel, GenerationMixin, PreTrainedModel, PretrainedConfig
|
||||
from transformers.modeling_outputs import Seq2SeqLMOutput, BaseModelOutput
|
||||
|
||||
from surya.model.ordering.decoder import MBartOrder
|
||||
from surya.model.ordering.encoder import VariableDonutSwinModel
|
||||
|
||||
|
||||
class OrderVisionEncoderDecoderModel(PreTrainedModel, GenerationMixin):
|
||||
def __init__(
|
||||
self,
|
||||
config: Optional[PretrainedConfig] = None,
|
||||
encoder: Optional[PreTrainedModel] = None,
|
||||
decoder: Optional[PreTrainedModel] = None,
|
||||
text_encoder: Optional[PreTrainedModel] = None,
|
||||
):
|
||||
# initialize with config
|
||||
# make sure input & output embeddings is not tied
|
||||
config.tie_word_embeddings = False
|
||||
config.decoder.tie_word_embeddings = False
|
||||
super().__init__(config)
|
||||
|
||||
if encoder is None:
|
||||
encoder = VariableDonutSwinModel(config.encoder)
|
||||
|
||||
if decoder is None:
|
||||
decoder = MBartOrder(config.decoder, attn_implementation=config._attn_implementation)
|
||||
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
|
||||
# make sure that the individual model's config refers to the shared config
|
||||
# so that the updates to the config will be synced
|
||||
self.encoder.config = self.config.encoder
|
||||
self.decoder.config = self.config.decoder
|
||||
|
||||
def get_encoder(self):
|
||||
return self.encoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.decoder
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.decoder.get_output_embeddings()
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
return self.decoder.set_output_embeddings(new_embeddings)
|
||||
|
||||
class OrderVisionEncoderDecoderModel(VisionEncoderDecoderModel, GenerationMixin):
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
@ -88,3 +130,13 @@ class OrderVisionEncoderDecoderModel(VisionEncoderDecoderModel, GenerationMixin)
|
||||
encoder_hidden_states=encoder_outputs.hidden_states,
|
||||
encoder_attentions=encoder_outputs.attentions,
|
||||
)
|
||||
|
||||
def resize_token_embeddings(self, *args, **kwargs):
|
||||
raise NotImplementedError(
|
||||
"Resizing the embedding layers via the VisionEncoderDecoderModel directly is not supported.Please use the"
|
||||
" respective methods of the wrapped decoder object (model.decoder.resize_token_embeddings(...))"
|
||||
)
|
||||
|
||||
def _reorder_cache(self, past_key_values, beam_idx):
|
||||
# apply decoder cache reordering here
|
||||
return self.decoder._reorder_cache(past_key_values, beam_idx)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from transformers import DetrConfig, BeitConfig, DetrImageProcessor, VisionEncoderDecoderConfig, AutoModelForCausalLM, \
|
||||
AutoModel
|
||||
from surya.model.ordering.config import MBartOrderConfig, VariableDonutSwinConfig
|
||||
from surya.model.ordering.config import MBartOrderConfig, VariableDonutSwinConfig, SuryaOrderConfig
|
||||
from surya.model.ordering.decoder import MBartOrder
|
||||
from surya.model.ordering.encoder import VariableDonutSwinModel
|
||||
from surya.model.ordering.encoderdecoder import OrderVisionEncoderDecoderModel
|
||||
@ -9,13 +9,13 @@ from surya.settings import settings
|
||||
|
||||
|
||||
def load_model(checkpoint=settings.ORDER_MODEL_CHECKPOINT, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE):
|
||||
config = VisionEncoderDecoderConfig.from_pretrained(checkpoint)
|
||||
config = SuryaOrderConfig.from_pretrained(checkpoint)
|
||||
|
||||
decoder_config = vars(config.decoder)
|
||||
decoder_config = config.decoder
|
||||
decoder = MBartOrderConfig(**decoder_config)
|
||||
config.decoder = decoder
|
||||
|
||||
encoder_config = vars(config.encoder)
|
||||
encoder_config = config.encoder
|
||||
encoder = VariableDonutSwinConfig(**encoder_config)
|
||||
config.encoder = encoder
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user