diff --git a/data/PROMISE12/setup.py b/data/PROMISE12/setup.py index 2eb8a7de..b2c63ad6 100644 --- a/data/PROMISE12/setup.py +++ b/data/PROMISE12/setup.py @@ -6,19 +6,3 @@ zip_ref = zipfile.ZipFile(os.path.join(zip_dir,zip_filename), 'r') zip_ref.extractall(target_dir) zip_ref.close() - -import SimpleITK -import glob -for fn in glob.glob('*.mhd'): - if 'segmentation' in fn: - fn_out=fn[:-4]+'.nii.gz' - else: - fn_out=fn[:-4]+'_T2.nii.gz' - SimpleITK.WriteImage(SimpleITK.ReadImage(fn),fn_out) - try: - os.remove(fn) - os.remove(fn[:-4]+'.raw') - except: - pass - - diff --git a/demos/PROMISE12/promise12_demo_train_config.txt b/demos/PROMISE12/promise12_demo_train_config.txt index df241bd9..f99860fb 100644 --- a/demos/PROMISE12/promise12_demo_train_config.txt +++ b/demos/PROMISE12/promise12_demo_train_config.txt @@ -1,17 +1,18 @@ [image modality 1] path_to_search = ./data/PROMISE12 -filename_contains = T2 -filename_not_contains = Case2 +filename_contains = Case,mhd +filename_not_contains = Case2,segmentation [label modality 1] path_to_search = ./data/PROMISE12 -filename_contains = segmentation +filename_contains = Case,segmentation,mhd filename_not_contains = Case2 [settings] cuda_devices = "" model_dir = ./models/model_vnet net_name = vnet +activation_function = prelu # preprocessing threads parameters queue_length = 8 @@ -56,6 +57,9 @@ max_angle = 10.0 spatial_scaling = True min_percentage = -10.0 max_percentage = 10.0 +min_numb_labels = 2 +min_sampling_ratio = 0.000001 +window_sampling = uniform # ** training only ** gradient descent and loss parameters lr = 0.0001 diff --git a/utilities/misc_io.py b/utilities/misc_io.py index c596e407..b08c545d 100755 --- a/utilities/misc_io.py +++ b/utilities/misc_io.py @@ -8,6 +8,14 @@ import numpy as np import scipy.ndimage +image_loaders = [nib.load] +try: + import utilities.simple_itk_as_nibabel + image_loaders.append(utilities.simple_itk_as_nibabel.SimpleITKAsNibabel) +except ImportError: + raise ImportWarning('SimpleITK adapter failed to load, reducing the supported file formats.') + + warnings.simplefilter("ignore", UserWarning) FILE_EXTENSIONS = [".nii.gz", ".tar.gz"] @@ -23,10 +31,29 @@ def create_affine_pixdim(affine, pixdim): np.expand_dims(np.append(np.asarray(pixdim), 1), axis=1), [1, 4]) return np.multiply(np.divide(affine, to_divide.T), to_multiply.T) - +def load_image(filename): + # load an image from a supported filetype and return an object + # that matches nibabel's spatialimages interface + for image_loader in image_loaders: + try: + img=image_loader(filename) + img = correct_image_if_necessary(img) + return img + except nib.filebasedimages.ImageFileError: # if the image_loader cannot handle the type continue to next loader + pass + raise nib.filebasedimages.ImageFileError('No loader could load the file') # Throw last error + +def correct_image_if_necessary(img): + # Check that affine matches zooms + pixdim = img.header.get_zooms() + if not np.array_equal(np.sqrt(np.sum(np.square(img.affine[0:3, 0:3]), 0)), np.asarray(pixdim)): + if img.hasattr('get_sform'): + # assume it is a malformed NIfTI and try to fix it + img=rectify_header_sform_qform(img) + return img + def rectify_header_sform_qform(img_nii): # TODO: check img_nii is a nibabel object - pixdim = img_nii.header.get_zooms() sform = img_nii.get_sform() qform = img_nii.get_qform() norm_sform = np.sqrt(np.sum(np.square(sform[0:3, 0:3]), 0)) @@ -155,7 +182,7 @@ def csv_cell_to_volume_5d(csv_cell): data_array[t][m] = expand_to_5d(np.zeros(dimensions)) continue # load a 3d volume - img_nii = nib.load(csv_cell()[t][m]) + img_nii = load_image(csv_cell()[t][m]) img_data_shape = img_nii.header.get_data_shape() assert np.prod(img_data_shape) > 1 diff --git a/utilities/simple_itk_as_nibabel.py b/utilities/simple_itk_as_nibabel.py new file mode 100644 index 00000000..de5e6044 --- /dev/null +++ b/utilities/simple_itk_as_nibabel.py @@ -0,0 +1,27 @@ +import SimpleITK as sitk +import nibabel +import numpy as np + + +class SimpleITKAsNibabel(nibabel.spatialimages.SpatialImage): + ''' Minimal interface to use a SimpleITK image as if it were + a nibabel object. Currently only supports the subset of the + interface used by NiftyNet and is read only + ''' + def __init__(self, filename): + try: + self._SimpleITKImage = sitk.ReadImage(filename) + except RuntimeError as err: + if 'Unable to determine ImageIO reader' in str(err): + nibabel.filebasedimages.ImageFileError(str(err)) + else: + raise + self._header = SimpleITKAsNibabelHeader(self._SimpleITKImage) + # get affine transform + c=np.array([self._SimpleITKImage.TransformContinuousIndexToPhysicalPoint(p) for p in ((1,0,0),(0,1,0),(0,0,1),(0,0,0))]) + affine = np.transpose(np.concatenate([np.concatenate([c[0:3]-c[3:],c[3:]],0),[[0.],[0.],[0.],[1.]]],1)) + super(SimpleITKAsNibabel,self).__init__(sitk.GetArrayFromImage(self._SimpleITKImage), affine) +class SimpleITKAsNibabelHeader(nibabel.spatialimages.SpatialHeader): + def __init__(self, image_reference): + super(SimpleITKAsNibabelHeader,self).__init__(data_dtype=sitk.GetArrayViewFromImage(image_reference).dtype,shape=sitk.GetArrayViewFromImage(image_reference).shape, + zooms=image_reference.GetSpacing()) diff --git a/utilities/subject.py b/utilities/subject.py index e52831b9..7968dd57 100755 --- a/utilities/subject.py +++ b/utilities/subject.py @@ -132,7 +132,7 @@ def _read_original_affine(self): and update the corresponding field if not done yet """ img_object = self.__find_first_nibabel_object() - util.rectify_header_sform_qform(img_object) + util.correct_image_if_necessary(img_object) return img_object.affine @CacheFunctionOutput @@ -159,9 +159,10 @@ def __find_first_nibabel_object(self): list_files = [item for sublist in input_image_files for item in sublist] for filename in list_files: if not filename == '' and os.path.exists(filename): - path, name, ext = util.split_filename(filename) - if 'nii' in ext: - return nib.load(filename) + try: + return util.load_image(filename) + except: + pass # ignore failures here return None def __reorient_to_stand(self, data_5d):