Skip to content

Commit

Permalink
testing image window dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
wyli committed Aug 8, 2018
1 parent 9d9b14a commit 762b7fd
Showing 1 changed file with 17 additions and 20 deletions.
37 changes: 17 additions & 20 deletions niftynet/engine/image_window_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class ImageWindowDataset(Layer):
"""

def __init__(self,
reader=None,
reader,
window_sizes=None,
batch_size=1,
windows_per_image=1,
Expand All @@ -48,12 +48,10 @@ def __init__(self,
name='image_dataset'):
Layer.__init__(self, name=name)

self.num_threads = 1

self.dataset = None
self.iterator = None
self.reader = reader

self.num_threads = 1
self.batch_size = batch_size
self.queue_length = int(max(queue_length, round(batch_size * 5.0)))
if self.queue_length > queue_length:
Expand All @@ -67,17 +65,14 @@ def __init__(self,
self.smaller_final_batch_mode = look_up_operations(
smaller_final_batch_mode.lower(), SMALLER_FINAL_BATCH_MODE)

self.n_subjects = 1
self.window = None
if reader is not None:
self.window = ImageWindow.from_data_reader_properties(
reader.input_sources,
reader.shapes,
reader.tf_dtypes,
window_sizes or (-1, -1, -1))
self.n_subjects = reader.num_subjects
self.window.n_samples = \
1 if self.from_generator else windows_per_image
self.reader = reader
self.window = ImageWindow.from_data_reader_properties(
reader.input_sources,
reader.shapes,
reader.tf_dtypes,
window_sizes or (-1, -1, -1))
self.window.n_samples = \
1 if self.from_generator else windows_per_image
# random seeds? (requires num_threads = 1)

@property
Expand Down Expand Up @@ -105,7 +100,7 @@ def tf_dtypes(self):
"""
returns a dictionary of sampler output tensorflow dtypes
"""
assert self.window, 'Unknown output shapes: self.window not initialised'
assert self.window, 'Unknown output dtypes: self.window not initialised'
return self.window.tf_dtypes

def layer_op(self, idx=None):
Expand All @@ -119,7 +114,7 @@ def layer_op(self, idx=None):
yield a dictionary
{
'image_name': a numpy array,
'image_name': a numpy array [h, w, d, chn],
'image_name_location': (image_id,
x_start, y_start, z_start,
x_end, y_end, z_end)
Expand All @@ -129,7 +124,7 @@ def layer_op(self, idx=None):
return a dictionary:
{
'image_name': a numpy array,
'image_name': a numpy array [n_samples, h, w, d, chn],
'image_name_location': [n_samples, 7]
}
Expand All @@ -153,6 +148,7 @@ def layer_op(self, idx=None):
assert self.window.n_samples == 1, \
'image_window_dataset.layer_op() requires: ' \
'windows_per_image should be 1.'

image_id, image_data, _ = self.reader(idx=idx)
for mod in list(image_data):
spatial_shape = image_data[mod].shape[:N_SPATIAL]
Expand Down Expand Up @@ -285,9 +281,10 @@ def _dataset_from_range(self):
:return: a `tf.data.Dataset`
"""
# dataset: a list of integers
dataset = tf.data.Dataset.range(self.n_subjects)
num_subjects = self.reader.num_subjects
dataset = tf.data.Dataset.range(num_subjects)
if self.shuffle:
dataset = dataset.shuffle(buffer_size=self.n_subjects, seed=None)
dataset = dataset.shuffle(buffer_size=num_subjects, seed=None)

# dataset: map each integer i to n windows sampled from subject i
def _tf_wrapper(idx):
Expand Down

0 comments on commit 762b7fd

Please sign in to comment.