Skip to content

Commit

Permalink
Add do_thumbnail for consistency
Browse files Browse the repository at this point in the history
  • Loading branch information
Niels Rogge authored and Niels Rogge committed Aug 11, 2022
1 parent 28e6cc4 commit 8a000ac
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 8 deletions.
22 changes: 14 additions & 8 deletions src/transformers/models/donut/feature_extraction_donut.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,16 @@ class DonutFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin)
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the shorter edge of the input to the minimum value of a certain `size`, and thumbnail the
input to the given `size`.
Whether to resize the shorter edge of the input to the minimum value of a certain `size`.
size (`Tuple(int)`, *optional*, defaults to [1920, 2560]):
Resize the shorter edge of the input to the minimum value of the given size. Should be a tuple of (width,
height). Only has an effect if `do_resize` is set to `True`.
resample (`int`, *optional*, defaults to `PIL.Image.BILINEAR`):
An optional resampling filter. This can be one of `PIL.Image.NEAREST`, `PIL.Image.BOX`,
`PIL.Image.BILINEAR`, `PIL.Image.HAMMING`, `PIL.Image.BICUBIC` or `PIL.Image.LANCZOS`. Only has an effect
if `do_resize` is set to `True`.
do_thumbnail (`bool`, *optional*, defaults to `True`):
Whether to thumbnail the input to the given `size`.
do_align_long_axis (`bool`, *optional*, defaults to `False`):
Whether to rotate the input if the height is greater than width.
do_pad (`bool`, *optional*, defaults to `True`):
Expand All @@ -71,6 +72,7 @@ def __init__(
do_resize=True,
size=[1920, 2560],
resample=Image.BILINEAR,
do_thumbnail=True,
do_align_long_axis=False,
do_pad=True,
do_normalize=True,
Expand All @@ -82,6 +84,7 @@ def __init__(
self.do_resize = do_resize
self.size = size
self.resample = resample
self.do_thumbnail = do_thumbnail
self.do_align_long_axis = do_align_long_axis
self.do_pad = do_pad
self.do_normalize = do_normalize
Expand All @@ -97,10 +100,10 @@ def rotate_image(self, image, size):

return image

def resize_and_thumbnail(self, image, size, resample):
# 1. resize the shorter edge of the image to `min(size)`
image = self.resize(image, size=min(size), resample=resample, default_to_square=False)
# 2. create a thumbnail
def thumbnail(self, image, size):
if not isinstance(image, Image.Image):
image = self.to_pil_image(image)

image.thumbnail((size[0], size[1]))

return image
Expand Down Expand Up @@ -183,13 +186,16 @@ def __call__(
if not is_batched:
images = [images]

# transformations (rotating + resizing + padding + normalization)
# transformations (rotating + resizing + thumbnailing + padding + normalization)
if self.do_align_long_axis:
images = [self.rotate_image(image, self.size) for image in images]
if self.do_resize and self.size is not None:
images = [
self.resize_and_thumbnail(image=image, size=self.size, resample=self.resample) for image in images
self.resize(image=image, size=min(self.size), resample=self.resample, default_to_square=False)
for image in images
]
if self.do_thumbnail and self.size is not None:
images = [self.thumbnail(image=image, size=self.size) for image in images]
if self.do_pad and self.size is not None:
images = [self.pad(image=image, size=self.size, random_padding=random_padding) for image in images]
if self.do_normalize:
Expand Down
4 changes: 4 additions & 0 deletions tests/models/donut/test_feature_extraction_donut.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(
max_resolution=400,
do_resize=True,
size=[20, 18],
do_thumbnail=True,
do_align_axis=False,
do_pad=True,
do_normalize=True,
Expand All @@ -58,6 +59,7 @@ def __init__(
self.max_resolution = max_resolution
self.do_resize = do_resize
self.size = size
self.do_thumbnail = do_thumbnail
self.do_align_axis = do_align_axis
self.do_pad = do_pad
self.do_normalize = do_normalize
Expand All @@ -68,6 +70,7 @@ def prepare_feat_extract_dict(self):
return {
"do_resize": self.do_resize,
"size": self.size,
"do_thumbnail": self.do_thumbnail,
"do_align_long_axis": self.do_align_axis,
"do_pad": self.do_pad,
"do_normalize": self.do_normalize,
Expand All @@ -93,6 +96,7 @@ def test_feat_extract_properties(self):
feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
self.assertTrue(hasattr(feature_extractor, "do_resize"))
self.assertTrue(hasattr(feature_extractor, "size"))
self.assertTrue(hasattr(feature_extractor, "do_thumbnail"))
self.assertTrue(hasattr(feature_extractor, "do_align_long_axis"))
self.assertTrue(hasattr(feature_extractor, "do_pad"))
self.assertTrue(hasattr(feature_extractor, "do_normalize"))
Expand Down

0 comments on commit 8a000ac

Please sign in to comment.