Skip to content

Commit

Permalink
Refactor code
Browse files Browse the repository at this point in the history
  • Loading branch information
ngxbac committed Sep 6, 2019
1 parent 08a2a9d commit 5f7fa5f
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 6 deletions.
16 changes: 14 additions & 2 deletions src/models/densenet.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import torch.nn as nn
import torch
import pretrainedmodels
from cnn_finetune import make_model


def cell_densenet(model_name='densenet121', num_classes=1108, n_channels=6):
def cell_densenet(model_name='densenet121', num_classes=1108, n_channels=6, weight=None):
model = make_model(
model_name=model_name,
num_classes=num_classes,
num_classes=31,
pretrained=True
)
conv1 = model._features[0]
Expand All @@ -20,6 +21,17 @@ def cell_densenet(model_name='densenet121', num_classes=1108, n_channels=6):
# copy pretrained weights
model._features[0].weight.data[:,:3,:,:] = conv1.weight.data
model._features[0].weight.data[:,3:n_channels,:,:] = conv1.weight.data[:,:int(n_channels-3),:,:]

if weight:
model_state_dict = torch.load(weight)['model_state_dict']
model.load_state_dict(model_state_dict)
print(f"\n\n******************************* Loaded checkpoint {weight}")

in_features = model._classifier.in_features
model._classifier = nn.Linear(
in_features=in_features, out_features=num_classes
)

return model


Expand Down
8 changes: 4 additions & 4 deletions src/models/senet.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,9 +237,9 @@ def cell_senet(model_name='se_resnext50_32x4d', num_classes=1108, n_channels=6,
model_state_dict = torch.load(weight)['model_state_dict']
model.load_state_dict(model_state_dict)
print(f"\n\n******************************* Loaded checkpoint {weight}")
in_features = model._classifier.in_features
model._classifier = nn.Linear(
in_features=in_features, out_features=num_classes
)
in_features = model._classifier.in_features
model._classifier = nn.Linear(
in_features=in_features, out_features=num_classes
)

return model

0 comments on commit 5f7fa5f

Please sign in to comment.