Skip to content

Commit

Permalink
Update xrv_test.py
Browse files Browse the repository at this point in the history
  • Loading branch information
etetteh committed Jul 23, 2021
1 parent 4781fe9 commit 7ace2e0
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions xrv_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
parser.add_argument('--batch_size', type=int, default=64, help='')
parser.add_argument('--shuffle', type=bool, default=False, help='')
parser.add_argument('--num_workers', type=int, default=4, help='')
parser.add_argument('--num_batches', type=int, default=290, help='')
parser.add_argument('--num_batches', type=int, default=230, help='')

### Data Augmentation
parser.add_argument('--data_aug_rot', type=int, default=45, help='')
Expand Down Expand Up @@ -70,13 +70,15 @@ def tqdm(*args, **kwargs):

transforms = torchvision.transforms.Compose([datasets.XRayCenterCrop(), datasets.XRayResizer(112)])

datasets.default_pathologies = ['Cardiomegaly',
'Pneumonia',
datasets.default_pathologies = [
'Cardiomegaly',
#'Pneumonia',
'Effusion',
'Edema',
'Atelectasis',
#'Atelectasis',
'Consolidation',
'Pneumothorax']
#'Pneumothorax'
]


if "nih" in cfg.dataset_name:
Expand Down Expand Up @@ -238,7 +240,8 @@ def inference(name, model, device, data_loader, criterion, limit=cfg.num_batches
model=model,
device=device,
data_loader=test_loader,
criterion=criterion)
criterion=criterion,
limit=cfg.num_batches//2)

print(f"Average AUC for all pathologies {test_auc:4.4f}")
print(f"Test loss: {test_loss:4.4f}")
Expand Down

0 comments on commit 7ace2e0

Please sign in to comment.