diff --git a/src/__init__.py b/src/__init__.py index cd034ce..27b7670 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -9,7 +9,7 @@ # Register models -registry.Model(cell_resnet) +registry.Model(ResNet) registry.Model(cell_senet) registry.Model(cell_densenet) registry.Model(EfficientNet) diff --git a/src/models/__init__.py b/src/models/__init__.py index 1a0a12c..ec8182d 100644 --- a/src/models/__init__.py +++ b/src/models/__init__.py @@ -1,4 +1,4 @@ -from .resnet import cell_resnet +from .resnet import ResNet from .senet import cell_senet from .densenet import cell_densenet from .efficientnet import EfficientNet \ No newline at end of file diff --git a/src/models/resnet.py b/src/models/resnet.py index edcc52a..243f0ba 100644 --- a/src/models/resnet.py +++ b/src/models/resnet.py @@ -1,30 +1,39 @@ import torch.nn as nn import pretrainedmodels from cnn_finetune import make_model +import timm +from .utils import * -def cell_resnet(model_name, num_classes=1108, n_channels=6): - model = make_model( - model_name=model_name, - num_classes=num_classes, - pretrained=True - ) - conv1 = model._features[0] - model._features[0] = nn.Conv2d(in_channels=n_channels, - out_channels=conv1.out_channels, - kernel_size=conv1.kernel_size, - stride=conv1.stride, - padding=conv1.padding, - bias=conv1.bias) - - # copy pretrained weights - model._features[0].weight.data[:,:3,:,:] = conv1.weight.data - model._features[0].weight.data[:,3:n_channels,:,:] = conv1.weight.data[:,:int(n_channels-3),:,:] - return model - - -if __name__ == '__main__': - import torch - model = cell_resnet(model_name='resnet18') - x = torch.randn((1, 6, 320, 320)) - y = model(x) \ No newline at end of file +class ResNet(nn.Module): + def __init__(self, model_name="resnet50", + num_classes=1108, + n_channels=6): + super(ResNet, self).__init__() + + self.model = timm.create_model(model_name, pretrained=True, num_classes=num_classes) + conv1 = self.model.conv1 + self.model.conv1 = nn.Conv2d(in_channels=n_channels, + out_channels=conv1.out_channels, + kernel_size=conv1.kernel_size, + stride=conv1.stride, + padding=conv1.padding, + bias=conv1.bias) + + # copy pretrained weights + self.model.conv1.weight.data[:, :3, :, :] = conv1.weight.data + self.model.conv1.weight.data[:, 3:n_channels, :, :] = conv1.weight.data[:, :int(n_channels - 3), :, :] + + def forward(self, x): + return self.model(x) + + def freeze(self): + for param in self.model.parameters(): + param.requires_grad = False + + for param in self.model.fc.parameters(): + param.requires_grad = True + + def unfreeze(self): + for param in self.model.parameters(): + param.requires_grad = True