Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
ngxbac committed Aug 10, 2019
1 parent 8c6511c commit 2979d8c
Show file tree
Hide file tree
Showing 5 changed files with 192 additions and 13 deletions.
3 changes: 3 additions & 0 deletions src/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
registry.Model(ResNet)
registry.Model(cell_senet)
registry.Model(cell_densenet)
registry.Model(SENetGrouplevel)
registry.Model(EfficientNet)
registry.Model(SENetTIMM)
registry.Model(InceptionV3TIMM)
Expand All @@ -21,13 +22,15 @@
registry.Model(DSSENet)
registry.Model(ResNet50CutMix)
registry.Model(Fishnet)
registry.Model(SENetCellType)

# Register callbacks
registry.Callback(LabelSmoothCriterionCallback)
registry.Callback(SmoothMixupCallback)
registry.Callback(DSAccuracyCallback)
registry.Callback(DSCriterionCallback)
registry.Callback(SlackLogger)
registry.Callback(TwoHeadsCriterionCallback)

# Register criterions
registry.Criterion(LabelSmoothingCrossEntropy)
Expand Down
69 changes: 67 additions & 2 deletions src/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,6 @@ def _compute_loss(self, state: RunnerState, criterion):
return loss




class DSCriterionCallback(Callback):
def __init__(
self,
Expand Down Expand Up @@ -252,6 +250,73 @@ def on_batch_end(self, state: RunnerState):
self._add_loss_to_state(state, loss)


class TwoHeadsCriterionCallback(Callback):
def __init__(
self,
input_key: str = "targets",
output_key: str = "logits",
prefix: str = "loss",
criterion_key: str = None,
loss_key: str = None,
multiplier: float = 1.0,
loss_weights: List[float] = None,
):
self.input_key = input_key
self.output_key = output_key
self.prefix = prefix
self.criterion_key = criterion_key
self.loss_key = loss_key
self.multiplier = multiplier
self.loss_weights = loss_weights

def _add_loss_to_state(self, state: RunnerState, loss):
if self.loss_key is None:
if state.loss is not None:
if isinstance(state.loss, list):
state.loss.append(loss)
else:
state.loss = [state.loss, loss]
else:
state.loss = loss
else:
if state.loss is not None:
assert isinstance(state.loss, dict)
state.loss[self.loss_key] = loss
else:
state.loss = {self.loss_key: loss}

def _compute_loss(self, state: RunnerState, criterion):
outputs = state.output[self.output_key]
outputs1 = state.output["logits1"]
input_sirna = state.input[self.input_key]
input_cell = state.input['cell_type']
loss = 0

loss += criterion(outputs, input_sirna)
loss += nn.CrossEntropyLoss()(outputs1, input_cell)

return loss

def on_stage_start(self, state: RunnerState):
assert state.criterion is not None

def on_batch_end(self, state: RunnerState):
if state.loader_name.startswith("train"):
criterion = state.get_key(
key="criterion", inner_key=self.criterion_key
)
else:
criterion = nn.CrossEntropyLoss()

loss = self._compute_loss(state, criterion) * self.multiplier

state.metrics.add_batch_value(metrics_dict={
self.prefix: loss.item(),
})

self._add_loss_to_state(state, loss)


class DSAccuracyCallback(Callback):
"""
Accuracy metric callback.
Expand Down
2 changes: 1 addition & 1 deletion src/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .resnet import ResNet, ResNet50CutMix
from .senet import cell_senet, SENetTIMM
from .senet import cell_senet, SENetTIMM, SENetGrouplevel, SENetCellType
from .densenet import cell_densenet
from .efficientnet import EfficientNet
from .inceptionv3 import InceptionV3TIMM
Expand Down
126 changes: 117 additions & 9 deletions src/models/senet.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import torch
from cnn_finetune import make_model
import timm
from .utils import *
Expand Down Expand Up @@ -37,7 +38,112 @@ def unfreeze(self):
param.requires_grad = True


def cell_senet(model_name='se_resnext50_32x4d', num_classes=1108, n_channels=6):
class SENetGrouplevel(nn.Module):
def __init__(self, model_name="seresnext26_32x4d",
num_classes=1108,
n_channels=6):
super(SENetGrouplevel, self).__init__()

self.model = make_model(
model_name=model_name,
num_classes=num_classes,
pretrained=True,
dropout_p=0.3
)
print("*" * 100)
print("SENetGrouplevel")
conv1 = self.model._features[0].conv1
self.model._features[0].conv1 = nn.Conv2d(in_channels=n_channels,
out_channels=conv1.out_channels,
kernel_size=conv1.kernel_size,
stride=conv1.stride,
padding=conv1.padding,
bias=conv1.bias)

# copy pretrained weights
self.model._features[0].conv1.weight.data[:, :3, :, :] = conv1.weight.data
self.model._features[0].conv1.weight.data[:, 3:n_channels, :, :] = conv1.weight.data[:, :int(n_channels - 3), :, :]

self.group_label_embedding = nn.Embedding(num_embeddings=4, embedding_dim=8)

in_features = self.model._classifier.in_features
self.final_fc = nn.Linear(
in_features=in_features + 8, out_features=num_classes
)

def forward(self, x, group_label):
features = self.model._features(x)
features = self.model.pool(features)
features = features.view(features.size(0), -1)

group_embedding = self.group_label_embedding(group_label)
features = torch.cat([
features, group_embedding
], 1)

return self.final_fc(features)

def freeze(self):
for param in self.model.parameters():
param.requires_grad = False

def unfreeze(self):
for param in self.model.parameters():
param.requires_grad = True


class SENetCellType(nn.Module):
def __init__(self, model_name="seresnext26_32x4d",
num_classes=1108,
n_channels=6):
super(SENetCellType, self).__init__()

self.model = make_model(
model_name=model_name,
num_classes=num_classes,
pretrained=True,
dropout_p=0.3
)
print("*" * 100)
print("SENetGrouplevel")
conv1 = self.model._features[0].conv1
self.model._features[0].conv1 = nn.Conv2d(in_channels=n_channels,
out_channels=conv1.out_channels,
kernel_size=conv1.kernel_size,
stride=conv1.stride,
padding=conv1.padding,
bias=conv1.bias)

# copy pretrained weights
self.model._features[0].conv1.weight.data[:, :3, :, :] = conv1.weight.data
self.model._features[0].conv1.weight.data[:, 3:n_channels, :, :] = conv1.weight.data[:, :int(n_channels - 3), :, :]

in_features = self.model._classifier.in_features
self.final_sirna = nn.Linear(
in_features=in_features, out_features=num_classes
)

self.final_cell_type = nn.Linear(
in_features=in_features, out_features=4
)

def forward(self, x):
features = self.model._features(x)
features = self.model.pool(features)
features = features.view(features.size(0), -1)

return self.final_sirna(features), self.final_cell_type(features)

def freeze(self):
for param in self.model.parameters():
param.requires_grad = False

def unfreeze(self):
for param in self.model.parameters():
param.requires_grad = True


def cell_senet(model_name='se_resnext50_32x4d', num_classes=1108, n_channels=6, weight=None):
model = make_model(
model_name=model_name,
num_classes=num_classes,
Expand All @@ -58,12 +164,14 @@ def cell_senet(model_name='se_resnext50_32x4d', num_classes=1108, n_channels=6):
# copy pretrained weights
model._features[0].conv1.weight.data[:,:3,:,:] = conv1.weight.data
model._features[0].conv1.weight.data[:,3:n_channels,:,:] = conv1.weight.data[:,:int(n_channels-3),:,:]
# model = SENet(
# model_name=model_name,
# num_classes=num_classes,
# n_channels=n_channels
# )
#
# for param in model.parameters():
# param.requires_grad = True

if weight:
model_state_dict = torch.load(weight)['model_state_dict']
model.load_state_dict(model_state_dict)
print(f"\n\n******************************* Loaded checkpoint {weight}")
in_features = model._classifier.in_features
model._classifier = nn.Linear(
in_features=in_features, out_features=num_classes
)

return model
5 changes: 4 additions & 1 deletion src/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ class ModelRunner(Runner):
def predict_batch(self, batch: Mapping[str, Any]):
# import pdb
# pdb.set_trace()
output = self.model(batch["images"])
if 'group_labels' in batch:
output = self.model(batch["images"], batch['group_labels'])
else:
output = self.model(batch["images"])
return {
"logits": output
}

0 comments on commit 2979d8c

Please sign in to comment.