Fix colors

This commit is contained in:
Vik Paruchuri 2024-05-23 16:09:23 -07:00
parent 06a9a8b5a2
commit 28e95a20a6
4 changed files with 6 additions and 3 deletions

View File

@ -81,7 +81,7 @@ def slice_bboxes_from_image(image: Image.Image, bboxes):
def slice_polys_from_image(image: Image.Image, polys):
image_array = np.array(image)
image_array = np.array(image, dtype=np.uint8)
lines = []
for idx, poly in enumerate(polys):
lines.append(slice_and_pad_poly(image_array, poly))
@ -98,8 +98,9 @@ def slice_and_pad_poly(image_array: np.array, coordinates):
coordinates = [(x - bbox[0], y - bbox[1]) for x, y in coordinates]
# Pad the area outside the polygon with the pad value
mask = np.zeros_like(cropped_polygon, dtype=np.uint8)
mask = np.zeros(cropped_polygon.shape[:2], dtype=np.uint8)
cv2.fillPoly(mask, [np.int32(coordinates)], 1)
mask = np.stack([mask] * 3, axis=-1)
cropped_polygon[mask == 0] = settings.RECOGNITION_PAD_VALUE
rectangle_image = Image.fromarray(cropped_polygon)

View File

@ -155,7 +155,7 @@ class SuryaImageProcessor(DonutImageProcessor):
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray:
input_height, input_width = get_image_size(image, channel_dim=input_data_format)
input_height, input_width = image.shape[:2]
output_height, output_width = size["height"], size["width"]
if (output_width < output_height and input_width > input_height) or (

View File

@ -62,6 +62,7 @@ def run_ocr(images: List[Image.Image], langs: List[List[str]], det_model, det_pr
all_slices = []
slice_map = []
all_langs = []
for idx, (det_pred, image, lang) in enumerate(zip(det_predictions, images, langs)):
polygons = [p.polygon for p in det_pred.bboxes]
slices = slice_polys_from_image(image, polygons)

View File

@ -37,6 +37,7 @@ def batch_recognition(images: List, languages: List[List[str]], model, processor
has_math = ["_math" in lang for lang in batch_langs]
batch_images = images[i:i+batch_size]
batch_images = [image.convert("RGB") for image in batch_images]
model_inputs = processor(text=[""] * len(batch_langs), images=batch_images, lang=batch_langs)
batch_pixel_values = model_inputs["pixel_values"]