Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
etetteh committed Sep 15, 2021
1 parent ecf9fd4 commit 52ad456
Showing 1 changed file with 5 additions and 24 deletions.
29 changes: 5 additions & 24 deletions chest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from os.path import exists, join
from tqdm import tqdm as tqdm_base

import timm
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
Expand All @@ -30,7 +29,7 @@

import torchxrayvision as xrv
from merger import Merge_Dataset
from timm.optim import AdamP


parser = argparse.ArgumentParser(description='X-RAY Pathology Detection')
parser.add_argument('--seed', type=int, default=0, help='')
Expand All @@ -50,10 +49,10 @@

### Data loader
parser.add_argument('--cuda', type=bool, default=True, help='')
parser.add_argument('--batch_size', type=int, default=128, help='')
parser.add_argument('--batch_size', type=int, default=64, help='')
parser.add_argument('--shuffle', type=bool, default=True, help='')
parser.add_argument('--num_workers', type=int, default=0, help='')
parser.add_argument('--num_batches', type=int, default=215, help='')
parser.add_argument('--num_batches', type=int, default=430, help='')
parser.add_argument('--num_epochs', type=int, default=200, help='')

### Data Augmentation
Expand All @@ -64,20 +63,6 @@
cfg = parser.parse_args()
print(cfg)

np.random.seed(cfg.seed)
random.seed(cfg.seed)
torch.manual_seed(cfg.seed)
g = torch.Generator()
g.manual_seed(cfg.seed)

def seed_worker(worker_id):
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
random.seed(worker_seed)




def tqdm(*args, **kwargs):
if hasattr(tqdm_base, '_instances'):
for instance in list(tqdm_base._instances):
Expand Down Expand Up @@ -234,7 +219,6 @@ def tqdm(*args, **kwargs):
batch_size=cfg.batch_size,
shuffle=True,
num_workers=cfg.num_workers,
worker_init_fn=seed_worker,
pin_memory=True,
drop_last=True)
train_loader[0].insert(0, dataloader)
Expand All @@ -247,7 +231,6 @@ def tqdm(*args, **kwargs):
batch_size=cfg.batch_size,
shuffle=True,
num_workers=cfg.num_workers,
worker_init_fn=seed_worker,
pin_memory=True,
drop_last=True)
dataloader.insert(0, tr_l)
Expand All @@ -256,15 +239,13 @@ def tqdm(*args, **kwargs):
batch_size=cfg.batch_size,
shuffle=True,
num_workers=cfg.num_workers,
worker_init_fn=seed_worker,
pin_memory=True,
drop_last=True)

test_loader = DataLoader(test_data,
batch_size=cfg.batch_size,
shuffle=False,
num_workers=cfg.num_workers,
worker_init_fn=seed_worker,
pin_memory=True,
drop_last=False)

Expand Down Expand Up @@ -335,7 +316,7 @@ def train_epoch(num_batches, epoch, model, device, train_loader, criterion, opti
loss += (loss1 + loss2)
loss += 1e-5 * weight_norm

loss.backward(retain_graph=True)
loss.backward()
optimizer.step()

avg_loss.append(train_nll.detach().cpu().numpy())
Expand Down Expand Up @@ -585,7 +566,7 @@ def get_model_inputimg(model_name, num_classes):
output_dir = "baseline_split-" + str(cfg.split) + "_" + model_name + "_valid-" + cfg.valid_data + "/"
else:
print("\n Training REx Baseline Model \n")
output_dir = "rex_baseline_split-" + str(cfg.split) + "_" + model_name + "_valid-" + cfg.valid_data + "/"
output_dir = "bmbs_split-" + str(cfg.split) + "_" + model_name + "_valid-" + cfg.valid_data + "/"

metrics, best_metric, = main(model, model_name, output_dir, num_epochs=cfg.num_epochs)
print(f"Best validation AUC: {best_metric:4.4f}")
Expand Down

0 comments on commit 52ad456

Please sign in to comment.