diff --git a/chest.py b/chest.py index f32e8a8..7c277af 100644 --- a/chest.py +++ b/chest.py @@ -17,7 +17,6 @@ from os.path import exists, join from tqdm import tqdm as tqdm_base -import timm import torch import torch.nn as nn from torch.utils.data import DataLoader @@ -30,7 +29,7 @@ import torchxrayvision as xrv from merger import Merge_Dataset -from timm.optim import AdamP + parser = argparse.ArgumentParser(description='X-RAY Pathology Detection') parser.add_argument('--seed', type=int, default=0, help='') @@ -50,10 +49,10 @@ ### Data loader parser.add_argument('--cuda', type=bool, default=True, help='') -parser.add_argument('--batch_size', type=int, default=128, help='') +parser.add_argument('--batch_size', type=int, default=64, help='') parser.add_argument('--shuffle', type=bool, default=True, help='') parser.add_argument('--num_workers', type=int, default=0, help='') -parser.add_argument('--num_batches', type=int, default=215, help='') +parser.add_argument('--num_batches', type=int, default=430, help='') parser.add_argument('--num_epochs', type=int, default=200, help='') ### Data Augmentation @@ -64,20 +63,6 @@ cfg = parser.parse_args() print(cfg) -np.random.seed(cfg.seed) -random.seed(cfg.seed) -torch.manual_seed(cfg.seed) -g = torch.Generator() -g.manual_seed(cfg.seed) - -def seed_worker(worker_id): - worker_seed = torch.initial_seed() % 2**32 - np.random.seed(worker_seed) - random.seed(worker_seed) - - - - def tqdm(*args, **kwargs): if hasattr(tqdm_base, '_instances'): for instance in list(tqdm_base._instances): @@ -234,7 +219,6 @@ def tqdm(*args, **kwargs): batch_size=cfg.batch_size, shuffle=True, num_workers=cfg.num_workers, - worker_init_fn=seed_worker, pin_memory=True, drop_last=True) train_loader[0].insert(0, dataloader) @@ -247,7 +231,6 @@ def tqdm(*args, **kwargs): batch_size=cfg.batch_size, shuffle=True, num_workers=cfg.num_workers, - worker_init_fn=seed_worker, pin_memory=True, drop_last=True) dataloader.insert(0, tr_l) @@ -256,7 +239,6 @@ def tqdm(*args, **kwargs): batch_size=cfg.batch_size, shuffle=True, num_workers=cfg.num_workers, - worker_init_fn=seed_worker, pin_memory=True, drop_last=True) @@ -264,7 +246,6 @@ def tqdm(*args, **kwargs): batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.num_workers, - worker_init_fn=seed_worker, pin_memory=True, drop_last=False) @@ -335,7 +316,7 @@ def train_epoch(num_batches, epoch, model, device, train_loader, criterion, opti loss += (loss1 + loss2) loss += 1e-5 * weight_norm - loss.backward(retain_graph=True) + loss.backward() optimizer.step() avg_loss.append(train_nll.detach().cpu().numpy()) @@ -585,7 +566,7 @@ def get_model_inputimg(model_name, num_classes): output_dir = "baseline_split-" + str(cfg.split) + "_" + model_name + "_valid-" + cfg.valid_data + "/" else: print("\n Training REx Baseline Model \n") - output_dir = "rex_baseline_split-" + str(cfg.split) + "_" + model_name + "_valid-" + cfg.valid_data + "/" + output_dir = "bmbs_split-" + str(cfg.split) + "_" + model_name + "_valid-" + cfg.valid_data + "/" metrics, best_metric, = main(model, model_name, output_dir, num_epochs=cfg.num_epochs) print(f"Best validation AUC: {best_metric:4.4f}")