diff --git a/chest.py b/chest.py index 7c277af..8ff1adc 100644 --- a/chest.py +++ b/chest.py @@ -41,7 +41,6 @@ parser.add_argument('--split', type=int, default=0, help='') parser.add_argument('--valid_data', type=str, default="mc", help='') -parser.add_argument('--baseline', action='store_true', default=False) parser.add_argument('--pretrained', action='store_true', default=False, help='') parser.add_argument('--feat_extract', action='store_true', default=False, help='') parser.add_argument('--merge_train', action='store_true', default=False, help='') @@ -117,18 +116,18 @@ def tqdm(*args, **kwargs): csvpath=cfg.dataset_dir + "/PC/PADCHEST_chest_x_ray_images_labels_160K_01.02.19.csv", transform=transforms, data_aug=data_aug, unique_patients=False) -xrv.datasets.default_pathologies = ['Cardiomegaly', +our_pathologies = ['Cardiomegaly', 'Effusion', 'Edema', 'Consolidation', ] -print(f"Common pathologies among all train and validation datasets: {xrv.datasets.default_pathologies}") +print(f"Common pathologies among all train and validation datasets: {our_pathologies}") -xrv.datasets.relabel_dataset(xrv.datasets.default_pathologies, NIH_dataset) -xrv.datasets.relabel_dataset(xrv.datasets.default_pathologies, MIMIC_CH_dataset) -xrv.datasets.relabel_dataset(xrv.datasets.default_pathologies, CHEX_dataset) -xrv.datasets.relabel_dataset(xrv.datasets.default_pathologies, PC_dataset) +xrv.datasets.relabel_dataset(our_pathologies, NIH_dataset) +xrv.datasets.relabel_dataset(our_pathologies, MIMIC_CH_dataset) +xrv.datasets.relabel_dataset(our_pathologies, CHEX_dataset) +xrv.datasets.relabel_dataset(our_pathologies, PC_dataset) ################################### Dataset selection for Train, Validation and Inference ################################## ############### Split 0 ##################### @@ -212,28 +211,18 @@ def tqdm(*args, **kwargs): dmerge = xrv.datasets.Merge_Dataset(train_datas) train_datas = [cmerge, dmerge] -if cfg.baseline: - train_loader = [[{} for i in range(2*cfg.num_batches)]] - train_dataset = Merge_Dataset(datasets=train_datas, seed=cfg.seed, num_samples=cfg.batch_size*cfg.num_batches) - dataloader = DataLoader(train_dataset, - batch_size=cfg.batch_size, - shuffle=True, - num_workers=cfg.num_workers, - pin_memory=True, - drop_last=True) - train_loader[0].insert(0, dataloader) -else: - train_loader = [[{} for i in range(cfg.num_batches)] for i in range(len(train_datas))] - for dataloader in train_loader: - for data in train_datas: - if train_loader.index(dataloader) == train_datas.index(data): - tr_l = DataLoader(xrv.datasets.SubsetDataset(dataset=data, idxs=range(cfg.batch_size*cfg.num_batches)), - batch_size=cfg.batch_size, - shuffle=True, - num_workers=cfg.num_workers, - pin_memory=True, - drop_last=True) - dataloader.insert(0, tr_l) + +train_loader = [[{} for i in range(cfg.num_batches)] for i in range(len(train_datas))] +for dataloader in train_loader: + for data in train_datas: + if train_loader.index(dataloader) == train_datas.index(data): + tr_l = DataLoader(xrv.datasets.SubsetDataset(dataset=data, idxs=range(cfg.batch_size*cfg.num_batches)), + batch_size=cfg.batch_size, + shuffle=True, + num_workers=cfg.num_workers, + pin_memory=True, + drop_last=True) + dataloader.insert(0, tr_l) valid_loader = DataLoader(valid_data, batch_size=cfg.batch_size, @@ -280,10 +269,7 @@ def train_epoch(num_batches, epoch, model, device, train_loader, criterion, opti model.train() avg_loss = [] - if cfg.baseline: - t = tqdm(range(1, 2*num_batches+1)) - else: - t = tqdm(range(1, num_batches+1)) + t = tqdm(range(1, num_batches+1)) for step in t: for idx, dataloader in enumerate(train_loader): @@ -297,24 +283,18 @@ def train_epoch(num_batches, epoch, model, device, train_loader, criterion, opti dataloader[step]["loss"] = compute_loss(outputs, target, dataloader[0], criterion, device) - if cfg.baseline: - train_nll = torch.tensor(train_loader[0][step]['loss']) - loss = 0.0 - loss1 = train_loader[0][step]['loss'] - loss += loss1 - else: - train_nll = torch.stack([train_loader[0][step]['loss'], train_loader[1][step]['loss']]).mean() + train_nll = torch.stack([train_loader[0][step]['loss'], train_loader[1][step]['loss']]).mean() - weight_norm = torch.tensor(0.).to(device) - for w in model.parameters(): - weight_norm += w.norm().pow(2) + weight_norm = torch.tensor(0.).to(device) + for w in model.parameters(): + weight_norm += w.norm().pow(2) - loss1 = train_loader[0][step]['loss'] - loss2 = train_loader[1][step]['loss'] + loss1 = train_loader[0][step]['loss'] + loss2 = train_loader[1][step]['loss'] - loss = 0.0 - loss += (loss1 + loss2) - loss += 1e-5 * weight_norm + loss = 0.0 + loss += (loss1 + loss2) + loss += 1e-5 * weight_norm loss.backward() optimizer.step() @@ -558,15 +538,11 @@ def get_model_inputimg(model_name, num_classes): model_zoo = sorted(['resnet50', 'densenet121',]) for model_name in model_zoo: - model = get_model_inputimg(model_name, num_classes=len(xrv.datasets.default_pathologies)) + model = get_model_inputimg(model_name, num_classes=len(our_pathologies)) model = model.to(device) - if cfg.baseline: - print("\n Training Baseline Model \n") - output_dir = "baseline_split-" + str(cfg.split) + "_" + model_name + "_valid-" + cfg.valid_data + "/" - else: - print("\n Training REx Baseline Model \n") - output_dir = "bmbs_split-" + str(cfg.split) + "_" + model_name + "_valid-" + cfg.valid_data + "/" + print("\n Training Model \n") + output_dir = str(cfg.merge_train) + "split-" + str(cfg.split) + "_" + model_name + "_valid-" + cfg.valid_data + "/" metrics, best_metric, = main(model, model_name, output_dir, num_epochs=cfg.num_epochs) print(f"Best validation AUC: {best_metric:4.4f}")