Skip to content

Commit

Permalink
Change augmentations. Accum 2, bs 64
Browse files Browse the repository at this point in the history
  • Loading branch information
ngxbac committed Jul 4, 2019
1 parent af1d11b commit 5f080ec
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 10 deletions.
2 changes: 1 addition & 1 deletion bin/train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ export CUDA_VISIBLE_DEVICES=2,3
RUN_CONFIG=config.yml


LOGDIR=/raid/bac/kaggle/logs/recursion_cell/test/rgb_more_augs_2/se_resnext50_32x4d/
LOGDIR=/raid/bac/kaggle/logs/recursion_cell/test/rgb_no_crop_512_accum2/se_resnext50_32x4d/
catalyst-dl run \
--config=./configs/${RUN_CONFIG} \
--logdir=$LOGDIR \
Expand Down
5 changes: 3 additions & 2 deletions configs/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@ stages:
gamma: 0.5

data_params:
batch_size: 128
batch_size: 64
num_workers: 4
drop_last: False

image_size: &image_size 320
image_size: &image_size 512
train_csv: "./csv/train_0.csv"
valid_csv: "./csv/valid_0.csv"
root: "/raid/data/kaggle/recursion-cellular-image-classification/"
Expand All @@ -53,6 +53,7 @@ stages:
callback: CriterionCallback
optimizer:
callback: OptimizerCallback
accumulation_steps: 2
accuracy:
callback: AccuracyCallback
accuracy_args: [1]
Expand Down
6 changes: 3 additions & 3 deletions src/augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

def train_aug(image_size=224):
return Compose([
RandomCrop(448, 448),
# RandomCrop(448, 448),
Resize(image_size, image_size),
RandomRotate90(),
Flip(),
Expand All @@ -15,7 +15,7 @@ def train_aug(image_size=224):

def valid_aug(image_size=224):
return Compose([
CenterCrop(448, 448),
Resize(320, 320)
# CenterCrop(448, 448),
Resize(image_size, image_size)
# Normalize(),
], p=1)
10 changes: 6 additions & 4 deletions src/make_submission.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ def predict(model, loader):

def predict_all():
test_csv = '/raid/data/kaggle/recursion-cellular-image-classification/test.csv'
log_dir = "/raid/bac/kaggle/logs/recursion_cell/test/rgb_more_augs_2/se_resnext50_32x4d/"
# test_csv = './csv/valid_0.csv'
log_dir = "/raid/bac/kaggle/logs/recursion_cell/test/rgb_no_crop_512/se_resnext50_32x4d/"
root = "/raid/data/kaggle/recursion-cellular-image-classification/"
site = 1
channels = [1,2,3]
Expand All @@ -54,7 +55,7 @@ def predict_all():
dataset = RecursionCellularSite(
csv_file=test_csv,
root=root,
transform=valid_aug(320),
transform=valid_aug(512),
mode='test',
site=site,
channels=channels
Expand All @@ -74,8 +75,9 @@ def predict_all():
submission = df.copy()
submission['sirna'] = all_preds.astype(int)
os.makedirs("submission", exist_ok=True)
submission.to_csv('./submission/submission_se_resnext50_32x4d_rgb_more_aug_2.csv', index=False, columns=['id_code', 'sirna'])
submission.to_csv('./submission/se_resnext50_32x4d_no_crop_512_test.csv', index=False, columns=['id_code', 'sirna'])
np.save("./submission/se_resnext50_32x4d_no_crop_512_test.npy", pred)


if __name__ == '__main__':
predict_all()
predict_all()

0 comments on commit 5f080ec

Please sign in to comment.