diff --git a/gunpowder/nodes/pad.py b/gunpowder/nodes/pad.py index 6bbfdc58..14bc644a 100644 --- a/gunpowder/nodes/pad.py +++ b/gunpowder/nodes/pad.py @@ -27,15 +27,22 @@ class Pad(BatchFilter): a coordinate, this amount will be added to the ROI in the positive and negative direction. + mode (string): + + One of 'constant' or 'reflect'. + Default is 'constant' + value (scalar or ``None``): The value to report inside the padding. If not given, 0 is used. + Only used in case of 'constant' mode. Only used for :class:`Array`. """ - def __init__(self, key, size, value=None): + def __init__(self, key, size, mode="constant", value=None): self.key = key self.size = size + self.mode = mode self.value = value def setup(self): @@ -119,18 +126,38 @@ def __expand(self, a, from_roi, to_roi, value): num_channels = len(a.shape) - from_roi.dims channel_shapes = a.shape[:num_channels] - b = np.zeros(channel_shapes + to_roi.shape, dtype=a.dtype) - if value != 0: - b[:] = value - shift = -to_roi.offset + if self.mode == "constant": + if value != 0: + b[:] = value + elif self.mode == "reflect": + if a.shape == b.shape: + pass # handled later + else: + diff = Coordinate(b.shape) - Coordinate(a.shape) + if diff.ndim == 3: # (C Y X) + b[:, : diff[1], diff[2] :] = a[:, : diff[1], :][:, ::-1, :] # Y + b[:, diff[1] :, : diff[2]] = a[:, :, : diff[2]][:, :, ::-1] # X + b[:, : diff[1], : diff[2]] = a[:, : diff[1], : diff[2]][ + :, ::-1, ::-1 + ] + elif diff.ndim == 4: # (C Z Y X) + b[:, : diff[1], diff[2] :, diff[3] :] = a[:, : diff[1], :, :][ + :, ::-1, :, : + ] # Z + b[:, diff[1] :, : diff[2], diff[3] :] = a[:, :, : diff[2], :][ + :, :, ::-1, : + ] # Y + b[:, diff[1] :, diff[2] :, : diff[3]] = a[:, :, :, : diff[3]][ + :, :, :, ::-1 + ] # X + b[:, : diff[1], : diff[2], : diff[3]] = a[ + :, : diff[1], : diff[2], : diff[3] + ][:, ::-1, ::-1, ::-1] logger.debug("shifting 'from' by " + str(shift)) a_in_b = from_roi.shift(shift).to_slices() - logger.debug("target shape is " + str(b.shape)) logger.debug("target slice is " + str(a_in_b)) - b[(slice(None),) * num_channels + a_in_b] = a - return b