From 35b590bce52ec2c9374d5f4f53efff3dca594e6f Mon Sep 17 00:00:00 2001 From: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Date: Mon, 29 Aug 2022 17:48:24 +0100 Subject: [PATCH] Revert to and safely handle flag in owlvit config (#18750) --- src/transformers/image_utils.py | 8 ++++---- .../models/owlvit/feature_extraction_owlvit.py | 7 +++++++ 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/src/transformers/image_utils.py b/src/transformers/image_utils.py index 120d7b3c1bd26c..437e7c5685586b 100644 --- a/src/transformers/image_utils.py +++ b/src/transformers/image_utils.py @@ -131,7 +131,7 @@ def convert_rgb(self, image): return image.convert("RGB") - def rescale_image(self, image: np.ndarray, scale: Union[float, int]) -> np.ndarray: + def rescale(self, image: np.ndarray, scale: Union[float, int]) -> np.ndarray: """ Rescale a numpy image by scale amount """ @@ -163,7 +163,7 @@ def to_numpy_array(self, image, rescale=None, channel_first=True): rescale = isinstance(image.flat[0], np.integer) if rescale is None else rescale if rescale: - image = self.rescale_image(image.astype(np.float32), 1 / 255.0) + image = self.rescale(image.astype(np.float32), 1 / 255.0) if channel_first and image.ndim == 3: image = image.transpose(2, 0, 1) @@ -214,9 +214,9 @@ def normalize(self, image, mean, std, rescale=False): # type it may need rescaling. elif rescale: if isinstance(image, np.ndarray): - image = self.rescale_image(image.astype(np.float32), 1 / 255.0) + image = self.rescale(image.astype(np.float32), 1 / 255.0) elif is_torch_tensor(image): - image = self.rescale_image(image.float(), 1 / 255.0) + image = self.rescale(image.float(), 1 / 255.0) if isinstance(image, np.ndarray): if not isinstance(mean, np.ndarray): diff --git a/src/transformers/models/owlvit/feature_extraction_owlvit.py b/src/transformers/models/owlvit/feature_extraction_owlvit.py index f8a45706835d8f..0af33eccaef044 100644 --- a/src/transformers/models/owlvit/feature_extraction_owlvit.py +++ b/src/transformers/models/owlvit/feature_extraction_owlvit.py @@ -85,6 +85,13 @@ def __init__( image_std=None, **kwargs ): + # Early versions of the OWL-ViT config on the hub had "rescale" as a flag. This clashes with the + # vision feature extractor method `rescale` as it would be set as an attribute during the super().__init__ + # call. This is for backwards compatibility. + if "rescale" in kwargs: + rescale_val = kwargs.pop("rescale") + kwargs["do_rescale"] = rescale_val + super().__init__(**kwargs) self.size = size self.resample = resample