-
Notifications
You must be signed in to change notification settings - Fork 563
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #110 from VikParuchuri/dev
Processor improvements
- Loading branch information
Showing
12 changed files
with
337 additions
and
93 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,284 @@ | ||
import warnings | ||
from typing import Any, Dict, List, Optional, Union | ||
|
||
import numpy as np | ||
|
||
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict | ||
from transformers.image_transforms import to_channel_dimension_format | ||
from transformers.image_utils import ( | ||
IMAGENET_DEFAULT_MEAN, | ||
IMAGENET_DEFAULT_STD, | ||
ChannelDimension, | ||
ImageInput, | ||
PILImageResampling, | ||
infer_channel_dimension_format, | ||
make_list_of_images, | ||
) | ||
from transformers.utils import TensorType | ||
|
||
|
||
import PIL.Image | ||
import torch | ||
|
||
|
||
class SegformerImageProcessor(BaseImageProcessor): | ||
r""" | ||
Constructs a Segformer image processor. | ||
Args: | ||
do_resize (`bool`, *optional*, defaults to `True`): | ||
Whether to resize the image's (height, width) dimensions to the specified `(size["height"], | ||
size["width"])`. Can be overridden by the `do_resize` parameter in the `preprocess` method. | ||
size (`Dict[str, int]` *optional*, defaults to `{"height": 512, "width": 512}`): | ||
Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess` | ||
method. | ||
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`): | ||
Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the | ||
`preprocess` method. | ||
do_rescale (`bool`, *optional*, defaults to `True`): | ||
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale` | ||
parameter in the `preprocess` method. | ||
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): | ||
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess` | ||
method. | ||
do_normalize (`bool`, *optional*, defaults to `True`): | ||
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess` | ||
method. | ||
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`): | ||
Mean to use if normalizing the image. This is a float or list of floats the length of the number of | ||
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. | ||
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`): | ||
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the | ||
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. | ||
do_reduce_labels (`bool`, *optional*, defaults to `False`): | ||
Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 is | ||
used for background, and background itself is not included in all classes of a dataset (e.g. ADE20k). The | ||
background label will be replaced by 255. Can be overridden by the `do_reduce_labels` parameter in the | ||
`preprocess` method. | ||
""" | ||
|
||
model_input_names = ["pixel_values"] | ||
|
||
def __init__( | ||
self, | ||
do_resize: bool = True, | ||
size: Dict[str, int] = None, | ||
resample: PILImageResampling = PILImageResampling.BILINEAR, | ||
do_rescale: bool = True, | ||
rescale_factor: Union[int, float] = 1 / 255, | ||
do_normalize: bool = True, | ||
image_mean: Optional[Union[float, List[float]]] = None, | ||
image_std: Optional[Union[float, List[float]]] = None, | ||
do_reduce_labels: bool = False, | ||
**kwargs, | ||
) -> None: | ||
if "reduce_labels" in kwargs: | ||
warnings.warn( | ||
"The `reduce_labels` parameter is deprecated and will be removed in a future version. Please use " | ||
"`do_reduce_labels` instead.", | ||
FutureWarning, | ||
) | ||
do_reduce_labels = kwargs.pop("reduce_labels") | ||
|
||
super().__init__(**kwargs) | ||
size = size if size is not None else {"height": 512, "width": 512} | ||
size = get_size_dict(size) | ||
self.do_resize = do_resize | ||
self.size = size | ||
self.resample = resample | ||
self.do_rescale = do_rescale | ||
self.rescale_factor = rescale_factor | ||
self.do_normalize = do_normalize | ||
self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN | ||
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD | ||
self.do_reduce_labels = do_reduce_labels | ||
self._valid_processor_keys = [ | ||
"images", | ||
"segmentation_maps", | ||
"do_resize", | ||
"size", | ||
"resample", | ||
"do_rescale", | ||
"rescale_factor", | ||
"do_normalize", | ||
"image_mean", | ||
"image_std", | ||
"do_reduce_labels", | ||
"return_tensors", | ||
"data_format", | ||
"input_data_format", | ||
] | ||
|
||
@classmethod | ||
def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs): | ||
""" | ||
Overrides the `from_dict` method from the base class to make sure `do_reduce_labels` is updated if image | ||
processor is created using from_dict and kwargs e.g. `SegformerImageProcessor.from_pretrained(checkpoint, | ||
reduce_labels=True)` | ||
""" | ||
image_processor_dict = image_processor_dict.copy() | ||
if "reduce_labels" in kwargs: | ||
image_processor_dict["reduce_labels"] = kwargs.pop("reduce_labels") | ||
return super().from_dict(image_processor_dict, **kwargs) | ||
|
||
def _preprocess( | ||
self, | ||
image: ImageInput, | ||
do_resize: bool, | ||
do_rescale: bool, | ||
do_normalize: bool, | ||
size: Optional[Dict[str, int]] = None, | ||
resample: PILImageResampling = None, | ||
rescale_factor: Optional[float] = None, | ||
image_mean: Optional[Union[float, List[float]]] = None, | ||
image_std: Optional[Union[float, List[float]]] = None, | ||
input_data_format: Optional[Union[str, ChannelDimension]] = None, | ||
): | ||
|
||
if do_rescale: | ||
image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) | ||
|
||
if do_normalize: | ||
image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format) | ||
|
||
return image | ||
|
||
def _preprocess_image( | ||
self, | ||
image: ImageInput, | ||
do_resize: bool = None, | ||
size: Dict[str, int] = None, | ||
resample: PILImageResampling = None, | ||
do_rescale: bool = None, | ||
rescale_factor: float = None, | ||
do_normalize: bool = None, | ||
image_mean: Optional[Union[float, List[float]]] = None, | ||
image_std: Optional[Union[float, List[float]]] = None, | ||
data_format: Optional[Union[str, ChannelDimension]] = None, | ||
input_data_format: Optional[Union[str, ChannelDimension]] = None, | ||
) -> np.ndarray: | ||
"""Preprocesses a single image.""" | ||
# All transformations expect numpy arrays. | ||
if input_data_format is None: | ||
input_data_format = infer_channel_dimension_format(image) | ||
|
||
image = self._preprocess( | ||
image=image, | ||
do_resize=do_resize, | ||
size=size, | ||
resample=resample, | ||
do_rescale=do_rescale, | ||
rescale_factor=rescale_factor, | ||
do_normalize=do_normalize, | ||
image_mean=image_mean, | ||
image_std=image_std, | ||
input_data_format=input_data_format, | ||
) | ||
if data_format is not None: | ||
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) | ||
return image | ||
|
||
def __call__(self, images, segmentation_maps=None, **kwargs): | ||
""" | ||
Preprocesses a batch of images and optionally segmentation maps. | ||
Overrides the `__call__` method of the `Preprocessor` class so that both images and segmentation maps can be | ||
passed in as positional arguments. | ||
""" | ||
return super().__call__(images, segmentation_maps=segmentation_maps, **kwargs) | ||
|
||
def preprocess( | ||
self, | ||
images: ImageInput, | ||
segmentation_maps: Optional[ImageInput] = None, | ||
do_resize: Optional[bool] = None, | ||
size: Optional[Dict[str, int]] = None, | ||
resample: PILImageResampling = None, | ||
do_rescale: Optional[bool] = None, | ||
rescale_factor: Optional[float] = None, | ||
do_normalize: Optional[bool] = None, | ||
image_mean: Optional[Union[float, List[float]]] = None, | ||
image_std: Optional[Union[float, List[float]]] = None, | ||
do_reduce_labels: Optional[bool] = None, | ||
return_tensors: Optional[Union[str, TensorType]] = None, | ||
data_format: ChannelDimension = ChannelDimension.FIRST, | ||
input_data_format: Optional[Union[str, ChannelDimension]] = None, | ||
**kwargs, | ||
) -> PIL.Image.Image: | ||
""" | ||
Preprocess an image or batch of images. | ||
Args: | ||
images (`ImageInput`): | ||
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If | ||
passing in images with pixel values between 0 and 1, set `do_rescale=False`. | ||
segmentation_maps (`ImageInput`, *optional*): | ||
Segmentation map to preprocess. | ||
do_resize (`bool`, *optional*, defaults to `self.do_resize`): | ||
Whether to resize the image. | ||
size (`Dict[str, int]`, *optional*, defaults to `self.size`): | ||
Size of the image after `resize` is applied. | ||
resample (`int`, *optional*, defaults to `self.resample`): | ||
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only | ||
has an effect if `do_resize` is set to `True`. | ||
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): | ||
Whether to rescale the image values between [0 - 1]. | ||
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): | ||
Rescale factor to rescale the image by if `do_rescale` is set to `True`. | ||
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): | ||
Whether to normalize the image. | ||
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): | ||
Image mean. | ||
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): | ||
Image standard deviation. | ||
do_reduce_labels (`bool`, *optional*, defaults to `self.do_reduce_labels`): | ||
Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 | ||
is used for background, and background itself is not included in all classes of a dataset (e.g. | ||
ADE20k). The background label will be replaced by 255. | ||
return_tensors (`str` or `TensorType`, *optional*): | ||
The type of tensors to return. Can be one of: | ||
- Unset: Return a list of `np.ndarray`. | ||
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. | ||
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. | ||
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. | ||
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. | ||
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): | ||
The channel dimension format for the output image. Can be one of: | ||
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format. | ||
- `ChannelDimension.LAST`: image in (height, width, num_channels) format. | ||
input_data_format (`ChannelDimension` or `str`, *optional*): | ||
The channel dimension format for the input image. If unset, the channel dimension format is inferred | ||
from the input image. Can be one of: | ||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. | ||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. | ||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format. | ||
""" | ||
do_resize = do_resize if do_resize is not None else self.do_resize | ||
do_rescale = do_rescale if do_rescale is not None else self.do_rescale | ||
do_normalize = do_normalize if do_normalize is not None else self.do_normalize | ||
resample = resample if resample is not None else self.resample | ||
size = size if size is not None else self.size | ||
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor | ||
image_mean = image_mean if image_mean is not None else self.image_mean | ||
image_std = image_std if image_std is not None else self.image_std | ||
|
||
images = make_list_of_images(images) | ||
images = [ | ||
self._preprocess_image( | ||
image=img, | ||
do_resize=do_resize, | ||
resample=resample, | ||
size=size, | ||
do_rescale=do_rescale, | ||
rescale_factor=rescale_factor, | ||
do_normalize=do_normalize, | ||
image_mean=image_mean, | ||
image_std=image_std, | ||
data_format=data_format, | ||
input_data_format=input_data_format, | ||
) | ||
for img in images | ||
] | ||
|
||
data = {"pixel_values": images} | ||
return BatchFeature(data=data, tensor_type=return_tensors) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.