Skip to content

Commit

Permalink
Train resnet50 with cut mix
Browse files Browse the repository at this point in the history
  • Loading branch information
ngxbac committed Jul 30, 2019
1 parent ec4b398 commit e83ac86
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 52 deletions.
21 changes: 14 additions & 7 deletions bin/train.sh
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
#!/usr/bin/env bash

export CUDA_VISIBLE_DEVICES=0,1,2,3
export CUDA_VISIBLE_DEVICES=2,3
RUN_CONFIG=config.yml


LOGDIR=./bin/30_epochs/
catalyst-dl run \
--config=./configs/${RUN_CONFIG} \
--logdir=$LOGDIR \
--out_dir=$LOGDIR:str \
--verbose
for channels in [1,2,3,4,5]; do
for fold in 0; do
LOGDIR=/raid/bac/kaggle/logs/recursion_cell/test/190730/cutmix/fold_$fold/resnet50/
catalyst-dl run \
--config=./configs/${RUN_CONFIG} \
--logdir=$LOGDIR \
--out_dir=$LOGDIR:str \
--stages/data_params/channels=$channels:list \
--stages/data_params/train_csv=./csv/train_$fold.csv:str \
--stages/data_params/valid_csv=./csv/valid_$fold.csv:str \
--verbose
done
done
53 changes: 9 additions & 44 deletions configs/config.yml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
model_params:
model: cell_senet
model_name: se_resnext50_32x4d
n_channels: 4
model: ResNet50CutMix
n_channels: 5
num_classes: 1108
pretrained: "/raid/bac/pretrained_models/pytorch/ResNet50_CutMix_v2.pth"

args:
expdir: "src"
Expand All @@ -24,17 +24,17 @@ stages:

data_params:
batch_size: 64
num_workers: 16
num_workers: 8
drop_last: False
# drop_last: True

image_size: &image_size 512
train_csv: "./csv/train_0.csv"
valid_csv: "./csv/valid_0.csv"
root: "./data/"
root: "/raid/data/kaggle/recursion-cellular-image-classification/"
sites: [1]
channels: [1, 2, 3, 4]
##################################################
channels: [1, 2, 3, 4, 5]

stage0:

optimizer_params:
Expand Down Expand Up @@ -64,7 +64,7 @@ stages:
reduce_metric: *reduce_metric
saver:
callback: CheckpointCallback
##########################################################

stage1:

optimizer_params:
Expand All @@ -82,39 +82,4 @@ stages:
state_params:
num_epochs: 50

callbacks_params: *callback_params
######################################################
# stage2:

# optimizer_params:
# optimizer: Nadam
# lr: 0.00001

# scheduler_params:
# scheduler: ReduceLROnPlateau
# patience: 2
# # num_steps: 30
# # lr_range: [0.0007, 0.00001]
# # warmup_steps: 0
# # momentum_range: [0.85, 0.95]

# state_params:
# num_epochs: 10

# criterion_params:
# criterion: LabelSmoothingCrossEntropy

# callbacks_params:
# loss:
# callback: LabelSmoothCriterionCallback
# optimizer:
# callback: OptimizerCallback
# accumulation_steps: 2
# accuracy:
# callback: AccuracyCallback
# accuracy_args: [1]
# scheduler:
# callback: SchedulerCallback
# reduce_metric: *reduce_metric
# saver:
# callback: CheckpointCallback
callbacks_params: *callback_params
1 change: 1 addition & 0 deletions src/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
registry.Model(GluonResnetTIMM)
registry.Model(DSInceptionV3)
registry.Model(DSSENet)
registry.Model(ResNet50CutMix)

# Register callbacks
registry.Callback(LabelSmoothCriterionCallback)
Expand Down
2 changes: 1 addition & 1 deletion src/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .resnet import ResNet
from .resnet import ResNet, ResNet50CutMix
from .senet import cell_senet, SENetTIMM
from .densenet import cell_densenet
from .efficientnet import EfficientNet
Expand Down
52 changes: 52 additions & 0 deletions src/models/resnet.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import torch.nn as nn
import pretrainedmodels
import torch
from torchvision import models
from cnn_finetune import make_model
import timm
from .utils import *
Expand Down Expand Up @@ -37,3 +39,53 @@ def freeze(self):
def unfreeze(self):
for param in self.model.parameters():
param.requires_grad = True


class ResNet50CutMix(nn.Module):
def __init__(self, num_classes=1108,
pretrained=None,
n_channels=6):
super(ResNet50CutMix, self).__init__()

self.model = models.resnet50(pretrained=False)
if pretrained:
checkpoint = torch.load(pretrained)['model']

model_dict = self.model.state_dict()
for k in model_dict.keys():
if (('module.' + k) in checkpoint.keys()):
model_dict[k] = checkpoint.get(('module.' + k))
else:
print("{} is not in dict !".format(k))

self.model.load_state_dict(model_dict)
print("Loaded checkpoint: ", pretrained)

conv1 = self.model.conv1
self.model.conv1 = nn.Conv2d(in_channels=n_channels,
out_channels=conv1.out_channels,
kernel_size=conv1.kernel_size,
stride=conv1.stride,
padding=conv1.padding,
bias=conv1.bias)

# copy pretrained weights
self.model.conv1.weight.data[:, :3, :, :] = conv1.weight.data
self.model.conv1.weight.data[:, 3:n_channels, :, :] = conv1.weight.data[:, :int(n_channels - 3), :, :]

dim_feats = self.model.fc.in_features
self.model.fc = nn.Linear(dim_feats, num_classes)

def forward(self, x):
return self.model(x)

def freeze(self):
for param in self.model.parameters():
param.requires_grad = False

for param in self.model.fc.parameters():
param.requires_grad = True

def unfreeze(self):
for param in self.model.parameters():
param.requires_grad = True

0 comments on commit e83ac86

Please sign in to comment.