Skip to content

Commit

Permalink
Merge branch 'fix_densevnet' into 'dev'
Browse files Browse the repository at this point in the history
Fix densevnet

See merge request CMIC/NiftyNet!200
  • Loading branch information
eligibson committed Feb 8, 2018
2 parents 1382547 + 67d2f92 commit e429b23
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 58 deletions.
146 changes: 88 additions & 58 deletions niftynet/network/dense_vnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
from __future__ import absolute_import, print_function

from collections import namedtuple
import abc

import tensorflow as tf

from niftynet.layer import layer_util
from niftynet.layer.base_layer import TrainableLayer
from niftynet.layer.convolution import ConvolutionalLayer
from niftynet.layer.channel_sparse_convolution import ChannelSparseConvolutionalLayer
from niftynet.layer.channel_sparse_convolution \
import ChannelSparseConvolutionalLayer
from niftynet.layer.bn import BNLayer
from niftynet.layer.spatial_transformer import ResamplerLayer
from niftynet.layer.grid_warper import AffineGridWarperLayer
Expand Down Expand Up @@ -116,27 +118,27 @@ def __init__(self,
'Image coordinate augmentation is not yet implemented')

def create_network(self):
hp = self.hyperparameters
hyper = self.hyperparameters

# Initial Convolution
net_initial_conv = ConvolutionalLayer(
hp['n_input_channels'][0],
hyper['n_input_channels'][0],
kernel_size=5, stride=2
)

# Dense Block Params
downsample_channels = list(hp['n_input_channels'][1:]) + [None]
num_blocks = len(hp["n_dense_channels"])
downsample_channels = list(hyper['n_input_channels'][1:]) + [None]
num_blocks = len(hyper["n_dense_channels"])
use_bdo = self.architecture_parameters['use_bdo']

# Create DenseBlocks
net_dense_vblocks = []

for idx in range(num_blocks):
dense_ch = hp["n_dense_channels"][idx] # Number or dense channels
seg_ch = hp["n_seg_channels"][idx] # Number of segmentation ch
down_ch = downsample_channels[idx] # Number of downsampling ch
dil_rate = hp["dilation_rates"][idx] # Dilation rate
dense_ch = hyper["n_dense_channels"][idx] # Num dense channels
seg_ch = hyper["n_seg_channels"][idx] # Num segmentation ch
down_ch = downsample_channels[idx] # Num of downsampling ch
dil_rate = hyper["dilation_rates"][idx] # Dilation rate

# Dense feature block
dblock = DenseFeatureStackBlockWithSkipAndDownsample(
Expand All @@ -148,7 +150,7 @@ def create_network(self):

# Segmentation
net_seg_layer = ConvolutionalLayer(
self.num_classes, kernel_size=hp['final_kernel'],
self.num_classes, kernel_size=hyper['final_kernel'],
with_bn=False, with_bias=True
)

Expand All @@ -163,10 +165,16 @@ def downsample_input(self, input_tensor, n_spatial_dims):
d_size2 = (1,) + (2,) * n_spatial_dims + (1,)

# Downsample input
return tf.nn.avg_pool(input_tensor, d_size1, d_size2, 'SAME')
if n_spatial_dims == 2:
return tf.nn.avg_pool(input_tensor, d_size1, d_size2, 'SAME')
elif n_spatial_dims == 3:
return tf.nn.avg_pool3d(input_tensor, d_size1, d_size2, 'SAME')
else:
raise NotImplementedError(
'Downsampling only supports 2D and 3D images')

def layer_op(self, input_tensor, is_training, layer_id=-1):
hp = self.hyperparameters
hyper = self.hyperparameters

# Initialize DenseVNet network layers
net = self.create_network()
Expand All @@ -182,10 +190,10 @@ def layer_op(self, input_tensor, is_training, layer_id=-1):
n_spatial_dims = input_tensor.shape.ndims - 2

# Quick access to hyperparams
pkeep = hp['p_channels_selected']
pkeep = hyper['p_channels_selected']

# Validate input dimension with dilation rates
modulo = 2 ** (len(hp['dilation_rates']))
modulo = 2 ** (len(hyper['dilation_rates']))
assert layer_util.check_spatial_dims(input_tensor,
lambda x: x % modulo == 0)

Expand All @@ -194,16 +202,16 @@ def layer_op(self, input_tensor, is_training, layer_id=-1):
#

# On the fly data augmentation
if is_training and hp['augmentation_scale'] > 0:
if is_training and hyper['augmentation_scale'] > 0:
if n_spatial_dims == 2:
augmentation_class = Affine2DAugmentationLayer
elif n_spatial_dims == 3
elif n_spatial_dims == 3:
augmentation_class = Affine3DAugmentationLayer
else:
raise NotImplementedError(
'Affine augmentation only supports 2D and 3D images')

augment_layer = augmentation_class(hp['augmentation_scale'],
augment_layer = augmentation_class(hyper['augmentation_scale'],
'LINEAR', 'ZERO')
input_tensor = augment_layer(input_tensor)

Expand Down Expand Up @@ -255,7 +263,7 @@ def layer_op(self, input_tensor, is_training, layer_id=-1):
seg_output += xyz_prior

# Invert augmentation if any
if is_training and hp['augmentation_scale'] > 0:
if is_training and hyper['augmentation_scale'] > 0:
inverse_aug = augment_layer.inverse()
seg_output = inverse_aug(seg_output)

Expand All @@ -267,13 +275,22 @@ def layer_op(self, input_tensor, is_training, layer_id=-1):
seg_summary = seg_argmax * (255. / self.num_classes - 1)

# Image Summary
m, v = tf.nn.moments(input_tensor, axes=[1, 2, 3], keep_dims=True)
timg = (tf.to_float(input_tensor - m) / (tf.sqrt(v) * 2.) + 1.) * 127.
img_summary = tf.minimum(255., tf.maximum(0., timg))

# Show summaries
image3_axial('imgseg', tf.concat([img_summary, seg_summary], 1),
5, [tf.GraphKeys.SUMMARIES])
norm_axes = list(range(1, n_spatial_dims+1))
mean, var = tf.nn.moments(input_tensor, axes=norm_axes, keep_dims=True)
timg = tf.to_float(input_tensor - mean) / (tf.sqrt(var) * 2.)
timg = (timg + 1.) * 127.
single_channel = tf.reduce_mean(timg, axis=-1, keep_dims=True)
img_summary = tf.minimum(255., tf.maximum(0., single_channel))
if n_spatial_dims == 2:
tf.summary.image('imgseg', tf.concat([img_summary, seg_summary], 1),
5, [tf.GraphKeys.SUMMARIES])
elif n_spatial_dims == 3:
# Show summaries
image3_axial('imgseg', tf.concat([img_summary, seg_summary], 1),
5, [tf.GraphKeys.SUMMARIES])
else:
raise NotImplementedError(
'Image Summary only supports 2D and 3D images')

return seg_output

Expand All @@ -287,7 +304,8 @@ def image_resize(image, output_size):
return tf.image.resize_images(image, output_size)
first_reshape = tf.reshape(image, input_size[0:3] + [-1])
first_resize = tf.image.resize_images(first_reshape, output_size[0:2])
second_shape = input_size[:1] + [output_size[0] * output_size[1], input_size[3], -1]
second_shape = input_size[:1] + [output_size[0] * output_size[1],
input_size[3], -1]
second_reshape = tf.reshape(first_resize, second_shape)
second_resize = tf.image.resize_images(second_reshape,
[second_shape[1], output_size[2]])
Expand Down Expand Up @@ -473,18 +491,19 @@ def layer_op(self, input_tensor, is_training=None, keep_prob=None):


class AffineAugmentationLayer(TrainableLayer):
""" This layer applies a small random (per-iteration) affine
""" This layer applies a small random (per-iteration) affine
transformation to an image. The distribution of transformations
generally results in scaling the image up, with minimal sampling
outside the original image. """
outside the original image."""
__metaclass__ = abc.ABCMeta

def __init__(self, scale, interpolation,
boundary, transform_func=None,
name='AffineAugmentation'):
""""
scale denotes how extreme the perturbation is, with 1. meaning
no perturbation and 0.5 giving larger perturbations.
interpolation denotes the image value interpolation used by
interpolation denotes the image value interpolation used by
the resampling
boundary denotes the boundary handling used by the resampling
transform_func should be a function returning a relative
Expand All @@ -502,12 +521,12 @@ def __init__(self, scale, interpolation,

def random_transform(self, batch_size):
if self._transform is None:
corners = self.get_corners()
corners_ = self.get_corners()

_batch_ones = tf.ones([batch_size, len(corners[0], 1])
_batch_ones = tf.ones([batch_size, len(corners_[0]), 1])

corners = tf.tile(corners, [batch_size, 1, 1])
random_size = [batch_size, len(corners[0], len(corners[0][0]]
corners = tf.tile(corners_, [batch_size, 1, 1])
random_size = [batch_size, len(corners_[0]), len(corners_[0][0])]
random_scale = tf.random_uniform(random_size, 0, self.scale)
corners2 = corners * (1 - random_scale)
corners_homog = tf.concat([corners, _batch_ones], 2)
Expand All @@ -522,21 +541,22 @@ def inverse_transform(self, batch_size):
return tf.matrix_inverse(self.transform_func(batch_size))

def layer_op(self, input_tensor):
sz = input_tensor.shape.as_list()
grid_warper = AffineGridWarperLayer(sz[1:-1],
sz[1:-1])
size = input_tensor.shape.as_list()
grid_warper = AffineGridWarperLayer(size[1:-1],
size[1:-1])

resampler = ResamplerLayer(interpolation=self.interpolation,
boundary=self.boundary)

relative_transform = self.transform_func(sz[0])
to_relative = tf.tile(self.get_transform_to_relative, [sz[0], 1, 1])
relative_transform = self.transform_func(size[0])
to_relative = tf.tile(self.get_tfm_to_relative(size), [size[0], 1, 1])

from_relative = tf.matrix_inverse(to_relative)
voxel_transform = tf.matmul(from_relative,
tf.matmul(relative_transform, to_relative))
warp_parameters = tf.reshape(voxel_transform[:, 0:3, 0:4],
[sz[0], 12])
dims = self.spatial_dims
warp_parameters = tf.reshape(voxel_transform[:, 0:dims, 0:dims + 1],
[size[0], dims * (dims + 1)])
grid = grid_warper(warp_parameters)
return resampler(input_tensor, grid)

Expand All @@ -551,36 +571,46 @@ def inverse(self, interpolation=None, boundary=None):
boundary,
self.inverse_transform)

class Affine2DAugmentationLayer(TrainableLayer):
@abc.abstractproperty
def spatial_dims(self):
raise NotImplementedError

@abc.abstractmethod
def get_corners(self):
raise NotImplementedError

@abc.abstractmethod
def get_tfm_to_relative(self):
raise NotImplementedError

class Affine2DAugmentationLayer(AffineAugmentationLayer):
""" Specialization of AffineAugmentationLayer for 2D coordinates """
spatial_dims = 2
def get_corners(self):
return [
[[-1., -1.],
return [[[-1., -1.],
[-1., 1.],
[1., -1.],
[1., 1.]]
]
[1., 1.]]]

def get_transform_to_relative(self, sz):
return [[[2./(sz[1]-1), 0., -1.],
[0., 2. / (sz[2] - 1), -1.],
def get_tfm_to_relative(self, size):
return [[[2./(size[1]-1), 0., -1.],
[0., 2. / (size[2] - 1), -1.],
[0., 0., 1.]]]

class Affine3DAugmentationLayer(TrainableLayer):
class Affine3DAugmentationLayer(AffineAugmentationLayer):
""" Specialization of AffineAugmentationLayer for 3D coordinates """
spatial_dims = 3
def get_corners(self):
return [
[[-1., -1., -1.],
return [[[-1., -1., -1.],
[-1., -1., 1.],
[-1., 1., -1.],
[-1., 1., 1.],
[1., -1., -1.],
[1., -1., 1.],
[1., 1., -1.],
[1., 1., 1.]]
]
def get_transform_to_relative(self, sz):
return [[[2./(sz[1]-1), 0., 0., -1.],
[0., 2. / (sz[2] - 1), 0., -1.],
[0., 0., 2. / (sz[3] - 1), -1.],
[0., 0., 0., 1.]]]
[1., 1., 1.]]]
def get_tfm_to_relative(self, size):
return [[[2./(size[1]-1), 0., 0., -1.],
[0., 2. / (size[2] - 1), 0., -1.],
[0., 0., 2. / (size[3] - 1), -1.],
[0., 0., 0., 1.]]]
42 changes: 42 additions & 0 deletions tests/dense_vnet_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from __future__ import absolute_import, print_function

import unittest

import os
import tensorflow as tf
from tensorflow.contrib.layers.python.layers import regularizers

from niftynet.network.dense_vnet import DenseVNet

@unittest.skipIf(os.environ.get('QUICKTEST', "").lower() == "true", 'Skipping slow tests')
class DenseVNetTest(tf.test.TestCase):
def test_3d_shape(self):
input_shape = (2, 72, 72, 72, 3)
x = tf.ones(input_shape)

dense_vnet_instance = DenseVNet(
num_classes=2)
out = dense_vnet_instance(x, is_training=True)
# print(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))

with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
out = sess.run(out)
self.assertAllClose((2, 72, 72, 72, 2), out.shape)

def test_2d_shape(self):
input_shape = (2, 72, 72, 3)
x = tf.ones(input_shape)

dense_vnet_instance = DenseVNet(
num_classes=2)
out = dense_vnet_instance(x, is_training=True)
# print(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))

with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
out = sess.run(out)
self.assertAllClose((2, 72, 72, 2), out.shape)

if __name__ == "__main__":
tf.test.main()

0 comments on commit e429b23

Please sign in to comment.