diff --git a/retrieval/retrieval_data.py b/retrieval/retrieval_data.py index 4a8699c..ec57d4b 100644 --- a/retrieval/retrieval_data.py +++ b/retrieval/retrieval_data.py @@ -58,7 +58,6 @@ def decode(self, serialized_example): # Convert from a scalar string tensor to a float32 tensor with shape image_decoded = tf.image.decode_png(features['image/encoded'], channels=3) image = tf.image.convert_image_dtype(image_decoded, dtype=tf.float32) - image = tf.image.resize(image, [self.resize_h, self.resize_w]) # Convert label from a scalar uint8 tensor to an int32 scalar. label = tf.cast(features['image/class/label'], tf.int64) @@ -83,8 +82,12 @@ def eval(self, filename, image, label): def tencrop(self, filename, image, label): """Prepare one image for TenCrop """ - flip_mode = random.randint(0, 1) + # Resize the image to the height and width. + image = tf.expand_dims(image, 0) + image = tf.image.resize(image, [self.resize_h, self.resize_w]) + image = tf.squeeze(image, [0]) + flip_mode = random.randint(0, 1) images = [] for i in range(5): image = tf.random_crop(image, [RANDOM_CROP_SIZE, RANDOM_CROP_SIZE, 3]) diff --git a/val_data.py b/val_data.py index 4a8699c..ec57d4b 100644 --- a/val_data.py +++ b/val_data.py @@ -58,7 +58,6 @@ def decode(self, serialized_example): # Convert from a scalar string tensor to a float32 tensor with shape image_decoded = tf.image.decode_png(features['image/encoded'], channels=3) image = tf.image.convert_image_dtype(image_decoded, dtype=tf.float32) - image = tf.image.resize(image, [self.resize_h, self.resize_w]) # Convert label from a scalar uint8 tensor to an int32 scalar. label = tf.cast(features['image/class/label'], tf.int64) @@ -83,8 +82,12 @@ def eval(self, filename, image, label): def tencrop(self, filename, image, label): """Prepare one image for TenCrop """ - flip_mode = random.randint(0, 1) + # Resize the image to the height and width. + image = tf.expand_dims(image, 0) + image = tf.image.resize(image, [self.resize_h, self.resize_w]) + image = tf.squeeze(image, [0]) + flip_mode = random.randint(0, 1) images = [] for i in range(5): image = tf.random_crop(image, [RANDOM_CROP_SIZE, RANDOM_CROP_SIZE, 3])