diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index f8866e04d75e0e..1083d908b79c80 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -521,6 +521,8 @@ title: Utilities for Trainer - local: internal/generation_utils title: Utilities for Generation + - local: internal/image_processing_utils + title: Utilities for Image Processors - local: internal/file_utils title: General Utilities title: Internal Helpers diff --git a/docs/source/en/internal/image_processing_utils.mdx b/docs/source/en/internal/image_processing_utils.mdx new file mode 100644 index 00000000000000..1ec890e9e1f786 --- /dev/null +++ b/docs/source/en/internal/image_processing_utils.mdx @@ -0,0 +1,30 @@ + + +# Utilities for Image Processors + +This page lists all the utility functions used by the image processors, mainly the functional +transformations used to process the images. + +Most of those are only useful if you are studying the code of the image processors in the library. + +## Image Transformations + +[[autodoc]] image_transforms.rescale + +[[autodoc]] image_transforms.resize + +[[autodoc]] image_transforms.to_pil_image + +## ImageProcessorMixin + +[[autodoc]] image_processing_utils.ImageProcessorMixin diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 0e69839f0ec17e..263a9a27cc22ca 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -680,6 +680,8 @@ name for name in dir(dummy_vision_objects) if not name.startswith("_") ] else: + _import_structure["image_processing_utils"] = ["ImageProcessorMixin"] + _import_structure["image_transforms"] = ["rescale", "resize", "to_pil_image"] _import_structure["image_utils"] = ["ImageFeatureExtractionMixin"] _import_structure["models.beit"].append("BeitFeatureExtractor") _import_structure["models.clip"].append("CLIPFeatureExtractor") @@ -3648,6 +3650,8 @@ except OptionalDependencyNotAvailable: from .utils.dummy_vision_objects import * else: + from .image_processing_utils import ImageProcessorMixin + from .image_transforms import rescale, resize, to_pil_image from .image_utils import ImageFeatureExtractionMixin from .models.beit import BeitFeatureExtractor from .models.clip import CLIPFeatureExtractor diff --git a/src/transformers/image_processing_utils.py b/src/transformers/image_processing_utils.py new file mode 100644 index 00000000000000..ba9d3c0962e3f6 --- /dev/null +++ b/src/transformers/image_processing_utils.py @@ -0,0 +1,54 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .feature_extraction_utils import BatchFeature as BaseBatchFeature +from .feature_extraction_utils import FeatureExtractionMixin +from .utils import logging + + +logger = logging.get_logger(__name__) + + +# TODO: Move BatchFeature to be imported by both feature_extraction_utils and image_processing_utils +# We override the class string here, but logic is the same. +class BatchFeature(BaseBatchFeature): + r""" + Holds the output of the image processor specific `__call__` methods. + + This class is derived from a python dictionary and can be used as a dictionary. + + Args: + data (`dict`): + Dictionary of lists/arrays/tensors returned by the __call__ method ('pixel_values', etc.). + tensor_type (`Union[None, str, TensorType]`, *optional*): + You can give a tensor_type here to convert the lists of integers in PyTorch/TensorFlow/Numpy Tensors at + initialization. + """ + + +# We use aliasing whilst we phase out the old API. Once feature extractors for vision models +# are deprecated, ImageProcessor mixin will be implemented. Any shared logic will be abstracted out. +ImageProcessorMixin = FeatureExtractionMixin + + +class BaseImageProcessor(ImageProcessorMixin): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def __call__(self, images, **kwargs) -> BatchFeature: + return self.preprocess(images, **kwargs) + + def preprocess(self, images, **kwargs) -> BatchFeature: + raise NotImplementedError("Each image processor must implement its own preprocess method") diff --git a/src/transformers/image_transforms.py b/src/transformers/image_transforms.py new file mode 100644 index 00000000000000..024b46911a750a --- /dev/null +++ b/src/transformers/image_transforms.py @@ -0,0 +1,259 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, List, Optional, Tuple, Union + +import numpy as np + +from transformers.utils.import_utils import is_flax_available, is_tf_available, is_torch_available, is_vision_available + + +if is_vision_available(): + import PIL + + from .image_utils import ( + ChannelDimension, + get_image_size, + infer_channel_dimension_format, + is_jax_tensor, + is_tf_tensor, + is_torch_tensor, + ) + + +if TYPE_CHECKING: + if is_torch_available(): + import torch + if is_tf_available(): + import tensorflow as tf + if is_flax_available(): + import jax.numpy as jnp + + +def to_channel_dimension_format(image: np.ndarray, channel_dim: Union[ChannelDimension, str]) -> np.ndarray: + """ + Converts `image` to the channel dimension format specified by `channel_dim`. + + Args: + image (`numpy.ndarray`): + The image to have its channel dimension set. + channel_dim (`ChannelDimension`): + The channel dimension format to use. + + Returns: + `np.ndarray`: The image with the channel dimension set to `channel_dim`. + """ + if not isinstance(image, np.ndarray): + raise ValueError(f"Input image must be of type np.ndarray, got {type(image)}") + + current_channel_dim = infer_channel_dimension_format(image) + target_channel_dim = ChannelDimension(channel_dim) + if current_channel_dim == target_channel_dim: + return image + + if target_channel_dim == ChannelDimension.FIRST: + image = image.transpose((2, 0, 1)) + elif target_channel_dim == ChannelDimension.LAST: + image = image.transpose((1, 2, 0)) + else: + raise ValueError("Unsupported channel dimension format: {}".format(channel_dim)) + + return image + + +def rescale( + image: np.ndarray, scale: float, data_format: Optional[ChannelDimension] = None, dtype=np.float32 +) -> np.ndarray: + """ + Rescales `image` by `scale`. + + Args: + image (`np.ndarray`): + The image to rescale. + scale (`float`): + The scale to use for rescaling the image. + data_format (`ChannelDimension`, *optional*): + The channel dimension format of the image. If not provided, it will be the same as the input image. + dtype (`np.dtype`, *optional*, defaults to `np.float32`): + The dtype of the output image. Defaults to `np.float32`. Used for backwards compatibility with feature + extractors. + + Returns: + `np.ndarray`: The rescaled image. + """ + if not isinstance(image, np.ndarray): + raise ValueError(f"Input image must be of type np.ndarray, got {type(image)}") + + rescaled_image = image * scale + if data_format is not None: + rescaled_image = to_channel_dimension_format(rescaled_image, data_format) + rescaled_image = rescaled_image.astype(dtype) + return rescaled_image + + +def to_pil_image( + image: Union[np.ndarray, PIL.Image.Image, "torch.Tensor", "tf.Tensor", "jnp.Tensor"], + do_rescale: Optional[bool] = None, +) -> PIL.Image.Image: + """ + Converts `image` to a PIL Image. Optionally rescales it and puts the channel dimension back as the last axis if + needed. + + Args: + image (`PIL.Image.Image` or `numpy.ndarray` or `torch.Tensor` or `tf.Tensor`): + The image to convert to the `PIL.Image` format. + do_rescale (`bool`, *optional*): + Whether or not to apply the scaling factor (to make pixel values integers between 0 and 255). Will default + to `True` if the image type is a floating type, `False` otherwise. + + Returns: + `PIL.Image.Image`: The converted image. + """ + if isinstance(image, PIL.Image.Image): + return image + + # Convert all tensors to numpy arrays before converting to PIL image + if is_torch_tensor(image) or is_tf_tensor(image): + image = image.numpy() + elif is_jax_tensor(image): + image = np.array(image) + elif not isinstance(image, np.ndarray): + raise ValueError("Input image type not supported: {}".format(type(image))) + + # If the channel as been moved to first dim, we put it back at the end. + image = to_channel_dimension_format(image, ChannelDimension.LAST) + + # PIL.Image can only store uint8 values, so we rescale the image to be between 0 and 255 if needed. + do_rescale = isinstance(image.flat[0], float) if do_rescale is None else do_rescale + if do_rescale: + image = rescale(image, 255) + image = image.astype(np.uint8) + return PIL.Image.fromarray(image) + + +def get_resize_output_image_size( + input_image: np.ndarray, + size: Union[int, Tuple[int, int], List[int], Tuple[int]], + default_to_square: bool = True, + max_size: Optional[int] = None, +) -> tuple: + """ + Find the target (height, width) dimension of the output image after resizing given the input image and the desired + size. + + Args: + input_image (`np.ndarray`): + The image to resize. + size (`int` or `Tuple[int, int]` or List[int] or Tuple[int]): + The size to use for resizing the image. If `size` is a sequence like (h, w), output size will be matched to + this. + + If `size` is an int and `default_to_square` is `True`, then image will be resized to (size, size). If + `size` is an int and `default_to_square` is `False`, then smaller edge of the image will be matched to this + number. i.e, if height > width, then image will be rescaled to (size * height / width, size). + default_to_square (`bool`, *optional*, defaults to `True`): + How to convert `size` when it is a single int. If set to `True`, the `size` will be converted to a square + (`size`,`size`). If set to `False`, will replicate + [`torchvision.transforms.Resize`](https://pytorch.org/vision/stable/transforms.html#torchvision.transforms.Resize) + with support for resizing only the smallest edge and providing an optional `max_size`. + max_size (`int`, *optional*): + The maximum allowed for the longer edge of the resized image: if the longer edge of the image is greater + than `max_size` after being resized according to `size`, then the image is resized again so that the longer + edge is equal to `max_size`. As a result, `size` might be overruled, i.e the smaller edge may be shorter + than `size`. Only used if `default_to_square` is `False`. + + Returns: + `tuple`: The target (height, width) dimension of the output image after resizing. + """ + if isinstance(size, (tuple, list)): + if len(size) == 2: + return tuple(size) + elif len(size) == 1: + # Perform same logic as if size was an int + size = size[0] + else: + raise ValueError("size must have 1 or 2 elements if it is a list or tuple") + + if default_to_square: + return (size, size) + + height, width = get_image_size(input_image) + short, long = (width, height) if width <= height else (height, width) + requested_new_short = size + + if short == requested_new_short: + return (height, width) + + new_short, new_long = requested_new_short, int(requested_new_short * long / short) + + if max_size is not None: + if max_size <= requested_new_short: + raise ValueError( + f"max_size = {max_size} must be strictly greater than the requested " + f"size for the smaller edge size = {size}" + ) + if new_long > max_size: + new_short, new_long = int(max_size * new_short / new_long), max_size + + return (new_long, new_short) if width <= height else (new_short, new_long) + + +def resize( + image, + size: Tuple[int, int], + resample=PIL.Image.BILINEAR, + data_format: Optional[ChannelDimension] = None, + return_numpy: bool = True, +) -> np.ndarray: + """ + Resizes `image` to (h, w) specified by `size` using the PIL library. + + Args: + image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`): + The image to resize. + size (`Tuple[int, int]`): + The size to use for resizing the image. + resample (`int`, *optional*, defaults to `PIL.Image.BILINEAR`): + The filter to user for resampling. + data_format (`ChannelDimension`, *optional*): + The channel dimension format of the output image. If `None`, will use the inferred format from the input. + return_numpy (`bool`, *optional*, defaults to `True`): + Whether or not to return the resized image as a numpy array. If False a `PIL.Image.Image` object is + returned. + + Returns: + `np.ndarray`: The resized image. + """ + if not len(size) == 2: + raise ValueError("size must have 2 elements") + + # For all transformations, we want to keep the same data format as the input image unless otherwise specified. + # The resized image from PIL will always have channels last, so find the input format first. + data_format = infer_channel_dimension_format(image) if data_format is None else data_format + + # To maintain backwards compatibility with the resizing done in previous image feature extractors, we use + # the pillow library to resize the image and then convert back to numpy + if not isinstance(image, PIL.Image.Image): + # PIL expects image to have channels last + image = to_channel_dimension_format(image, ChannelDimension.LAST) + image = to_pil_image(image) + height, width = size + # PIL images are in the format (width, height) + resized_image = image.resize((width, height), resample=resample) + + if return_numpy: + resized_image = np.array(resized_image) + resized_image = to_channel_dimension_format(resized_image, data_format) + return resized_image diff --git a/src/transformers/image_utils.py b/src/transformers/image_utils.py index 437e7c5685586b..0ba86d14b7975d 100644 --- a/src/transformers/image_utils.py +++ b/src/transformers/image_utils.py @@ -14,33 +14,128 @@ # limitations under the License. import os -from typing import List, Union +from typing import TYPE_CHECKING, List, Tuple, Union import numpy as np -import PIL.Image -import PIL.ImageOps import requests -from .utils import is_torch_available +from .utils import is_flax_available, is_tf_available, is_torch_available, is_vision_available from .utils.constants import ( # noqa: F401 IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD, ) -from .utils.generic import _is_torch +from .utils.generic import ExplicitEnum, _is_jax, _is_tensorflow, _is_torch, to_numpy + + +if is_vision_available(): + import PIL.Image + import PIL.ImageOps + + +if TYPE_CHECKING: + if is_torch_available(): + import torch ImageInput = Union[ - PIL.Image.Image, np.ndarray, "torch.Tensor", List[PIL.Image.Image], List[np.ndarray], List["torch.Tensor"] # noqa -] + "PIL.Image.Image", np.ndarray, "torch.Tensor", List["PIL.Image.Image"], List[np.ndarray], List["torch.Tensor"] +] # noqa + + +class ChannelDimension(ExplicitEnum): + FIRST = "channels_first" + LAST = "channels_last" def is_torch_tensor(obj): return _is_torch(obj) if is_torch_available() else False +def is_tf_tensor(obj): + return _is_tensorflow(obj) if is_tf_available() else False + + +def is_jax_tensor(obj): + return _is_jax(obj) if is_flax_available() else False + + +def is_valid_image(img): + return ( + isinstance(img, (PIL.Image.Image, np.ndarray)) + or is_torch_tensor(img) + or is_tf_tensor(img) + or is_jax_tensor(img) + ) + + +def valid_images(imgs): + return all(is_valid_image(img) for img in imgs) + + +def is_batched(img): + if isinstance(img, (list, tuple)): + return is_valid_image(img[0]) + return False + + +def to_numpy_array(img) -> np.ndarray: + if isinstance(img, PIL.Image.Image): + return np.array(img) + return to_numpy(img) + + +def infer_channel_dimension_format(image: np.ndarray) -> ChannelDimension: + """ + Infers the channel dimension format of `image`. + + Args: + image (`np.ndarray`): + The image to infer the channel dimension of. + + Returns: + The channel dimension of the image. + """ + if image.ndim == 3: + first_dim, last_dim = 0, 2 + elif image.ndim == 4: + first_dim, last_dim = 1, 3 + else: + raise ValueError(f"Unsupported number of image dimensions: {image.ndim}") + + if image.shape[first_dim] in (1, 3): + return ChannelDimension.FIRST + elif image.shape[last_dim] in (1, 3): + return ChannelDimension.LAST + raise ValueError("Unable to infer channel dimension format") + + +def get_image_size(image: np.ndarray, channel_dim: ChannelDimension = None) -> Tuple[int, int]: + """ + Returns the (height, width) dimensions of the image. + + Args: + image (`np.ndarray`): + The image to get the dimensions of. + channel_dim (`ChannelDimension`, *optional*): + Which dimension the channel dimension is in. If `None`, will infer the channel dimension from the image. + + Returns: + A tuple of the image's height and width. + """ + if channel_dim is None: + channel_dim = infer_channel_dimension_format(image) + + if channel_dim == ChannelDimension.FIRST: + return image.shape[-2], image.shape[-1] + elif channel_dim == ChannelDimension.LAST: + return image.shape[-3], image.shape[-2] + else: + raise ValueError(f"Unsupported data format: {channel_dim}") + + def load_image(image: Union[str, "PIL.Image.Image"]) -> "PIL.Image.Image": """ Loads `image` to a PIL Image. @@ -236,7 +331,7 @@ def normalize(self, image, mean, std, rescale=False): else: return (image - mean) / std - def resize(self, image, size, resample=PIL.Image.BILINEAR, default_to_square=True, max_size=None): + def resize(self, image, size, resample=None, default_to_square=True, max_size=None): """ Resizes `image`. Enforces conversion of input to PIL.Image. @@ -266,6 +361,8 @@ def resize(self, image, size, resample=PIL.Image.BILINEAR, default_to_square=Tru Returns: image: A resized `PIL.Image.Image`. """ + resample = resample if resample is not None else PIL.Image.BILINEAR + self._ensure_format_supported(image) if not isinstance(image, PIL.Image.Image): @@ -393,7 +490,7 @@ def flip_channel_order(self, image): return image[::-1, :, :] - def rotate(self, image, angle, resample=PIL.Image.NEAREST, expand=0, center=None, translate=None, fillcolor=None): + def rotate(self, image, angle, resample=None, expand=0, center=None, translate=None, fillcolor=None): """ Returns a rotated copy of `image`. This method returns a copy of `image`, rotated the given number of degrees counter clockwise around its centre. @@ -406,6 +503,8 @@ def rotate(self, image, angle, resample=PIL.Image.NEAREST, expand=0, center=None Returns: image: A rotated `PIL.Image.Image`. """ + resample = resample if resample is not None else PIL.Image.NEAREST + self._ensure_format_supported(image) if not isinstance(image, PIL.Image.Image): diff --git a/src/transformers/models/glpn/feature_extraction_glpn.py b/src/transformers/models/glpn/feature_extraction_glpn.py index 2694d56b898bec..fe63276c4798a6 100644 --- a/src/transformers/models/glpn/feature_extraction_glpn.py +++ b/src/transformers/models/glpn/feature_extraction_glpn.py @@ -14,126 +14,11 @@ # limitations under the License. """Feature extractor class for GLPN.""" -from typing import Optional, Union - -import numpy as np -from PIL import Image - -from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin -from ...image_utils import ImageFeatureExtractionMixin, ImageInput, is_torch_tensor -from ...utils import TensorType, logging +from ...utils import logging +from .image_processing_glpn import GLPNImageProcessor logger = logging.get_logger(__name__) - -class GLPNFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin): - r""" - Constructs a GLPN feature extractor. - - This feature extractor inherits from [`FeatureExtractionMixin`] which contains most of the main methods. Users - should refer to this superclass for more information regarding those methods. - - Args: - do_resize (`bool`, *optional*, defaults to `True`): - Whether to resize the input based on certain `size_divisor`. - size_divisor (`int` or `Tuple(int)`, *optional*, defaults to 32): - Make sure the input is divisible by this value. Only has an effect if `do_resize` is set to `True`. - resample (`int`, *optional*, defaults to `PIL.Image.BILINEAR`): - An optional resampling filter. This can be one of `PIL.Image.NEAREST`, `PIL.Image.BOX`, - `PIL.Image.BILINEAR`, `PIL.Image.HAMMING`, `PIL.Image.BICUBIC` or `PIL.Image.LANCZOS`. Only has an effect - if `do_resize` is set to `True`. - do_rescale (`bool`, *optional*, defaults to `True`): - Whether or not to apply the scaling factor (to make pixel values floats between 0. and 1.). - """ - - model_input_names = ["pixel_values"] - - def __init__(self, do_resize=True, size_divisor=32, resample=Image.BILINEAR, do_rescale=True, **kwargs): - super().__init__(**kwargs) - self.do_resize = do_resize - self.size_divisor = size_divisor - self.resample = resample - self.do_rescale = do_rescale - - def _resize(self, image, size_divisor, resample): - if not isinstance(image, Image.Image): - image = self.to_pil_image(image) - - width, height = image.size - new_h, new_w = height // size_divisor * size_divisor, width // size_divisor * size_divisor - - image = self.resize(image, size=(new_w, new_h), resample=resample) - - return image - - def __call__( - self, images: ImageInput, return_tensors: Optional[Union[str, TensorType]] = None, **kwargs - ) -> BatchFeature: - """ - Main method to prepare for the model one or several image(s). - - - - NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass - PIL images. - - - - Args: - images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): - The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch - tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a - number of channels, H and W are image height and width. - - return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`): - If set, will return tensors of a particular framework. Acceptable values are: - - - `'tf'`: Return TensorFlow `tf.constant` objects. - - `'pt'`: Return PyTorch `torch.Tensor` objects. - - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. - - Returns: - [`BatchFeature`]: A [`BatchFeature`] with the following fields: - - - **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, height, - width). - """ - # Input type checking for clearer error - valid_images = False - - # Check that images has a valid type - if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images): - valid_images = True - elif isinstance(images, (list, tuple)): - if len(images) == 0 or isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]): - valid_images = True - - if not valid_images: - raise ValueError( - "Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example), " - "`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)." - ) - - is_batched = bool( - isinstance(images, (list, tuple)) - and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0])) - ) - - if not is_batched: - images = [images] - - # transformations (resizing + rescaling) - if self.do_resize and self.size_divisor is not None: - images = [ - self._resize(image=image, size_divisor=self.size_divisor, resample=self.resample) for image in images - ] - if self.do_rescale: - images = [self.to_numpy_array(image=image) for image in images] - - # return as BatchFeature - data = {"pixel_values": images} - encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors) - - return encoded_inputs +# Feature extractor for GLPN is being replaced by image processor +GLPNFeatureExtractor = GLPNImageProcessor diff --git a/src/transformers/models/glpn/image_processing_glpn.py b/src/transformers/models/glpn/image_processing_glpn.py new file mode 100644 index 00000000000000..98ae1d53f73d60 --- /dev/null +++ b/src/transformers/models/glpn/image_processing_glpn.py @@ -0,0 +1,186 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for GLPN.""" + +from typing import List, Optional, Union + +import numpy as np +import PIL.Image + +from transformers.utils.generic import TensorType + +from ...image_processing_utils import BaseImageProcessor, BatchFeature +from ...image_transforms import rescale, resize, to_channel_dimension_format +from ...image_utils import ChannelDimension, get_image_size, is_batched, to_numpy_array, valid_images +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class GLPNImageProcessor(BaseImageProcessor): + r""" + Constructs a GLPN image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Set the class default for the `do_resize` parameter. Controls whether to resize the image's (height, width) + dimensions, rounding them down to the closest multiple of `size_divisor`. + size_divisor (`int`, *optional*, defaults to 32): + Set the class default for the `size_divisor` parameter. When `do_resize` is `True`, images are resized so + their height and width are rounded down to the closest multiple of `size_divisor`. + resample (`PIL.Image` resampling filter, *optional*, defaults to `PIL.Image.BILINEAR`): + Set the class default for `resample`. Defines the resampling filter to use if resizing the image. + do_rescale (`bool`, *optional*, defaults to `True`): + Set the class default for the `do_rescale` parameter. Controls whether or not to apply the scaling factor + (to make pixel values floats between 0. and 1.). + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize: bool = True, + size_divisor: int = 32, + resample=PIL.Image.BILINEAR, + do_rescale: bool = True, + **kwargs + ) -> None: + self.do_resize = do_resize + self.do_rescale = do_rescale + self.size_divisor = size_divisor + self.resample = resample + super().__init__(**kwargs) + + def resize( + self, image: np.ndarray, size_divisor: int, resample, data_format: Optional[ChannelDimension] = None, **kwargs + ) -> np.ndarray: + """ + Resize the image, rounding the (height, width) dimensions down to the closest multiple of size_divisor. + + If the image is of dimension (3, 260, 170) and size_divisor is 32, the image will be resized to (3, 256, 160). + + Args: + image (`np.ndarray`): + The image to resize. + size_divisor (`int`): + The image is resized so its height and width are rounded down to the closest multiple of + `size_divisor`. + resample: + `PIL.Image` resampling filter to use when resizing the image e.g. `PIL.Image.BILINEAR`. + data_format (`ChannelDimension`, *optional*): + The channel dimension format for the output image. If `None`, the channel dimension format of the input + image is used. Can be one of: + - `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `ChannelDimension.LAST`: image in (height, width, num_channels) format. + + Returns: + `np.ndarray`: The resized image. + """ + height, width = get_image_size(image) + # Rounds the height and width down to the closest multiple of size_divisor + new_h = height // size_divisor * size_divisor + new_w = width // size_divisor * size_divisor + image = resize(image, (new_h, new_w), resample=resample, data_format=data_format, **kwargs) + return image + + def rescale( + self, image: np.ndarray, scale: float, data_format: Optional[ChannelDimension] = None, **kwargs + ) -> np.ndarray: + """ + Rescale the image by the given scaling factor `scale`. + + Args: + image (`np.ndarray`): + The image to rescale. + scale (`float`): + The scaling factor to rescale pixel values by. + data_format (`ChannelDimension`, *optional*): + The channel dimension format for the output image. If `None`, the channel dimension format of the input + image is used. Can be one of: + - `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `ChannelDimension.LAST`: image in (height, width, num_channels) format. + + Returns: + `np.ndarray`: The rescaled image. + """ + return rescale(image=image, scale=scale, data_format=data_format, **kwargs) + + def preprocess( + self, + images: Union["PIL.Image.Image", TensorType, List["PIL.Image.Image"], List[TensorType]], + do_resize: Optional[bool] = None, + size_divisor: Optional[int] = None, + resample=None, + do_rescale: Optional[bool] = None, + return_tensors: Optional[Union[TensorType, str]] = None, + data_format: ChannelDimension = ChannelDimension.FIRST, + **kwargs + ) -> BatchFeature: + """ + Preprocess the given images. + + Args: + images (`PIL.Image.Image` or `TensorType` or `List[np.ndarray]` or `List[TensorType]`): + The image or images to preprocess. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the input such that the (height, width) dimensions are a multiple of `size_divisor`. + size_divisor (`int`, *optional*, defaults to `self.size_divisor`): + When `do_resize` is `True`, images are resized so their height and width are rounded down to the + closest multiple of `size_divisor`. + resample (`PIL.Image` resampling filter, *optional*, defaults to `self.resample`): + `PIL.Image` resampling filter to use if resizing the image e.g. `PIL.Image.BILINEAR`. Only has an + effect if `do_resize` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether or not to apply the scaling factor (to make pixel values floats between 0. and 1.). + return_tensors (`str`, *optional*): + The type of tensors to return. Can be one of: + - `None`: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `ChannelDimension.LAST`: image in (height, width, num_channels) format. + """ + do_resize = do_resize if do_resize is not None else self.do_resize + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + size_divisor = size_divisor if size_divisor is not None else self.size_divisor + resample = resample if resample is not None else self.resample + + if do_resize and size_divisor is None: + raise ValueError("size_divisor is required for resizing") + + if not is_batched(images): + images = [images] + + if not valid_images(images): + raise ValueError("Invalid image(s)") + + # All transformations expect numpy arrays. + images = [to_numpy_array(img) for img in images] + + if do_resize: + images = [self.resize(image, size_divisor=size_divisor, resample=resample) for image in images] + + if do_rescale: + images = [self.rescale(image, scale=1 / 255) for image in images] + + images = [to_channel_dimension_format(image, data_format) for image in images] + + data = {"pixel_values": images} + return BatchFeature(data=data, tensor_type=return_tensors) diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 2269f225485820..97e013fee50450 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -164,6 +164,7 @@ SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json" CONFIG_NAME = "config.json" FEATURE_EXTRACTOR_NAME = "preprocessor_config.json" +IMAGE_PROCESSOR_NAME = FEATURE_EXTRACTOR_NAME MODEL_CARD_NAME = "modelcard.json" SENTENCEPIECE_UNDERLINE = "▁" diff --git a/src/transformers/utils/dummy_vision_objects.py b/src/transformers/utils/dummy_vision_objects.py index 2d1f0a88cd0ebf..a3112c4454b4bb 100644 --- a/src/transformers/utils/dummy_vision_objects.py +++ b/src/transformers/utils/dummy_vision_objects.py @@ -3,6 +3,25 @@ from ..utils import DummyObject, requires_backends +class ImageProcessorMixin(metaclass=DummyObject): + _backends = ["vision"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["vision"]) + + +def rescale(*args, **kwargs): + requires_backends(rescale, ["vision"]) + + +def resize(*args, **kwargs): + requires_backends(resize, ["vision"]) + + +def to_pil_image(*args, **kwargs): + requires_backends(to_pil_image, ["vision"]) + + class ImageFeatureExtractionMixin(metaclass=DummyObject): _backends = ["vision"] diff --git a/tests/test_image_transforms.py b/tests/test_image_transforms.py new file mode 100644 index 00000000000000..69e6de1587b8d6 --- /dev/null +++ b/tests/test_image_transforms.py @@ -0,0 +1,174 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +from parameterized import parameterized +from transformers.testing_utils import require_flax, require_tf, require_torch, require_vision +from transformers.utils.import_utils import is_flax_available, is_tf_available, is_torch_available, is_vision_available + + +if is_torch_available(): + import torch + +if is_tf_available(): + import tensorflow as tf + +if is_flax_available(): + import jax + +if is_vision_available(): + import PIL.Image + + from transformers.image_transforms import ( + get_resize_output_image_size, + resize, + to_channel_dimension_format, + to_pil_image, + ) + + +def get_random_image(height, width, num_channels=3, channels_first=True): + shape = (num_channels, height, width) if channels_first else (height, width, num_channels) + random_array = np.random.randint(0, 256, shape, dtype=np.uint8) + return random_array + + +@require_vision +class ImageTransformsTester(unittest.TestCase): + @parameterized.expand( + [ + ("numpy_float_channels_first", (3, 4, 5), np.float32), + ("numpy_float_channels_last", (4, 5, 3), np.float32), + ("numpy_int_channels_first", (3, 4, 5), np.int32), + ("numpy_uint_channels_first", (3, 4, 5), np.uint8), + ] + ) + @require_vision + def test_to_pil_image(self, name, image_shape, dtype): + image = np.random.randint(0, 256, image_shape).astype(dtype) + pil_image = to_pil_image(image) + self.assertIsInstance(pil_image, PIL.Image.Image) + self.assertEqual(pil_image.size, (5, 4)) + + @require_tf + def test_to_pil_image_from_tensorflow(self): + # channels_first + image = tf.random.uniform((3, 4, 5)) + pil_image = to_pil_image(image) + self.assertIsInstance(pil_image, PIL.Image.Image) + self.assertEqual(pil_image.size, (5, 4)) + + # channels_last + image = tf.random.uniform((4, 5, 3)) + pil_image = to_pil_image(image) + self.assertIsInstance(pil_image, PIL.Image.Image) + self.assertEqual(pil_image.size, (5, 4)) + + @require_torch + def test_to_pil_image_from_torch(self): + # channels first + image = torch.rand((3, 4, 5)) + pil_image = to_pil_image(image) + self.assertIsInstance(pil_image, PIL.Image.Image) + self.assertEqual(pil_image.size, (5, 4)) + + # channels last + image = torch.rand((4, 5, 3)) + pil_image = to_pil_image(image) + self.assertIsInstance(pil_image, PIL.Image.Image) + self.assertEqual(pil_image.size, (5, 4)) + + @require_flax + def test_to_pil_image_from_jax(self): + key = jax.random.PRNGKey(0) + # channel first + image = jax.random.uniform(key, (3, 4, 5)) + pil_image = to_pil_image(image) + self.assertIsInstance(pil_image, PIL.Image.Image) + self.assertEqual(pil_image.size, (5, 4)) + + # channel last + image = jax.random.uniform(key, (4, 5, 3)) + pil_image = to_pil_image(image) + self.assertIsInstance(pil_image, PIL.Image.Image) + self.assertEqual(pil_image.size, (5, 4)) + + def test_to_channel_dimension_format(self): + # Test that function doesn't reorder if channel dim matches the input. + image = np.random.rand(3, 4, 5) + image = to_channel_dimension_format(image, "channels_first") + self.assertEqual(image.shape, (3, 4, 5)) + + image = np.random.rand(4, 5, 3) + image = to_channel_dimension_format(image, "channels_last") + self.assertEqual(image.shape, (4, 5, 3)) + + # Test that function reorders if channel dim doesn't match the input. + image = np.random.rand(3, 4, 5) + image = to_channel_dimension_format(image, "channels_last") + self.assertEqual(image.shape, (4, 5, 3)) + + image = np.random.rand(4, 5, 3) + image = to_channel_dimension_format(image, "channels_first") + self.assertEqual(image.shape, (3, 4, 5)) + + def test_get_resize_output_image_size(self): + image = np.random.randint(0, 256, (3, 224, 224)) + + # Test the output size defaults to (x, x) if an int is given. + self.assertEqual(get_resize_output_image_size(image, 10), (10, 10)) + self.assertEqual(get_resize_output_image_size(image, [10]), (10, 10)) + self.assertEqual(get_resize_output_image_size(image, (10,)), (10, 10)) + + # Test the output size is the same as the input if a two element tuple/list is given. + self.assertEqual(get_resize_output_image_size(image, (10, 20)), (10, 20)) + self.assertEqual(get_resize_output_image_size(image, [10, 20]), (10, 20)) + self.assertEqual(get_resize_output_image_size(image, (10, 20), default_to_square=True), (10, 20)) + # To match pytorch behaviour, max_size is only relevant if size is an int + self.assertEqual(get_resize_output_image_size(image, (10, 20), max_size=5), (10, 20)) + + # Test output size = (int(size * height / width), size) if size is an int and height > width + image = np.random.randint(0, 256, (3, 50, 40)) + self.assertEqual(get_resize_output_image_size(image, 20, default_to_square=False), (25, 20)) + + # Test output size = (size, int(size * width / height)) if size is an int and width <= height + image = np.random.randint(0, 256, (3, 40, 50)) + self.assertEqual(get_resize_output_image_size(image, 20, default_to_square=False), (20, 25)) + + # Test size is resized if longer size > max_size + image = np.random.randint(0, 256, (3, 50, 40)) + self.assertEqual(get_resize_output_image_size(image, 20, default_to_square=False, max_size=22), (22, 17)) + + def test_resize(self): + image = np.random.randint(0, 256, (3, 224, 224)) + + # Check the channel order is the same by default + resized_image = resize(image, (30, 40)) + self.assertIsInstance(resized_image, np.ndarray) + self.assertEqual(resized_image.shape, (3, 30, 40)) + + # Check channel order is changed if specified + resized_image = resize(image, (30, 40), data_format="channels_last") + self.assertIsInstance(resized_image, np.ndarray) + self.assertEqual(resized_image.shape, (30, 40, 3)) + + # Check PIL.Image.Image is return if return_numpy=False + resized_image = resize(image, (30, 40), return_numpy=False) + self.assertIsInstance(resized_image, PIL.Image.Image) + # PIL size is in (width, height) order + self.assertEqual(resized_image.size, (40, 30)) diff --git a/tests/utils/test_image_utils.py b/tests/utils/test_image_utils.py index 3c1be7102c1abc..0ae5d78fb2dc0a 100644 --- a/tests/utils/test_image_utils.py +++ b/tests/utils/test_image_utils.py @@ -17,8 +17,10 @@ import datasets import numpy as np +import pytest from transformers import is_torch_available, is_vision_available +from transformers.image_utils import ChannelDimension from transformers.testing_utils import require_torch, require_vision @@ -29,7 +31,7 @@ import PIL.Image from transformers import ImageFeatureExtractionMixin - from transformers.image_utils import load_image + from transformers.image_utils import get_image_size, infer_channel_dimension_format, load_image def get_random_image(height, width): @@ -485,3 +487,51 @@ def test_load_img_exif_transpose(self): img_arr_with_exif_transpose.shape, (500, 333, 3), ) + + +class UtilFunctionTester(unittest.TestCase): + def test_get_image_size(self): + # Test we can infer the size and channel dimension of an image. + image = np.random.randint(0, 256, (32, 64, 3)) + self.assertEqual(get_image_size(image), (32, 64)) + + image = np.random.randint(0, 256, (3, 32, 64)) + self.assertEqual(get_image_size(image), (32, 64)) + + # Test the channel dimension can be overriden + image = np.random.randint(0, 256, (3, 32, 64)) + self.assertEqual(get_image_size(image, channel_dim=ChannelDimension.LAST), (3, 32)) + + def test_infer_channel_dimension(self): + # Test we fail with invalid input + with pytest.raises(ValueError): + infer_channel_dimension_format(np.random.randint(0, 256, (10, 10))) + + with pytest.raises(ValueError): + infer_channel_dimension_format(np.random.randint(0, 256, (10, 10, 10, 10, 10))) + + # Test we fail if neither first not last dimension is of size 3 or 1 + with pytest.raises(ValueError): + infer_channel_dimension_format(np.random.randint(0, 256, (10, 1, 50))) + + # Test we correctly identify the channel dimension + image = np.random.randint(0, 256, (3, 4, 5)) + inferred_dim = infer_channel_dimension_format(image) + self.assertEqual(inferred_dim, ChannelDimension.FIRST) + + image = np.random.randint(0, 256, (1, 4, 5)) + inferred_dim = infer_channel_dimension_format(image) + self.assertEqual(inferred_dim, ChannelDimension.FIRST) + + image = np.random.randint(0, 256, (4, 5, 3)) + inferred_dim = infer_channel_dimension_format(image) + self.assertEqual(inferred_dim, ChannelDimension.LAST) + + image = np.random.randint(0, 256, (4, 5, 1)) + inferred_dim = infer_channel_dimension_format(image) + self.assertEqual(inferred_dim, ChannelDimension.LAST) + + # We can take a batched array of images and find the dimension + image = np.random.randint(0, 256, (1, 3, 4, 5)) + inferred_dim = infer_channel_dimension_format(image) + self.assertEqual(inferred_dim, ChannelDimension.FIRST) diff --git a/utils/tests_fetcher.py b/utils/tests_fetcher.py index f9e0f86af5bc91..06a20263092112 100644 --- a/utils/tests_fetcher.py +++ b/utils/tests_fetcher.py @@ -353,6 +353,7 @@ def create_reverse_dependency_map(): "feature_extraction_sequence_utils.py": "test_sequence_feature_extraction_common.py", "feature_extraction_utils.py": "test_feature_extraction_common.py", "file_utils.py": ["utils/test_file_utils.py", "utils/test_model_output.py"], + "image_transforms.py": "test_image_transforms.py", "utils/generic.py": ["utils/test_file_utils.py", "utils/test_model_output.py", "utils/test_generic.py"], "utils/hub.py": "utils/test_hub_utils.py", "modelcard.py": "utils/test_model_card.py",