diff --git a/imgaug/augmenters/blur.py b/imgaug/augmenters/blur.py index afaffa88b..e48922e6b 100644 --- a/imgaug/augmenters/blur.py +++ b/imgaug/augmenters/blur.py @@ -295,7 +295,12 @@ def _augment_images(self, images, random_state, parents, hooks): ki = samples[i] if ki > 1: ki = ki + 1 if ki % 2 == 0 else ki - result[i] = cv2.medianBlur(result[i], ki) + image_aug = cv2.medianBlur(result[i], ki) + # cv2.medianBlur() removes channel axis for single-channel + # images + if image_aug.ndim == 2: + image_aug = image_aug[..., np.newaxis] + result[i] = image_aug return result def _augment_keypoints(self, keypoints_on_images, random_state, parents, hooks):