Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
triducnguyentang committed Jul 19, 2019
2 parents 059ba34 + 2933135 commit 0294553
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 24 deletions.
3 changes: 3 additions & 0 deletions src/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
registry.Model(cell_senet)
registry.Model(cell_densenet)
registry.Model(EfficientNet)
registry.Model(SENetTIMM)
registry.Model(InceptionV3TIMM)
registry.Model(GluonResnetTIMM)

# Register callbacks
registry.Callback(LabelSmoothCriterionCallback)
Expand Down
4 changes: 2 additions & 2 deletions src/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def _postprocess_model_for_stage(self, stage: str, model: nn.Module):
else:
for param in model_._features.parameters():
param.requires_grad = False
print("Freeze backbone model !!!")
print("Freeze backbone model !!!")

else:
if hasattr(model_, 'unfreeze'):
Expand All @@ -37,7 +37,7 @@ def _postprocess_model_for_stage(self, stage: str, model: nn.Module):
else:
for param in model_._features.parameters():
param.requires_grad = True
print("Freeze backbone model !!!")
print("Freeze backbone model !!!")

return model_

Expand Down
6 changes: 4 additions & 2 deletions src/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from .resnet import ResNet
from .senet import cell_senet
from .senet import cell_senet, SENetTIMM
from .densenet import cell_densenet
from .efficientnet import EfficientNet
from .efficientnet import EfficientNet
from .inceptionv3 import InceptionV3TIMM
from .gluon_resnet import GluonResnetTIMM
38 changes: 38 additions & 0 deletions src/models/gluon_resnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from cnn_finetune import make_model
import timm
from .utils import *


class GluonResnetTIMM(nn.Module):
def __init__(self, model_name="gluon_resnet50_v1d",
num_classes=1108,
n_channels=6):
super(GluonResnetTIMM, self).__init__()

self.model = timm.create_model(model_name, pretrained=True, num_classes=num_classes)
print(self.model)
conv1 = self.model.conv1
self.model.conv1 = nn.Conv2d(in_channels=n_channels,
out_channels=conv1.out_channels,
kernel_size=conv1.kernel_size,
stride=conv1.stride,
padding=conv1.padding,
bias=conv1.bias)

# copy pretrained weights
self.model.conv1.weight.data[:, :3, :, :] = conv1.weight.data
self.model.conv1.weight.data[:, 3:n_channels, :, :] = conv1.weight.data[:, :int(n_channels - 3), :, :]

def forward(self, x):
return self.model(x)

def freeze(self):
for param in self.model.parameters():
param.requires_grad = False

for param in self.model.get_classifier().parameters():
param.requires_grad = True

def unfreeze(self):
for param in self.model.get_classifier().parameters():
param.requires_grad = True
37 changes: 37 additions & 0 deletions src/models/inceptionv3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from cnn_finetune import make_model
import timm
from .utils import *


class InceptionV3TIMM(nn.Module):
def __init__(self, model_name="gluon_inception_v3",
num_classes=1108,
n_channels=6):
super(InceptionV3TIMM, self).__init__()

self.model = timm.create_model(model_name, pretrained=True, num_classes=num_classes)
conv1 = self.model.Conv2d_1a_3x3.conv
self.model.Conv2d_1a_3x3.conv = nn.Conv2d(in_channels=n_channels,
out_channels=conv1.out_channels,
kernel_size=conv1.kernel_size,
stride=conv1.stride,
padding=conv1.padding,
bias=conv1.bias)

# copy pretrained weights
self.model.Conv2d_1a_3x3.conv.weight.data[:, :3, :, :] = conv1.weight.data
self.model.Conv2d_1a_3x3.conv.weight.data[:, 3:n_channels, :, :] = conv1.weight.data[:, :int(n_channels - 3), :, :]

def forward(self, x):
return self.model(x)

def freeze(self):
for param in self.model.parameters():
param.requires_grad = False

for param in self.model.fc.parameters():
param.requires_grad = True

def unfreeze(self):
for param in self.model.fc.parameters():
param.requires_grad = True
48 changes: 28 additions & 20 deletions src/models/senet.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,41 @@
import torch.nn as nn
import pretrainedmodels
from cnn_finetune import make_model
import timm
from .utils import *


class SENet(nn.Module):
def __init__(self, model_name="se_resnext50_32x4d",
class SENetTIMM(nn.Module):
def __init__(self, model_name="seresnext26_32x4d",
num_classes=1108,
n_channels=6):
super(SENet, self).__init__()

self.model = make_model(
model_name=model_name,
num_classes=num_classes,
pretrained=True,
# pool=GlobalConcatPool2d(),
# classifier_factory=make_classifier
)
self.conv = Conv2dSame(
in_channels=n_channels,
out_channels=3,
kernel_size=1,
stride=1,
)
super(SENetTIMM, self).__init__()

self.model = timm.create_model(model_name, pretrained=True, num_classes=num_classes)
conv1 = self.model.layer0.conv1
self.model.layer0.conv1 = nn.Conv2d(in_channels=n_channels,
out_channels=conv1.out_channels,
kernel_size=conv1.kernel_size,
stride=conv1.stride,
padding=conv1.padding,
bias=conv1.bias)

# copy pretrained weights
self.model.layer0.conv1.weight.data[:, :3, :, :] = conv1.weight.data
self.model.layer0.conv1.weight.data[:, 3:n_channels, :, :] = conv1.weight.data[:, :int(n_channels - 3), :, :]

def forward(self, x):
x = self.conv(x)
return self.model(x)

def freeze(self):
for param in self.model.parameters():
param.requires_grad = False

for param in self.model.get_classifier().parameters():
param.requires_grad = True

def unfreeze(self):
for param in self.model.get_classifier().parameters():
param.requires_grad = True


def cell_senet(model_name='se_resnext50_32x4d', num_classes=1108, n_channels=6):
model = make_model(
Expand Down

0 comments on commit 0294553

Please sign in to comment.