Skip to content

Commit

Permalink
Train with multi channels
Browse files Browse the repository at this point in the history
  • Loading branch information
ngxbac committed Jul 9, 2019
1 parent c6f5c1b commit 0d6360a
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 9 deletions.
4 changes: 2 additions & 2 deletions bin/train.sh
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
#!/usr/bin/env bash

export CUDA_VISIBLE_DEVICES=2
export CUDA_VISIBLE_DEVICES=2,3
RUN_CONFIG=config.yml


LOGDIR=/raid/bac/kaggle/logs/recursion_cell/test/c1234_s1_smooth_nadam_rndsite/se_resnext50_32x4d/
LOGDIR=/raid/bac/kaggle/logs/recursion_cell/test/c1234_s1_smooth_nadam_rndsite_64/se_resnext50_32x4d/
catalyst-dl run \
--config=./configs/${RUN_CONFIG} \
--logdir=$LOGDIR \
Expand Down
15 changes: 15 additions & 0 deletions bin/train_multi_channels.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#!/usr/bin/env bash

export CUDA_VISIBLE_DEVICES=2,3
RUN_CONFIG=config.yml


for channels in [1,2,3,5] [1,2,3,6] [1,2,4,5] [1,2,4,6] [1,2,5,6] [1,3,4,5] [1,3,4,6] [1,3,5,6] [1,4,5,6] [2,3,4,5] [2,3,4,6] [2,3,5,6] [2,4,5,6] [3,4,5,6]; do
LOGDIR=/raid/bac/kaggle/logs/recursion_cell/search_channels/$channels/se_resnext50_32x4d/
catalyst-dl run \
--config=./configs/${RUN_CONFIG} \
--logdir=$LOGDIR \
--out_dir=$LOGDIR:str \
--stages/data_params/channels=$channels:list \
--verbose
done
4 changes: 2 additions & 2 deletions configs/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ stages:
momentum_range: [0.85, 0.95]

data_params:
batch_size: 32
batch_size: 64
num_workers: 8
drop_last: False

Expand All @@ -57,7 +57,7 @@ stages:
callback: LabelSmoothCriterionCallback
optimizer:
callback: OptimizerCallback
accumulation_steps: 4
accumulation_steps: 2
accuracy:
callback: AccuracyCallback
accuracy_args: [1]
Expand Down
94 changes: 94 additions & 0 deletions src/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import pandas as pd
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as Ftorch
from torch.utils.data import DataLoader
import os
import glob
import click
from tqdm import *

from models import *
from augmentation import *
from dataset import RecursionCellularSite


device = torch.device('cuda')


def predict(model, loader):
model.eval()
preds = []
with torch.no_grad():
for dct in tqdm(loader, total=len(loader)):
images = dct['images'].to(device)
pred = model(images)
pred = Ftorch.softmax(pred)
pred = pred.detach().cpu().numpy()
preds.append(pred)

preds = np.concatenate(preds, axis=0)
return preds


def predict_all():
test_csv = '/raid/data/kaggle/recursion-cellular-image-classification/test.csv'
# test_csv = './csv/valid_0.csv'
model_name = 'se_resnext50_32x4d'

for channel_str in ["[1,2,4,5]", "[1,2,3,5]", "[1,2,5,6]", "[1,3,4,5]"]:
experiment = 'c1234_s1_smooth_nadam_rndsite_64'

log_dir = f"/raid/bac/kaggle/logs/recursion_cell/search_channels/{channel_str}/{model_name}/"
root = "/raid/data/kaggle/recursion-cellular-image-classification/"
sites = [1]
channels = [int(i) for i in channel_str[1:-1].split(',')]

preds = []
model = cell_senet(
model_name="se_resnext50_32x4d",
num_classes=1108,
n_channels=len(channels) * len(sites)
)

checkpoint = f"{log_dir}/checkpoints/best.pth"
checkpoint = torch.load(checkpoint)
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(device)
model = nn.DataParallel(model)

for site in [1, 2]:
# Dataset
dataset = RecursionCellularSite(
csv_file=test_csv,
root=root,
transform=valid_aug(512),
mode='test',
sites=[site],
channels=channels
)

loader = DataLoader(
dataset=dataset,
batch_size=8,
shuffle=False,
num_workers=4,
)

pred = predict(model, loader)
preds.append(pred)

preds = np.asarray(preds).mean(axis=0)
all_preds = np.argmax(preds, axis=1)
df = pd.read_csv(test_csv)
submission = df.copy()
submission['sirna'] = all_preds.astype(int)
os.makedirs("prediction", exist_ok=True)
submission.to_csv(f'./prediction/{model_name}_{channel_str}.csv', index=False, columns=['id_code', 'sirna'])
np.save(f"./prediction/{model_name}_{channel_str}.npy", preds)


if __name__ == '__main__':
predict_all()
11 changes: 6 additions & 5 deletions src/make_submission.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as Ftorch
from torch.utils.data import DataLoader
import os
Expand Down Expand Up @@ -36,14 +37,15 @@ def predict_all():
test_csv = '/raid/data/kaggle/recursion-cellular-image-classification/test.csv'
# test_csv = './csv/valid_0.csv'

experiment = 'c1234_s1_smooth_nadam_rndsite'
experiment = 'c1234_s1_smooth_nadam_rndsite_64'
model_name = 'se_resnext50_32x4d'

log_dir = f"/raid/bac/kaggle/logs/recursion_cell/test/{experiment}/{model_name}/"
root = "/raid/data/kaggle/recursion-cellular-image-classification/"
sites = [1]
channels = [1,2,3,4]

preds = []
model = cell_senet(
model_name="se_resnext50_32x4d",
num_classes=1108,
Expand All @@ -54,8 +56,7 @@ def predict_all():
checkpoint = torch.load(checkpoint)
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(device)

preds = []
model = nn.DataParallel(model)

for site in [1, 2]:
# Dataset
Expand Down Expand Up @@ -84,8 +85,8 @@ def predict_all():
submission = df.copy()
submission['sirna'] = all_preds.astype(int)
os.makedirs("submission", exist_ok=True)
submission.to_csv(f'./submission/{model_name}_{experiment}.csv', index=False, columns=['id_code', 'sirna'])
np.save(f"./submission/{model_name}_{experiment}.npy", pred)
submission.to_csv(f'./submission/{model_name}_{experiment}_3ckps.csv', index=False, columns=['id_code', 'sirna'])
np.save(f"./submission/{model_name}_{experiment}_3ckps.npy", preds)


if __name__ == '__main__':
Expand Down

0 comments on commit 0d6360a

Please sign in to comment.