Skip to content

Commit

Permalink
Merge pull request #110 from VikParuchuri/dev
Browse files Browse the repository at this point in the history
Processor improvements
  • Loading branch information
VikParuchuri committed May 23, 2024
2 parents 74e8c0c + 06a9a8b commit 2a5e542
Show file tree
Hide file tree
Showing 12 changed files with 337 additions and 93 deletions.
3 changes: 0 additions & 3 deletions benchmark/ordering.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,13 @@
import copy
import json

from surya.benchmark.metrics import precision_recall
from surya.model.ordering.model import load_model
from surya.model.ordering.processor import load_processor
from surya.postprocessing.heatmap import draw_bboxes_on_image
from surya.ordering import batch_ordering
from surya.settings import settings
from surya.benchmark.metrics import rank_accuracy
import os
import time
from tabulate import tabulate
import datasets


Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "surya-ocr"
version = "0.4.7"
version = "0.4.8"
description = "OCR, layout, reading order, and line detection in 90+ languages"
authors = ["Vik Paruchuri <vik.paruchuri@gmail.com>"]
readme = "README.md"
Expand Down
4 changes: 2 additions & 2 deletions surya/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from surya.model.detection.segformer import SegformerForRegressionMask
from surya.postprocessing.heatmap import get_and_clean_boxes
from surya.postprocessing.affinity import get_vertical_lines
from surya.input.processing import prepare_image, split_image, get_total_splits
from surya.input.processing import prepare_image_detection, split_image, get_total_splits
from surya.schema import TextDetectionResult
from surya.settings import settings
from tqdm import tqdm
Expand Down Expand Up @@ -62,7 +62,7 @@ def batch_detection(images: List, model: SegformerForRegressionMask, processor,
split_index.extend([image_idx] * len(image_parts))
split_heights.extend(split_height)

image_splits = [prepare_image(image, processor) for image in image_splits]
image_splits = [prepare_image_detection(image, processor) for image in image_splits]
# Batch images in dim 0
batch = torch.stack(image_splits, dim=0).to(model.dtype).to(model.device)

Expand Down
2 changes: 1 addition & 1 deletion surya/input/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def split_image(img, processor):
return [img.copy()], [img_height]


def prepare_image(img, processor):
def prepare_image_detection(img, processor):
new_size = (processor.size["width"], processor.size["height"])

img.thumbnail(new_size, Image.Resampling.LANCZOS) # Shrink largest dimension to fit new size
Expand Down
8 changes: 4 additions & 4 deletions surya/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from surya.settings import settings


def get_regions_from_detection_result(detection_result: TextDetectionResult, heatmaps: List[Image.Image], orig_size, id2label, segment_assignment, vertical_line_width=20) -> List[LayoutBox]:
def get_regions_from_detection_result(detection_result: TextDetectionResult, heatmaps: List[np.ndarray], orig_size, id2label, segment_assignment, vertical_line_width=20) -> List[LayoutBox]:
logits = np.stack(heatmaps, axis=0)
vertical_line_bboxes = [line for line in detection_result.vertical_lines]
line_bboxes = detection_result.bboxes
Expand Down Expand Up @@ -126,7 +126,7 @@ def get_regions_from_detection_result(detection_result: TextDetectionResult, hea
new_boxes.append(LayoutBox(polygon=bbox.polygon, label="Text", confidence=.5))

for bbox in new_boxes:
bbox.rescale(list(reversed(heatmap.shape)), orig_size)
bbox.rescale(list(reversed(heatmaps[0].shape)), orig_size)

detected_boxes = [bbox for bbox in new_boxes if bbox.area > 16]

Expand All @@ -145,7 +145,7 @@ def get_regions_from_detection_result(detection_result: TextDetectionResult, hea
return detected_boxes


def get_regions(heatmaps: List[Image.Image], orig_size, id2label, segment_assignment) -> List[LayoutBox]:
def get_regions(heatmaps: List[np.ndarray], orig_size, id2label, segment_assignment) -> List[LayoutBox]:
bboxes = []
for i in range(1, len(id2label)): # Skip the blank class
heatmap = heatmaps[i]
Expand All @@ -160,7 +160,7 @@ def get_regions(heatmaps: List[Image.Image], orig_size, id2label, segment_assign
return bboxes


def parallel_get_regions(heatmaps: List[Image.Image], orig_size, id2label, detection_results=None) -> List[LayoutResult]:
def parallel_get_regions(heatmaps: List[np.ndarray], orig_size, id2label, detection_results=None) -> LayoutResult:
logits = np.stack(heatmaps, axis=0)
segment_assignment = logits.argmax(axis=0)
if detection_results is not None:
Expand Down
284 changes: 284 additions & 0 deletions surya/model/detection/processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,284 @@
import warnings
from typing import Any, Dict, List, Optional, Union

import numpy as np

from transformers.image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from transformers.image_transforms import to_channel_dimension_format
from transformers.image_utils import (
IMAGENET_DEFAULT_MEAN,
IMAGENET_DEFAULT_STD,
ChannelDimension,
ImageInput,
PILImageResampling,
infer_channel_dimension_format,
make_list_of_images,
)
from transformers.utils import TensorType


import PIL.Image
import torch


class SegformerImageProcessor(BaseImageProcessor):
r"""
Constructs a Segformer image processor.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the image's (height, width) dimensions to the specified `(size["height"],
size["width"])`. Can be overridden by the `do_resize` parameter in the `preprocess` method.
size (`Dict[str, int]` *optional*, defaults to `{"height": 512, "width": 512}`):
Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
method.
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the
`preprocess` method.
do_rescale (`bool`, *optional*, defaults to `True`):
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
parameter in the `preprocess` method.
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
method.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
method.
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
do_reduce_labels (`bool`, *optional*, defaults to `False`):
Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 is
used for background, and background itself is not included in all classes of a dataset (e.g. ADE20k). The
background label will be replaced by 255. Can be overridden by the `do_reduce_labels` parameter in the
`preprocess` method.
"""

model_input_names = ["pixel_values"]

def __init__(
self,
do_resize: bool = True,
size: Dict[str, int] = None,
resample: PILImageResampling = PILImageResampling.BILINEAR,
do_rescale: bool = True,
rescale_factor: Union[int, float] = 1 / 255,
do_normalize: bool = True,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_reduce_labels: bool = False,
**kwargs,
) -> None:
if "reduce_labels" in kwargs:
warnings.warn(
"The `reduce_labels` parameter is deprecated and will be removed in a future version. Please use "
"`do_reduce_labels` instead.",
FutureWarning,
)
do_reduce_labels = kwargs.pop("reduce_labels")

super().__init__(**kwargs)
size = size if size is not None else {"height": 512, "width": 512}
size = get_size_dict(size)
self.do_resize = do_resize
self.size = size
self.resample = resample
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
self.do_reduce_labels = do_reduce_labels
self._valid_processor_keys = [
"images",
"segmentation_maps",
"do_resize",
"size",
"resample",
"do_rescale",
"rescale_factor",
"do_normalize",
"image_mean",
"image_std",
"do_reduce_labels",
"return_tensors",
"data_format",
"input_data_format",
]

@classmethod
def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
"""
Overrides the `from_dict` method from the base class to make sure `do_reduce_labels` is updated if image
processor is created using from_dict and kwargs e.g. `SegformerImageProcessor.from_pretrained(checkpoint,
reduce_labels=True)`
"""
image_processor_dict = image_processor_dict.copy()
if "reduce_labels" in kwargs:
image_processor_dict["reduce_labels"] = kwargs.pop("reduce_labels")
return super().from_dict(image_processor_dict, **kwargs)

def _preprocess(
self,
image: ImageInput,
do_resize: bool,
do_rescale: bool,
do_normalize: bool,
size: Optional[Dict[str, int]] = None,
resample: PILImageResampling = None,
rescale_factor: Optional[float] = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
):

if do_rescale:
image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)

if do_normalize:
image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)

return image

def _preprocess_image(
self,
image: ImageInput,
do_resize: bool = None,
size: Dict[str, int] = None,
resample: PILImageResampling = None,
do_rescale: bool = None,
rescale_factor: float = None,
do_normalize: bool = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray:
"""Preprocesses a single image."""
# All transformations expect numpy arrays.
if input_data_format is None:
input_data_format = infer_channel_dimension_format(image)

image = self._preprocess(
image=image,
do_resize=do_resize,
size=size,
resample=resample,
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
input_data_format=input_data_format,
)
if data_format is not None:
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
return image

def __call__(self, images, segmentation_maps=None, **kwargs):
"""
Preprocesses a batch of images and optionally segmentation maps.
Overrides the `__call__` method of the `Preprocessor` class so that both images and segmentation maps can be
passed in as positional arguments.
"""
return super().__call__(images, segmentation_maps=segmentation_maps, **kwargs)

def preprocess(
self,
images: ImageInput,
segmentation_maps: Optional[ImageInput] = None,
do_resize: Optional[bool] = None,
size: Optional[Dict[str, int]] = None,
resample: PILImageResampling = None,
do_rescale: Optional[bool] = None,
rescale_factor: Optional[float] = None,
do_normalize: Optional[bool] = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_reduce_labels: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: ChannelDimension = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
) -> PIL.Image.Image:
"""
Preprocess an image or batch of images.
Args:
images (`ImageInput`):
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
segmentation_maps (`ImageInput`, *optional*):
Segmentation map to preprocess.
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image.
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
Size of the image after `resize` is applied.
resample (`int`, *optional*, defaults to `self.resample`):
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only
has an effect if `do_resize` is set to `True`.
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
Whether to rescale the image values between [0 - 1].
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
Whether to normalize the image.
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
Image mean.
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
Image standard deviation.
do_reduce_labels (`bool`, *optional*, defaults to `self.do_reduce_labels`):
Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0
is used for background, and background itself is not included in all classes of a dataset (e.g.
ADE20k). The background label will be replaced by 255.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of:
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
"""
do_resize = do_resize if do_resize is not None else self.do_resize
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
resample = resample if resample is not None else self.resample
size = size if size is not None else self.size
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std

images = make_list_of_images(images)
images = [
self._preprocess_image(
image=img,
do_resize=do_resize,
resample=resample,
size=size,
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
data_format=data_format,
input_data_format=input_data_format,
)
for img in images
]

data = {"pixel_values": images}
return BatchFeature(data=data, tensor_type=return_tensors)
4 changes: 2 additions & 2 deletions surya/model/detection/segformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import math
from typing import Optional, Tuple, Union

from transformers import SegformerConfig, SegformerForSemanticSegmentation, SegformerImageProcessor, \
SegformerDecodeHead, SegformerModel
from transformers import SegformerConfig, SegformerForSemanticSegmentation, SegformerDecodeHead, SegformerModel
from surya.model.detection.processor import SegformerImageProcessor
import torch
from torch import nn

Expand Down
Loading

0 comments on commit 2a5e542

Please sign in to comment.