mirror of
https://github.com/VikParuchuri/surya.git
synced 2026-06-04 21:03:53 +08:00
feat: new unified tokenizer
This commit is contained in:
parent
e1aa09d3bc
commit
d6f3515009
@ -1,12 +1,14 @@
|
||||
import html
|
||||
import re
|
||||
from typing import List, Union, Dict
|
||||
from typing import List, Union, Dict, Optional, Tuple, Iterable
|
||||
import numpy as np
|
||||
import torch
|
||||
from tokenizers import AddedToken
|
||||
|
||||
import json
|
||||
import os
|
||||
from transformers import PreTrainedTokenizer, Qwen2Tokenizer as Qwen2OriginalTokenizer
|
||||
|
||||
|
||||
from surya.common.s3 import S3DownloaderMixin
|
||||
from surya.common.surya.schema import TASK_NAMES, TaskNames
|
||||
from surya.logging import get_logger
|
||||
@ -245,6 +247,423 @@ class InnerOCRTokenizer:
|
||||
class Qwen2Tokenizer(S3DownloaderMixin, Qwen2OriginalTokenizer):
|
||||
pass
|
||||
|
||||
class GreedyMathUTF16Tokenizer(PreTrainedTokenizer):
|
||||
"""
|
||||
HuggingFace slow tokenizer implementing:
|
||||
- UTF-16 code units as the base [0..65535]
|
||||
- Math tokens as greedy-longest-match ids after UTF-16
|
||||
- Literal special tokens after math tokens
|
||||
Absolute ID layout:
|
||||
[0 .. 65535] : UTF-16 units
|
||||
[65536 .. 65536+M-1] : math tokens
|
||||
[65536+M .. 65536+M+S-1] : special tokens
|
||||
"""
|
||||
|
||||
vocab_files_names = {
|
||||
"vocab_file": "vocab_math.json", # {"\\frac": 0, "\\alpha": 1, ...} raw contiguous ids 0..M-1
|
||||
"specials_file": "specials.json", # [flat list for legacy]
|
||||
"specials_dict_file": "specials_dict.json", # category dict (preferred)
|
||||
}
|
||||
model_input_names = ["input_ids", "attention_mask"]
|
||||
is_fast = False
|
||||
|
||||
# ---------- helpers ----------
|
||||
@staticmethod
|
||||
def _to_utf16_units(s: str) -> List[int]:
|
||||
b = s.encode("utf-16le")
|
||||
return [int.from_bytes(b[i : i + 2], "little") for i in range(0, len(b), 2)]
|
||||
|
||||
@staticmethod
|
||||
def _from_utf16_units(units: List[int]) -> str:
|
||||
b = bytearray()
|
||||
for u in units:
|
||||
b += int(u).to_bytes(2, "little")
|
||||
return b.decode("utf-16le", errors="strict")
|
||||
|
||||
class _TrieNode:
|
||||
__slots__ = ("child", "id", "leaf")
|
||||
|
||||
def __init__(self):
|
||||
self.child: Dict[str, "GreedyMathUTF16Tokenizer._TrieNode"] = {}
|
||||
self.id: Optional[int] = None
|
||||
self.leaf: bool = False
|
||||
|
||||
@classmethod
|
||||
def _build_trie(
|
||||
cls, token_to_id: Dict[str, int]
|
||||
) -> "GreedyMathUTF16Tokenizer._TrieNode":
|
||||
root = cls._TrieNode()
|
||||
for tok, tid in token_to_id.items():
|
||||
node = root
|
||||
for ch in tok:
|
||||
node = node.child.setdefault(ch, cls._TrieNode())
|
||||
node.leaf = True
|
||||
node.id = tid
|
||||
return root
|
||||
|
||||
@classmethod
|
||||
def _encode_math_greedy(
|
||||
cls,
|
||||
s: str,
|
||||
trie: "GreedyMathUTF16Tokenizer._TrieNode",
|
||||
math_base: int,
|
||||
debug: bool = False,
|
||||
) -> List[int]:
|
||||
i, n = 0, len(s)
|
||||
out: List[int] = []
|
||||
while i < n:
|
||||
node = trie
|
||||
j = i
|
||||
last_id = None
|
||||
last_j = i
|
||||
while j < n and (ch := s[j]) in node.child:
|
||||
node = node.child[ch]
|
||||
j += 1
|
||||
if node.leaf:
|
||||
last_id, last_j = node.id, j
|
||||
if last_id is not None:
|
||||
if debug:
|
||||
print(f"[MATH] matched {s[i:last_j]!r} -> {last_id}")
|
||||
out.append(math_base + last_id)
|
||||
i = last_j
|
||||
else:
|
||||
units = cls._to_utf16_units(s[i])
|
||||
if debug:
|
||||
print(f"[MATH] fallback {s[i]!r} -> utf16 {units}")
|
||||
out.extend(units)
|
||||
i += 1
|
||||
return out
|
||||
|
||||
# ---------- init ----------
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file: Optional[str] = None,
|
||||
specials_file: Optional[str] = None,
|
||||
specials_dict_file: Optional[str] = None,
|
||||
*,
|
||||
# You can also pass programmatically instead of files:
|
||||
math_vocab: Optional[Dict[str, int]] = None,
|
||||
special_tokens: Optional[List[str]] = None,
|
||||
special_tokens_dict: Optional[Dict[str, List[str]]] = None,
|
||||
debug: bool = False,
|
||||
# Standard HF special token kwargs:
|
||||
bos_token: Optional[str] = None,
|
||||
eos_token: Optional[str] = None,
|
||||
pad_token: Optional[str] = None,
|
||||
unk_token: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
# Load math vocab
|
||||
if vocab_file and os.path.isfile(vocab_file):
|
||||
with open(vocab_file, "r", encoding="utf-8") as f:
|
||||
mv = json.load(f)
|
||||
else:
|
||||
mv = math_vocab or {}
|
||||
|
||||
# Make math ids contiguous if needed
|
||||
if mv:
|
||||
max_id = max(mv.values())
|
||||
if set(mv.values()) != set(range(max_id + 1)):
|
||||
items = sorted(mv.items(), key=lambda kv: kv[1])
|
||||
mv = {tok: i for i, (tok, _) in enumerate(items)}
|
||||
|
||||
# Load special tokens (prefer category dict; fallback to flat list or defaults)
|
||||
sp_dict = None
|
||||
if specials_dict_file and os.path.isfile(specials_dict_file):
|
||||
with open(specials_dict_file, "r", encoding="utf-8") as f:
|
||||
sp_dict = json.load(f)
|
||||
elif special_tokens_dict is not None:
|
||||
sp_dict = dict(special_tokens_dict)
|
||||
|
||||
if sp_dict is None:
|
||||
# Legacy path: flat list from file or provided/default list
|
||||
if specials_file and os.path.isfile(specials_file):
|
||||
with open(specials_file, "r", encoding="utf-8") as f:
|
||||
sp_list_flat = json.load(f)
|
||||
else:
|
||||
sp_list_flat = special_tokens or SPECIAL_TOKENS
|
||||
sp_dict = {"all": list(sp_list_flat)}
|
||||
|
||||
# Ensure "all" exists and is unique/preserved in order.
|
||||
if "all" not in sp_dict or not isinstance(sp_dict["all"], list):
|
||||
order = [
|
||||
"system",
|
||||
"formatting",
|
||||
"math_external",
|
||||
"script",
|
||||
"layout",
|
||||
"reasoning",
|
||||
"table_structure",
|
||||
"reserved",
|
||||
]
|
||||
seen = set()
|
||||
all_tokens: List[str] = []
|
||||
for k in order:
|
||||
if k in sp_dict and isinstance(sp_dict[k], list):
|
||||
for t in sp_dict[k]:
|
||||
if t not in seen:
|
||||
all_tokens.append(t)
|
||||
seen.add(t)
|
||||
sp_dict["all"] = all_tokens
|
||||
|
||||
# Keep a copy of categories (if present) for downstream processor logic.
|
||||
self.special_tokens = sp_dict
|
||||
sp_list = list(sp_dict.get("all", []))
|
||||
# Regex list should favor longest-first to avoid partial matches.
|
||||
specials_for_regex = sorted(sp_list, key=len, reverse=True)
|
||||
|
||||
self.debug = debug
|
||||
self.UTF16_SPACE = 65536
|
||||
self.math_token_to_rawid = dict(mv) # 0..M-1
|
||||
self.math_vocab_size = len(self.math_token_to_rawid)
|
||||
self.MATH_BASE = self.UTF16_SPACE
|
||||
self.SPECIAL_BASE = self.UTF16_SPACE + self.math_vocab_size
|
||||
|
||||
# Maps
|
||||
self.math_absid_to_token = {
|
||||
self.MATH_BASE + rid: tok for tok, rid in self.math_token_to_rawid.items()
|
||||
}
|
||||
self.special_tokens_list = sp_list # ID assignment order
|
||||
self.special_to_absid = {
|
||||
tok: self.SPECIAL_BASE + i for i, tok in enumerate(self.special_tokens_list)
|
||||
}
|
||||
self.absid_to_special = {v: k for k, v in self.special_to_absid.items()}
|
||||
|
||||
# Public attributes for legacy/processor:
|
||||
# All specials mapping (token -> absolute id)
|
||||
self.SPECIAL_TOKEN_MAPPING: Dict[str, int] = dict(self.special_to_absid)
|
||||
# Subset used heavily by processor for quick access
|
||||
self.reverse_special_token_mapping = {
|
||||
v: k for k, v in self.SPECIAL_TOKEN_MAPPING.items()
|
||||
}
|
||||
self.LAYOUT_LABEL2ID = {
|
||||
k: v
|
||||
for k, v in self.SPECIAL_TOKEN_MAPPING.items()
|
||||
if k in self.special_tokens["layout"]
|
||||
}
|
||||
self.TABLE_STRUCTURE_LABEL2ID = {
|
||||
k: v
|
||||
for k, v in self.SPECIAL_TOKEN_MAPPING.items()
|
||||
if k in self.special_tokens["table_structure"]
|
||||
}
|
||||
if not self.special_tokens.get("system", []):
|
||||
print("Warning: No system tokens found in special_tokens")
|
||||
|
||||
self.MATH_TAG_START = "<math"
|
||||
self.MATH_END_TAG = "</math>"
|
||||
|
||||
sys_list = self.special_tokens.get("system", [])
|
||||
self.system_tokens: Dict[str, int] = {
|
||||
t: self.special_to_absid[t] for t in sys_list if t in self.special_to_absid
|
||||
}
|
||||
|
||||
# Regex for literal specials
|
||||
self.specials_pattern = (
|
||||
re.compile(r"(" + "|".join(re.escape(k) for k in specials_for_regex) + r")")
|
||||
if specials_for_regex
|
||||
else None
|
||||
)
|
||||
|
||||
# Trie for math greedy match
|
||||
self.trie = self._build_trie(self.math_token_to_rawid)
|
||||
|
||||
# Tell HF about special tokens (metadata)
|
||||
kwargs.setdefault("bos_token", bos_token)
|
||||
kwargs.setdefault("eos_token", eos_token or "</S>")
|
||||
kwargs.setdefault("pad_token", pad_token or "<PAD>")
|
||||
kwargs.setdefault("unk_token", unk_token)
|
||||
|
||||
super().__init__(
|
||||
vocab_file=vocab_file,
|
||||
specials_file=specials_file,
|
||||
specials_dict_file=specials_dict_file,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# ---------- required HF surface ----------
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
return self.UTF16_SPACE + self.math_vocab_size + len(self.special_tokens_list)
|
||||
|
||||
def get_vocab(self) -> Dict[str, int]:
|
||||
# Compact vocab: just math+specials with ABSOLUTE ids.
|
||||
v = {tok: self.MATH_BASE + rid for tok, rid in self.math_token_to_rawid.items()}
|
||||
v.update(self.special_to_absid)
|
||||
return v
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.vocab_size
|
||||
|
||||
# Core encode/decode on ABSOLUTE ids
|
||||
def _encode_core(self, text: str) -> List[int]:
|
||||
text = html.unescape(text)
|
||||
ids: List[int] = []
|
||||
in_math = False
|
||||
chunks = self.specials_pattern.split(text) if self.specials_pattern else [text]
|
||||
for chunk in chunks:
|
||||
if chunk in self.special_to_absid:
|
||||
ids.append(self.special_to_absid[chunk])
|
||||
if chunk.startswith("<math"):
|
||||
in_math = True
|
||||
elif chunk.startswith("</math>"):
|
||||
in_math = False
|
||||
if self.debug:
|
||||
print(f"[TAG] {chunk!r} -> {self.special_to_absid[chunk]}")
|
||||
continue
|
||||
|
||||
if in_math:
|
||||
ids.extend(
|
||||
self._encode_math_greedy(
|
||||
chunk, self.trie, self.MATH_BASE, debug=self.debug
|
||||
)
|
||||
)
|
||||
else:
|
||||
units = self._to_utf16_units(chunk)
|
||||
if self.debug and units:
|
||||
print(
|
||||
f"[TEXT] utf16 {chunk[:32]!r} -> {units[:8]}{'...' if len(units) > 8 else ''}"
|
||||
)
|
||||
ids.extend(units)
|
||||
return ids
|
||||
|
||||
def _decode_core(self, ids: Iterable[int]) -> str:
|
||||
out: List[str] = []
|
||||
buf: List[int] = []
|
||||
|
||||
def flush():
|
||||
if buf:
|
||||
out.append(self._from_utf16_units(buf))
|
||||
buf.clear()
|
||||
|
||||
for tid in ids:
|
||||
if tid >= self.MATH_BASE and tid < self.SPECIAL_BASE:
|
||||
flush()
|
||||
out.append(self.math_absid_to_token.get(tid, ""))
|
||||
elif tid >= self.SPECIAL_BASE:
|
||||
flush()
|
||||
out.append(self.absid_to_special.get(tid, ""))
|
||||
else:
|
||||
buf.append(int(tid))
|
||||
flush()
|
||||
return "".join(out)
|
||||
|
||||
# ---- Tokenizer interface ----
|
||||
def _tokenize(self, text: str, **kwargs) -> List[str]:
|
||||
ids = self._encode_core(text)
|
||||
toks: List[str] = []
|
||||
for i in ids:
|
||||
if i < self.MATH_BASE:
|
||||
toks.append(f"<U+{i:04X}>")
|
||||
elif i < self.SPECIAL_BASE:
|
||||
toks.append(self.math_absid_to_token.get(i, "<UNK_MATH>"))
|
||||
else:
|
||||
toks.append(self.absid_to_special.get(i, "<UNK_SPECIAL>"))
|
||||
return toks
|
||||
|
||||
def _convert_token_to_id(self, token: str) -> int:
|
||||
if token.startswith("<U+") and token.endswith(">"):
|
||||
try:
|
||||
return int(token[3:-1], 16) # UTF-16 unit
|
||||
except Exception:
|
||||
return self.unk_token_id if self.unk_token_id is not None else 0
|
||||
# math or specials
|
||||
if token in self.math_token_to_rawid:
|
||||
return self.MATH_BASE + self.math_token_to_rawid[token]
|
||||
if token in self.special_to_absid:
|
||||
return self.special_to_absid[token]
|
||||
# rare path: single-char token -> its UTF-16 unit
|
||||
if len(token) == 1:
|
||||
u = self._to_utf16_units(token)
|
||||
if len(u) == 1:
|
||||
return u[0]
|
||||
return self.unk_token_id if self.unk_token_id is not None else 0
|
||||
|
||||
def _convert_id_to_token(self, index: int) -> str:
|
||||
if index < self.MATH_BASE:
|
||||
return f"<U+{index:04X}>"
|
||||
if index < self.SPECIAL_BASE:
|
||||
return self.math_absid_to_token.get(index, "<UNK_MATH>")
|
||||
return self.absid_to_special.get(index, "<UNK_SPECIAL>")
|
||||
|
||||
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
||||
ids = [self._convert_token_to_id(t) for t in tokens]
|
||||
return self._decode_core(ids)
|
||||
|
||||
def decode(self, token_ids, skip_special_tokens: bool = False, **kwargs) -> str:
|
||||
# Accept int, list, tuple, numpy, torch
|
||||
if hasattr(token_ids, "tolist"):
|
||||
token_ids = token_ids.tolist()
|
||||
elif isinstance(token_ids, int):
|
||||
token_ids = [token_ids]
|
||||
else:
|
||||
token_ids = list(token_ids)
|
||||
token_ids = [int(i) for i in token_ids] # normalize early
|
||||
|
||||
if skip_special_tokens:
|
||||
token_ids = [i for i in token_ids if i < self.SPECIAL_BASE]
|
||||
return self._decode_core(token_ids)
|
||||
|
||||
# HF plumbing
|
||||
def build_inputs_with_special_tokens(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||
) -> List[int]:
|
||||
out = (
|
||||
list(token_ids_0)
|
||||
if token_ids_1 is None
|
||||
else list(token_ids_0) + list(token_ids_1)
|
||||
)
|
||||
# if self.eos_token_id is not None and (not out or out[-1] != self.eos_token_id):
|
||||
# out.append(self.eos_token_id)
|
||||
return out
|
||||
|
||||
def get_special_tokens_mask(
|
||||
self,
|
||||
token_ids_0: List[int],
|
||||
token_ids_1: Optional[List[int]] = None,
|
||||
already_has_special_tokens: bool = False,
|
||||
) -> List[int]:
|
||||
def mask(seq: List[int]) -> List[int]:
|
||||
return [1 if i >= self.SPECIAL_BASE else 0 for i in seq]
|
||||
|
||||
return (
|
||||
mask(token_ids_0)
|
||||
if token_ids_1 is None
|
||||
else mask(token_ids_0) + mask(token_ids_1)
|
||||
)
|
||||
|
||||
def create_token_type_ids_from_sequences(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||
) -> List[int]:
|
||||
return [0] * (
|
||||
len(token_ids_0)
|
||||
if token_ids_1 is None
|
||||
else len(token_ids_0) + len(token_ids_1)
|
||||
)
|
||||
|
||||
# Save/load raw assets
|
||||
def save_vocabulary(
|
||||
self, save_directory: str, filename_prefix: Optional[str] = None
|
||||
) -> Tuple[str, str]:
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
pre = (filename_prefix + "-") if filename_prefix else ""
|
||||
vocab_path = os.path.join(
|
||||
save_directory, pre + self.vocab_files_names["vocab_file"]
|
||||
)
|
||||
specials_path = os.path.join(
|
||||
save_directory, pre + self.vocab_files_names["specials_file"]
|
||||
)
|
||||
specials_dict_path = os.path.join(
|
||||
save_directory, pre + self.vocab_files_names["specials_dict_file"]
|
||||
)
|
||||
with open(vocab_path, "w", encoding="utf-8") as f:
|
||||
json.dump(self.math_token_to_rawid, f, ensure_ascii=False, indent=2)
|
||||
# Save both the flat list ("all") and the category dict (preferred)
|
||||
with open(specials_path, "w", encoding="utf-8") as f:
|
||||
json.dump(self.special_tokens_list, f, ensure_ascii=False, indent=2)
|
||||
with open(specials_dict_path, "w", encoding="utf-8") as f:
|
||||
json.dump(self.special_tokens, f, ensure_ascii=False, indent=2)
|
||||
return (vocab_path, specials_path)
|
||||
|
||||
|
||||
class SuryaOCRTokenizer(S3DownloaderMixin, PreTrainedTokenizer):
|
||||
def __init__(
|
||||
@ -258,52 +677,38 @@ class SuryaOCRTokenizer(S3DownloaderMixin, PreTrainedTokenizer):
|
||||
|
||||
self.special_tokens = special_tokens
|
||||
|
||||
self.qwen_tokenizer = Qwen2Tokenizer.from_pretrained(model_checkpoint)
|
||||
self.ocr_tokenizer = InnerOCRTokenizer(
|
||||
special_tokens=special_tokens, qwen_tokenizer=self.qwen_tokenizer
|
||||
self.ocr_tokenizer = GreedyMathUTF16Tokenizer.from_pretrained(
|
||||
model_checkpoint,
|
||||
)
|
||||
|
||||
self.system_tokens = {
|
||||
v: self.ocr_tokenizer._tokenize(v)[0]
|
||||
v: self.ocr_tokenizer(v)["input_ids"][0]
|
||||
for v in special_tokens.get("system", [])
|
||||
}
|
||||
self.SPECIAL_TOKEN_MAPPING = self.ocr_tokenizer.SPECIAL_TOKEN_MAPPING
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.qwen_offset = len(self.qwen_tokenizer)
|
||||
self.special_token_offset = (
|
||||
self.qwen_offset + self.ocr_tokenizer.SPECIAL_TOKEN_OFFSET
|
||||
)
|
||||
|
||||
def get_vocab(self) -> Dict[str, int]:
|
||||
return self.qwen_tokenizer.get_vocab()
|
||||
return self.ocr_tokenizer.get_vocab()
|
||||
|
||||
def _add_tokens(
|
||||
self,
|
||||
new_tokens: Union[List[str], List[AddedToken]],
|
||||
special_tokens: bool = False,
|
||||
) -> int:
|
||||
return self.qwen_tokenizer._add_tokens(
|
||||
return self.ocr_tokenizer._add_tokens(
|
||||
new_tokens, special_tokens=special_tokens
|
||||
)
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return self.ocr_tokenizer.vocab_size + self.qwen_offset
|
||||
return self.ocr_tokenizer.vocab_size
|
||||
|
||||
def _tokenize(self, text: str, **kwargs):
|
||||
task = kwargs.get("task", TaskNames.ocr_with_boxes)
|
||||
assert task in TASK_NAMES, f"Invalid task: {task}"
|
||||
# task = kwargs.get("task", TaskNames.ocr_with_boxes)
|
||||
# assert task in TASK_NAMES, f"Invalid task: {task}"
|
||||
|
||||
if task in [
|
||||
TaskNames.ocr_with_boxes,
|
||||
TaskNames.ocr_without_boxes,
|
||||
TaskNames.layout,
|
||||
]:
|
||||
tokens = self.ocr_tokenizer._tokenize(text)
|
||||
else:
|
||||
tokens = self.qwen_tokenizer(text)["input_ids"]
|
||||
tokens = self.ocr_tokenizer(text)["input_ids"]
|
||||
|
||||
return tokens
|
||||
|
||||
@ -331,20 +736,12 @@ class SuryaOCRTokenizer(S3DownloaderMixin, PreTrainedTokenizer):
|
||||
return {"input_ids": tokenized}
|
||||
|
||||
def decode(self, token_ids, **kwargs):
|
||||
task_name = kwargs.get("task")
|
||||
assert task_name in TASK_NAMES, f"Invalid task: {task_name}"
|
||||
|
||||
if isinstance(token_ids, (np.ndarray, torch.Tensor)):
|
||||
token_ids = token_ids.tolist()
|
||||
|
||||
if task_name in [
|
||||
TaskNames.ocr_with_boxes,
|
||||
TaskNames.ocr_without_boxes,
|
||||
TaskNames.layout,
|
||||
TaskNames.table_structure,
|
||||
]:
|
||||
decoded_text = self.ocr_tokenizer.decode(token_ids)
|
||||
else:
|
||||
decoded_text = self.qwen_tokenizer.decode(token_ids)
|
||||
|
||||
decoded_text = self.ocr_tokenizer.decode(token_ids, skip_special_tokens=False)
|
||||
# replace all <SCRIPT-...> tokens with empty strings
|
||||
decoded_text = re.sub(r"<SCRIPT-.*?>", "", decoded_text)
|
||||
# replace </S> with empty string
|
||||
decoded_text = re.sub(r"</S>", "", decoded_text)
|
||||
return decoded_text
|
||||
|
||||
@ -219,11 +219,8 @@ class RecognitionPredictor(BasePredictor):
|
||||
|
||||
detokenize_sequences = []
|
||||
detokenize_sequence = []
|
||||
past_char_qwen_token = False
|
||||
|
||||
def _add_detokenize_sequence(
|
||||
qwen_token: bool,
|
||||
past_char_qwen_token: bool,
|
||||
special_token: bool,
|
||||
past_special_token: bool,
|
||||
force: bool = False,
|
||||
@ -231,18 +228,15 @@ class RecognitionPredictor(BasePredictor):
|
||||
nonlocal detokenize_sequence, detokenize_sequences
|
||||
|
||||
if (
|
||||
qwen_token != past_char_qwen_token
|
||||
or force
|
||||
or special_token
|
||||
special_token
|
||||
or past_special_token
|
||||
or force
|
||||
) and detokenize_sequence:
|
||||
chars = [dt[0] for dt in detokenize_sequence]
|
||||
scores = [dt[1] for dt in detokenize_sequence]
|
||||
bboxes = [dt[2] for dt in detokenize_sequence]
|
||||
|
||||
if past_char_qwen_token:
|
||||
detokenize_sequences.append((chars, scores, None, "qwen"))
|
||||
elif past_special_token:
|
||||
if past_special_token:
|
||||
detokenize_sequences.append((chars, scores, None, "special"))
|
||||
else:
|
||||
detokenize_sequences.append((chars, scores, bboxes, "ocr"))
|
||||
@ -258,21 +252,17 @@ class RecognitionPredictor(BasePredictor):
|
||||
]:
|
||||
break
|
||||
|
||||
qwen_token = char_id < self.processor.ocr_tokenizer.qwen_offset
|
||||
special_token = (
|
||||
self.processor.ocr_tokenizer.qwen_offset
|
||||
<= char_id
|
||||
< self.processor.ocr_tokenizer.special_token_offset
|
||||
char_id >= self.processor.ocr_tokenizer.ocr_tokenizer.SPECIAL_BASE
|
||||
)
|
||||
_add_detokenize_sequence(
|
||||
qwen_token, past_char_qwen_token, special_token, past_special_token
|
||||
special_token, past_special_token
|
||||
)
|
||||
detokenize_sequence.append((char_id, score, bbox))
|
||||
past_char_qwen_token = qwen_token
|
||||
past_special_token = special_token
|
||||
|
||||
_add_detokenize_sequence(
|
||||
False, past_char_qwen_token, False, past_special_token, force=True
|
||||
False, past_special_token, force=True
|
||||
)
|
||||
|
||||
img_chars = []
|
||||
|
||||
Loading…
Reference in New Issue
Block a user