Skip to content

Commit

Permalink
Deepsupervision- inceptionv3
Browse files Browse the repository at this point in the history
  • Loading branch information
ngxbac committed Jul 29, 2019
1 parent 13e5c7a commit 72c2c99
Show file tree
Hide file tree
Showing 7 changed files with 401 additions and 2 deletions.
19 changes: 19 additions & 0 deletions bin/train_ds.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#!/usr/bin/env bash

export CUDA_VISIBLE_DEVICES=2,3
RUN_CONFIG=config_ds.yml


for channels in [1,2,3,4,5]; do
for fold in 0; do
LOGDIR=/raid/bac/kaggle/logs/recursion_cell/test/190729/fold_$fold/DSInceptionV3/
catalyst-dl run \
--config=./configs/${RUN_CONFIG} \
--logdir=$LOGDIR \
--out_dir=$LOGDIR:str \
--stages/data_params/channels=$channels:list \
--stages/data_params/train_csv=./csv/train_$fold.csv:str \
--stages/data_params/valid_csv=./csv/valid_$fold.csv:str \
--verbose
done
done
99 changes: 99 additions & 0 deletions configs/config_ds.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
model_params:
model: DSInceptionV3
pretrained: True
n_channels: 5
num_classes: 1108

args:
expdir: "src"
logdir: &logdir "./logs/cell"
baselogdir: "./logs/cell"

distributed_params:
opt_level: O1

stages:

state_params:
main_metric: &reduce_metric acc_final
minimize_metric: False

criterion_params:
# criterion: CrossEntropyLoss
criterion: LabelSmoothingCrossEntropy

data_params:
batch_size: 64
num_workers: 16
drop_last: False
# drop_last: True

image_size: &image_size 512
train_csv: "./csv/train_0.csv"
valid_csv: "./csv/valid_0.csv"
root: "/raid/data/kaggle/recursion-cellular-image-classification/"
sites: [1]
channels: [1, 2, 3, 4]

stage0:

optimizer_params:
optimizer: Nadam
lr: 0.001

scheduler_params:
scheduler: MultiStepLR
milestones: [10]
gamma: 0.3

state_params:
num_epochs: 2

callbacks_params: &callback_params
loss:
callback: DSCriterionCallback
loss_weights: [0.001, 0.005, 0.01, 0.02, 0.02, 0.1, 1.0]
optimizer:
callback: OptimizerCallback
accumulation_steps: 2
accuracy:
callback: DSAccuracyCallback
logit_names: ["m2", "m4", "m6", "m8", "m9", "m10", "final"]
scheduler:
callback: SchedulerCallback
reduce_metric: *reduce_metric
saver:
callback: CheckpointCallback

stage1:

optimizer_params:
optimizer: Nadam
lr: 0.0001

scheduler_params:
scheduler: OneCycleLR
num_steps: 50
lr_range: [0.0005, 0.00001]
# lr_range: [0.0015, 0.00003]
warmup_steps: 5
momentum_range: [0.85, 0.95]

state_params:
num_epochs: 50

callbacks_params:
loss:
callback: DSCriterionCallback
loss_weights: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
optimizer:
callback: OptimizerCallback
accumulation_steps: 2
accuracy:
callback: DSAccuracyCallback
logit_names: ["m2", "m4", "m6", "m8", "m9", "m10", "final"]
scheduler:
callback: SchedulerCallback
reduce_metric: *reduce_metric
saver:
callback: CheckpointCallback
3 changes: 3 additions & 0 deletions src/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@
registry.Model(SENetTIMM)
registry.Model(InceptionV3TIMM)
registry.Model(GluonResnetTIMM)
registry.Model(DSInceptionV3)

# Register callbacks
registry.Callback(LabelSmoothCriterionCallback)
registry.Callback(SmoothMixupCallback)
registry.Callback(DSAccuracyCallback)
registry.Callback(DSCriterionCallback)

# Register criterions
registry.Criterion(LabelSmoothingCrossEntropy)
Expand Down
103 changes: 102 additions & 1 deletion src/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from catalyst.dl.core import Callback, RunnerState
from catalyst.dl.callbacks import CriterionCallback
from catalyst.dl.utils.criterion import accuracy
import torch
import torch.nn as nn
import numpy as np
Expand Down Expand Up @@ -138,4 +139,104 @@ def _compute_loss(self, state: RunnerState, criterion):

loss = self.lam * criterion(pred, y_a) + \
(1 - self.lam) * criterion(pred, y_b)
return loss
return loss




class DSCriterionCallback(Callback):
def __init__(
self,
input_key: str = "targets",
output_key: str = "logits",
prefix: str = "loss",
criterion_key: str = None,
loss_key: str = None,
multiplier: float = 1.0,
loss_weights: List[float] = None,
):
self.input_key = input_key
self.output_key = output_key
self.prefix = prefix
self.criterion_key = criterion_key
self.loss_key = loss_key
self.multiplier = multiplier
self.loss_weights = loss_weights

def _add_loss_to_state(self, state: RunnerState, loss):
if self.loss_key is None:
if state.loss is not None:
if isinstance(state.loss, list):
state.loss.append(loss)
else:
state.loss = [state.loss, loss]
else:
state.loss = loss
else:
if state.loss is not None:
assert isinstance(state.loss, dict)
state.loss[self.loss_key] = loss
else:
state.loss = {self.loss_key: loss}

def _compute_loss(self, state: RunnerState, criterion):
outputs = state.output[self.output_key]
input = state.input[self.input_key]
assert len(self.loss_weights) == len(outputs)
loss = 0
for i, output in enumerate(outputs):
loss += criterion(output, input) * self.loss_weights[i]
return loss

def on_stage_start(self, state: RunnerState):
assert state.criterion is not None

def on_batch_end(self, state: RunnerState):
if state.loader_name.startswith("train"):
criterion = state.get_key(
key="criterion", inner_key=self.criterion_key
)
else:
criterion = nn.CrossEntropyLoss()

loss = self._compute_loss(state, criterion) * self.multiplier

state.metrics.add_batch_value(metrics_dict={
self.prefix: loss.item(),
})

self._add_loss_to_state(state, loss)


class DSAccuracyCallback(Callback):
"""
Accuracy metric callback.
"""

def __init__(
self,
input_key: str = "targets",
output_key: str = "logits",
prefix: str = "acc",
logit_names: List[str] = None,
):
self.prefix = prefix
self.metric_fn = accuracy
self.input_key = input_key
self.output_key = output_key
self.logit_names = logit_names

def on_batch_end(self, state: RunnerState):
outputs = state.output[self.output_key]
targets = state.input[self.input_key]

assert len(outputs) == len(self.logit_names)

batch_metrics = {}

for logit_name, output in zip(self.logit_names, outputs):
metric = self.metric_fn(output, targets)
key = f"{self.prefix}_{logit_name}"
batch_metrics[key] = metric[0]

state.metrics.add_batch_value(metrics_dict=batch_metrics)
3 changes: 2 additions & 1 deletion src/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@
from .densenet import cell_densenet
from .efficientnet import EfficientNet
from .inceptionv3 import InceptionV3TIMM
from .gluon_resnet import GluonResnetTIMM
from .gluon_resnet import GluonResnetTIMM
from .deepsupervision import DSInceptionV3
1 change: 1 addition & 0 deletions src/models/deepsupervision/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .inception_v3 import DSInceptionV3
Loading

0 comments on commit 72c2c99

Please sign in to comment.