-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_inference.py
123 lines (106 loc) · 8.27 KB
/
run_inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import SimpleITK as sitk
import os
from datasets.cyclegan import CycleGANDataset
from torchvision.transforms import transforms
from torch.utils.data import DataLoader, Dataset
from models.unet import UNet
import torch
import numpy as np
import segmentation_models_pytorch as smp
from preprocessing import nifti_to_2d_slices
import nibabel as nib
from tqdm import tqdm
import argparse
from nnUNet.nnunet.paths import default_plans_identifier, network_training_output_dir, default_cascade_trainer, default_trainer
from nnUNet.nnunet.inference import predict_simple
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("-i", '--input_folder', help="Must contain all modalities for each patient in the correct"
" order (same as training). Files must be named "
"CASENAME_XXXX.nii.gz where XXXX is the modality "
"identifier (0000, 0001, etc)", required=True)
parser.add_argument('-o', "--output_folder", required=True, help="folder for saving predictions")
parser.add_argument('-t', '--task_name', help='task name or task ID, required.',
default=default_plans_identifier, required=True)
parser.add_argument('-tr', '--trainer_class_name',
help='Name of the nnUNetTrainer used for 2D U-Net, full resolution 3D U-Net and low resolution '
'U-Net. The default is %s. If you are running inference with the cascade and the folder '
'pointed to by --lowres_segmentations does not contain the segmentation maps generated by '
'the low resolution U-Net then the low resolution segmentation maps will be automatically '
'generated. For this case, make sure to set the trainer class here that matches your '
'--cascade_trainer_class_name (this part can be ignored if defaults are used).'
% default_trainer,
required=False,
default=default_trainer)
parser.add_argument('-ctr', '--cascade_trainer_class_name',
help="Trainer class name used for predicting the 3D full resolution U-Net part of the cascade."
"Default is %s" % default_cascade_trainer, required=False,
default=default_cascade_trainer)
parser.add_argument('-m', '--model', help="2d, 3d_lowres, 3d_fullres or 3d_cascade_fullres. Default: 3d_fullres",
default="3d_fullres", required=False)
parser.add_argument('-p', '--plans_identifier', help='do not touch this unless you know what you are doing',
default=default_plans_identifier, required=False)
parser.add_argument('-f', '--folds', nargs='+', default='None',
help="folds to use for prediction. Default is None which means that folds will be detected "
"automatically in the model output folder")
parser.add_argument('-z', '--save_npz', required=False, action='store_true',
help="use this if you want to ensemble these predictions with those of other models. Softmax "
"probabilities will be saved as compressed numpy arrays in output_folder and can be "
"merged between output_folders with nnUNet_ensemble_predictions")
parser.add_argument('-l', '--lowres_segmentations', required=False, default='None',
help="if model is the highres stage of the cascade then you can use this folder to provide "
"predictions from the low resolution 3D U-Net. If this is left at default, the "
"predictions will be generated automatically (provided that the 3D low resolution U-Net "
"network weights are present")
parser.add_argument("--part_id", type=int, required=False, default=0, help="Used to parallelize the prediction of "
"the folder over several GPUs. If you "
"want to use n GPUs to predict this "
"folder you need to run this command "
"n times with --part_id=0, ... n-1 and "
"--num_parts=n (each with a different "
"GPU (for example via "
"CUDA_VISIBLE_DEVICES=X)")
parser.add_argument("--num_parts", type=int, required=False, default=1,
help="Used to parallelize the prediction of "
"the folder over several GPUs. If you "
"want to use n GPUs to predict this "
"folder you need to run this command "
"n times with --part_id=0, ... n-1 and "
"--num_parts=n (each with a different "
"GPU (via "
"CUDA_VISIBLE_DEVICES=X)")
parser.add_argument("--num_threads_preprocessing", required=False, default=6, type=int, help=
"Determines many background processes will be used for data preprocessing. Reduce this if you "
"run into out of memory (RAM) problems. Default: 6")
parser.add_argument("--num_threads_nifti_save", required=False, default=2, type=int, help=
"Determines many background processes will be used for segmentation export. Reduce this if you "
"run into out of memory (RAM) problems. Default: 2")
parser.add_argument("--disable_tta", required=False, default=False, action="store_true",
help="set this flag to disable test time data augmentation via mirroring. Speeds up inference "
"by roughly factor 4 (2D) or 8 (3D)")
parser.add_argument("--overwrite_existing", required=False, default=False, action="store_true",
help="Set this flag if the target folder contains predictions that you would like to overwrite")
parser.add_argument("--mode", type=str, default="normal", required=False, help="Hands off!")
parser.add_argument("--all_in_gpu", type=str, default="None", required=False, help="can be None, False or True. "
"Do not touch.")
parser.add_argument("--step_size", type=float, default=0.5, required=False, help="don't touch")
# parser.add_argument("--interp_order", required=False, default=3, type=int,
# help="order of interpolation for segmentations, has no effect if mode=fastest. Do not touch this.")
# parser.add_argument("--interp_order_z", required=False, default=0, type=int,
# help="order of interpolation along z is z is done differently. Do not touch this.")
# parser.add_argument("--force_separate_z", required=False, default="None", type=str,
# help="force_separate_z resampling. Can be None, True or False, has no effect if mode=fastest. "
# "Do not touch this.")
parser.add_argument('-chk',
help='checkpoint name, default: model_final_checkpoint',
required=False,
default='model_final_checkpoint')
parser.add_argument('--disable_mixed_precision', default=False, action='store_true', required=False,
help='Predictions are done with mixed precision by default. This improves speed and reduces '
'the required vram. If you want to disable mixed precision you can set this flag. Note '
'that this is not recommended (mixed precision is ~2x faster!)')
args = parser.parse_args()
predict_simple.main(args)
print("Finish..")