Skip to content

Commit

Permalink
Merge branch '192-model_zoo_tests' into 'dev'
Browse files Browse the repository at this point in the history
"Create tests for demos and model zoo entries"

Closes NifTK#192

See merge request CMIC/NiftyNet!227
  • Loading branch information
wyli committed Mar 26, 2018
2 parents a773efe + edd6f79 commit 1c5b5d0
Show file tree
Hide file tree
Showing 11 changed files with 249 additions and 40 deletions.
4 changes: 4 additions & 0 deletions .gitlab-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ testjob:
- dev
- tags
- 195-add-evaluation-action-3
- 192-model_zoo_tests
script:
# !!kill coverage in case of hanging processes
- if pgrep coverage; then pkill -f coverage; fi
Expand Down Expand Up @@ -120,6 +121,7 @@ testjob:
- python net_autoencoder.py inference -c config/vae_config.ini --inference_type encode --save_seg_dir output/vae_demo_features
- python net_autoencoder.py inference -c config/vae_config.ini --inference_type encode-decode

- python -m tests.test_model_zoo
- python -m unittest discover -s "tests" -p "*_test.py"

# deactivate virtual environment
Expand Down Expand Up @@ -174,6 +176,7 @@ testjob:

- coverage run -a --source . net_download.py dense_vnet_abdominal_ct_model_zoo -r

- coverage run -a --source . -m tests.test_model_zoo
- coverage run -a --source . -m unittest discover -s "tests" -p "*_test.py"
- coverage report -m
- echo 'finished test'
Expand All @@ -200,6 +203,7 @@ quicktest:
- 221-add-changelog-entry-for-version-0.2.2
- 223-put-bug-fixes-under-fixed-header-in-changelog
- 195-add-evaluation-action-3
- 192-model_zoo_tests
script:
# print system info
- which nvidia-smi
Expand Down
16 changes: 11 additions & 5 deletions niftynet/application/base_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,22 @@
EVAL = "evaluation"


class SingletonApplication(type):
_instances = None

application_singleton_instance = None # global so it can be reset
class SingletonApplication(type):
def __call__(cls, *args, **kwargs):
if cls._instances is None:
cls._instances = \
global application_singleton_instance
if application_singleton_instance is None:
application_singleton_instance = \
super(SingletonApplication, cls).__call__(*args, **kwargs)
# else:
# raise RuntimeError('application instance already started.')
return cls._instances
return application_singleton_instance

@classmethod
def clear(cls):
global application_singleton_instance
application_singleton_instance = None


class BaseApplication(with_metaclass(SingletonApplication, object)):
Expand Down
4 changes: 2 additions & 2 deletions niftynet/engine/application_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,11 +507,11 @@ def _device_string(self, device_id=0, is_worker=True):
# in training: use gpu only for workers whenever n_local_gpus
device = 'gpu' if (is_worker and n_local_gpus > 0) else 'cpu'
if device == 'gpu' and device_id >= n_local_gpus:
tf.logging.fatal(
tf.logging.warning(
'trying to use gpu id %s, but only has %s GPU(s), '
'please set num_gpus to %s at most',
device_id, n_local_gpus, n_local_gpus)
raise ValueError
#raise ValueError
return '/{}:{}'.format(device, device_id)
# in inference: use gpu for everything whenever n_local_gpus
return '/gpu:0' if n_local_gpus > 0 else '/cpu:0'
Expand Down
20 changes: 20 additions & 0 deletions niftynet/engine/windows_aggregator_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __init__(self, image_reader=None, output_path='.'):
self._image_id = None
self.prefix = ''
self.output_path = os.path.abspath(output_path)
self.inferred_cleared = False

@property
def input_image(self):
Expand Down Expand Up @@ -133,3 +134,22 @@ def crop_batch(window, location, border):
' spatial dims are: %s', window_shape, spatial_shape)
raise NotImplementedError
return window, location

def log_inferred(self, subject_name, filename):
"""
This function writes out a csv of inferred files
:param subject_name: subject name corresponding to output
:param filename: filename of output
:return:
"""
inferred_csv = os.path.join(self.output_path, 'inferred.csv')
if not self.inferred_cleared:
if os.path.exists(inferred_csv):
os.remove(inferred_csv)
self.inferred_cleared = True
if not os.path.exists(self.output_path):
os.makedirs(self.output_path)
with open(inferred_csv, 'a+') as csv_file:
filename = os.path.join(self.output_path, filename)
csv_file.write('{},{}\n'.format(subject_name, filename))
8 changes: 1 addition & 7 deletions niftynet/engine/windows_aggregator_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,7 @@ def __init__(self,
self.name = name
self.output_interp_order = 0
self.prefix = prefix
self.output_path = os.path.abspath(output_path)
self.inferred_csv = os.path.join(self.output_path, 'inferred.csv')
self.csv_path = os.path.join(self.output_path, self.prefix+'.csv')
if os.path.exists(self.inferred_csv):
os.remove(self.inferred_csv)
if os.path.exists(self.csv_path):
os.remove(self.csv_path)

Expand Down Expand Up @@ -72,7 +68,5 @@ def _save_current_image(self, image_out):
with open(self.csv_path, 'a') as csv_file:
data_str = ','.join([str(i) for i in image_out[0, 0, 0, 0, :]])
csv_file.write(subject_name+','+data_str+'\n')
with open(self.inferred_csv, 'a') as csv_file:
filename = os.path.join(self.output_path, filename)
csv_file.write('{},{}\n'.format(subject_name, filename))
self.log_inferred(subject_name, filename)
return
8 changes: 1 addition & 7 deletions niftynet/engine/windows_aggregator_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,9 @@ def __init__(self,
self, image_reader=image_reader, output_path=output_path)
self.name = name
self.image_out = None
self.output_path = os.path.abspath(output_path)
self.inferred_csv = os.path.join(self.output_path, 'inferred.csv')
self.window_border = window_border
self.output_interp_order = interp_order
self.prefix = prefix
if os.path.exists(self.inferred_csv):
os.remove(self.inferred_csv)

def decode_batch(self, window, location):
n_samples = location.shape[0]
Expand Down Expand Up @@ -91,7 +87,5 @@ def _save_current_image(self):
self.image_out,
source_image_obj,
self.output_interp_order)
with open(self.inferred_csv, 'a') as csv_file:
filename = os.path.join(self.output_path, filename)
csv_file.write('{},{}\n'.format(subject_name, filename))
self.log_inferred(subject_name, filename)
return
8 changes: 1 addition & 7 deletions niftynet/engine/windows_aggregator_resize.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,9 @@ def __init__(self,
ImageWindowsAggregator.__init__(
self, image_reader=image_reader, output_path=output_path)
self.name = name
self.output_path = os.path.abspath(output_path)
self.inferred_csv = os.path.join(self.output_path, 'inferred.csv')
self.window_border = window_border
self.output_interp_order = interp_order
self.prefix = prefix
if os.path.exists(self.inferred_csv):
os.remove(self.inferred_csv)

def decode_batch(self, window, location):
"""
Expand Down Expand Up @@ -103,7 +99,5 @@ def _save_current_image(self, image_out, resize_to):
image_out,
source_image_obj,
self.output_interp_order)
with open(self.inferred_csv, 'a') as csv_file:
filename = os.path.join(self.output_path, filename)
csv_file.write('{},{}\n'.format(subject_name, filename))
self.log_inferred(subject_name, filename)
return
10 changes: 6 additions & 4 deletions niftynet/layer/resampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,10 +466,12 @@ def _binary_neighbour_ids(spatial_rank):
for i in range(2 ** spatial_rank)]


@tf.RegisterGradient('FloorMod')
def _floormod_grad(op, grad):
return [None, None]

try: # Some tf versions have this defined already
@tf.RegisterGradient('FloorMod')
def _floormod_grad(op, grad):
return [None, None]
except:
pass

SUPPORTED_INTERPOLATION = {'BSPLINE', 'LINEAR', 'NEAREST', 'IDW'}

Expand Down
20 changes: 14 additions & 6 deletions niftynet/utilities/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,13 @@ def download(example_ids,
A list of identifiers for the samples to download
:param download_if_already_existing:
If true, data will always be downloaded
:param verbose:
If true, download info will be printed
"""

global_config = NiftyNetGlobalConfig()

config_store = ConfigStore(global_config)
config_store = ConfigStore(global_config, verbose=verbose)

# If a single id is specified, convert to a list
example_ids = [example_ids] \
Expand Down Expand Up @@ -111,7 +113,7 @@ def download_file(url, download_path):
move(downloaded_file, destination_path)


def download_and_decompress(url, download_path):
def download_and_decompress(url, download_path, verbose=True):
"""
Download an archive from a resource URL and
decompresses/unarchives to the given location
Expand All @@ -132,7 +134,10 @@ def download_and_decompress(url, download_path):
downloaded_file = os.path.join(tempfile.gettempdir(), filename)

# Download the file
urlretrieve(url, downloaded_file, reporthook=progress_bar_wrapper)
if verbose:
urlretrieve(url, downloaded_file, reporthook=progress_bar_wrapper)
else:
urlretrieve(url, downloaded_file)

# Decompress and extract all files to the specified local path
tar = tarfile.open(downloaded_file, "r")
Expand All @@ -149,13 +154,14 @@ class ConfigStore:
remote repository with local caching
"""

def __init__(self, global_config):
def __init__(self, global_config, verbose=True):
self._download_folder = global_config.get_niftynet_home_folder()
self._config_folder = global_config.get_niftynet_config_folder()
self._local = ConfigStoreCache(
os.path.join(self._config_folder, '.downloads_local_config_cache'))
self._remote = RemoteProxy(self._config_folder,
global_config.get_download_server_url())
self._verbose = verbose

def exists(self, example_id):
"""
Expand Down Expand Up @@ -258,8 +264,10 @@ def _download(self, remote_config_sections, example_id):
'configuration file')
local_download_path = self._get_local_download_path(
config_params, example_id)
download_and_decompress(url=config_params['url'],
download_path=local_download_path)
download_and_decompress(
url=config_params['url'],
download_path=local_download_path,
verbose=self._verbose)
print('{} -- {}: OK.'.format(example_id, section_name))
print("Downloaded data to " + local_download_path)
else:
Expand Down
2 changes: 0 additions & 2 deletions tests/rand_flip_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@

from niftynet.layer.rand_flip import RandomFlipLayer

os.environ["CUDA_VISIBLE_DEVICES"] = '-1'


class RandFlipTest(tf.test.TestCase):
def test_1d_flip(self):
Expand Down
Loading

0 comments on commit 1c5b5d0

Please sign in to comment.