Skip to content

Commit

Permalink
Merge branch 'unit-tests-for-channel-sparse-convolution' into 'dev'
Browse files Browse the repository at this point in the history
Added shape tests for channel sparse convolution

See merge request CMIC/NiftyNet!211
  • Loading branch information
eligibson committed Feb 21, 2018
2 parents b2c7fcd + 368f759 commit 1f740b5
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 2 deletions.
11 changes: 9 additions & 2 deletions niftynet/layer/channel_sparse_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,17 @@ def layer_op(self,input_tensor,input_mask,output_mask):
'w', shape=w_full_size,
initializer=self.initializers['w'],
regularizer=self.regularizers['w'])
assert(spatial_rank in [2,3])
if spatial_rank==2:
transpositions = [[3,2,1,0],[1,0,2,3],[3,2,0,1]]
else:
transpositions = [[4,3,2,1,0],[1,0,2,3,4],[4,3,2,0,1]]

sparse_kernel = tf.transpose(tf.boolean_mask(
tf.transpose(tf.boolean_mask(
tf.transpose(conv_kernel,[4,3,2,1,0]),
_output_mask),[1,0,2,3,4]),_input_mask),[4,3,2,0,1])
tf.transpose(conv_kernel,transpositions[0]),
_output_mask),transpositions[1]),_input_mask),
transpositions[2])
output_tensor = tf.nn.convolution(input=input_tensor,
filter=sparse_kernel,
strides=full_stride,
Expand Down
73 changes: 73 additions & 0 deletions tests/channel_sparse_convolution_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from __future__ import absolute_import, print_function

import tensorflow as tf
import numpy as np

from niftynet.layer.channel_sparse_convolution import ChannelSparseConvolutionalLayer

class ChannelSparseConvolutionalLayerTest(tf.test.TestCase):
def test_3d_shape(self):
x = tf.random_normal(shape=[2,4,5,6,4])
conv1 = ChannelSparseConvolutionalLayer(4)
conv2 = ChannelSparseConvolutionalLayer(8,kernel_size=[1,1,3])
conv3 = ChannelSparseConvolutionalLayer(4, acti_func='relu')
conv4 = ChannelSparseConvolutionalLayer(8, with_bn=False)
conv5 = ChannelSparseConvolutionalLayer(4, with_bias=True)
x1, mask1=conv1(x, None, True, 1.)
x2, mask2=conv2(x1, mask1, True, 1.)
x3, mask3=conv3(x2, mask2, True, .5)
x4, mask4=conv4(x3, mask3, True, .75)
x5, mask5=conv5(x4, mask4, True, 1.)

with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
out1, out2, out3, out4, out5 = sess.run([x1,x2,x3,x4,x5])
self.assertAllClose([2,4,5,6,4], out1.shape)
self.assertAllClose([2,4,5,6,8], out2.shape)
self.assertAllClose([2,4,5,6,2], out3.shape)
self.assertAllClose([2,4,5,6,6], out4.shape)
self.assertAllClose([2,4,5,6,4], out5.shape)

def test_2d_shape(self):
x = tf.random_normal(shape=[2,4,5,4])
conv1 = ChannelSparseConvolutionalLayer(4)
conv2 = ChannelSparseConvolutionalLayer(8,kernel_size=[1,1,3])
conv3 = ChannelSparseConvolutionalLayer(4, acti_func='relu')
conv4 = ChannelSparseConvolutionalLayer(8, with_bn=False)
conv5 = ChannelSparseConvolutionalLayer(4, with_bias=True)
x1, mask1=conv1(x, None, True, 1.)
x2, mask2=conv2(x1, mask1, True, 1.)
x3, mask3=conv3(x2, mask2, True, .5)
x4, mask4=conv4(x3, mask3, True, .75)
x5, mask5=conv5(x4, mask4, True, 1.)

with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
out1, out2, out3, out4, out5 = sess.run([x1,x2,x3,x4,x5])
self.assertAllClose([2,4,5,4], out1.shape)
self.assertAllClose([2,4,5,8], out2.shape)
self.assertAllClose([2,4,5,2], out3.shape)
self.assertAllClose([2,4,5,6], out4.shape)
self.assertAllClose([2,4,5,4], out5.shape)

def test_masks(self):
x = tf.random_normal(shape=[2,4,5,4])
conv1 = ChannelSparseConvolutionalLayer(10)
conv2 = ChannelSparseConvolutionalLayer(10)
conv3 = ChannelSparseConvolutionalLayer(10)
conv4 = ChannelSparseConvolutionalLayer(10)
x1, mask1=conv1(x, None, True, 1.)
x2, mask2=conv2(x1, mask1, True, .5)
x3, mask3=conv3(x2, mask2, True, .2)
x4, mask4=conv4(x3, mask3, True, 1.)

with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
out1, out2, out3, out4 = sess.run([mask1, mask2, mask3, mask4])
self.assertAllClose([10, 5, 2, 10], [np.sum(out1),
np.sum(out2),
np.sum(out3),
np.sum(out4)])

if __name__ == "__main__":
tf.test.main()

0 comments on commit 1f740b5

Please sign in to comment.