Skip to content

Commit

Permalink
aug by sites
Browse files Browse the repository at this point in the history
  • Loading branch information
ngxbac committed Jul 7, 2019
1 parent ea1b498 commit cd0a09c
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 27 deletions.
2 changes: 1 addition & 1 deletion bin/train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ export CUDA_VISIBLE_DEVICES=2,3
RUN_CONFIG=config.yml


LOGDIR=/raid/bac/kaggle/logs/recursion_cell/test/c123_s12_smooth_nadam_/se_resnext50_32x4d/
LOGDIR=/raid/bac/kaggle/logs/recursion_cell/test/c123_s1_smooth_nadam_rndsite/se_resnext50_32x4d/
catalyst-dl run \
--config=./configs/${RUN_CONFIG} \
--logdir=$LOGDIR \
Expand Down
4 changes: 2 additions & 2 deletions configs/config.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
model_params:
model: cell_senet
model_name: se_resnext50_32x4d
n_channels: 6
n_channels: 3
num_classes: 1108

args:
Expand Down Expand Up @@ -43,7 +43,7 @@ stages:
train_csv: "./csv/train_0.csv"
valid_csv: "./csv/valid_0.csv"
root: "/raid/data/kaggle/recursion-cellular-image-classification/"
sites: [1, 2]
sites: [1]
channels: [1, 2, 3]

stage1:
Expand Down
14 changes: 11 additions & 3 deletions src/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def __init__(self,
self.plates = df['plate'].values
self.wells = df['well'].values

if mode == 'train':
if mode != 'test':
self.labels = df['sirna'].values
else:
self.labels = [0] * len(self.experiments)
Expand All @@ -243,7 +243,15 @@ def __getitem__(self, idx):

channel_paths = []

for site in self.sites:
if self.mode == 'train':
if np.random.rand() < 0.5:
sites = [1]
else:
sites = [2]
else:
sites = self.sites

for site in sites:
for channel in self.channels:
path = image_path(
dataset=self.mode,
Expand All @@ -259,7 +267,7 @@ def __getitem__(self, idx):
std_arr = []
mean_arr = []

for site in self.sites:
for site in sites:
for channel in self.channels:
mean = self.stat_dict[experiment][plate][well][site][channel]["mean"]
std = self.stat_dict[experiment][plate][well][site][channel]["std"]
Expand Down
47 changes: 26 additions & 21 deletions src/make_submission.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@ def predict_all():
test_csv = '/raid/data/kaggle/recursion-cellular-image-classification/test.csv'
# test_csv = './csv/valid_0.csv'

experiment = 'c123_s12_smooth_nadam_'
experiment = 'c123_s1_smooth_nadam_rndsite'
model_name = 'se_resnext50_32x4d'

log_dir = f"/raid/bac/kaggle/logs/recursion_cell/test/{experiment}/{model_name}/"
root = "/raid/data/kaggle/recursion-cellular-image-classification/"
sites = [1, 2]
sites = [1]
channels = [1,2,3]

model = cell_senet(
Expand All @@ -55,26 +55,31 @@ def predict_all():
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(device)

# Dataset
dataset = RecursionCellularSite(
csv_file=test_csv,
root=root,
transform=valid_aug(512),
mode='test',
sites=sites,
channels=channels
)

loader = DataLoader(
dataset=dataset,
batch_size=128,
shuffle=False,
num_workers=4,
)

pred = predict(model, loader)
preds = []

all_preds = np.argmax(pred, axis=1)
for site in [1, 2]:
# Dataset
dataset = RecursionCellularSite(
csv_file=test_csv,
root=root,
transform=valid_aug(512),
mode='test',
sites=[site],
channels=channels
)

loader = DataLoader(
dataset=dataset,
batch_size=128,
shuffle=False,
num_workers=4,
)

pred = predict(model, loader)
preds.append(pred)

preds = np.asarray(preds).mean(axis=0)
all_preds = np.argmax(preds, axis=1)
df = pd.read_csv(test_csv)
submission = df.copy()
submission['sirna'] = all_preds.astype(int)
Expand Down

0 comments on commit cd0a09c

Please sign in to comment.