diff --git a/niftynet/application/classification_application.py b/niftynet/application/classification_application.py index 5f526a2b..746bfef6 100755 --- a/niftynet/application/classification_application.py +++ b/niftynet/application/classification_application.py @@ -1,3 +1,10 @@ +""" +This module defines an image-level classification application +that maps from images to scalar, multi-class labels. + +This class is instantiated and initalized by the application_driver. +""" + import tensorflow as tf from niftynet.application.base_application import BaseApplication @@ -6,7 +13,8 @@ from niftynet.engine.application_variables import \ CONSOLE, NETWORK_OUTPUT, TF_SUMMARIES from niftynet.engine.sampler_resize import ResizeSampler -from niftynet.engine.windows_aggregator_classifier import ClassifierSamplesAggregator +from niftynet.engine.windows_aggregator_classifier import \ + ClassifierSamplesAggregator from niftynet.io.image_reader import ImageReader from niftynet.layer.discrete_label_normalisation import \ DiscreteLabelNormalisationLayer @@ -25,6 +33,16 @@ class ClassificationApplication(BaseApplication): + """This class defines an application for image-level classification + problems mapping from images to scalar labels. + + This is the application class to be instantiated by the driver + and referred to in configuration files. + + Although structurally similar to segmentation, this application + supports different samplers/aggregators (because patch-based + processing is not appropriate), and monitoring metrics.""" + REQUIRED_CONFIG_SECTION = "CLASSIFICATION" def __init__(self, net_param, action_param, is_training): @@ -140,7 +158,7 @@ def initialise_resize_sampler(self): batch_size=self.net_param.batch_size, shuffle_buffer=self.is_training, queue_length=self.net_param.queue_length) for reader in - self.readers]] + self.readers]] def initialise_aggregator(self): self.output_decoder = ClassifierSamplesAggregator( @@ -177,6 +195,49 @@ def initialise_network(self): b_regularizer=b_regularizer, acti_func=self.net_param.activation_function) + def add_confusion_matrix_summaries_(self, + outputs_collector, + net_out, + data_dict): + """ This method defines several monitoring metrics that + are derived from the confusion matrix """ + labels = tf.reshape(tf.cast(data_dict['label'], tf.int64), [-1]) + prediction = tf.reshape(tf.argmax(net_out, -1), [-1]) + num_classes = self.classification_param.num_classes + conf_mat = tf.contrib.metrics.confusion_matrix(labels, + prediction, + num_classes) + conf_mat = tf.to_float(conf_mat) / float(self.net_param.batch_size) + if self.classification_param.num_classes == 2: + outputs_collector.add_to_collection( + var=conf_mat[1][1], name='true_positives', + average_over_devices=True, summary_type='scalar', + collection=TF_SUMMARIES) + outputs_collector.add_to_collection( + var=conf_mat[1][0], name='false_negatives', + average_over_devices=True, summary_type='scalar', + collection=TF_SUMMARIES) + outputs_collector.add_to_collection( + var=conf_mat[0][1], name='false_positives', + average_over_devices=True, summary_type='scalar', + collection=TF_SUMMARIES) + outputs_collector.add_to_collection( + var=conf_mat[0][0], name='true_negatives', + average_over_devices=True, summary_type='scalar', + collection=TF_SUMMARIES) + else: + outputs_collector.add_to_collection( + var=conf_mat[tf.newaxis, :, :, tf.newaxis], + name='confusion_matrix', + average_over_devices=True, summary_type='image', + collection=TF_SUMMARIES) + + outputs_collector.add_to_collection( + var=tf.trace(conf_mat), name='accuracy', + average_over_devices=True, summary_type='scalar', + collection=TF_SUMMARIES) + + def connect_data_and_network(self, outputs_collector=None, gradients_collector=None): @@ -225,6 +286,9 @@ def switch_sampler(for_training): var=data_loss, name='data_loss', average_over_devices=True, summary_type='scalar', collection=TF_SUMMARIES) + self.add_confusion_matrix_summaries_(outputs_collector, + net_out, + data_dict) else: # converting logits into final output for # classification probabilities or argmax classification labels