diff --git a/src/__init__.py b/src/__init__.py index 4ed3770..4932882 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -13,6 +13,7 @@ registry.Model(ResNet) registry.Model(cell_senet) registry.Model(cell_densenet) +registry.Model(SENetGrouplevel) registry.Model(EfficientNet) registry.Model(SENetTIMM) registry.Model(InceptionV3TIMM) @@ -21,6 +22,7 @@ registry.Model(DSSENet) registry.Model(ResNet50CutMix) registry.Model(Fishnet) +registry.Model(SENetCellType) # Register callbacks registry.Callback(LabelSmoothCriterionCallback) @@ -28,6 +30,7 @@ registry.Callback(DSAccuracyCallback) registry.Callback(DSCriterionCallback) registry.Callback(SlackLogger) +registry.Callback(TwoHeadsCriterionCallback) # Register criterions registry.Criterion(LabelSmoothingCrossEntropy) diff --git a/src/callbacks.py b/src/callbacks.py index 86378d0..a35ceda 100644 --- a/src/callbacks.py +++ b/src/callbacks.py @@ -186,8 +186,6 @@ def _compute_loss(self, state: RunnerState, criterion): return loss - - class DSCriterionCallback(Callback): def __init__( self, @@ -252,6 +250,73 @@ def on_batch_end(self, state: RunnerState): self._add_loss_to_state(state, loss) +class TwoHeadsCriterionCallback(Callback): + def __init__( + self, + input_key: str = "targets", + output_key: str = "logits", + prefix: str = "loss", + criterion_key: str = None, + loss_key: str = None, + multiplier: float = 1.0, + loss_weights: List[float] = None, + ): + self.input_key = input_key + self.output_key = output_key + self.prefix = prefix + self.criterion_key = criterion_key + self.loss_key = loss_key + self.multiplier = multiplier + self.loss_weights = loss_weights + + def _add_loss_to_state(self, state: RunnerState, loss): + if self.loss_key is None: + if state.loss is not None: + if isinstance(state.loss, list): + state.loss.append(loss) + else: + state.loss = [state.loss, loss] + else: + state.loss = loss + else: + if state.loss is not None: + assert isinstance(state.loss, dict) + state.loss[self.loss_key] = loss + else: + state.loss = {self.loss_key: loss} + + def _compute_loss(self, state: RunnerState, criterion): + outputs = state.output[self.output_key] + outputs1 = state.output["logits1"] + input_sirna = state.input[self.input_key] + input_cell = state.input['cell_type'] + loss = 0 + + loss += criterion(outputs, input_sirna) + loss += nn.CrossEntropyLoss()(outputs1, input_cell) + + return loss + + def on_stage_start(self, state: RunnerState): + assert state.criterion is not None + + def on_batch_end(self, state: RunnerState): + if state.loader_name.startswith("train"): + criterion = state.get_key( + key="criterion", inner_key=self.criterion_key + ) + else: + criterion = nn.CrossEntropyLoss() + + loss = self._compute_loss(state, criterion) * self.multiplier + + state.metrics.add_batch_value(metrics_dict={ + self.prefix: loss.item(), + }) + + self._add_loss_to_state(state, loss) + + class DSAccuracyCallback(Callback): """ Accuracy metric callback. diff --git a/src/models/__init__.py b/src/models/__init__.py index ffab74a..2bd6794 100644 --- a/src/models/__init__.py +++ b/src/models/__init__.py @@ -1,5 +1,5 @@ from .resnet import ResNet, ResNet50CutMix -from .senet import cell_senet, SENetTIMM +from .senet import cell_senet, SENetTIMM, SENetGrouplevel, SENetCellType from .densenet import cell_densenet from .efficientnet import EfficientNet from .inceptionv3 import InceptionV3TIMM diff --git a/src/models/senet.py b/src/models/senet.py index f282333..c5c1b21 100644 --- a/src/models/senet.py +++ b/src/models/senet.py @@ -1,3 +1,4 @@ +import torch from cnn_finetune import make_model import timm from .utils import * @@ -37,7 +38,112 @@ def unfreeze(self): param.requires_grad = True -def cell_senet(model_name='se_resnext50_32x4d', num_classes=1108, n_channels=6): +class SENetGrouplevel(nn.Module): + def __init__(self, model_name="seresnext26_32x4d", + num_classes=1108, + n_channels=6): + super(SENetGrouplevel, self).__init__() + + self.model = make_model( + model_name=model_name, + num_classes=num_classes, + pretrained=True, + dropout_p=0.3 + ) + print("*" * 100) + print("SENetGrouplevel") + conv1 = self.model._features[0].conv1 + self.model._features[0].conv1 = nn.Conv2d(in_channels=n_channels, + out_channels=conv1.out_channels, + kernel_size=conv1.kernel_size, + stride=conv1.stride, + padding=conv1.padding, + bias=conv1.bias) + + # copy pretrained weights + self.model._features[0].conv1.weight.data[:, :3, :, :] = conv1.weight.data + self.model._features[0].conv1.weight.data[:, 3:n_channels, :, :] = conv1.weight.data[:, :int(n_channels - 3), :, :] + + self.group_label_embedding = nn.Embedding(num_embeddings=4, embedding_dim=8) + + in_features = self.model._classifier.in_features + self.final_fc = nn.Linear( + in_features=in_features + 8, out_features=num_classes + ) + + def forward(self, x, group_label): + features = self.model._features(x) + features = self.model.pool(features) + features = features.view(features.size(0), -1) + + group_embedding = self.group_label_embedding(group_label) + features = torch.cat([ + features, group_embedding + ], 1) + + return self.final_fc(features) + + def freeze(self): + for param in self.model.parameters(): + param.requires_grad = False + + def unfreeze(self): + for param in self.model.parameters(): + param.requires_grad = True + + +class SENetCellType(nn.Module): + def __init__(self, model_name="seresnext26_32x4d", + num_classes=1108, + n_channels=6): + super(SENetCellType, self).__init__() + + self.model = make_model( + model_name=model_name, + num_classes=num_classes, + pretrained=True, + dropout_p=0.3 + ) + print("*" * 100) + print("SENetGrouplevel") + conv1 = self.model._features[0].conv1 + self.model._features[0].conv1 = nn.Conv2d(in_channels=n_channels, + out_channels=conv1.out_channels, + kernel_size=conv1.kernel_size, + stride=conv1.stride, + padding=conv1.padding, + bias=conv1.bias) + + # copy pretrained weights + self.model._features[0].conv1.weight.data[:, :3, :, :] = conv1.weight.data + self.model._features[0].conv1.weight.data[:, 3:n_channels, :, :] = conv1.weight.data[:, :int(n_channels - 3), :, :] + + in_features = self.model._classifier.in_features + self.final_sirna = nn.Linear( + in_features=in_features, out_features=num_classes + ) + + self.final_cell_type = nn.Linear( + in_features=in_features, out_features=4 + ) + + def forward(self, x): + features = self.model._features(x) + features = self.model.pool(features) + features = features.view(features.size(0), -1) + + return self.final_sirna(features), self.final_cell_type(features) + + def freeze(self): + for param in self.model.parameters(): + param.requires_grad = False + + def unfreeze(self): + for param in self.model.parameters(): + param.requires_grad = True + + +def cell_senet(model_name='se_resnext50_32x4d', num_classes=1108, n_channels=6, weight=None): model = make_model( model_name=model_name, num_classes=num_classes, @@ -58,12 +164,14 @@ def cell_senet(model_name='se_resnext50_32x4d', num_classes=1108, n_channels=6): # copy pretrained weights model._features[0].conv1.weight.data[:,:3,:,:] = conv1.weight.data model._features[0].conv1.weight.data[:,3:n_channels,:,:] = conv1.weight.data[:,:int(n_channels-3),:,:] - # model = SENet( - # model_name=model_name, - # num_classes=num_classes, - # n_channels=n_channels - # ) - # - # for param in model.parameters(): - # param.requires_grad = True + + if weight: + model_state_dict = torch.load(weight)['model_state_dict'] + model.load_state_dict(model_state_dict) + print(f"\n\n******************************* Loaded checkpoint {weight}") + in_features = model._classifier.in_features + model._classifier = nn.Linear( + in_features=in_features, out_features=num_classes + ) + return model diff --git a/src/runner.py b/src/runner.py index a4c54d1..87eb37b 100644 --- a/src/runner.py +++ b/src/runner.py @@ -6,7 +6,10 @@ class ModelRunner(Runner): def predict_batch(self, batch: Mapping[str, Any]): # import pdb # pdb.set_trace() - output = self.model(batch["images"]) + if 'group_labels' in batch: + output = self.model(batch["images"], batch['group_labels']) + else: + output = self.model(batch["images"]) return { "logits": output }