mirror of
https://github.com/VikParuchuri/surya.git
synced 2026-06-04 21:03:53 +08:00
Add in GQA
This commit is contained in:
parent
e1376dba6e
commit
f3d9de9400
43
ocr_text.py
43
ocr_text.py
@ -1,5 +1,4 @@
|
||||
import argparse
|
||||
import copy
|
||||
import json
|
||||
from collections import defaultdict
|
||||
|
||||
@ -8,10 +7,10 @@ from surya.input.processing import slice_polys_from_image
|
||||
from surya.model.detection.segformer import load_model as load_detection_model, load_processor as load_detection_processor
|
||||
from surya.model.recognition.model import load_model as load_recognition_model
|
||||
from surya.model.recognition.processor import load_processor as load_recognition_processor
|
||||
from surya.model.recognition.tokenizer import _tokenize
|
||||
from surya.postprocessing.text import draw_text_on_image
|
||||
from surya.detection import batch_detection
|
||||
from surya.recognition import batch_recognition
|
||||
from surya.postprocessing.affinity import draw_lines_on_image
|
||||
from surya.postprocessing.heatmap import draw_polys_on_image
|
||||
from surya.settings import settings
|
||||
import os
|
||||
|
||||
@ -23,13 +22,12 @@ def main():
|
||||
parser.add_argument("--max", type=int, help="Maximum number of pages to process.", default=None)
|
||||
parser.add_argument("--start_page", type=int, help="Page to start processing at.", default=0)
|
||||
parser.add_argument("--images", action="store_true", help="Save images of detected bboxes.", default=False)
|
||||
parser.add_argument("--debug", action="store_true", help="Run in debug mode.", default=False)
|
||||
parser.add_argument("--lang", type=str, help="Language to use for OCR. Comma separate for multiple.", default="en")
|
||||
args = parser.parse_args()
|
||||
|
||||
langs = args.lang.split(",")
|
||||
detection_model = load_detection_model()
|
||||
detection_processor = load_detection_processor()
|
||||
detection_model = load_detection_model()
|
||||
|
||||
if os.path.isdir(args.input_path):
|
||||
images, names = load_from_folder(args.input_path, args.max, args.start_page)
|
||||
@ -38,27 +36,52 @@ def main():
|
||||
images, names = load_from_file(args.input_path, args.max, args.start_page)
|
||||
folder_name = os.path.basename(args.input_path).split(".")[0]
|
||||
|
||||
predictions = batch_detection(images, detection_model, detection_processor)
|
||||
det_predictions = batch_detection(images, detection_model, detection_processor)
|
||||
result_path = os.path.join(args.results_dir, folder_name)
|
||||
os.makedirs(result_path, exist_ok=True)
|
||||
|
||||
del detection_processor
|
||||
del detection_model
|
||||
|
||||
recognition_model = load_recognition_model()
|
||||
_, lang_tokens = _tokenize("", langs)
|
||||
recognition_model = load_recognition_model(langs=lang_tokens) # Prune model moes to only include languages we need
|
||||
recognition_processor = load_recognition_processor()
|
||||
|
||||
slice_map = []
|
||||
all_slices = []
|
||||
all_langs = []
|
||||
for idx, (image, pred, name) in enumerate(zip(images, predictions, names)):
|
||||
for idx, (image, pred, name) in enumerate(zip(images, det_predictions, names)):
|
||||
slices = slice_polys_from_image(image, pred["polygons"])
|
||||
slice_map.append(len(slices))
|
||||
all_slices.extend(slices)
|
||||
all_langs.extend([langs] * len(slices))
|
||||
|
||||
predictions = batch_recognition(all_slices, all_langs, recognition_model, recognition_processor)
|
||||
print(predictions)
|
||||
rec_predictions = batch_recognition(all_slices, all_langs, recognition_model, recognition_processor)
|
||||
|
||||
predictions_by_page = defaultdict(list)
|
||||
slice_start = 0
|
||||
for idx, (image, det_pred, name) in enumerate(zip(images, det_predictions, names)):
|
||||
slice_end = slice_start + slice_map[idx]
|
||||
image_lines = rec_predictions[slice_start:slice_end]
|
||||
slice_start = slice_end
|
||||
|
||||
assert len(image_lines) == len(det_pred["polygons"]) == len(det_pred["bboxes"])
|
||||
predictions_by_page[name].append({
|
||||
"lines": image_lines,
|
||||
"polys": det_pred["polygons"],
|
||||
"bboxes": det_pred["bboxes"],
|
||||
"name": name,
|
||||
"page_number": len(predictions_by_page[name]) + 1
|
||||
})
|
||||
|
||||
if args.images:
|
||||
page_image = draw_text_on_image(det_pred["bboxes"], image_lines, image.size)
|
||||
page_image.save(os.path.join(result_path, f"{name}_{idx}_text.png"))
|
||||
|
||||
with open(os.path.join(result_path, "results.json"), "w+") as f:
|
||||
json.dump(predictions_by_page, f)
|
||||
|
||||
print(f"Wrote results to {result_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
2
static/fonts/.gitignore
vendored
Normal file
2
static/fonts/.gitignore
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
*
|
||||
!.gitignore
|
||||
@ -8,7 +8,7 @@ from transformers.activations import ACT2FN
|
||||
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_attention_mask
|
||||
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions, BaseModelOutputWithPastAndCrossAttentions
|
||||
from transformers.models.mbart.modeling_mbart import MBartPreTrainedModel, MBartDecoder, \
|
||||
MBartLearnedPositionalEmbedding, MBART_ATTENTION_CLASSES
|
||||
MBartLearnedPositionalEmbedding
|
||||
from surya.model.recognition.config import MBartMoEConfig
|
||||
import torch
|
||||
import math
|
||||
@ -84,6 +84,195 @@ class MBartExpertLayer(nn.Module):
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
From llama
|
||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||
"""
|
||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
class MBartGQAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
dropout: float = 0.0,
|
||||
is_decoder: bool = False,
|
||||
bias: bool = True,
|
||||
is_causal: bool = False,
|
||||
config: Optional[MBartConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.num_heads = num_heads
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.num_kv_groups = self.num_heads // self.num_kv_heads
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0, f"num_heads ({self.num_heads}) must be divisible by num_kv_heads ({self.num_kv_heads})"
|
||||
assert embed_dim % self.num_kv_heads == 0, f"embed_dim ({self.embed_dim}) must be divisible by num_kv_heads ({self.num_kv_heads})"
|
||||
|
||||
self.dropout = dropout
|
||||
self.head_dim = embed_dim // num_heads
|
||||
self.config = config
|
||||
|
||||
if (self.head_dim * num_heads) != self.embed_dim:
|
||||
raise ValueError(
|
||||
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
|
||||
f" and `num_heads`: {num_heads})."
|
||||
)
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.is_decoder = is_decoder
|
||||
self.is_causal = is_causal
|
||||
|
||||
self.k_proj = nn.Linear(embed_dim, self.num_kv_heads * self.head_dim, bias=bias)
|
||||
self.v_proj = nn.Linear(embed_dim, self.num_kv_heads * self.head_dim, bias=bias)
|
||||
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||
|
||||
def _shape_key_value(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return tensor.view(bsz, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
key_value_states: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||
# for the decoder
|
||||
is_cross_attention = key_value_states is not None
|
||||
|
||||
bsz, tgt_len, _ = hidden_states.size()
|
||||
|
||||
# get query proj
|
||||
query_states = self.q_proj(hidden_states) * self.scaling
|
||||
# get key, value proj
|
||||
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
|
||||
# is checking that the `sequence_length` of the `past_key_value` is the same as
|
||||
# the provided `key_value_states` to support prefix tuning
|
||||
if (
|
||||
is_cross_attention
|
||||
and past_key_value is not None
|
||||
and past_key_value[0].shape[2] == key_value_states.shape[1]
|
||||
):
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = past_key_value[0]
|
||||
value_states = past_key_value[1]
|
||||
elif is_cross_attention:
|
||||
# cross_attentions
|
||||
key_states = self._shape_key_value(self.k_proj(key_value_states), -1, bsz)
|
||||
value_states = self._shape_key_value(self.v_proj(key_value_states), -1, bsz)
|
||||
elif past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = self._shape_key_value(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._shape_key_value(self.v_proj(hidden_states), -1, bsz)
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
else:
|
||||
# self_attention
|
||||
key_states = self._shape_key_value(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._shape_key_value(self.v_proj(hidden_states), -1, bsz)
|
||||
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_states, value_states)
|
||||
|
||||
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
||||
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
||||
|
||||
# Expand kv heads, then match query shape
|
||||
key_states = repeat_kv(key_states, self.num_kv_groups)
|
||||
value_states = repeat_kv(value_states, self.num_kv_groups)
|
||||
key_states = key_states.reshape(*proj_shape)
|
||||
value_states = value_states.reshape(*proj_shape)
|
||||
|
||||
src_len = key_states.size(1)
|
||||
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
||||
|
||||
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
if layer_head_mask is not None:
|
||||
if layer_head_mask.size() != (self.num_heads,):
|
||||
raise ValueError(
|
||||
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
|
||||
f" {layer_head_mask.size()}"
|
||||
)
|
||||
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
if output_attentions:
|
||||
# this operation is a bit awkward, but it's required to
|
||||
# make sure that attn_weights keeps its gradient.
|
||||
# In order to do so, attn_weights have to be reshaped
|
||||
# twice and have to be reused in the following
|
||||
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
else:
|
||||
attn_weights_reshaped = None
|
||||
|
||||
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
||||
|
||||
attn_output = torch.bmm(attn_probs, value_states)
|
||||
|
||||
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
|
||||
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
|
||||
# partitioned across GPUs when using tensor-parallelism.
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
||||
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights_reshaped, past_key_value
|
||||
|
||||
|
||||
MBART_ATTENTION_CLASSES = {
|
||||
"eager": MBartGQAttention,
|
||||
"flash_attention_2": None
|
||||
}
|
||||
|
||||
|
||||
class MBartMoEDecoderLayer(nn.Module):
|
||||
def __init__(self, config: MBartConfig, has_moe=False):
|
||||
super().__init__()
|
||||
@ -92,6 +281,7 @@ class MBartMoEDecoderLayer(nn.Module):
|
||||
self.self_attn = MBART_ATTENTION_CLASSES[config._attn_implementation](
|
||||
embed_dim=self.embed_dim,
|
||||
num_heads=config.decoder_attention_heads,
|
||||
num_kv_heads=config.kv_heads,
|
||||
dropout=config.attention_dropout,
|
||||
is_decoder=True,
|
||||
is_causal=True,
|
||||
@ -105,6 +295,7 @@ class MBartMoEDecoderLayer(nn.Module):
|
||||
self.encoder_attn = MBART_ATTENTION_CLASSES[config._attn_implementation](
|
||||
self.embed_dim,
|
||||
config.decoder_attention_heads,
|
||||
num_kv_heads=config.kv_heads,
|
||||
dropout=config.attention_dropout,
|
||||
is_decoder=True,
|
||||
config=config,
|
||||
@ -498,4 +689,15 @@ class MBartMoE(MBartForCausalLM):
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": use_cache,
|
||||
"langs": langs
|
||||
}
|
||||
}
|
||||
|
||||
def prune_moe_experts(self, keep_keys: List[int]):
|
||||
keep_keys = [str(key) for key in keep_keys]
|
||||
for layer in self.model.decoder.layers:
|
||||
if not layer.has_moe:
|
||||
continue
|
||||
|
||||
lang_keys = list(layer.moe.experts.keys())
|
||||
for lang in lang_keys:
|
||||
if lang not in keep_keys:
|
||||
layer.moe.experts.pop(lang)
|
||||
@ -1,3 +1,4 @@
|
||||
from typing import List, Optional
|
||||
from transformers import VisionEncoderDecoderModel, VisionEncoderDecoderConfig, AutoModel, AutoModelForCausalLM
|
||||
from surya.model.recognition.config import MBartMoEConfig, VariableDonutSwinConfig
|
||||
from surya.model.recognition.encoder import VariableDonutSwinModel
|
||||
@ -5,7 +6,7 @@ from surya.model.recognition.decoder import MBartMoE
|
||||
from surya.settings import settings
|
||||
|
||||
|
||||
def load_model(checkpoint=settings.RECOGNITION_MODEL_CHECKPOINT, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE):
|
||||
def load_model(checkpoint=settings.RECOGNITION_MODEL_CHECKPOINT, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE, langs: Optional[List[int]] = None):
|
||||
config = VisionEncoderDecoderConfig.from_pretrained(checkpoint)
|
||||
|
||||
decoder_config = vars(config.decoder)
|
||||
@ -25,6 +26,10 @@ def load_model(checkpoint=settings.RECOGNITION_MODEL_CHECKPOINT, device=settings
|
||||
assert isinstance(model.decoder, MBartMoE)
|
||||
assert isinstance(model.encoder, VariableDonutSwinModel)
|
||||
|
||||
# Prune moe experts that are not needed
|
||||
if langs is not None:
|
||||
model.decoder.prune_moe_experts(langs)
|
||||
|
||||
model = model.to(device)
|
||||
model = model.eval()
|
||||
print(f"Loading recognition model {checkpoint} on device {device} with dtype {dtype}")
|
||||
|
||||
@ -16,7 +16,7 @@ from surya.settings import settings
|
||||
def load_processor():
|
||||
processor = SuryaProcessor()
|
||||
processor.image_processor.train = False
|
||||
processor.image_processor.max_size = {"height": 180, "width": 900}
|
||||
processor.image_processor.max_size = settings.RECOGNITION_IMAGE_SIZE
|
||||
processor.tokenizer.model_max_length = settings.RECOGNITION_MAX_TOKENS
|
||||
return processor
|
||||
|
||||
|
||||
35
surya/postprocessing/text.py
Normal file
35
surya/postprocessing/text.py
Normal file
@ -0,0 +1,35 @@
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from surya.settings import settings
|
||||
|
||||
|
||||
def get_text_size(text, font):
|
||||
im = Image.new(mode="P", size=(0, 0))
|
||||
draw = ImageDraw.Draw(im)
|
||||
_, _, width, height = draw.textbbox((0, 0), text=text, font=font)
|
||||
return width, height
|
||||
|
||||
|
||||
def draw_text_on_image(bboxes, texts, image_size=(1024, 1024), font_path=settings.RECOGNITION_RENDER_FONT, font_size=14):
|
||||
image = Image.new('RGB', image_size, color='white')
|
||||
draw = ImageDraw.Draw(image)
|
||||
|
||||
for bbox, text in zip(bboxes, texts):
|
||||
bbox_width = bbox[2] - bbox[0]
|
||||
bbox_height = bbox[3] - bbox[1]
|
||||
|
||||
# Shrink the text to fit in the bbox if needed
|
||||
font = ImageFont.truetype(font_path, font_size)
|
||||
text_width, text_height = get_text_size(text, font)
|
||||
while (text_width > bbox_width or text_height > bbox_height) and font_size > 6:
|
||||
font_size -= 1
|
||||
font = ImageFont.truetype(font_path, font_size)
|
||||
text_width, text_height = get_text_size(text, font)
|
||||
|
||||
# Calculate text position (centered in bbox)
|
||||
text_width, text_height = get_text_size(text, font)
|
||||
x = bbox[0] + (bbox_width - text_width) / 2
|
||||
y = bbox[1] + (bbox_height - text_height) / 2
|
||||
|
||||
draw.text((x, y), text, fill="black", font=font)
|
||||
|
||||
return image
|
||||
@ -3,6 +3,7 @@ import torch
|
||||
from PIL import Image
|
||||
from surya.settings import settings
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
|
||||
|
||||
def batch_recognition(images: List, languages: List[List[str]], model, processor):
|
||||
@ -17,9 +18,9 @@ def batch_recognition(images: List, languages: List[List[str]], model, processor
|
||||
batch_pixel_values = model_inputs["pixel_values"][i:i+settings.RECOGNITION_BATCH_SIZE]
|
||||
batch_decoder_input = [[model.config.decoder_start_token_id] + lang for lang in batch_langs]
|
||||
|
||||
batch_langs = torch.tensor(batch_langs, dtype=torch.long).to(model.device)
|
||||
batch_pixel_values = torch.tensor(batch_pixel_values, dtype=model.dtype).to(model.device)
|
||||
batch_decoder_input = torch.tensor(batch_decoder_input, dtype=torch.long).to(model.device)
|
||||
batch_langs = torch.from_numpy(np.array(batch_langs, dtype=np.int64)).to(model.device)
|
||||
batch_pixel_values = torch.tensor(np.array(batch_pixel_values), dtype=model.dtype).to(model.device)
|
||||
batch_decoder_input = torch.from_numpy(np.array(batch_decoder_input, dtype=np.int64)).to(model.device)
|
||||
|
||||
with torch.inference_mode():
|
||||
generated_ids = model.generate(
|
||||
|
||||
@ -12,6 +12,12 @@ class Settings(BaseSettings):
|
||||
TORCH_DEVICE: Optional[str] = None
|
||||
IMAGE_DPI: int = 96
|
||||
|
||||
# Paths
|
||||
DATA_DIR: str = "data"
|
||||
RESULT_DIR: str = "results"
|
||||
BASE_DIR: str = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
FONT_DIR: str = os.path.join(BASE_DIR, "static", "fonts")
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def TORCH_DEVICE_MODEL(self) -> str:
|
||||
@ -47,13 +53,11 @@ class Settings(BaseSettings):
|
||||
DETECTOR_NMS_THRESHOLD: float = 0.35 # Threshold for non-maximum suppression
|
||||
|
||||
# Text recognition
|
||||
RECOGNITION_MODEL_CHECKPOINT: str = "vikp/rec_test"
|
||||
RECOGNITION_MODEL_CHECKPOINT: str = "vikp/rec_test_gqa"
|
||||
RECOGNITION_MAX_TOKENS: int = 512
|
||||
RECOGNITION_BATCH_SIZE: int = 128 if TORCH_DEVICE_MODEL == "cuda" else 8
|
||||
|
||||
# Paths
|
||||
DATA_DIR: str = "data"
|
||||
RESULT_DIR: str = "results"
|
||||
RECOGNITION_BATCH_SIZE: int = 8 if TORCH_DEVICE_MODEL in ["cpu", "mps"] else 128
|
||||
RECOGNITION_IMAGE_SIZE: Dict = {"height": 196, "width": 896}
|
||||
RECOGNITION_RENDER_FONT: str = os.path.join(FONT_DIR, "GoNotoKurrent-Regular.ttf")
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
|
||||
Loading…
Reference in New Issue
Block a user