diff --git a/bin/train.sh b/bin/train.sh index 4172765..ba30cb6 100755 --- a/bin/train.sh +++ b/bin/train.sh @@ -1,12 +1,19 @@ #!/usr/bin/env bash -export CUDA_VISIBLE_DEVICES=0,1,2,3 +export CUDA_VISIBLE_DEVICES=2,3 RUN_CONFIG=config.yml -LOGDIR=./bin/30_epochs/ -catalyst-dl run \ - --config=./configs/${RUN_CONFIG} \ - --logdir=$LOGDIR \ - --out_dir=$LOGDIR:str \ - --verbose \ No newline at end of file +for channels in [1,2,3,4,5]; do + for fold in 0; do + LOGDIR=/raid/bac/kaggle/logs/recursion_cell/test/190730/cutmix/fold_$fold/resnet50/ + catalyst-dl run \ + --config=./configs/${RUN_CONFIG} \ + --logdir=$LOGDIR \ + --out_dir=$LOGDIR:str \ + --stages/data_params/channels=$channels:list \ + --stages/data_params/train_csv=./csv/train_$fold.csv:str \ + --stages/data_params/valid_csv=./csv/valid_$fold.csv:str \ + --verbose + done +done \ No newline at end of file diff --git a/configs/config.yml b/configs/config.yml index a6a7968..03b84c8 100644 --- a/configs/config.yml +++ b/configs/config.yml @@ -1,8 +1,8 @@ model_params: - model: cell_senet - model_name: se_resnext50_32x4d - n_channels: 4 + model: ResNet50CutMix + n_channels: 5 num_classes: 1108 + pretrained: "/raid/bac/pretrained_models/pytorch/ResNet50_CutMix_v2.pth" args: expdir: "src" @@ -24,17 +24,17 @@ stages: data_params: batch_size: 64 - num_workers: 16 + num_workers: 8 drop_last: False # drop_last: True image_size: &image_size 512 train_csv: "./csv/train_0.csv" valid_csv: "./csv/valid_0.csv" - root: "./data/" + root: "/raid/data/kaggle/recursion-cellular-image-classification/" sites: [1] - channels: [1, 2, 3, 4] -################################################## + channels: [1, 2, 3, 4, 5] + stage0: optimizer_params: @@ -64,7 +64,7 @@ stages: reduce_metric: *reduce_metric saver: callback: CheckpointCallback -########################################################## + stage1: optimizer_params: @@ -82,39 +82,4 @@ stages: state_params: num_epochs: 50 - callbacks_params: *callback_params -###################################################### - # stage2: - - # optimizer_params: - # optimizer: Nadam - # lr: 0.00001 - - # scheduler_params: - # scheduler: ReduceLROnPlateau - # patience: 2 - # # num_steps: 30 - # # lr_range: [0.0007, 0.00001] - # # warmup_steps: 0 - # # momentum_range: [0.85, 0.95] - - # state_params: - # num_epochs: 10 - - # criterion_params: - # criterion: LabelSmoothingCrossEntropy - - # callbacks_params: - # loss: - # callback: LabelSmoothCriterionCallback - # optimizer: - # callback: OptimizerCallback - # accumulation_steps: 2 - # accuracy: - # callback: AccuracyCallback - # accuracy_args: [1] - # scheduler: - # callback: SchedulerCallback - # reduce_metric: *reduce_metric - # saver: - # callback: CheckpointCallback + callbacks_params: *callback_params \ No newline at end of file diff --git a/src/__init__.py b/src/__init__.py index 44de322..83ac81f 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -18,6 +18,7 @@ registry.Model(GluonResnetTIMM) registry.Model(DSInceptionV3) registry.Model(DSSENet) +registry.Model(ResNet50CutMix) # Register callbacks registry.Callback(LabelSmoothCriterionCallback) diff --git a/src/models/__init__.py b/src/models/__init__.py index 654810a..82d7990 100644 --- a/src/models/__init__.py +++ b/src/models/__init__.py @@ -1,4 +1,4 @@ -from .resnet import ResNet +from .resnet import ResNet, ResNet50CutMix from .senet import cell_senet, SENetTIMM from .densenet import cell_densenet from .efficientnet import EfficientNet diff --git a/src/models/resnet.py b/src/models/resnet.py index 243f0ba..07f5f5b 100644 --- a/src/models/resnet.py +++ b/src/models/resnet.py @@ -1,5 +1,7 @@ import torch.nn as nn import pretrainedmodels +import torch +from torchvision import models from cnn_finetune import make_model import timm from .utils import * @@ -37,3 +39,53 @@ def freeze(self): def unfreeze(self): for param in self.model.parameters(): param.requires_grad = True + + +class ResNet50CutMix(nn.Module): + def __init__(self, num_classes=1108, + pretrained=None, + n_channels=6): + super(ResNet50CutMix, self).__init__() + + self.model = models.resnet50(pretrained=False) + if pretrained: + checkpoint = torch.load(pretrained)['model'] + + model_dict = self.model.state_dict() + for k in model_dict.keys(): + if (('module.' + k) in checkpoint.keys()): + model_dict[k] = checkpoint.get(('module.' + k)) + else: + print("{} is not in dict !".format(k)) + + self.model.load_state_dict(model_dict) + print("Loaded checkpoint: ", pretrained) + + conv1 = self.model.conv1 + self.model.conv1 = nn.Conv2d(in_channels=n_channels, + out_channels=conv1.out_channels, + kernel_size=conv1.kernel_size, + stride=conv1.stride, + padding=conv1.padding, + bias=conv1.bias) + + # copy pretrained weights + self.model.conv1.weight.data[:, :3, :, :] = conv1.weight.data + self.model.conv1.weight.data[:, 3:n_channels, :, :] = conv1.weight.data[:, :int(n_channels - 3), :, :] + + dim_feats = self.model.fc.in_features + self.model.fc = nn.Linear(dim_feats, num_classes) + + def forward(self, x): + return self.model(x) + + def freeze(self): + for param in self.model.parameters(): + param.requires_grad = False + + for param in self.model.fc.parameters(): + param.requires_grad = True + + def unfreeze(self): + for param in self.model.parameters(): + param.requires_grad = True \ No newline at end of file