diff --git a/benchmark/ordering.py b/benchmark/ordering.py index fd301ee..d03d025 100644 --- a/benchmark/ordering.py +++ b/benchmark/ordering.py @@ -3,16 +3,13 @@ import copy import json -from surya.benchmark.metrics import precision_recall from surya.model.ordering.model import load_model from surya.model.ordering.processor import load_processor -from surya.postprocessing.heatmap import draw_bboxes_on_image from surya.ordering import batch_ordering from surya.settings import settings from surya.benchmark.metrics import rank_accuracy import os import time -from tabulate import tabulate import datasets diff --git a/pyproject.toml b/pyproject.toml index 8c76bf9..8dc8d57 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "surya-ocr" -version = "0.4.7" +version = "0.4.8" description = "OCR, layout, reading order, and line detection in 90+ languages" authors = ["Vik Paruchuri "] readme = "README.md" diff --git a/surya/detection.py b/surya/detection.py index 90c3fcb..be16689 100644 --- a/surya/detection.py +++ b/surya/detection.py @@ -7,7 +7,7 @@ from surya.model.detection.segformer import SegformerForRegressionMask from surya.postprocessing.heatmap import get_and_clean_boxes from surya.postprocessing.affinity import get_vertical_lines -from surya.input.processing import prepare_image, split_image, get_total_splits +from surya.input.processing import prepare_image_detection, split_image, get_total_splits from surya.schema import TextDetectionResult from surya.settings import settings from tqdm import tqdm @@ -62,7 +62,7 @@ def batch_detection(images: List, model: SegformerForRegressionMask, processor, split_index.extend([image_idx] * len(image_parts)) split_heights.extend(split_height) - image_splits = [prepare_image(image, processor) for image in image_splits] + image_splits = [prepare_image_detection(image, processor) for image in image_splits] # Batch images in dim 0 batch = torch.stack(image_splits, dim=0).to(model.dtype).to(model.device) diff --git a/surya/input/processing.py b/surya/input/processing.py index feb6021..17ce4ab 100644 --- a/surya/input/processing.py +++ b/surya/input/processing.py @@ -45,7 +45,7 @@ def split_image(img, processor): return [img.copy()], [img_height] -def prepare_image(img, processor): +def prepare_image_detection(img, processor): new_size = (processor.size["width"], processor.size["height"]) img.thumbnail(new_size, Image.Resampling.LANCZOS) # Shrink largest dimension to fit new size diff --git a/surya/layout.py b/surya/layout.py index 68ed61b..89f2a65 100644 --- a/surya/layout.py +++ b/surya/layout.py @@ -10,7 +10,7 @@ from surya.settings import settings -def get_regions_from_detection_result(detection_result: TextDetectionResult, heatmaps: List[Image.Image], orig_size, id2label, segment_assignment, vertical_line_width=20) -> List[LayoutBox]: +def get_regions_from_detection_result(detection_result: TextDetectionResult, heatmaps: List[np.ndarray], orig_size, id2label, segment_assignment, vertical_line_width=20) -> List[LayoutBox]: logits = np.stack(heatmaps, axis=0) vertical_line_bboxes = [line for line in detection_result.vertical_lines] line_bboxes = detection_result.bboxes @@ -126,7 +126,7 @@ def get_regions_from_detection_result(detection_result: TextDetectionResult, hea new_boxes.append(LayoutBox(polygon=bbox.polygon, label="Text", confidence=.5)) for bbox in new_boxes: - bbox.rescale(list(reversed(heatmap.shape)), orig_size) + bbox.rescale(list(reversed(heatmaps[0].shape)), orig_size) detected_boxes = [bbox for bbox in new_boxes if bbox.area > 16] @@ -145,7 +145,7 @@ def get_regions_from_detection_result(detection_result: TextDetectionResult, hea return detected_boxes -def get_regions(heatmaps: List[Image.Image], orig_size, id2label, segment_assignment) -> List[LayoutBox]: +def get_regions(heatmaps: List[np.ndarray], orig_size, id2label, segment_assignment) -> List[LayoutBox]: bboxes = [] for i in range(1, len(id2label)): # Skip the blank class heatmap = heatmaps[i] @@ -160,7 +160,7 @@ def get_regions(heatmaps: List[Image.Image], orig_size, id2label, segment_assign return bboxes -def parallel_get_regions(heatmaps: List[Image.Image], orig_size, id2label, detection_results=None) -> List[LayoutResult]: +def parallel_get_regions(heatmaps: List[np.ndarray], orig_size, id2label, detection_results=None) -> LayoutResult: logits = np.stack(heatmaps, axis=0) segment_assignment = logits.argmax(axis=0) if detection_results is not None: diff --git a/surya/model/detection/processor.py b/surya/model/detection/processor.py new file mode 100644 index 0000000..822d7d1 --- /dev/null +++ b/surya/model/detection/processor.py @@ -0,0 +1,284 @@ +import warnings +from typing import Any, Dict, List, Optional, Union + +import numpy as np + +from transformers.image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from transformers.image_transforms import to_channel_dimension_format +from transformers.image_utils import ( + IMAGENET_DEFAULT_MEAN, + IMAGENET_DEFAULT_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + infer_channel_dimension_format, + make_list_of_images, +) +from transformers.utils import TensorType + + +import PIL.Image +import torch + + +class SegformerImageProcessor(BaseImageProcessor): + r""" + Constructs a Segformer image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions to the specified `(size["height"], + size["width"])`. Can be overridden by the `do_resize` parameter in the `preprocess` method. + size (`Dict[str, int]` *optional*, defaults to `{"height": 512, "width": 512}`): + Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess` + method. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`): + Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the + `preprocess` method. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale` + parameter in the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess` + method. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess` + method. + image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + do_reduce_labels (`bool`, *optional*, defaults to `False`): + Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 is + used for background, and background itself is not included in all classes of a dataset (e.g. ADE20k). The + background label will be replaced by 255. Can be overridden by the `do_reduce_labels` parameter in the + `preprocess` method. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize: bool = True, + size: Dict[str, int] = None, + resample: PILImageResampling = PILImageResampling.BILINEAR, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_reduce_labels: bool = False, + **kwargs, + ) -> None: + if "reduce_labels" in kwargs: + warnings.warn( + "The `reduce_labels` parameter is deprecated and will be removed in a future version. Please use " + "`do_reduce_labels` instead.", + FutureWarning, + ) + do_reduce_labels = kwargs.pop("reduce_labels") + + super().__init__(**kwargs) + size = size if size is not None else {"height": 512, "width": 512} + size = get_size_dict(size) + self.do_resize = do_resize + self.size = size + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN + self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD + self.do_reduce_labels = do_reduce_labels + self._valid_processor_keys = [ + "images", + "segmentation_maps", + "do_resize", + "size", + "resample", + "do_rescale", + "rescale_factor", + "do_normalize", + "image_mean", + "image_std", + "do_reduce_labels", + "return_tensors", + "data_format", + "input_data_format", + ] + + @classmethod + def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs): + """ + Overrides the `from_dict` method from the base class to make sure `do_reduce_labels` is updated if image + processor is created using from_dict and kwargs e.g. `SegformerImageProcessor.from_pretrained(checkpoint, + reduce_labels=True)` + """ + image_processor_dict = image_processor_dict.copy() + if "reduce_labels" in kwargs: + image_processor_dict["reduce_labels"] = kwargs.pop("reduce_labels") + return super().from_dict(image_processor_dict, **kwargs) + + def _preprocess( + self, + image: ImageInput, + do_resize: bool, + do_rescale: bool, + do_normalize: bool, + size: Optional[Dict[str, int]] = None, + resample: PILImageResampling = None, + rescale_factor: Optional[float] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ): + + if do_rescale: + image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + + if do_normalize: + image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format) + + return image + + def _preprocess_image( + self, + image: ImageInput, + do_resize: bool = None, + size: Dict[str, int] = None, + resample: PILImageResampling = None, + do_rescale: bool = None, + rescale_factor: float = None, + do_normalize: bool = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> np.ndarray: + """Preprocesses a single image.""" + # All transformations expect numpy arrays. + if input_data_format is None: + input_data_format = infer_channel_dimension_format(image) + + image = self._preprocess( + image=image, + do_resize=do_resize, + size=size, + resample=resample, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + input_data_format=input_data_format, + ) + if data_format is not None: + image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) + return image + + def __call__(self, images, segmentation_maps=None, **kwargs): + """ + Preprocesses a batch of images and optionally segmentation maps. + + Overrides the `__call__` method of the `Preprocessor` class so that both images and segmentation maps can be + passed in as positional arguments. + """ + return super().__call__(images, segmentation_maps=segmentation_maps, **kwargs) + + def preprocess( + self, + images: ImageInput, + segmentation_maps: Optional[ImageInput] = None, + do_resize: Optional[bool] = None, + size: Optional[Dict[str, int]] = None, + resample: PILImageResampling = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_reduce_labels: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: ChannelDimension = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> PIL.Image.Image: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + segmentation_maps (`ImageInput`, *optional*): + Segmentation map to preprocess. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to `self.size`): + Size of the image after `resize` is applied. + resample (`int`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only + has an effect if `do_resize` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image values between [0 - 1]. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Image mean. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation. + do_reduce_labels (`bool`, *optional*, defaults to `self.do_reduce_labels`): + Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 + is used for background, and background itself is not included in all classes of a dataset (e.g. + ADE20k). The background label will be replaced by 255. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `ChannelDimension.LAST`: image in (height, width, num_channels) format. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + do_resize = do_resize if do_resize is not None else self.do_resize + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + resample = resample if resample is not None else self.resample + size = size if size is not None else self.size + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + + images = make_list_of_images(images) + images = [ + self._preprocess_image( + image=img, + do_resize=do_resize, + resample=resample, + size=size, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + data_format=data_format, + input_data_format=input_data_format, + ) + for img in images + ] + + data = {"pixel_values": images} + return BatchFeature(data=data, tensor_type=return_tensors) \ No newline at end of file diff --git a/surya/model/detection/segformer.py b/surya/model/detection/segformer.py index 87634a4..a3a822b 100644 --- a/surya/model/detection/segformer.py +++ b/surya/model/detection/segformer.py @@ -5,8 +5,8 @@ import math from typing import Optional, Tuple, Union -from transformers import SegformerConfig, SegformerForSemanticSegmentation, SegformerImageProcessor, \ - SegformerDecodeHead, SegformerModel +from transformers import SegformerConfig, SegformerForSemanticSegmentation, SegformerDecodeHead, SegformerModel +from surya.model.detection.processor import SegformerImageProcessor import torch from torch import nn diff --git a/surya/model/ordering/processor.py b/surya/model/ordering/processor.py index 3262682..c6f463b 100644 --- a/surya/model/ordering/processor.py +++ b/surya/model/ordering/processor.py @@ -31,34 +31,25 @@ def __init__(self, *args, **kwargs): self.patch_size = kwargs.get("patch_size", (4, 4)) - def process_inner(self, images: List[List]): - # This will be in list of lists format, with height x width x channel - assert isinstance(images[0], (list, np.ndarray)) + def process_inner(self, images: List[np.ndarray]): + images = [img.transpose(2, 0, 1) for img in images] # convert to CHW format - # convert list of lists format to array - if isinstance(images[0], list): - # numpy unit8 needed for augmentation - np_images = [np.array(img, dtype=np.uint8) for img in images] - else: - np_images = [img.astype(np.uint8) for img in images] - np_images = [img.transpose(2, 0, 1) for img in np_images] # convert to CHW format - - assert np_images[0].shape[0] == 3 # RGB input images, channel dim last + assert images[0].shape[0] == 3 # RGB input images, channel dim last # Convert to float32 for rescale/normalize - np_images = [img.astype(np.float32) for img in np_images] + images = [img.astype(np.float32) for img in images] # Rescale and normalize - np_images = [ + images = [ self.rescale(img, scale=self.rescale_factor, input_data_format=ChannelDimension.FIRST) - for img in np_images + for img in images ] - np_images = [ + images = [ self.normalize(img, mean=self.image_mean, std=self.image_std, input_data_format=ChannelDimension.FIRST) - for img in np_images + for img in images ] - return np_images + return images def process_boxes(self, boxes): padded_boxes = [] @@ -152,7 +143,7 @@ def preprocess( boxes = new_boxes # Convert to numpy for later processing steps - images = [to_numpy_array(image) for image in images] + images = [np.array(image) for image in images] images = self.process_inner(images) boxes, box_mask, box_counts = self.process_boxes(boxes) diff --git a/surya/model/recognition/processor.py b/surya/model/recognition/processor.py index 5d42ac3..645197a 100644 --- a/surya/model/recognition/processor.py +++ b/surya/model/recognition/processor.py @@ -1,5 +1,6 @@ from typing import Dict, Union, Optional, List, Tuple +import cv2 from torch import TensorType from transformers import DonutImageProcessor, DonutProcessor, AutoImageProcessor, DonutSwinConfig from transformers.image_processing_utils import BaseImageProcessor, get_size_dict, BatchFeature @@ -29,84 +30,64 @@ def __init__(self, *args, max_size=None, train=False, **kwargs): self.max_size = max_size self.train = train - def numpy_resize(self, image: np.ndarray, size, resample): - image = PIL.Image.fromarray(image) - resized = self.pil_resize(image, size, resample) - resized = np.array(resized, dtype=np.uint8) - resized_image = resized.transpose(2, 0, 1) - - return resized_image - - def pil_resize(self, image: PIL.Image.Image, size, resample): - width, height = image.size + def numpy_resize(self, image: np.ndarray, size, interpolation=cv2.INTER_LANCZOS4): + height, width = image.shape[:2] max_width, max_height = size["width"], size["height"] - if width != max_width or height != max_height: - # Shrink to fit within dimensions - width_scale = max_width / width - height_scale = max_height / height - scale = min(width_scale, height_scale) - new_width = min(int(width * scale), max_width) - new_height = min(int(height * scale), max_height) - image = image.resize((new_width, new_height), resample) + if (height == max_height and width <= max_width) or (width == max_width and height <= max_height): + return image - image.thumbnail((max_width, max_height), resample) + scale = min(max_width / width, max_height / height) - assert image.width <= max_width and image.height <= max_height - - return image + new_width = int(width * scale) + new_height = int(height * scale) - def process_inner(self, images: List[List], train=False): - # This will be in list of lists format, with height x width x channel - assert isinstance(images[0], (list, np.ndarray)) + resized_image = cv2.resize(image, (new_width, new_height), interpolation=interpolation) + resized_image = resized_image.transpose(2, 0, 1) - # convert list of lists format to array - if isinstance(images[0], list): - # numpy unit8 needed for augmentation - np_images = [np.array(img, dtype=np.uint8) for img in images] - else: - np_images = [img.astype(np.uint8) for img in images] + return resized_image - assert np_images[0].shape[2] == 3 # RGB input images, channel dim last + def process_inner(self, images: List[np.ndarray], train=False): + assert images[0].shape[2] == 3 # RGB input images, channel dim last # Rotate if the bbox is wider than it is tall - np_images = [self.align_long_axis(image, size=self.max_size, input_data_format=ChannelDimension.LAST) for image in np_images] + images = [self.align_long_axis(image, size=self.max_size, input_data_format=ChannelDimension.LAST) for image in images] # Verify that the image is wider than it is tall - for img in np_images: + for img in images: assert img.shape[1] >= img.shape[0] # This also applies the right channel dim format, to channel x height x width - np_images = [self.numpy_resize(img, self.max_size, self.resample) for img in np_images] - assert np_images[0].shape[0] == 3 # RGB input images, channel dim first + images = [self.numpy_resize(img, self.max_size, self.resample) for img in images] + assert images[0].shape[0] == 3 # RGB input images, channel dim first # Convert to float32 for rescale/normalize - np_images = [img.astype(np.float32) for img in np_images] + images = [img.astype(np.float32) for img in images] # Pads with 255 (whitespace) # Pad to max size to improve performance max_size = self.max_size - np_images = [ + images = [ self.pad_image( image=image, size=max_size, random_padding=train, # Change amount of padding randomly during training input_data_format=ChannelDimension.FIRST, - pad_value=255.0 + pad_value=settings.RECOGNITION_PAD_VALUE ) - for image in np_images + for image in images ] # Rescale and normalize - np_images = [ + images = [ self.rescale(img, scale=self.rescale_factor, input_data_format=ChannelDimension.FIRST) - for img in np_images + for img in images ] - np_images = [ + images = [ self.normalize(img, mean=self.image_mean, std=self.image_std, input_data_format=ChannelDimension.FIRST) - for img in np_images + for img in images ] - return np_images + return images def preprocess( @@ -131,15 +112,8 @@ def preprocess( ) -> PIL.Image.Image: images = make_list_of_images(images) - if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." - ) - # Convert to numpy for later processing steps - images = [to_numpy_array(image) for image in images] - + images = [np.array(img) for img in images] images = self.process_inner(images, train=self.train) data = {"pixel_values": images} return BatchFeature(data=data, tensor_type=return_tensors) diff --git a/surya/ocr.py b/surya/ocr.py index 3b072bc..0847762 100644 --- a/surya/ocr.py +++ b/surya/ocr.py @@ -1,17 +1,11 @@ -from collections import defaultdict -from concurrent.futures import ProcessPoolExecutor from typing import List -from tqdm import tqdm - -import torch from PIL import Image from surya.detection import batch_text_detection from surya.input.processing import slice_polys_from_image, slice_bboxes_from_image -from surya.postprocessing.text import truncate_repetitions, sort_text_lines +from surya.postprocessing.text import sort_text_lines from surya.recognition import batch_recognition from surya.schema import TextLine, OCRResult -from surya.settings import settings def run_recognition(images: List[Image.Image], langs: List[List[str]], rec_model, rec_processor, bboxes: List[List[List[int]]] = None, polygons: List[List[List[List[int]]]] = None, batch_size=None) -> List[OCRResult]: diff --git a/surya/postprocessing/heatmap.py b/surya/postprocessing/heatmap.py index c850d18..7562826 100644 --- a/surya/postprocessing/heatmap.py +++ b/surya/postprocessing/heatmap.py @@ -91,6 +91,8 @@ def detect_boxes(linemap, text_threshold, low_text): det = [] confidences = [] max_confidence = 0 + mask = np.zeros_like(linemap, dtype=np.uint8) + for k in range(1, label_count): # size filtering size = stats[k, cv2.CC_STAT_AREA] @@ -140,7 +142,7 @@ def detect_boxes(linemap, text_threshold, low_text): box = np.roll(box, 4-startidx, 0) box = np.array(box) - mask = np.zeros_like(linemap, dtype=np.uint8) + mask.fill(0) cv2.fillPoly(mask, [np.int32(box)], 1) roi = np.where(mask == 1, linemap, 0) diff --git a/surya/recognition.py b/surya/recognition.py index 57b5ca6..6122cb0 100644 --- a/surya/recognition.py +++ b/surya/recognition.py @@ -1,6 +1,7 @@ from typing import List import torch from PIL import Image +from transformers import GenerationConfig from surya.postprocessing.math.latex import fix_math, contains_math from surya.postprocessing.text import truncate_repetitions @@ -52,6 +53,7 @@ def batch_recognition(images: List, languages: List[List[str]], model, processor decoder_input_ids=batch_decoder_input, decoder_langs=batch_langs, eos_token_id=processor.tokenizer.eos_id, + pad_token_id=processor.tokenizer.pad_token_id, max_new_tokens=settings.RECOGNITION_MAX_TOKENS, output_scores=True, return_dict_in_generate=True