Skip to content

Commit

Permalink
Update pad.py to include reflective padding
Browse files Browse the repository at this point in the history
  • Loading branch information
lmanan authored and pattonw committed Dec 19, 2023
1 parent b6c425f commit a7503d7
Showing 1 changed file with 35 additions and 8 deletions.
43 changes: 35 additions & 8 deletions gunpowder/nodes/pad.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,22 @@ class Pad(BatchFilter):
a coordinate, this amount will be added to the ROI in the positive
and negative direction.
mode (string):
One of 'constant' or 'reflect'.
Default is 'constant'
value (scalar or ``None``):
The value to report inside the padding. If not given, 0 is used.
Only used in case of 'constant' mode.
Only used for :class:`Array<Arrays>`.
"""

def __init__(self, key, size, value=None):
def __init__(self, key, size, mode="constant", value=None):
self.key = key
self.size = size
self.mode = mode
self.value = value

def setup(self):
Expand Down Expand Up @@ -119,18 +126,38 @@ def __expand(self, a, from_roi, to_roi, value):

num_channels = len(a.shape) - from_roi.dims
channel_shapes = a.shape[:num_channels]

b = np.zeros(channel_shapes + to_roi.shape, dtype=a.dtype)
if value != 0:
b[:] = value

shift = -to_roi.offset
if self.mode == "constant":
if value != 0:
b[:] = value
elif self.mode == "reflect":
if a.shape == b.shape:
pass # handled later
else:
diff = Coordinate(b.shape) - Coordinate(a.shape)
if diff.ndim == 3: # (C Y X)
b[:, : diff[1], diff[2] :] = a[:, : diff[1], :][:, ::-1, :] # Y
b[:, diff[1] :, : diff[2]] = a[:, :, : diff[2]][:, :, ::-1] # X
b[:, : diff[1], : diff[2]] = a[:, : diff[1], : diff[2]][
:, ::-1, ::-1
]
elif diff.ndim == 4: # (C Z Y X)
b[:, : diff[1], diff[2] :, diff[3] :] = a[:, : diff[1], :, :][
:, ::-1, :, :
] # Z
b[:, diff[1] :, : diff[2], diff[3] :] = a[:, :, : diff[2], :][
:, :, ::-1, :
] # Y
b[:, diff[1] :, diff[2] :, : diff[3]] = a[:, :, :, : diff[3]][
:, :, :, ::-1
] # X
b[:, : diff[1], : diff[2], : diff[3]] = a[
:, : diff[1], : diff[2], : diff[3]
][:, ::-1, ::-1, ::-1]
logger.debug("shifting 'from' by " + str(shift))
a_in_b = from_roi.shift(shift).to_slices()

logger.debug("target shape is " + str(b.shape))
logger.debug("target slice is " + str(a_in_b))

b[(slice(None),) * num_channels + a_in_b] = a

return b

0 comments on commit a7503d7

Please sign in to comment.