Skip to content

Commit

Permalink
code refactored
Browse files Browse the repository at this point in the history
  • Loading branch information
ecemlago committed Nov 20, 2021
1 parent 6e2656a commit 989ef4d
Showing 1 changed file with 17 additions and 15 deletions.
32 changes: 17 additions & 15 deletions process.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,21 @@
execute_in_docker = True

class Noduledetection(DetectionAlgorithm):
def __init__(self, train=False, retrain=False, retest=False):
def __init__(self, input_dir, output_dir, train=False, retrain=False, retest=False):
super().__init__(
validators=dict(
input_image=(
UniqueImagesValidator(),
UniquePathIndicesValidator(),
)
),
input_path = Path("/input/") if execute_in_docker else Path("./test/"),
output_file = Path("/output/nodules.json") if execute_in_docker else Path("./output/nodules.json")
input_path = Path(input_dir),
output_file = Path(os.path.join(output_dir,'nodules.json'))
)

#------------------------------- LOAD the model here ---------------------------------
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.input_path, self.output_path = input_dir, output_dir
print('using the device ', self.device)
self.model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=False, pretrained_backbone=False)
num_classes = 2 # 1 class (nodule) + background
Expand All @@ -49,19 +50,19 @@ def __init__(self, train=False, retrain=False, retest=False):

if not (train or retest):
# retrain or test phase
print('loading the model from container with model file:')
print('loading the model.pth file :')
self.model.load_state_dict(
torch.load(
"model.pth",
Path("/opt/algorithm/model.pth") if execute_in_docker else Path("model.pth"),
map_location=self.device,
)
)

if retest:
print('loading the retrained model for retest phase')
print('loading the retrained model_retrained.pth file')
self.model.load_state_dict(
torch.load(
Path("/input/model_retrained.pth") if execute_in_docker else Path("./output/model_retrained.pth"),
Path(os.path.join(self.input_path,'model_retrained.pth')),
map_location=self.device,
)
)
Expand Down Expand Up @@ -94,7 +95,7 @@ def process_case(self, *, idx, case):


#--------------------Write your retrain function here ------------
def train(self, input_dir, output_dir, num_epochs = 1):
def train(self, num_epochs = 1):
'''
input_dir: Input directory containing all the images to train with
output_dir: output_dir to write model to.
Expand All @@ -104,6 +105,7 @@ def train(self, input_dir, output_dir, num_epochs = 1):

# create training dataset and defined transformations
self.model.train()
input_dir = self.input_path
dataset = CXRNoduleDataset(input_dir, os.path.join(input_dir, 'metadata.csv'), get_transform(train=True))
print('training starts ')
# define training and validation data loaders
Expand All @@ -126,9 +128,9 @@ def train(self, input_dir, output_dir, num_epochs = 1):
print('epoch ', str(epoch),' is running')
# evaluate on the test dataset

# save retrained version frequently.
#IMPORTANT: save retrained version frequently.
print('saving the model')
torch.save(self.model.state_dict(), Path("/output/model_retrained.pth") if execute_in_docker else Path("./output/model_retrained.pth"))
torch.save(self.model.state_dict(), os.path.join(self.output_path, 'model_retrained.pth'))


def format_to_GC(self, np_prediction, spacing) -> Dict:
Expand Down Expand Up @@ -219,11 +221,11 @@ def predict(self, *, input_image: SimpleITK.Image) -> DataFrame:
parser.add_argument('--retrain', action='store_true', help = "Algorithm on retrain mode (loading previous weights).")
parser.add_argument('--retest', action='store_true', help = "Algorithm on evaluate mode after retraining.")

parsed_args = parser.parse_args()
if (parsed_args.train or parsed_args.retrain):
Noduledetection(parsed_args.train, parsed_args.retrain, parsed_args.retest).train(parsed_args.input_dir, parsed_args.output_dir)
else:
Noduledetection().process()
parsed_args = parser.parse_args()
if (parsed_args.train or parsed_args.retrain):# train mode: retrain or train
Noduledetection(parsed_args.input_dir, parsed_args.output_dir, parsed_args.train, parsed_args.retrain, parsed_args.retest).train()
else:# test mode (test or retest)
Noduledetection(parsed_args.input_dir, parsed_args.output_dir, retest=parsed_args.retest).process()



Expand Down

0 comments on commit 989ef4d

Please sign in to comment.