diff --git a/niftynet/network/resnet.py b/niftynet/network/resnet.py index c2cf5996..8133f259 100755 --- a/niftynet/network/resnet.py +++ b/niftynet/network/resnet.py @@ -7,16 +7,15 @@ import tensorflow as tf from niftynet.layer import layer_util -from niftynet.layer.bn import BNLayer -from niftynet.layer.fully_connected import FCLayer from niftynet.layer.base_layer import TrainableLayer +from niftynet.layer.bn import BNLayer from niftynet.layer.convolution import ConvolutionalLayer -from niftynet.layer.deconvolution import DeconvLayer -from niftynet.layer.elementwise import ElementwiseLayer +from niftynet.layer.fully_connected import FCLayer from niftynet.network.base_net import BaseNet -from niftynet.utilities.util_common import look_up_operations ResNetDesc = namedtuple('ResNetDesc', ['bn', 'fc', 'conv1', 'blocks']) + + class ResNet(BaseNet): """ ### Description @@ -37,11 +36,10 @@ class ResNet(BaseNet): ### Constraints """ - def __init__(self, num_classes, - n_features = [16, 64, 128, 256], - n_blocks_per_resolution = 10, + n_features=[16, 64, 128, 256], + n_blocks_per_resolution=10, w_initializer=None, w_regularizer=None, b_initializer=None, @@ -61,14 +59,13 @@ def __init__(self, :param name: layer name """ - super(ResNet, self).__init__( - num_classes=num_classes, - w_initializer=w_initializer, - w_regularizer=w_regularizer, - b_initializer=b_initializer, - b_regularizer=b_regularizer, - acti_func=acti_func, - name=name) + super(ResNet, self).__init__(num_classes=num_classes, + w_initializer=w_initializer, + w_regularizer=w_regularizer, + b_initializer=b_initializer, + b_regularizer=b_regularizer, + acti_func=acti_func, + name=name) self.n_features = n_features self.n_blocks_per_resolution = n_blocks_per_resolution @@ -85,14 +82,21 @@ def create(self): :return: tuple with batch norm layer, fully connected layer, first conv layer and all residual blocks """ - bn=BNLayer() - fc=FCLayer(self.num_classes) - conv1=self.Conv(self.n_features[0], acti_func=None, feature_normalization=None) - blocks=[] - blocks+=[DownResBlock(self.n_features[1], self.n_blocks_per_resolution, 1, self.Conv)] + bn = BNLayer() + fc = FCLayer(self.num_classes) + conv1 = self.Conv(self.n_features[0], + acti_func=None, + feature_normalization=None) + blocks = [] + blocks += [ + DownResBlock(self.n_features[1], self.n_blocks_per_resolution, 1, + self.Conv) + ] for n in self.n_features[2:]: - blocks+=[DownResBlock(n, self.n_blocks_per_resolution, 2, self.Conv)] - return ResNetDesc(bn=bn,fc=fc,conv1=conv1,blocks=blocks) + blocks += [ + DownResBlock(n, self.n_blocks_per_resolution, 2, self.Conv) + ] + return ResNetDesc(bn=bn, fc=fc, conv1=conv1, blocks=blocks) def layer_op(self, images, is_training=True, **unused_kwargs): """ @@ -106,13 +110,19 @@ def layer_op(self, images, is_training=True, **unused_kwargs): out = layers.conv1(images, is_training) for block in layers.blocks: out = block(out, is_training) - out = tf.reduce_mean(tf.nn.relu(layers.bn(out, is_training)),axis=[1,2,3]) + + spatial_rank = layer_util.infer_spatial_rank(out) + axis_to_avg = [dim + 1 for dim in range(spatial_rank)] + out = tf.reduce_mean(tf.nn.relu(layers.bn(out, is_training)), + axis=axis_to_avg) return layers.fc(out) - BottleneckBlockDesc1 = namedtuple('BottleneckBlockDesc1', ['conv']) -BottleneckBlockDesc2 = namedtuple('BottleneckBlockDesc2', ['common_bn', 'conv', 'conv_shortcut']) +BottleneckBlockDesc2 = namedtuple('BottleneckBlockDesc2', + ['common_bn', 'conv', 'conv_shortcut']) + + class BottleneckBlock(TrainableLayer): def __init__(self, n_output_chns, stride, Conv, name='bottleneck'): """ @@ -123,11 +133,11 @@ def __init__(self, n_output_chns, stride, Conv, name='bottleneck'): :param name: layer name """ self.n_output_chns = n_output_chns - self.stride=stride + self.stride = stride self.bottle_neck_chns = n_output_chns // 4 self.Conv = Conv super(BottleneckBlock, self).__init__(name=name) - + def create(self, input_chns): """ @@ -135,21 +145,29 @@ def create(self, input_chns): :return: tuple, with series of convolutional layers """ if self.n_output_chns == input_chns: - b1 = self.Conv(self.bottle_neck_chns, kernel_size=1, + b1 = self.Conv(self.bottle_neck_chns, + kernel_size=1, stride=self.stride) b2 = self.Conv(self.bottle_neck_chns, kernel_size=3) b3 = self.Conv(self.n_output_chns, 1) return BottleneckBlockDesc1(conv=[b1, b2, b3]) else: b1 = BNLayer() - b2 = self.Conv(self.bottle_neck_chns,kernel_size=1, - stride=self.stride, acti_func=None, feature_normalization=None) - b3 = self.Conv(self.bottle_neck_chns,kernel_size=3) - b4 = self.Conv(self.n_output_chns,kernel_size=1) - b5 = self.Conv(self.n_output_chns,kernel_size=1, - stride=self.stride, acti_func=None,feature_normalization=None) - return BottleneckBlockDesc2(common_bn=b1, conv=[b2, b3, b4], - conv_shortcut=b5) + b2 = self.Conv(self.bottle_neck_chns, + kernel_size=1, + stride=self.stride, + acti_func=None, + feature_normalization=None) + b3 = self.Conv(self.bottle_neck_chns, kernel_size=3) + b4 = self.Conv(self.n_output_chns, kernel_size=1) + b5 = self.Conv(self.n_output_chns, + kernel_size=1, + stride=self.stride, + acti_func=None, + feature_normalization=None) + return BottleneckBlockDesc2(common_bn=b1, + conv=[b2, b3, b4], + conv_shortcut=b5) def layer_op(self, images, is_training=True): """ @@ -160,20 +178,23 @@ def layer_op(self, images, is_training=True): """ layers = self.create(images.shape[-1]) if self.n_output_chns == images.shape[-1]: - out=layers.conv[0](images, is_training) - out=layers.conv[1](out, is_training) - out=layers.conv[2](out, is_training) - out = out+images + out = layers.conv[0](images, is_training) + out = layers.conv[1](out, is_training) + out = layers.conv[2](out, is_training) + out = out + images else: tmp = tf.nn.relu(layers.common_bn(images, is_training)) - out=layers.conv[0](tmp, is_training) - out=layers.conv[1](out, is_training) - out=layers.conv[2](out, is_training) + out = layers.conv[0](tmp, is_training) + out = layers.conv[1](out, is_training) + out = layers.conv[2](out, is_training) out = layers.conv_shortcut(tmp, is_training) + out print(out.shape) return out + DownResBlockDesc = namedtuple('DownResBlockDesc', ['blocks']) + + class DownResBlock(TrainableLayer): def __init__(self, n_output_chns, count, stride, Conv, name='downres'): """ @@ -187,20 +208,20 @@ def __init__(self, n_output_chns, count, stride, Conv, name='downres'): self.count = count self.stride = stride self.n_output_chns = n_output_chns - self.Conv=Conv + self.Conv = Conv super(DownResBlock, self).__init__(name=name) - + def create(self): """ :return: tuple, containing all the Bottleneck blocks composing the DownRes block """ - blocks=[] - blocks+=[BottleneckBlock(self.n_output_chns, self.stride, self.Conv)] - for it in range(1,self.count): - blocks+=[BottleneckBlock(self.n_output_chns, 1, self.Conv)] + blocks = [] + blocks += [BottleneckBlock(self.n_output_chns, self.stride, self.Conv)] + for it in range(1, self.count): + blocks += [BottleneckBlock(self.n_output_chns, 1, self.Conv)] return DownResBlockDesc(blocks=blocks) - + def layer_op(self, images, is_training): """ @@ -211,6 +232,5 @@ def layer_op(self, images, is_training): layers = self.create() out = images for l in layers.blocks: - out=l(out,is_training) + out = l(out, is_training) return out - diff --git a/tests/resnet_test.py b/tests/resnet_test.py new file mode 100755 index 00000000..7cbc960f --- /dev/null +++ b/tests/resnet_test.py @@ -0,0 +1,68 @@ +from __future__ import absolute_import, print_function + +import unittest + +import tensorflow as tf +from tensorflow.contrib.layers.python.layers import regularizers + +from niftynet.network.resnet import ResNet +from tests.niftynet_testcase import NiftyNetTestCase + +class ResNet3DTest(NiftyNetTestCase): + def test_3d_shape(self): + input_shape = (2, 8, 16, 32, 1) + x = tf.ones(input_shape) + + resnet_instance = ResNet(num_classes=160) + out = resnet_instance(x, is_training=True) + print(resnet_instance.num_trainable_params()) + + with self.cached_session() as sess: + sess.run(tf.global_variables_initializer()) + out = sess.run(out) + self.assertAllClose((2, 160), out.shape) + + def test_2d_shape(self): + input_shape = (2, 8, 16, 1) + x = tf.ones(input_shape) + + resnet_instance = ResNet(num_classes=160) + out = resnet_instance(x, is_training=True) + print(resnet_instance.num_trainable_params()) + + with self.cached_session() as sess: + sess.run(tf.global_variables_initializer()) + out = sess.run(out) + self.assertAllClose((2, 160), out.shape) + + def test_3d_reg_shape(self): + input_shape = (2, 8, 16, 24, 1) + x = tf.ones(input_shape) + + resnet_instance = ResNet(num_classes=160, + w_regularizer=regularizers.l2_regularizer(0.4)) + out = resnet_instance(x, is_training=True) + print(resnet_instance.num_trainable_params()) + + with self.cached_session() as sess: + sess.run(tf.global_variables_initializer()) + out = sess.run(out) + self.assertAllClose((2, 160), out.shape) + + def test_2d_reg_shape(self): + input_shape = (2, 8, 16, 1) + x = tf.ones(input_shape) + + resnet_instance = ResNet(num_classes=160, + w_regularizer=regularizers.l2_regularizer(0.4)) + out = resnet_instance(x, is_training=True) + print(resnet_instance.num_trainable_params()) + + with self.cached_session() as sess: + sess.run(tf.global_variables_initializer()) + out = sess.run(out) + self.assertAllClose((2, 160), out.shape) + + +if __name__ == "__main__": + tf.test.main()