Skip to content

Commit

Permalink
Add resnet
Browse files Browse the repository at this point in the history
  • Loading branch information
ngxbac committed Jul 14, 2019
1 parent 2c4a5a4 commit eea231f
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 27 deletions.
2 changes: 1 addition & 1 deletion src/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


# Register models
registry.Model(cell_resnet)
registry.Model(ResNet)
registry.Model(cell_senet)
registry.Model(cell_densenet)
registry.Model(EfficientNet)
Expand Down
2 changes: 1 addition & 1 deletion src/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .resnet import cell_resnet
from .resnet import ResNet
from .senet import cell_senet
from .densenet import cell_densenet
from .efficientnet import EfficientNet
59 changes: 34 additions & 25 deletions src/models/resnet.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,39 @@
import torch.nn as nn
import pretrainedmodels
from cnn_finetune import make_model
import timm
from .utils import *


def cell_resnet(model_name, num_classes=1108, n_channels=6):
model = make_model(
model_name=model_name,
num_classes=num_classes,
pretrained=True
)
conv1 = model._features[0]
model._features[0] = nn.Conv2d(in_channels=n_channels,
out_channels=conv1.out_channels,
kernel_size=conv1.kernel_size,
stride=conv1.stride,
padding=conv1.padding,
bias=conv1.bias)

# copy pretrained weights
model._features[0].weight.data[:,:3,:,:] = conv1.weight.data
model._features[0].weight.data[:,3:n_channels,:,:] = conv1.weight.data[:,:int(n_channels-3),:,:]
return model


if __name__ == '__main__':
import torch
model = cell_resnet(model_name='resnet18')
x = torch.randn((1, 6, 320, 320))
y = model(x)
class ResNet(nn.Module):
def __init__(self, model_name="resnet50",
num_classes=1108,
n_channels=6):
super(ResNet, self).__init__()

self.model = timm.create_model(model_name, pretrained=True, num_classes=num_classes)
conv1 = self.model.conv1
self.model.conv1 = nn.Conv2d(in_channels=n_channels,
out_channels=conv1.out_channels,
kernel_size=conv1.kernel_size,
stride=conv1.stride,
padding=conv1.padding,
bias=conv1.bias)

# copy pretrained weights
self.model.conv1.weight.data[:, :3, :, :] = conv1.weight.data
self.model.conv1.weight.data[:, 3:n_channels, :, :] = conv1.weight.data[:, :int(n_channels - 3), :, :]

def forward(self, x):
return self.model(x)

def freeze(self):
for param in self.model.parameters():
param.requires_grad = False

for param in self.model.fc.parameters():
param.requires_grad = True

def unfreeze(self):
for param in self.model.parameters():
param.requires_grad = True

0 comments on commit eea231f

Please sign in to comment.