Skip to content

Commit

Permalink
Test for training and inference for dense_vnet_abdominal_ct
Browse files Browse the repository at this point in the history
  • Loading branch information
xygorn committed Mar 22, 2018
1 parent 9a40a82 commit 1df9ec4
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 9 deletions.
2 changes: 2 additions & 0 deletions .gitlab-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ testjob:
- dev
- tags
- 195-add-evaluation-action-3
- 192-model_zoo_tests
script:
# !!kill coverage in case of hanging processes
- if pgrep coverage; then pkill -f coverage; fi
Expand Down Expand Up @@ -200,6 +201,7 @@ quicktest:
- 221-add-changelog-entry-for-version-0.2.2
- 223-put-bug-fixes-under-fixed-header-in-changelog
- 195-add-evaluation-action-3
- 192-model_zoo_tests
script:
# print system info
- which nvidia-smi
Expand Down
16 changes: 11 additions & 5 deletions niftynet/application/base_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,22 @@
EVAL = "evaluation"


class SingletonApplication(type):
_instances = None

application_singleton_instance = None # global so it can be reset
class SingletonApplication(type):
def __call__(cls, *args, **kwargs):
if cls._instances is None:
cls._instances = \
global application_singleton_instance
if application_singleton_instance is None:
application_singleton_instance = \
super(SingletonApplication, cls).__call__(*args, **kwargs)
# else:
# raise RuntimeError('application instance already started.')
return cls._instances
return application_singleton_instance

@classmethod
def clear(cls):
global application_singleton_instance
application_singleton_instance = None


class BaseApplication(with_metaclass(SingletonApplication, object)):
Expand Down
10 changes: 6 additions & 4 deletions niftynet/layer/resampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,10 +466,12 @@ def _binary_neighbour_ids(spatial_rank):
for i in range(2 ** spatial_rank)]


@tf.RegisterGradient('FloorMod')
def _floormod_grad(op, grad):
return [None, None]

try: # Some tf versions have this defined already
@tf.RegisterGradient('FloorMod')
def _floormod_grad(op, grad):
return [None, None]
except:
pass

SUPPORTED_INTERPOLATION = {'BSPLINE', 'LINEAR', 'NEAREST', 'IDW'}

Expand Down
60 changes: 60 additions & 0 deletions tests/model_zoo_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os, sys, unittest

import tensorflow as tf

from niftynet.utilities.download import download
from niftynet.utilities.niftynet_global_config import NiftyNetGlobalConfig
from niftynet import main as niftynet_main
from niftynet.application.base_application import SingletonApplication

MODEL_HOME = NiftyNetGlobalConfig().get_niftynet_home_folder()

def net_run_with_sys_argv(argv):
cache = sys.argv
sys.argv = argv
niftynet_main()
sys.argv = cache

class ModelZooTestMixin(object):
def test_train(self):
self.assertTrue(download('dense_vnet_abdominal_ct_model_zoo', True))
SingletonApplication.clear()
net_run_with_sys_argv(['net_run', '-a', self.application, '-c',
os.path.join(MODEL_HOME, 'extensions', self.location, self.config),
'train', '--max_iter', '2'])
checkpoint = os.path.join(MODEL_HOME, 'models', self.location, 'models', 'model.ckpt-1.index')
self.assertTrue(os.path.exists(checkpoint))
self.check_train()

def test_inference(self):
self.assertTrue(download('dense_vnet_abdominal_ct_model_zoo', True))
SingletonApplication.clear()
net_run_with_sys_argv(['net_run', '-a', self.application, '-c',
os.path.join(MODEL_HOME, 'extensions', self.location, self.config),
'inference'])
output = os.path.join(MODEL_HOME, 'models', self.location, self.expected_output)
self.assertTrue(os.path.exists(output))
self.check_inference()

def check_inference(self):
pass

def check_train(self):
pass


@unittest.skipIf(os.environ.get('QUICKTEST', "").lower() == "true", 'Skipping slow tests')
class DenseVNetAbdominalCTModelZooTest(tf.test.TestCase, ModelZooTestMixin):
id = 'dense_vnet_abdominal_ct_model_zoo'
location = 'dense_vnet_abdominal_ct'
config = 'config.ini'
application = 'net_segment'
expected_output = os.path.join('segmentation_output','100__niftynet_out.nii.gz')

if __name__ == "__main__":
tf.test.main()

0 comments on commit 1df9ec4

Please sign in to comment.