Skip to content

Commit

Permalink
Merge branch 'classification_application' into 'dev'
Browse files Browse the repository at this point in the history
better tensorboard metrics, pylint fixes

See merge request CMIC/NiftyNet!194
  • Loading branch information
eligibson committed Feb 6, 2018
2 parents 9344d88 + cc7350b commit 0410aa2
Showing 1 changed file with 66 additions and 2 deletions.
68 changes: 66 additions & 2 deletions niftynet/application/classification_application.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
"""
This module defines an image-level classification application
that maps from images to scalar, multi-class labels.
This class is instantiated and initalized by the application_driver.
"""

import tensorflow as tf

from niftynet.application.base_application import BaseApplication
Expand All @@ -6,7 +13,8 @@
from niftynet.engine.application_variables import \
CONSOLE, NETWORK_OUTPUT, TF_SUMMARIES
from niftynet.engine.sampler_resize import ResizeSampler
from niftynet.engine.windows_aggregator_classifier import ClassifierSamplesAggregator
from niftynet.engine.windows_aggregator_classifier import \
ClassifierSamplesAggregator
from niftynet.io.image_reader import ImageReader
from niftynet.layer.discrete_label_normalisation import \
DiscreteLabelNormalisationLayer
Expand All @@ -25,6 +33,16 @@


class ClassificationApplication(BaseApplication):
"""This class defines an application for image-level classification
problems mapping from images to scalar labels.
This is the application class to be instantiated by the driver
and referred to in configuration files.
Although structurally similar to segmentation, this application
supports different samplers/aggregators (because patch-based
processing is not appropriate), and monitoring metrics."""

REQUIRED_CONFIG_SECTION = "CLASSIFICATION"

def __init__(self, net_param, action_param, is_training):
Expand Down Expand Up @@ -140,7 +158,7 @@ def initialise_resize_sampler(self):
batch_size=self.net_param.batch_size,
shuffle_buffer=self.is_training,
queue_length=self.net_param.queue_length) for reader in
self.readers]]
self.readers]]

def initialise_aggregator(self):
self.output_decoder = ClassifierSamplesAggregator(
Expand Down Expand Up @@ -177,6 +195,49 @@ def initialise_network(self):
b_regularizer=b_regularizer,
acti_func=self.net_param.activation_function)

def add_confusion_matrix_summaries_(self,
outputs_collector,
net_out,
data_dict):
""" This method defines several monitoring metrics that
are derived from the confusion matrix """
labels = tf.reshape(tf.cast(data_dict['label'], tf.int64), [-1])
prediction = tf.reshape(tf.argmax(net_out, -1), [-1])
num_classes = self.classification_param.num_classes
conf_mat = tf.contrib.metrics.confusion_matrix(labels,
prediction,
num_classes)
conf_mat = tf.to_float(conf_mat) / float(self.net_param.batch_size)
if self.classification_param.num_classes == 2:
outputs_collector.add_to_collection(
var=conf_mat[1][1], name='true_positives',
average_over_devices=True, summary_type='scalar',
collection=TF_SUMMARIES)
outputs_collector.add_to_collection(
var=conf_mat[1][0], name='false_negatives',
average_over_devices=True, summary_type='scalar',
collection=TF_SUMMARIES)
outputs_collector.add_to_collection(
var=conf_mat[0][1], name='false_positives',
average_over_devices=True, summary_type='scalar',
collection=TF_SUMMARIES)
outputs_collector.add_to_collection(
var=conf_mat[0][0], name='true_negatives',
average_over_devices=True, summary_type='scalar',
collection=TF_SUMMARIES)
else:
outputs_collector.add_to_collection(
var=conf_mat[tf.newaxis, :, :, tf.newaxis],
name='confusion_matrix',
average_over_devices=True, summary_type='image',
collection=TF_SUMMARIES)

outputs_collector.add_to_collection(
var=tf.trace(conf_mat), name='accuracy',
average_over_devices=True, summary_type='scalar',
collection=TF_SUMMARIES)


def connect_data_and_network(self,
outputs_collector=None,
gradients_collector=None):
Expand Down Expand Up @@ -225,6 +286,9 @@ def switch_sampler(for_training):
var=data_loss, name='data_loss',
average_over_devices=True, summary_type='scalar',
collection=TF_SUMMARIES)
self.add_confusion_matrix_summaries_(outputs_collector,
net_out,
data_dict)
else:
# converting logits into final output for
# classification probabilities or argmax classification labels
Expand Down

0 comments on commit 0410aa2

Please sign in to comment.