Skip to content

Commit

Permalink
Add more augmentaions
Browse files Browse the repository at this point in the history
  • Loading branch information
ngxbac committed Jul 3, 2019
1 parent 48038a4 commit af1d11b
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 6 deletions.
2 changes: 1 addition & 1 deletion bin/train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ export CUDA_VISIBLE_DEVICES=2,3
RUN_CONFIG=config.yml


LOGDIR=/raid/bac/kaggle/logs/recursion_cell/test/rgb/se_resnext50_32x4d/
LOGDIR=/raid/bac/kaggle/logs/recursion_cell/test/rgb_more_augs_2/se_resnext50_32x4d/
catalyst-dl run \
--config=./configs/${RUN_CONFIG} \
--logdir=$LOGDIR \
Expand Down
11 changes: 8 additions & 3 deletions src/augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,19 @@

def train_aug(image_size=224):
return Compose([
RandomCrop(image_size, image_size),
HorizontalFlip(),
RandomCrop(448, 448),
Resize(image_size, image_size),
RandomRotate90(),
Flip(),
Transpose(),
# HorizontalFlip(),
# Normalize(),
], p=1)


def valid_aug(image_size=224):
return Compose([
CenterCrop(image_size, image_size),
CenterCrop(448, 448),
Resize(320, 320)
# Normalize(),
], p=1)
4 changes: 2 additions & 2 deletions src/make_submission.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def predict(model, loader):

def predict_all():
test_csv = '/raid/data/kaggle/recursion-cellular-image-classification/test.csv'
log_dir = "/raid/bac/kaggle/logs/recursion_cell/test/rgb/se_resnext50_32x4d/"
log_dir = "/raid/bac/kaggle/logs/recursion_cell/test/rgb_more_augs_2/se_resnext50_32x4d/"
root = "/raid/data/kaggle/recursion-cellular-image-classification/"
site = 1
channels = [1,2,3]
Expand Down Expand Up @@ -74,7 +74,7 @@ def predict_all():
submission = df.copy()
submission['sirna'] = all_preds.astype(int)
os.makedirs("submission", exist_ok=True)
submission.to_csv('./submission/submission_se_resnext50_32x4d_rgb.csv', index=False, columns=['id_code', 'sirna'])
submission.to_csv('./submission/submission_se_resnext50_32x4d_rgb_more_aug_2.csv', index=False, columns=['id_code', 'sirna'])


if __name__ == '__main__':
Expand Down

0 comments on commit af1d11b

Please sign in to comment.