Skip to content

Commit

Permalink
Code to enable NN to read other file types with support for SimpleITK
Browse files Browse the repository at this point in the history
loading if installed
  • Loading branch information
Eli GIBSON committed Jun 27, 2017
1 parent fd717f8 commit 2ca006a
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 26 deletions.
16 changes: 0 additions & 16 deletions data/PROMISE12/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,3 @@
zip_ref = zipfile.ZipFile(os.path.join(zip_dir,zip_filename), 'r')
zip_ref.extractall(target_dir)
zip_ref.close()

import SimpleITK
import glob
for fn in glob.glob('*.mhd'):
if 'segmentation' in fn:
fn_out=fn[:-4]+'.nii.gz'
else:
fn_out=fn[:-4]+'_T2.nii.gz'
SimpleITK.WriteImage(SimpleITK.ReadImage(fn),fn_out)
try:
os.remove(fn)
os.remove(fn[:-4]+'.raw')
except:
pass


10 changes: 7 additions & 3 deletions demos/PROMISE12/promise12_demo_train_config.txt
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
[image modality 1]
path_to_search = ./data/PROMISE12
filename_contains = T2
filename_not_contains = Case2
filename_contains = Case,mhd
filename_not_contains = Case2,segmentation

[label modality 1]
path_to_search = ./data/PROMISE12
filename_contains = segmentation
filename_contains = Case,segmentation,mhd
filename_not_contains = Case2

[settings]
cuda_devices = ""
model_dir = ./models/model_vnet
net_name = vnet
activation_function = prelu

# preprocessing threads parameters
queue_length = 8
Expand Down Expand Up @@ -56,6 +57,9 @@ max_angle = 10.0
spatial_scaling = True
min_percentage = -10.0
max_percentage = 10.0
min_numb_labels = 2
min_sampling_ratio = 0.000001
window_sampling = uniform

# ** training only ** gradient descent and loss parameters
lr = 0.0001
Expand Down
33 changes: 30 additions & 3 deletions utilities/misc_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,14 @@
import numpy as np
import scipy.ndimage

image_loaders = [nib.load]
try:
import utilities.simple_itk_as_nibabel
image_loaders.append(utilities.simple_itk_as_nibabel.SimpleITKAsNibabel)
except ImportError:
raise ImportWarning('SimpleITK adapter failed to load, reducing the supported file formats.')


warnings.simplefilter("ignore", UserWarning)

FILE_EXTENSIONS = [".nii.gz", ".tar.gz"]
Expand All @@ -23,10 +31,29 @@ def create_affine_pixdim(affine, pixdim):
np.expand_dims(np.append(np.asarray(pixdim), 1), axis=1), [1, 4])
return np.multiply(np.divide(affine, to_divide.T), to_multiply.T)


def load_image(filename):
# load an image from a supported filetype and return an object
# that matches nibabel's spatialimages interface
for image_loader in image_loaders:
try:
img=image_loader(filename)
img = correct_image_if_necessary(img)
return img
except nib.filebasedimages.ImageFileError: # if the image_loader cannot handle the type continue to next loader
pass
raise nib.filebasedimages.ImageFileError('No loader could load the file') # Throw last error

def correct_image_if_necessary(img):
# Check that affine matches zooms
pixdim = img.header.get_zooms()
if not np.array_equal(np.sqrt(np.sum(np.square(img.affine[0:3, 0:3]), 0)), np.asarray(pixdim)):
if img.hasattr('get_sform'):
# assume it is a malformed NIfTI and try to fix it
img=rectify_header_sform_qform(img)
return img

def rectify_header_sform_qform(img_nii):
# TODO: check img_nii is a nibabel object
pixdim = img_nii.header.get_zooms()
sform = img_nii.get_sform()
qform = img_nii.get_qform()
norm_sform = np.sqrt(np.sum(np.square(sform[0:3, 0:3]), 0))
Expand Down Expand Up @@ -155,7 +182,7 @@ def csv_cell_to_volume_5d(csv_cell):
data_array[t][m] = expand_to_5d(np.zeros(dimensions))
continue
# load a 3d volume
img_nii = nib.load(csv_cell()[t][m])
img_nii = load_image(csv_cell()[t][m])
img_data_shape = img_nii.header.get_data_shape()
assert np.prod(img_data_shape) > 1

Expand Down
27 changes: 27 additions & 0 deletions utilities/simple_itk_as_nibabel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import SimpleITK as sitk
import nibabel
import numpy as np


class SimpleITKAsNibabel(nibabel.spatialimages.SpatialImage):
''' Minimal interface to use a SimpleITK image as if it were
a nibabel object. Currently only supports the subset of the
interface used by NiftyNet and is read only
'''
def __init__(self, filename):
try:
self._SimpleITKImage = sitk.ReadImage(filename)
except RuntimeError as err:
if 'Unable to determine ImageIO reader' in str(err):
nibabel.filebasedimages.ImageFileError(str(err))
else:
raise
self._header = SimpleITKAsNibabelHeader(self._SimpleITKImage)
# get affine transform
c=np.array([self._SimpleITKImage.TransformContinuousIndexToPhysicalPoint(p) for p in ((1,0,0),(0,1,0),(0,0,1),(0,0,0))])
affine = np.transpose(np.concatenate([np.concatenate([c[0:3]-c[3:],c[3:]],0),[[0.],[0.],[0.],[1.]]],1))
super(SimpleITKAsNibabel,self).__init__(sitk.GetArrayFromImage(self._SimpleITKImage), affine)
class SimpleITKAsNibabelHeader(nibabel.spatialimages.SpatialHeader):
def __init__(self, image_reference):
super(SimpleITKAsNibabelHeader,self).__init__(data_dtype=sitk.GetArrayViewFromImage(image_reference).dtype,shape=sitk.GetArrayViewFromImage(image_reference).shape,
zooms=image_reference.GetSpacing())
9 changes: 5 additions & 4 deletions utilities/subject.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def _read_original_affine(self):
and update the corresponding field if not done yet
"""
img_object = self.__find_first_nibabel_object()
util.rectify_header_sform_qform(img_object)
util.correct_image_if_necessary(img_object)
return img_object.affine

@CacheFunctionOutput
Expand All @@ -159,9 +159,10 @@ def __find_first_nibabel_object(self):
list_files = [item for sublist in input_image_files for item in sublist]
for filename in list_files:
if not filename == '' and os.path.exists(filename):
path, name, ext = util.split_filename(filename)
if 'nii' in ext:
return nib.load(filename)
try:
return util.load_image(filename)
except:
pass # ignore failures here
return None

def __reorient_to_stand(self, data_5d):
Expand Down

0 comments on commit 2ca006a

Please sign in to comment.