Skip to content

Commit

Permalink
removed baseline flow to make it clear what is changing
Browse files Browse the repository at this point in the history
  • Loading branch information
ieee8023 committed Sep 17, 2021
1 parent b1e2746 commit d4a88dd
Showing 1 changed file with 31 additions and 55 deletions.
86 changes: 31 additions & 55 deletions chest.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
parser.add_argument('--split', type=int, default=0, help='')
parser.add_argument('--valid_data', type=str, default="mc", help='')

parser.add_argument('--baseline', action='store_true', default=False)
parser.add_argument('--pretrained', action='store_true', default=False, help='')
parser.add_argument('--feat_extract', action='store_true', default=False, help='')
parser.add_argument('--merge_train', action='store_true', default=False, help='')
Expand Down Expand Up @@ -117,18 +116,18 @@ def tqdm(*args, **kwargs):
csvpath=cfg.dataset_dir + "/PC/PADCHEST_chest_x_ray_images_labels_160K_01.02.19.csv",
transform=transforms, data_aug=data_aug, unique_patients=False)

xrv.datasets.default_pathologies = ['Cardiomegaly',
our_pathologies = ['Cardiomegaly',
'Effusion',
'Edema',
'Consolidation',
]

print(f"Common pathologies among all train and validation datasets: {xrv.datasets.default_pathologies}")
print(f"Common pathologies among all train and validation datasets: {our_pathologies}")

xrv.datasets.relabel_dataset(xrv.datasets.default_pathologies, NIH_dataset)
xrv.datasets.relabel_dataset(xrv.datasets.default_pathologies, MIMIC_CH_dataset)
xrv.datasets.relabel_dataset(xrv.datasets.default_pathologies, CHEX_dataset)
xrv.datasets.relabel_dataset(xrv.datasets.default_pathologies, PC_dataset)
xrv.datasets.relabel_dataset(our_pathologies, NIH_dataset)
xrv.datasets.relabel_dataset(our_pathologies, MIMIC_CH_dataset)
xrv.datasets.relabel_dataset(our_pathologies, CHEX_dataset)
xrv.datasets.relabel_dataset(our_pathologies, PC_dataset)

################################### Dataset selection for Train, Validation and Inference ##################################
############### Split 0 #####################
Expand Down Expand Up @@ -212,28 +211,18 @@ def tqdm(*args, **kwargs):
dmerge = xrv.datasets.Merge_Dataset(train_datas)
train_datas = [cmerge, dmerge]

if cfg.baseline:
train_loader = [[{} for i in range(2*cfg.num_batches)]]
train_dataset = Merge_Dataset(datasets=train_datas, seed=cfg.seed, num_samples=cfg.batch_size*cfg.num_batches)
dataloader = DataLoader(train_dataset,
batch_size=cfg.batch_size,
shuffle=True,
num_workers=cfg.num_workers,
pin_memory=True,
drop_last=True)
train_loader[0].insert(0, dataloader)
else:
train_loader = [[{} for i in range(cfg.num_batches)] for i in range(len(train_datas))]
for dataloader in train_loader:
for data in train_datas:
if train_loader.index(dataloader) == train_datas.index(data):
tr_l = DataLoader(xrv.datasets.SubsetDataset(dataset=data, idxs=range(cfg.batch_size*cfg.num_batches)),
batch_size=cfg.batch_size,
shuffle=True,
num_workers=cfg.num_workers,
pin_memory=True,
drop_last=True)
dataloader.insert(0, tr_l)

train_loader = [[{} for i in range(cfg.num_batches)] for i in range(len(train_datas))]
for dataloader in train_loader:
for data in train_datas:
if train_loader.index(dataloader) == train_datas.index(data):
tr_l = DataLoader(xrv.datasets.SubsetDataset(dataset=data, idxs=range(cfg.batch_size*cfg.num_batches)),
batch_size=cfg.batch_size,
shuffle=True,
num_workers=cfg.num_workers,
pin_memory=True,
drop_last=True)
dataloader.insert(0, tr_l)

valid_loader = DataLoader(valid_data,
batch_size=cfg.batch_size,
Expand Down Expand Up @@ -280,10 +269,7 @@ def train_epoch(num_batches, epoch, model, device, train_loader, criterion, opti
model.train()
avg_loss = []

if cfg.baseline:
t = tqdm(range(1, 2*num_batches+1))
else:
t = tqdm(range(1, num_batches+1))
t = tqdm(range(1, num_batches+1))

for step in t:
for idx, dataloader in enumerate(train_loader):
Expand All @@ -297,24 +283,18 @@ def train_epoch(num_batches, epoch, model, device, train_loader, criterion, opti

dataloader[step]["loss"] = compute_loss(outputs, target, dataloader[0], criterion, device)

if cfg.baseline:
train_nll = torch.tensor(train_loader[0][step]['loss'])
loss = 0.0
loss1 = train_loader[0][step]['loss']
loss += loss1
else:
train_nll = torch.stack([train_loader[0][step]['loss'], train_loader[1][step]['loss']]).mean()
train_nll = torch.stack([train_loader[0][step]['loss'], train_loader[1][step]['loss']]).mean()

weight_norm = torch.tensor(0.).to(device)
for w in model.parameters():
weight_norm += w.norm().pow(2)
weight_norm = torch.tensor(0.).to(device)
for w in model.parameters():
weight_norm += w.norm().pow(2)

loss1 = train_loader[0][step]['loss']
loss2 = train_loader[1][step]['loss']
loss1 = train_loader[0][step]['loss']
loss2 = train_loader[1][step]['loss']

loss = 0.0
loss += (loss1 + loss2)
loss += 1e-5 * weight_norm
loss = 0.0
loss += (loss1 + loss2)
loss += 1e-5 * weight_norm

loss.backward()
optimizer.step()
Expand Down Expand Up @@ -558,15 +538,11 @@ def get_model_inputimg(model_name, num_classes):
model_zoo = sorted(['resnet50', 'densenet121',])

for model_name in model_zoo:
model = get_model_inputimg(model_name, num_classes=len(xrv.datasets.default_pathologies))
model = get_model_inputimg(model_name, num_classes=len(our_pathologies))
model = model.to(device)

if cfg.baseline:
print("\n Training Baseline Model \n")
output_dir = "baseline_split-" + str(cfg.split) + "_" + model_name + "_valid-" + cfg.valid_data + "/"
else:
print("\n Training REx Baseline Model \n")
output_dir = "bmbs_split-" + str(cfg.split) + "_" + model_name + "_valid-" + cfg.valid_data + "/"
print("\n Training Model \n")
output_dir = str(cfg.merge_train) + "split-" + str(cfg.split) + "_" + model_name + "_valid-" + cfg.valid_data + "/"

metrics, best_metric, = main(model, model_name, output_dir, num_epochs=cfg.num_epochs)
print(f"Best validation AUC: {best_metric:4.4f}")
Expand Down

0 comments on commit d4a88dd

Please sign in to comment.