diff --git a/src/__init__.py b/src/__init__.py index 27b7670..b1e4f40 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -13,6 +13,9 @@ registry.Model(cell_senet) registry.Model(cell_densenet) registry.Model(EfficientNet) +registry.Model(SENetTIMM) +registry.Model(InceptionV3TIMM) +registry.Model(GluonResnetTIMM) # Register callbacks registry.Callback(LabelSmoothCriterionCallback) diff --git a/src/experiment.py b/src/experiment.py index e9702e3..579b381 100644 --- a/src/experiment.py +++ b/src/experiment.py @@ -28,7 +28,7 @@ def _postprocess_model_for_stage(self, stage: str, model: nn.Module): else: for param in model_._features.parameters(): param.requires_grad = False - print("Freeze backbone model !!!") + print("Freeze backbone model !!!") else: if hasattr(model_, 'unfreeze'): @@ -37,7 +37,7 @@ def _postprocess_model_for_stage(self, stage: str, model: nn.Module): else: for param in model_._features.parameters(): param.requires_grad = True - print("Freeze backbone model !!!") + print("Freeze backbone model !!!") return model_ diff --git a/src/models/__init__.py b/src/models/__init__.py index ec8182d..ab7e2f5 100644 --- a/src/models/__init__.py +++ b/src/models/__init__.py @@ -1,4 +1,6 @@ from .resnet import ResNet -from .senet import cell_senet +from .senet import cell_senet, SENetTIMM from .densenet import cell_densenet -from .efficientnet import EfficientNet \ No newline at end of file +from .efficientnet import EfficientNet +from .inceptionv3 import InceptionV3TIMM +from .gluon_resnet import GluonResnetTIMM \ No newline at end of file diff --git a/src/models/gluon_resnet.py b/src/models/gluon_resnet.py new file mode 100644 index 0000000..21f4164 --- /dev/null +++ b/src/models/gluon_resnet.py @@ -0,0 +1,38 @@ +from cnn_finetune import make_model +import timm +from .utils import * + + +class GluonResnetTIMM(nn.Module): + def __init__(self, model_name="gluon_resnet50_v1d", + num_classes=1108, + n_channels=6): + super(GluonResnetTIMM, self).__init__() + + self.model = timm.create_model(model_name, pretrained=True, num_classes=num_classes) + print(self.model) + conv1 = self.model.conv1 + self.model.conv1 = nn.Conv2d(in_channels=n_channels, + out_channels=conv1.out_channels, + kernel_size=conv1.kernel_size, + stride=conv1.stride, + padding=conv1.padding, + bias=conv1.bias) + + # copy pretrained weights + self.model.conv1.weight.data[:, :3, :, :] = conv1.weight.data + self.model.conv1.weight.data[:, 3:n_channels, :, :] = conv1.weight.data[:, :int(n_channels - 3), :, :] + + def forward(self, x): + return self.model(x) + + def freeze(self): + for param in self.model.parameters(): + param.requires_grad = False + + for param in self.model.get_classifier().parameters(): + param.requires_grad = True + + def unfreeze(self): + for param in self.model.get_classifier().parameters(): + param.requires_grad = True \ No newline at end of file diff --git a/src/models/inceptionv3.py b/src/models/inceptionv3.py new file mode 100644 index 0000000..37d4093 --- /dev/null +++ b/src/models/inceptionv3.py @@ -0,0 +1,37 @@ +from cnn_finetune import make_model +import timm +from .utils import * + + +class InceptionV3TIMM(nn.Module): + def __init__(self, model_name="gluon_inception_v3", + num_classes=1108, + n_channels=6): + super(InceptionV3TIMM, self).__init__() + + self.model = timm.create_model(model_name, pretrained=True, num_classes=num_classes) + conv1 = self.model.Conv2d_1a_3x3.conv + self.model.Conv2d_1a_3x3.conv = nn.Conv2d(in_channels=n_channels, + out_channels=conv1.out_channels, + kernel_size=conv1.kernel_size, + stride=conv1.stride, + padding=conv1.padding, + bias=conv1.bias) + + # copy pretrained weights + self.model.Conv2d_1a_3x3.conv.weight.data[:, :3, :, :] = conv1.weight.data + self.model.Conv2d_1a_3x3.conv.weight.data[:, 3:n_channels, :, :] = conv1.weight.data[:, :int(n_channels - 3), :, :] + + def forward(self, x): + return self.model(x) + + def freeze(self): + for param in self.model.parameters(): + param.requires_grad = False + + for param in self.model.fc.parameters(): + param.requires_grad = True + + def unfreeze(self): + for param in self.model.fc.parameters(): + param.requires_grad = True \ No newline at end of file diff --git a/src/models/senet.py b/src/models/senet.py index b9dc777..f282333 100644 --- a/src/models/senet.py +++ b/src/models/senet.py @@ -1,33 +1,41 @@ -import torch.nn as nn -import pretrainedmodels from cnn_finetune import make_model +import timm from .utils import * -class SENet(nn.Module): - def __init__(self, model_name="se_resnext50_32x4d", +class SENetTIMM(nn.Module): + def __init__(self, model_name="seresnext26_32x4d", num_classes=1108, n_channels=6): - super(SENet, self).__init__() - - self.model = make_model( - model_name=model_name, - num_classes=num_classes, - pretrained=True, - # pool=GlobalConcatPool2d(), - # classifier_factory=make_classifier - ) - self.conv = Conv2dSame( - in_channels=n_channels, - out_channels=3, - kernel_size=1, - stride=1, - ) + super(SENetTIMM, self).__init__() + + self.model = timm.create_model(model_name, pretrained=True, num_classes=num_classes) + conv1 = self.model.layer0.conv1 + self.model.layer0.conv1 = nn.Conv2d(in_channels=n_channels, + out_channels=conv1.out_channels, + kernel_size=conv1.kernel_size, + stride=conv1.stride, + padding=conv1.padding, + bias=conv1.bias) + + # copy pretrained weights + self.model.layer0.conv1.weight.data[:, :3, :, :] = conv1.weight.data + self.model.layer0.conv1.weight.data[:, 3:n_channels, :, :] = conv1.weight.data[:, :int(n_channels - 3), :, :] def forward(self, x): - x = self.conv(x) return self.model(x) + def freeze(self): + for param in self.model.parameters(): + param.requires_grad = False + + for param in self.model.get_classifier().parameters(): + param.requires_grad = True + + def unfreeze(self): + for param in self.model.get_classifier().parameters(): + param.requires_grad = True + def cell_senet(model_name='se_resnext50_32x4d', num_classes=1108, n_channels=6): model = make_model(