forked from NifTK/NiftyNet
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'dev' into 142-csv_reader-nnhack
- Loading branch information
Showing
217 changed files
with
22,273 additions
and
1,067 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
87 changes: 87 additions & 0 deletions
87
demos/Learning_Rate_Decay/Demo_applications/decay_lr_comparison_application.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
import tensorflow as tf | ||
|
||
from niftynet.application.segmentation_application import \ | ||
SegmentationApplication | ||
from niftynet.engine.application_factory import OptimiserFactory | ||
from niftynet.engine.application_variables import CONSOLE | ||
from niftynet.engine.application_variables import TF_SUMMARIES | ||
from niftynet.layer.loss_segmentation import LossFunction | ||
|
||
SUPPORTED_INPUT = set(['image', 'label', 'weight']) | ||
|
||
|
||
class DecayLearningRateApplication(SegmentationApplication): | ||
REQUIRED_CONFIG_SECTION = "SEGMENTATION" | ||
|
||
def __init__(self, net_param, action_param, is_training): | ||
SegmentationApplication.__init__( | ||
self, net_param, action_param, is_training) | ||
tf.logging.info('starting decay learning segmentation application') | ||
self.learning_rate = None | ||
self.current_lr = action_param.lr | ||
if self.action_param.validation_every_n > 0: | ||
raise NotImplementedError("validation process is not implemented " | ||
"in this demo.") | ||
|
||
def connect_data_and_network(self, | ||
outputs_collector=None, | ||
gradients_collector=None): | ||
data_dict = self.get_sampler()[0][0].pop_batch_op() | ||
image = tf.cast(data_dict['image'], tf.float32) | ||
net_out = self.net(image, self.is_training) | ||
|
||
if self.is_training: | ||
with tf.name_scope('Optimiser'): | ||
self.learning_rate = tf.placeholder(tf.float32, shape=[]) | ||
optimiser_class = OptimiserFactory.create( | ||
name=self.action_param.optimiser) | ||
self.optimiser = optimiser_class.get_instance( | ||
learning_rate=self.learning_rate) | ||
loss_func = LossFunction( | ||
n_class=self.segmentation_param.num_classes, | ||
loss_type=self.action_param.loss_type) | ||
data_loss = loss_func( | ||
prediction=net_out, | ||
ground_truth=data_dict.get('label', None), | ||
weight_map=data_dict.get('weight', None)) | ||
|
||
loss = data_loss | ||
reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) | ||
|
||
if self.net_param.decay > 0.0 and reg_losses: | ||
reg_loss = tf.reduce_mean( | ||
[tf.reduce_mean(reg_loss) for reg_loss in reg_losses]) | ||
loss = data_loss + reg_loss | ||
grads = self.optimiser.compute_gradients(loss) | ||
# collecting gradients variables | ||
gradients_collector.add_to_collection([grads]) | ||
# collecting output variables | ||
outputs_collector.add_to_collection( | ||
var=data_loss, name='loss', | ||
average_over_devices=False, collection=CONSOLE) | ||
outputs_collector.add_to_collection( | ||
var=self.learning_rate, name='lr', | ||
average_over_devices=False, collection=CONSOLE) | ||
outputs_collector.add_to_collection( | ||
var=data_loss, name='loss', | ||
average_over_devices=True, summary_type='scalar', | ||
collection=TF_SUMMARIES) | ||
else: | ||
# converting logits into final output for | ||
# classification probabilities or argmax classification labels | ||
SegmentationApplication.connect_data_and_network( | ||
self, outputs_collector, gradients_collector) | ||
|
||
def set_iteration_update(self, iteration_message): | ||
""" | ||
This function will be called by the application engine at each | ||
iteration. | ||
""" | ||
current_iter = iteration_message.current_iter | ||
if iteration_message.is_training: | ||
if current_iter > 0 and current_iter % 5 == 0: | ||
self.current_lr = self.current_lr / 1.02 | ||
iteration_message.data_feed_dict[self.is_validation] = False | ||
elif iteration_message.is_validation: | ||
iteration_message.data_feed_dict[self.is_validation] = True | ||
iteration_message.data_feed_dict[self.learning_rate] = self.current_lr |
85 changes: 85 additions & 0 deletions
85
demos/Learning_Rate_Decay/Demo_applications/no_decay_lr_comparison_application.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
import tensorflow as tf | ||
|
||
from niftynet.application.segmentation_application import \ | ||
SegmentationApplication | ||
from niftynet.engine.application_factory import OptimiserFactory | ||
from niftynet.engine.application_variables import CONSOLE | ||
from niftynet.engine.application_variables import TF_SUMMARIES | ||
from niftynet.layer.loss_segmentation import LossFunction | ||
|
||
SUPPORTED_INPUT = set(['image', 'label', 'weight']) | ||
|
||
|
||
class DecayLearningRateApplication(SegmentationApplication): | ||
REQUIRED_CONFIG_SECTION = "SEGMENTATION" | ||
|
||
def __init__(self, net_param, action_param, is_training): | ||
SegmentationApplication.__init__( | ||
self, net_param, action_param, is_training) | ||
tf.logging.info('starting decay learning segmentation application') | ||
self.learning_rate = None | ||
self.current_lr = action_param.lr | ||
if self.action_param.validation_every_n > 0: | ||
raise NotImplementedError("validation process is not implemented " | ||
"in this demo.") | ||
|
||
def connect_data_and_network(self, | ||
outputs_collector=None, | ||
gradients_collector=None): | ||
data_dict = self.get_sampler()[0][0].pop_batch_op() | ||
image = tf.cast(data_dict['image'], tf.float32) | ||
net_out = self.net(image, self.is_training) | ||
|
||
if self.is_training: | ||
with tf.name_scope('Optimiser'): | ||
self.learning_rate = tf.placeholder(tf.float32, shape=[]) | ||
optimiser_class = OptimiserFactory.create( | ||
name=self.action_param.optimiser) | ||
self.optimiser = optimiser_class.get_instance( | ||
learning_rate=self.learning_rate) | ||
loss_func = LossFunction( | ||
n_class=self.segmentation_param.num_classes, | ||
loss_type=self.action_param.loss_type) | ||
data_loss = loss_func( | ||
prediction=net_out, | ||
ground_truth=data_dict.get('label', None), | ||
weight_map=data_dict.get('weight', None)) | ||
|
||
loss = data_loss | ||
reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) | ||
|
||
if self.net_param.decay > 0.0 and reg_losses: | ||
reg_loss = tf.reduce_mean( | ||
[tf.reduce_mean(reg_loss) for reg_loss in reg_losses]) | ||
loss = data_loss + reg_loss | ||
grads = self.optimiser.compute_gradients(loss) | ||
# collecting gradients variables | ||
gradients_collector.add_to_collection([grads]) | ||
# collecting output variables | ||
outputs_collector.add_to_collection( | ||
var=data_loss, name='dice_loss', | ||
average_over_devices=False, collection=CONSOLE) | ||
outputs_collector.add_to_collection( | ||
var=self.learning_rate, name='lr', | ||
average_over_devices=False, collection=CONSOLE) | ||
outputs_collector.add_to_collection( | ||
var=data_loss, name='dice_loss', | ||
average_over_devices=True, summary_type='scalar', | ||
collection=TF_SUMMARIES) | ||
else: | ||
# converting logits into final output for | ||
# classification probabilities or argmax classification labels | ||
SegmentationApplication.connect_data_and_network( | ||
self, outputs_collector, gradients_collector) | ||
|
||
def set_iteration_update(self, iteration_message): | ||
""" | ||
This function will be called by the application engine at each | ||
iteration. | ||
""" | ||
current_iter = iteration_message.current_iter | ||
if iteration_message.is_training: | ||
iteration_message.data_feed_dict[self.is_validation] = False | ||
elif iteration_message.is_validation: | ||
iteration_message.data_feed_dict[self.is_validation] = True | ||
iteration_message.data_feed_dict[self.learning_rate] = self.current_lr |
334 changes: 334 additions & 0 deletions
334
demos/Learning_Rate_Decay/Demo_for_learning_rate_decay_application.ipynb
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# Learning rate decay application | ||
|
||
This application implements a simple learning rate schedule of | ||
"halving the learning rate every 3 iterations" for segmentation applications. | ||
|
||
The concept is general and could be used for other types of application. A brief demo is provide which can be fully run from a jupyter notebook provided a a working installation of NiftyNet exists on your system. | ||
|
||
The core function is implemented by: | ||
|
||
1) Adding a `self.learning_rate` placeholder, and connect it to the network | ||
in `connect_data_and_network` function | ||
|
||
2) Adding a `self.current_lr` variable to keep track of the current learning rate | ||
|
||
3) Overriding the default `set_iteration_update` function provided in `BaseApplication` | ||
so that `self.current_lr` is changed according to the `current_iter`. | ||
|
||
4) To feed the `self.current_lr` value to the network, the data feeding dictionary | ||
is updated within the customised `set_iteration_update` function, by | ||
``` | ||
iteration_message.data_feed_dict[self.learning_rate] = self.current_lr | ||
``` | ||
`iteration_message.data_feed_dict` will be used in | ||
`tf.Session.run(..., feed_dict=iteration_message.data_feed_dict)` by the engine | ||
at each iteration. | ||
|
||
|
||
*This demo only supports NiftyNet cloned from [GitHub](https://github.com/NifTK/NiftyNet).* | ||
Further demos/ trained models can be found at [NiftyNet model zoo](https://github.com/NifTK/NiftyNetModelZoo/blob/master/dense_vnet_abdominal_ct_model_zoo.md). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
import tensorflow as tf | ||
|
||
from niftynet.application.segmentation_application import \ | ||
SegmentationApplication | ||
from niftynet.engine.application_factory import OptimiserFactory | ||
from niftynet.engine.application_variables import CONSOLE | ||
from niftynet.engine.application_variables import TF_SUMMARIES | ||
from niftynet.layer.loss_segmentation import LossFunction | ||
|
||
SUPPORTED_INPUT = set(['image', 'label', 'weight']) | ||
|
||
|
||
class DecayLearningRateApplication(SegmentationApplication): | ||
REQUIRED_CONFIG_SECTION = "SEGMENTATION" | ||
|
||
def __init__(self, net_param, action_param, is_training): | ||
SegmentationApplication.__init__( | ||
self, net_param, action_param, is_training) | ||
tf.logging.info('starting decay learning segmentation application') | ||
self.learning_rate = None | ||
self.current_lr = action_param.lr | ||
if self.action_param.validation_every_n > 0: | ||
raise NotImplementedError("validation process is not implemented " | ||
"in this demo.") | ||
|
||
def connect_data_and_network(self, | ||
outputs_collector=None, | ||
gradients_collector=None): | ||
data_dict = self.get_sampler()[0][0].pop_batch_op() | ||
image = tf.cast(data_dict['image'], tf.float32) | ||
net_out = self.net(image, self.is_training) | ||
|
||
if self.is_training: | ||
with tf.name_scope('Optimiser'): | ||
self.learning_rate = tf.placeholder(tf.float32, shape=[]) | ||
optimiser_class = OptimiserFactory.create( | ||
name=self.action_param.optimiser) | ||
self.optimiser = optimiser_class.get_instance( | ||
learning_rate=self.learning_rate) | ||
loss_func = LossFunction( | ||
n_class=self.segmentation_param.num_classes, | ||
loss_type=self.action_param.loss_type) | ||
data_loss = loss_func( | ||
prediction=net_out, | ||
ground_truth=data_dict.get('label', None), | ||
weight_map=data_dict.get('weight', None)) | ||
|
||
loss = data_loss | ||
reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) | ||
|
||
if self.net_param.decay > 0.0 and reg_losses: | ||
reg_loss = tf.reduce_mean( | ||
[tf.reduce_mean(reg_loss) for reg_loss in reg_losses]) | ||
loss = data_loss + reg_loss | ||
grads = self.optimiser.compute_gradients(loss) | ||
# collecting gradients variables | ||
gradients_collector.add_to_collection([grads]) | ||
# collecting output variables | ||
outputs_collector.add_to_collection( | ||
var=data_loss, name='loss', | ||
average_over_devices=False, collection=CONSOLE) | ||
outputs_collector.add_to_collection( | ||
var=self.learning_rate, name='lr', | ||
average_over_devices=False, collection=CONSOLE) | ||
outputs_collector.add_to_collection( | ||
var=data_loss, name='loss', | ||
average_over_devices=True, summary_type='scalar', | ||
collection=TF_SUMMARIES) | ||
else: | ||
# converting logits into final output for | ||
# classification probabilities or argmax classification labels | ||
SegmentationApplication.connect_data_and_network( | ||
self, outputs_collector, gradients_collector) | ||
|
||
def set_iteration_update(self, iteration_message): | ||
""" | ||
This function will be called by the application engine at each | ||
iteration. | ||
""" | ||
current_iter = iteration_message.current_iter | ||
if iteration_message.is_training: | ||
if current_iter > 0 and current_iter % 5 == 0: | ||
self.current_lr = self.current_lr / 1.02 | ||
iteration_message.data_feed_dict[self.is_validation] = False | ||
elif iteration_message.is_validation: | ||
iteration_message.data_feed_dict[self.is_validation] = True | ||
iteration_message.data_feed_dict[self.learning_rate] = self.current_lr |
63 changes: 63 additions & 0 deletions
63
demos/Learning_Rate_Decay/learning_rate_demo_train_config.ini
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
############################ input configuration sections | ||
[images] # Name this as you see fit | ||
path_to_search = ./data/decathlon_hippocampus | ||
filename_contains = img_hippocampus_ | ||
filename_not_contains = ._ | ||
spatial_window_size = (24, 24, 24) | ||
interp_order = 3 | ||
|
||
[label] | ||
path_to_search = ./data/decathlon_hippocampus | ||
filename_contains = label_hippocampus_ | ||
filename_not_contains = ._ | ||
spatial_window_size = (24, 24, 24) | ||
interp_order = 0 | ||
|
||
############################## system configuration sections | ||
[SYSTEM] | ||
cuda_devices = "" | ||
num_threads = 6 | ||
num_gpus = 1 | ||
model_dir = ./models/model_multimodal_toy | ||
queue_length = 20 | ||
|
||
[NETWORK] | ||
name = highres3dnet | ||
activation_function = prelu | ||
batch_size = 1 | ||
decay = 0 | ||
reg_type = L2 | ||
|
||
# Volume level pre-processing | ||
volume_padding_size = 0 | ||
# Normalisation | ||
whitening = True | ||
normalise_foreground_only = False | ||
|
||
[TRAINING] | ||
sample_per_volume = 1 | ||
optimiser = gradientdescent | ||
# rotation_angle = (-10.0, 10.0) | ||
# scaling_percentage = (-10.0, 10.0) | ||
# random_flipping_axes= 1 | ||
lr = 0.0001 | ||
loss_type = CrossEntropy | ||
starting_iter = 0 | ||
save_every_n = 100 | ||
max_iter = 500 | ||
max_checkpoints = 20 | ||
|
||
[INFERENCE] | ||
border = 5 | ||
#inference_iter = 10 | ||
save_seg_dir = ./output/toy | ||
output_interp_order = 0 | ||
spatial_window_size = (64, 64, 64) | ||
|
||
############################ custom configuration sections | ||
[SEGMENTATION] | ||
image = images | ||
label = label | ||
output_prob = False | ||
num_classes = 3 | ||
label_normalisation = False |
Oops, something went wrong.