Skip to content

Commit

Permalink
Fix inference
Browse files Browse the repository at this point in the history
  • Loading branch information
ngxbac committed Jul 13, 2019
1 parent 4e3b079 commit 529f319
Showing 1 changed file with 46 additions and 50 deletions.
96 changes: 46 additions & 50 deletions src/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@ def predict_all():
# test_csv = './csv/valid_0.csv'
model_name = 'se_resnext50_32x4d'

tta_dict = test_tta(image_size=512)

for channel_str in [
"[1,2,3,4]", "[1,2,3,5]", "[1,2,3,6]",
"[1,2,4,5]", "[1,2,4,6]", "[1,2,5,6]",
Expand All @@ -48,61 +46,59 @@ def predict_all():
"[2,3,5,6]", "[2,4,5,6]", "[3,4,5,6]"
]:

log_dir = f"/raid/bac/kaggle/logs/recursion_cell/search_channels/{channel_str}/{model_name}/"
log_dir = f"/raid/bac/kaggle/logs/recursion_cell/search_channels/fold_0/{channel_str}/{model_name}/"
root = "/raid/data/kaggle/recursion-cellular-image-classification/"
sites = [1]
channels = [int(i) for i in channel_str[1:-1].split(',')]

log_dir = log_dir.replace('[', '[[]')
# log_dir = log_dir.replace('[', '[[]')
# log_dir = log_dir.replace(']', '[]]')

ckp = os.path.join(log_dir, "checkpoints/best.pth")
model = cell_senet(
model_name="se_resnext50_32x4d",
num_classes=1108,
n_channels=len(channels) * len(sites)
)

checkpoint = torch.load(ckp)
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(device)
# model = nn.DataParallel(model)

print("*" * 50)
print(f"checkpoint: {ckp}")
print(f"Channel: {channel_str}")
preds = []
for site in [1, 2]:
# Dataset
dataset = RecursionCellularSite(
csv_file=test_csv,
root=root,
transform=valid_aug(512),
mode='test',
sites=[site],
channels=channels
)

all_checkpoints = glob.glob(f"{log_dir}/checkpoints/stage1.*")
for ckp_num, ckp in enumerate(all_checkpoints):
model = cell_senet(
model_name="se_resnext50_32x4d",
num_classes=1108,
n_channels=len(channels) * len(sites)
loader = DataLoader(
dataset=dataset,
batch_size=128,
shuffle=False,
num_workers=8,
)

checkpoint = torch.load(ckp)
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(device)
# model = nn.DataParallel(model)

for tta_key, tta_value in tta_dict.items():
print("*" * 50)
print(f"checkpoint: {ckp}")
print(f"Channel: {channel_str}")
print(f"TTA: {tta_key}")
preds = []
for site in [1, 2]:
# Dataset
dataset = RecursionCellularSite(
csv_file=test_csv,
root=root,
transform=tta_value,
mode='test',
sites=[site],
channels=channels
)

loader = DataLoader(
dataset=dataset,
batch_size=128,
shuffle=False,
num_workers=8,
)

pred = predict(model, loader)
preds.append(pred)

preds = np.asarray(preds).mean(axis=0)
all_preds = np.argmax(preds, axis=1)
df = pd.read_csv(test_csv)
submission = df.copy()
submission['sirna'] = all_preds.astype(int)
os.makedirs("./prediction/fold_0/", exist_ok=True)
submission.to_csv(f'./prediction/fold_0/{model_name}_{channel_str}_ckp{ckp_num}_tta_{tta_key}_test.csv', index=False, columns=['id_code', 'sirna'])
np.save(f"./prediction/fold_0/{model_name}_{channel_str}_ckp{ckp_num}_tta_{tta_key}_test.npy", preds)
pred = predict(model, loader)
preds.append(pred)

preds = np.asarray(preds).mean(axis=0)
all_preds = np.argmax(preds, axis=1)
df = pd.read_csv(test_csv)
submission = df.copy()
submission['sirna'] = all_preds.astype(int)
os.makedirs("./prediction/fold_0/", exist_ok=True)
submission.to_csv(f'./prediction/fold_0/{model_name}_{channel_str}_test.csv', index=False, columns=['id_code', 'sirna'])
np.save(f"./prediction/fold_0/{model_name}_{channel_str}_test.npy", preds)


if __name__ == '__main__':
Expand Down

0 comments on commit 529f319

Please sign in to comment.