Skip to content

Commit

Permalink
Added class SpatialSELayer
Browse files Browse the repository at this point in the history
  • Loading branch information
pritesh-mehta committed Aug 31, 2018
1 parent 7ad4388 commit 3ac6fb2
Showing 1 changed file with 25 additions and 3 deletions.
28 changes: 25 additions & 3 deletions niftynet/layer/squeeze_excitation_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from niftynet.layer.base_layer import Layer
from niftynet.layer.fully_connected import FullyConnectedLayer
from niftynet.layer.convolution import ConvolutionalLayer
from niftynet.utilities.util_common import look_up_operations

SUPPORTED_OP = set(['AVG', 'MAX'])
Expand All @@ -14,7 +15,7 @@ class ChannelSELayer(Layer):
def __init__(self,
func='AVG',
reduction_ratio=16,
name='squeeze_excitation'):
name='channel_squeeze_excitation'):
self.func = func.upper()
self.reduction_ratio = reduction_ratio
self.layer_name = '{}_{}'.format(self.func.lower(), name)
Expand All @@ -23,7 +24,7 @@ def __init__(self,
look_up_operations(self.func, SUPPORTED_OP)

def layer_op(self, input_tensor):
# squeeze: global information embedding
# spatial squeeze
input_rank = len(input_tensor.shape)
reduce_indices = list(range(input_rank))[1:-1]
if self.func == 'AVG':
Expand All @@ -33,7 +34,7 @@ def layer_op(self, input_tensor):
else:
raise NotImplementedError("pooling function not supported")

# excitation: adaptive recalibration
# channel excitation
num_channels = int(squeeze_tensor.shape[-1])
reduction_ratio = self.reduction_ratio
if num_channels % reduction_ratio != 0:
Expand Down Expand Up @@ -62,3 +63,24 @@ def layer_op(self, input_tensor):
output_tensor = tf.multiply(input_tensor, fc_out_2)

return output_tensor

class SpatialSELayer(Layer):
def __init__(self,
name='spatial_squeeze_excitation'):
self.layer_name = '{}_{}'.format(self.func.lower(), name)
super(SpatialSELayer, self).__init__(name=self.layer_name)

def layer_op(self, input_tensor):
# channel squeeze
conv = ConvolutionalLayer(n_output_chns=1,
kernel_size=1,
with_bn=False,
acti_func='sigmoid',
name="se_conv")

squeeze_tensor = conv(input_tensor)

# spatial excitation
output_tensor = tf.multiply(input_tensor, squeeze_tensor)

return output_tensor

0 comments on commit 3ac6fb2

Please sign in to comment.