mirror of
https://github.com/VikParuchuri/surya.git
synced 2026-06-04 21:03:53 +08:00
Fix colors
This commit is contained in:
parent
06a9a8b5a2
commit
28e95a20a6
@ -81,7 +81,7 @@ def slice_bboxes_from_image(image: Image.Image, bboxes):
|
||||
|
||||
|
||||
def slice_polys_from_image(image: Image.Image, polys):
|
||||
image_array = np.array(image)
|
||||
image_array = np.array(image, dtype=np.uint8)
|
||||
lines = []
|
||||
for idx, poly in enumerate(polys):
|
||||
lines.append(slice_and_pad_poly(image_array, poly))
|
||||
@ -98,8 +98,9 @@ def slice_and_pad_poly(image_array: np.array, coordinates):
|
||||
coordinates = [(x - bbox[0], y - bbox[1]) for x, y in coordinates]
|
||||
|
||||
# Pad the area outside the polygon with the pad value
|
||||
mask = np.zeros_like(cropped_polygon, dtype=np.uint8)
|
||||
mask = np.zeros(cropped_polygon.shape[:2], dtype=np.uint8)
|
||||
cv2.fillPoly(mask, [np.int32(coordinates)], 1)
|
||||
mask = np.stack([mask] * 3, axis=-1)
|
||||
|
||||
cropped_polygon[mask == 0] = settings.RECOGNITION_PAD_VALUE
|
||||
rectangle_image = Image.fromarray(cropped_polygon)
|
||||
|
||||
@ -155,7 +155,7 @@ class SuryaImageProcessor(DonutImageProcessor):
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> np.ndarray:
|
||||
input_height, input_width = get_image_size(image, channel_dim=input_data_format)
|
||||
input_height, input_width = image.shape[:2]
|
||||
output_height, output_width = size["height"], size["width"]
|
||||
|
||||
if (output_width < output_height and input_width > input_height) or (
|
||||
|
||||
@ -62,6 +62,7 @@ def run_ocr(images: List[Image.Image], langs: List[List[str]], det_model, det_pr
|
||||
all_slices = []
|
||||
slice_map = []
|
||||
all_langs = []
|
||||
|
||||
for idx, (det_pred, image, lang) in enumerate(zip(det_predictions, images, langs)):
|
||||
polygons = [p.polygon for p in det_pred.bboxes]
|
||||
slices = slice_polys_from_image(image, polygons)
|
||||
|
||||
@ -37,6 +37,7 @@ def batch_recognition(images: List, languages: List[List[str]], model, processor
|
||||
has_math = ["_math" in lang for lang in batch_langs]
|
||||
batch_images = images[i:i+batch_size]
|
||||
batch_images = [image.convert("RGB") for image in batch_images]
|
||||
|
||||
model_inputs = processor(text=[""] * len(batch_langs), images=batch_images, lang=batch_langs)
|
||||
|
||||
batch_pixel_values = model_inputs["pixel_values"]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user