From 72c2c99a90d30ec2f7b18a1898988680a8ad0585 Mon Sep 17 00:00:00 2001 From: Nguyen Xuan Bac Date: Mon, 29 Jul 2019 16:12:57 +0900 Subject: [PATCH] Deepsupervision- inceptionv3 --- bin/train_ds.sh | 19 +++ configs/config_ds.yml | 99 ++++++++++++ src/__init__.py | 3 + src/callbacks.py | 103 +++++++++++- src/models/__init__.py | 3 +- src/models/deepsupervision/__init__.py | 1 + src/models/deepsupervision/inception_v3.py | 175 +++++++++++++++++++++ 7 files changed, 401 insertions(+), 2 deletions(-) create mode 100755 bin/train_ds.sh create mode 100644 configs/config_ds.yml create mode 100644 src/models/deepsupervision/__init__.py create mode 100644 src/models/deepsupervision/inception_v3.py diff --git a/bin/train_ds.sh b/bin/train_ds.sh new file mode 100755 index 0000000..c566b27 --- /dev/null +++ b/bin/train_ds.sh @@ -0,0 +1,19 @@ +#!/usr/bin/env bash + +export CUDA_VISIBLE_DEVICES=2,3 +RUN_CONFIG=config_ds.yml + + +for channels in [1,2,3,4,5]; do + for fold in 0; do + LOGDIR=/raid/bac/kaggle/logs/recursion_cell/test/190729/fold_$fold/DSInceptionV3/ + catalyst-dl run \ + --config=./configs/${RUN_CONFIG} \ + --logdir=$LOGDIR \ + --out_dir=$LOGDIR:str \ + --stages/data_params/channels=$channels:list \ + --stages/data_params/train_csv=./csv/train_$fold.csv:str \ + --stages/data_params/valid_csv=./csv/valid_$fold.csv:str \ + --verbose + done +done \ No newline at end of file diff --git a/configs/config_ds.yml b/configs/config_ds.yml new file mode 100644 index 0000000..0821240 --- /dev/null +++ b/configs/config_ds.yml @@ -0,0 +1,99 @@ +model_params: + model: DSInceptionV3 + pretrained: True + n_channels: 5 + num_classes: 1108 + +args: + expdir: "src" + logdir: &logdir "./logs/cell" + baselogdir: "./logs/cell" + +distributed_params: + opt_level: O1 + +stages: + + state_params: + main_metric: &reduce_metric acc_final + minimize_metric: False + + criterion_params: +# criterion: CrossEntropyLoss + criterion: LabelSmoothingCrossEntropy + + data_params: + batch_size: 64 + num_workers: 16 + drop_last: False + # drop_last: True + + image_size: &image_size 512 + train_csv: "./csv/train_0.csv" + valid_csv: "./csv/valid_0.csv" + root: "/raid/data/kaggle/recursion-cellular-image-classification/" + sites: [1] + channels: [1, 2, 3, 4] + + stage0: + + optimizer_params: + optimizer: Nadam + lr: 0.001 + + scheduler_params: + scheduler: MultiStepLR + milestones: [10] + gamma: 0.3 + + state_params: + num_epochs: 2 + + callbacks_params: &callback_params + loss: + callback: DSCriterionCallback + loss_weights: [0.001, 0.005, 0.01, 0.02, 0.02, 0.1, 1.0] + optimizer: + callback: OptimizerCallback + accumulation_steps: 2 + accuracy: + callback: DSAccuracyCallback + logit_names: ["m2", "m4", "m6", "m8", "m9", "m10", "final"] + scheduler: + callback: SchedulerCallback + reduce_metric: *reduce_metric + saver: + callback: CheckpointCallback + + stage1: + + optimizer_params: + optimizer: Nadam + lr: 0.0001 + + scheduler_params: + scheduler: OneCycleLR + num_steps: 50 + lr_range: [0.0005, 0.00001] + # lr_range: [0.0015, 0.00003] + warmup_steps: 5 + momentum_range: [0.85, 0.95] + + state_params: + num_epochs: 50 + + callbacks_params: + loss: + callback: DSCriterionCallback + loss_weights: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] + optimizer: + callback: OptimizerCallback + accumulation_steps: 2 + accuracy: + callback: DSAccuracyCallback + logit_names: ["m2", "m4", "m6", "m8", "m9", "m10", "final"] + scheduler: + callback: SchedulerCallback + reduce_metric: *reduce_metric + saver: + callback: CheckpointCallback \ No newline at end of file diff --git a/src/__init__.py b/src/__init__.py index b1e4f40..5bca301 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -16,10 +16,13 @@ registry.Model(SENetTIMM) registry.Model(InceptionV3TIMM) registry.Model(GluonResnetTIMM) +registry.Model(DSInceptionV3) # Register callbacks registry.Callback(LabelSmoothCriterionCallback) registry.Callback(SmoothMixupCallback) +registry.Callback(DSAccuracyCallback) +registry.Callback(DSCriterionCallback) # Register criterions registry.Criterion(LabelSmoothingCrossEntropy) diff --git a/src/callbacks.py b/src/callbacks.py index a301f58..2a19b13 100644 --- a/src/callbacks.py +++ b/src/callbacks.py @@ -1,5 +1,6 @@ from catalyst.dl.core import Callback, RunnerState from catalyst.dl.callbacks import CriterionCallback +from catalyst.dl.utils.criterion import accuracy import torch import torch.nn as nn import numpy as np @@ -138,4 +139,104 @@ def _compute_loss(self, state: RunnerState, criterion): loss = self.lam * criterion(pred, y_a) + \ (1 - self.lam) * criterion(pred, y_b) - return loss \ No newline at end of file + return loss + + + + +class DSCriterionCallback(Callback): + def __init__( + self, + input_key: str = "targets", + output_key: str = "logits", + prefix: str = "loss", + criterion_key: str = None, + loss_key: str = None, + multiplier: float = 1.0, + loss_weights: List[float] = None, + ): + self.input_key = input_key + self.output_key = output_key + self.prefix = prefix + self.criterion_key = criterion_key + self.loss_key = loss_key + self.multiplier = multiplier + self.loss_weights = loss_weights + + def _add_loss_to_state(self, state: RunnerState, loss): + if self.loss_key is None: + if state.loss is not None: + if isinstance(state.loss, list): + state.loss.append(loss) + else: + state.loss = [state.loss, loss] + else: + state.loss = loss + else: + if state.loss is not None: + assert isinstance(state.loss, dict) + state.loss[self.loss_key] = loss + else: + state.loss = {self.loss_key: loss} + + def _compute_loss(self, state: RunnerState, criterion): + outputs = state.output[self.output_key] + input = state.input[self.input_key] + assert len(self.loss_weights) == len(outputs) + loss = 0 + for i, output in enumerate(outputs): + loss += criterion(output, input) * self.loss_weights[i] + return loss + + def on_stage_start(self, state: RunnerState): + assert state.criterion is not None + + def on_batch_end(self, state: RunnerState): + if state.loader_name.startswith("train"): + criterion = state.get_key( + key="criterion", inner_key=self.criterion_key + ) + else: + criterion = nn.CrossEntropyLoss() + + loss = self._compute_loss(state, criterion) * self.multiplier + + state.metrics.add_batch_value(metrics_dict={ + self.prefix: loss.item(), + }) + + self._add_loss_to_state(state, loss) + + +class DSAccuracyCallback(Callback): + """ + Accuracy metric callback. + """ + + def __init__( + self, + input_key: str = "targets", + output_key: str = "logits", + prefix: str = "acc", + logit_names: List[str] = None, + ): + self.prefix = prefix + self.metric_fn = accuracy + self.input_key = input_key + self.output_key = output_key + self.logit_names = logit_names + + def on_batch_end(self, state: RunnerState): + outputs = state.output[self.output_key] + targets = state.input[self.input_key] + + assert len(outputs) == len(self.logit_names) + + batch_metrics = {} + + for logit_name, output in zip(self.logit_names, outputs): + metric = self.metric_fn(output, targets) + key = f"{self.prefix}_{logit_name}" + batch_metrics[key] = metric[0] + + state.metrics.add_batch_value(metrics_dict=batch_metrics) diff --git a/src/models/__init__.py b/src/models/__init__.py index ab7e2f5..72226d0 100644 --- a/src/models/__init__.py +++ b/src/models/__init__.py @@ -3,4 +3,5 @@ from .densenet import cell_densenet from .efficientnet import EfficientNet from .inceptionv3 import InceptionV3TIMM -from .gluon_resnet import GluonResnetTIMM \ No newline at end of file +from .gluon_resnet import GluonResnetTIMM +from .deepsupervision import DSInceptionV3 \ No newline at end of file diff --git a/src/models/deepsupervision/__init__.py b/src/models/deepsupervision/__init__.py new file mode 100644 index 0000000..253877c --- /dev/null +++ b/src/models/deepsupervision/__init__.py @@ -0,0 +1 @@ +from .inception_v3 import DSInceptionV3 \ No newline at end of file diff --git a/src/models/deepsupervision/inception_v3.py b/src/models/deepsupervision/inception_v3.py new file mode 100644 index 0000000..240f7cd --- /dev/null +++ b/src/models/deepsupervision/inception_v3.py @@ -0,0 +1,175 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision import models +from catalyst.contrib.modules.common import Flatten +from catalyst.contrib.modules.pooling import GlobalConcatPool2d + + +class DSInceptionV3(nn.Module): + def __init__( + self, + num_classes=6, + pretrained=True, + n_channels=4, + + ): + super(DSInceptionV3, self).__init__() + self.model = models.inception_v3( + pretrained=pretrained, + transform_input=False, + # aux_logits=False + ) + + # Adapt number of channels + conv1 = self.model.Conv2d_1a_3x3.conv + self.model.Conv2d_1a_3x3.conv = nn.Conv2d(in_channels=n_channels, + out_channels=conv1.out_channels, + kernel_size=conv1.kernel_size, + stride=conv1.stride, + padding=conv1.padding, + bias=conv1.bias) + + # copy pretrained weights + self.model.Conv2d_1a_3x3.conv.weight.data[:, :3, :, :] = conv1.weight.data + self.model.Conv2d_1a_3x3.conv.weight.data[:, 3:n_channels, :, :] = conv1.weight.data[:, :int(n_channels - 3), :, :] + + self.deepsuper_2 = nn.Sequential( + GlobalConcatPool2d(), + Flatten(), + nn.BatchNorm1d(288 * 2), + nn.Linear(288 * 2, num_classes) + ) + + self.deepsuper_4 = nn.Sequential( + GlobalConcatPool2d(), + Flatten(), + nn.BatchNorm1d(768 * 2), + nn.Linear(768 * 2, num_classes) + ) + + self.deepsuper_6 = nn.Sequential( + GlobalConcatPool2d(), + Flatten(), + nn.BatchNorm1d(768 * 2), + nn.Linear(768 * 2, num_classes) + ) + + self.deepsuper_8 = nn.Sequential( + GlobalConcatPool2d(), + Flatten(), + nn.BatchNorm1d(1280 * 2), + nn.Linear(1280 * 2, num_classes) + ) + + self.deepsuper_9 = nn.Sequential( + GlobalConcatPool2d(), + Flatten(), + nn.BatchNorm1d(2048 * 2), + nn.Linear(2048 * 2, num_classes) + ) + + self.deepsuper_10 = nn.Sequential( + GlobalConcatPool2d(), + Flatten(), + nn.BatchNorm1d(2048 * 2), + nn.Linear(2048 * 2, num_classes) + ) + + # WARNING: should adapt the Linear layer to be suitable for each image size !!! + self.fc = nn.Sequential( + nn.Conv2d(in_channels=2048, out_channels=128, kernel_size=(1, 1)), + nn.ReLU(), + Flatten(), + nn.Linear(25088, 1024), # Take care here: 3200 for 224x224, 25088 for 512x512 + nn.ReLU(), + nn.Dropout(0.3), + nn.Linear(1024, num_classes) + ) + + self.is_infer = False + + def freeze_base(self): + # pass + for param in self.model.parameters(): + param.requires_grad = False + + def unfreeze_base(self): + # pass + for param in self.model.parameters(): + param.requires_grad = True + + def forward(self, x): + if self.model.transform_input: + x = x.clone() + x[:, 0] = x[:, 0] * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 + x[:, 1] = x[:, 1] * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 + x[:, 2] = x[:, 2] * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 + # 299 x 299 x 3 + x = self.model.Conv2d_1a_3x3(x) + # 149 x 149 x 32 + x = self.model.Conv2d_2a_3x3(x) + # 147 x 147 x 32 + x = self.model.Conv2d_2b_3x3(x) + # 147 x 147 x 64 + x = F.max_pool2d(x, kernel_size=3, stride=2) + # 73 x 73 x 64 + x = self.model.Conv2d_3b_1x1(x) + # 73 x 73 x 80 + x = self.model.Conv2d_4a_3x3(x) + # 71 x 71 x 192 + x = F.max_pool2d(x, kernel_size=3, stride=2) # => Finish first convs + + # 35 x 35 x 192 + x = self.model.Mixed_5b(x) # => Finish mixed 0 + # 35 x 35 x 256 + x = self.model.Mixed_5c(x) # => Finish mixed 1 + # 35 x 35 x 288 + x = self.model.Mixed_5d(x) # => Finish mixed 2 + # import pdb + # pdb.set_trace() + x_mix_2 = self.deepsuper_2(x) + # 35 x 35 x 288 + x = self.model.Mixed_6a(x) # => Finish mixed 3 + # 17 x 17 x 768 + x = self.model.Mixed_6b(x) # => Finish mixed 4 + x_mix_4 = self.deepsuper_4(x) + # 17 x 17 x 768 + x = self.model.Mixed_6c(x) # => Finish mixed 5 + # 17 x 17 x 768 + x = self.model.Mixed_6d(x) # => Finish mixed 6 + x_mix_6 = self.deepsuper_6(x) + # 17 x 17 x 768 + x = self.model.Mixed_6e(x) # => Finish mixed 7 + # 17 x 17 x 768 + # if self.model.training and self.model.aux_logits: + # aux = self.model.AuxLogits(x) + # 17 x 17 x 768 + x = self.model.Mixed_7a(x) # => Finish mixed 8 + x_mix_8 = self.deepsuper_8(x) + # 8 x 8 x 1280 + x = self.model.Mixed_7b(x) # => Finish mixed 9 + x_mix_9 = self.deepsuper_9(x) + # 8 x 8 x 2048 + x = self.model.Mixed_7c(x) # => Finish mixed 10 + # 8 x 8 x 2048 + + # here is the model output + x_mix_10 = self.deepsuper_10(x) + x_final = self.fc(x) + + return x_mix_2, x_mix_4, x_mix_6, x_mix_8, x_mix_9, x_mix_10, x_final + + def freeze(self): + # Freeze all the backbone + for param in self.model.parameters(): + param.requires_grad = False + + def unfreeze(self): + # Unfreeze all the backbone + for param in self.model.parameters(): + param.requires_grad = True + + +if __name__ == '__main__': + model = DSInceptionV3() \ No newline at end of file