Skip to content

Commit

Permalink
update aggregator to the latest version
Browse files Browse the repository at this point in the history
  • Loading branch information
wyli committed Oct 7, 2019
1 parent e102051 commit f4054a5
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 208 deletions.
185 changes: 68 additions & 117 deletions niftynet/engine/windows_aggregator_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
windows aggregator decode sampling grid coordinates and image id from
batch data, forms image level output and write to hard drive.
"""
from __future__ import absolute_import, print_function, division
from __future__ import absolute_import, division, print_function

import os
from collections import OrderedDict

import numpy as np
import pandas as pd

import tensorflow as tf
# pylint: disable=too-many-nested-blocks
# pylint: disable=too-many-branches
import niftynet.io.misc_io as misc_io
Expand All @@ -25,6 +26,7 @@ class GridSamplesAggregator(ImageWindowsAggregator):
initialised as all zeros, and the values are replaced
by image window data decoded from batch.
"""

def __init__(self,
image_reader,
name='image',
Expand All @@ -44,134 +46,92 @@ def __init__(self,
self.fill_constant = fill_constant

def decode_batch(self, window, location):
'''
"""
Function used to save multiple outputs listed in the window
dictionary. For the fields that have the keyword 'window' in the
dictionary key, it will be saved as image. The rest will be saved as
csv. CSV files will contain at saving a first line of 0 (to be
changed into the header by the user), the first column being the
index of the window, followed by the list of output and the location
array for each considered window
:param window: dictionary of output
:param location: location of the input
:return:
'''
"""
n_samples = location.shape[0]
location_init = np.copy(location)
init_ones = None
for i in window:
if 'window' in i: # all outputs to be created as images should
location_cropped = {}
for key in window:
if 'window' in key: # all outputs to be created as images should
# contained the keyword "window"
init_ones = np.ones_like(window[i])
window[i], _ = self.crop_batch(window[i], location_init,
self.window_border)
location_init = np.copy(location)
print(i, np.sum(window[i]), np.max(window[i]))
_, location = self.crop_batch(init_ones, location_init,
self.window_border)
window[key], location_cropped[key] = self.crop_batch(
window[key], location, self.window_border)

for batch_id in range(n_samples):
image_id, x_start, y_start, z_start, x_end, y_end, z_end = \
location[batch_id, :]
image_id = location[batch_id, 0]
if image_id != self.image_id:
# image name changed:
# save current image and create an empty image
# save current result and create an empty result file
self._save_current_image()
self._save_current_csv()
if self._is_stopping_signal(location[batch_id]):
return False
self.image_out = {}
self.csv_out = {}
for i in window:
if 'window' in i: # check that we want to have an image
# and initialise accordingly
self.image_out[i] = self._initialise_empty_image(

self.image_out, self.csv_out = {}, {}
for key in window:
if 'window' in key:
# to be saved as image
self.image_out[key] = self._initialise_empty_image(
image_id=image_id,
n_channels=window[i].shape[-1],
dtype=window[i].dtype)
print("for output shape is ", self.image_out[i].shape)
n_channels=window[key].shape[-1],
dtype=window[key].dtype)
else:
if not isinstance(window[i], (list, tuple, np.ndarray)):
self.csv_out[i] = self._initialise_empty_csv(
1 + location_init[0, :].shape[-1])
else:
window[i] = np.asarray(window[i])
if n_samples > 1 and np.asarray(window[i]).ndim < 2:
window[i] = np.expand_dims(window[i], 1)
elif n_samples == 1 and np.asarray(
window[i]).shape[0] != n_samples:
window[i] = np.expand_dims(window[i], 0)
window_save = np.asarray(np.squeeze(
window[i][batch_id, ...]))
try:
assert window_save.ndim <= 2
except (TypeError, AssertionError):
tf.logging.error(
"The output you are trying to "
"save as csv is more than "
"bidimensional. Did you want "
"to save an image instead? "
"Put the keyword window "
"in the output dictionary"
" in your application file")
if window_save.ndim < 2:
window_save = np.expand_dims(window_save, 0)
self.csv_out[i] = self._initialise_empty_csv(
n_channel=window_save.shape[-1] + location_init
[0, :].shape[-1])
for i in window:
if 'window' in i:
self.image_out[i][
# to be saved as csv file
n_elements = np.int64(
np.asarray(window[key]).size / n_samples)
table_header = [
'{}_{}'.format(key, idx)
for idx in range(n_elements)
] if n_elements > 1 else ['{}'.format(key)]
table_header += [
'coord_{}'.format(idx)
for idx in range(location.shape[-1])
]
self.csv_out[key] = self._initialise_empty_csv(
key_names=table_header)

for key in window:
if 'window' in key:
x_start, y_start, z_start, x_end, y_end, z_end = \
location_cropped[key][batch_id, 1:]
self.image_out[key][
x_start:x_end, y_start:y_end, z_start:z_end, ...] = \
window[i][batch_id, ...]
window[key][batch_id, ...]
else:
if isinstance(window[i], (list, tuple, np.ndarray)):
window[i] = np.asarray(window[i])
if n_samples > 1 and window[i].ndim < 2:
window[i] = np.expand_dims(window[i], 1)
elif n_samples == 1 and window[i].shape[0] != n_samples:
window[i] = np.expand_dims(window[i], 0)
print(batch_id, "is batch_id ", window[i].shape)
window_save = np.squeeze(np.asarray(
window[i][batch_id, ...]))
try:
assert window_save.ndim <= 2
except (TypeError, AssertionError):
tf.logging.error(
"The output you are trying to "
"save as csv is more than "
"bidimensional. Did you want "
"to save an image instead? "
"Put the keyword window "
"in the output dictionary"
" in your application file")
while window_save.ndim < 2:
window_save = np.expand_dims(window_save, 0)
window_save = np.asarray(window_save)

window_loc = np.concatenate([
window_save, np.tile(
location_init[batch_id, ...],
[window_save.shape[0], 1])], 1)
else:
window_loc = np.concatenate([
np.reshape(window[i], [1, 1]), np.tile(
location_init[batch_id, ...], [1, 1])], 1)
self.csv_out[i] = np.concatenate([self.csv_out[i],
window_loc], 0)
window[key] = np.asarray(window[key]).reshape(
[n_samples, -1])
window_save = window[key][batch_id:batch_id + 1, :]
window_loc = location[batch_id:batch_id + 1, :]
csv_row = np.concatenate([window_save, window_loc], 1)
csv_row = csv_row.ravel()
key_names = self.csv_out[key].columns
self.csv_out[key] = self.csv_out[key].append(
OrderedDict(zip(key_names, csv_row)),
ignore_index=True)
return True

def _initialise_empty_image(self, image_id, n_channels, dtype=np.float):
'''
"""
Initialise an empty image in which to populate the output
:param image_id: image_id to be used in the reader
:param n_channels: numbers of channels of the saved output (for
multimodal output)
:param dtype: datatype used for the saving
:return: the initialised empty image
'''
"""
self.image_id = image_id
spatial_shape = self.input_image[self.name].shape[:3]
output_image_shape = spatial_shape + (n_channels,)
output_image_shape = spatial_shape + (n_channels, )
empty_image = np.zeros(output_image_shape, dtype=dtype)
for layer in self.reader.preprocessors:
if isinstance(layer, PadLayer):
Expand All @@ -182,26 +142,23 @@ def _initialise_empty_image(self, image_id, n_channels, dtype=np.float):

return empty_image

def _initialise_empty_csv(self, n_channel):
'''
def _initialise_empty_csv(self, key_names):
"""
Initialise a csv output file with a first line of zeros
:param n_channel: number of saved fields
:return: empty first line of the array to be saved as csv
'''
return np.zeros([1, n_channel])
"""
return pd.DataFrame(columns=key_names)

def _save_current_image(self):
'''
"""
For all the outputs to be saved as images, go through the dictionary
and save the resulting output after reversing the initial preprocessing
:return:
'''
"""
if self.input_image is None:
return
for i in self.image_out:
print(np.sum(self.image_out[i]), " is sum of image out %s before"
% i)
print("for output shape is now ", self.image_out[i].shape)
for layer in reversed(self.reader.preprocessors):
if isinstance(layer, PadLayer):
for i in self.image_out:
Expand All @@ -210,32 +167,26 @@ def _save_current_image(self):
for i in self.image_out:
self.image_out[i], _ = layer.inverse_op(self.image_out[i])
subject_name = self.reader.get_subject_id(self.image_id)
for i in self.image_out:
print(np.sum(self.image_out[i]), " is sum of image out %s after"
% i)
for i in self.image_out:
filename = "{}_{}_{}.nii.gz".format(i, subject_name, self.postfix)
source_image_obj = self.input_image[self.name]
misc_io.save_data_array(self.output_path,
filename,
self.image_out[i],
source_image_obj,
misc_io.save_data_array(self.output_path, filename,
self.image_out[i], source_image_obj,
self.output_interp_order)
self.log_inferred(subject_name, filename)
return

def _save_current_csv(self):
'''
"""
For all output to be saved as csv, loop through the dictionary of
output and create the csv
:return:
'''
"""
if self.input_image is None:
return
subject_name = self.reader.get_subject_id(self.image_id)
for i in self.csv_out:
filename = "{}_{}_{}.csv".format(i, subject_name, self.postfix)
misc_io.save_csv_array(self.output_path, filename, self.csv_out[
i][1:, :])
misc_io.save_csv_array(self.output_path, filename, self.csv_out[i])
self.log_inferred(subject_name, filename)
return
Loading

0 comments on commit f4054a5

Please sign in to comment.