From 5f7fa5fe40c30c19d4d4272708ac211e9c9a5a02 Mon Sep 17 00:00:00 2001 From: Nguyen Xuan Bac Date: Fri, 6 Sep 2019 13:48:29 +0900 Subject: [PATCH] Refactor code --- src/models/densenet.py | 16 ++++++++++++++-- src/models/senet.py | 8 ++++---- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/src/models/densenet.py b/src/models/densenet.py index ed158ba..cb41525 100644 --- a/src/models/densenet.py +++ b/src/models/densenet.py @@ -1,12 +1,13 @@ import torch.nn as nn +import torch import pretrainedmodels from cnn_finetune import make_model -def cell_densenet(model_name='densenet121', num_classes=1108, n_channels=6): +def cell_densenet(model_name='densenet121', num_classes=1108, n_channels=6, weight=None): model = make_model( model_name=model_name, - num_classes=num_classes, + num_classes=31, pretrained=True ) conv1 = model._features[0] @@ -20,6 +21,17 @@ def cell_densenet(model_name='densenet121', num_classes=1108, n_channels=6): # copy pretrained weights model._features[0].weight.data[:,:3,:,:] = conv1.weight.data model._features[0].weight.data[:,3:n_channels,:,:] = conv1.weight.data[:,:int(n_channels-3),:,:] + + if weight: + model_state_dict = torch.load(weight)['model_state_dict'] + model.load_state_dict(model_state_dict) + print(f"\n\n******************************* Loaded checkpoint {weight}") + + in_features = model._classifier.in_features + model._classifier = nn.Linear( + in_features=in_features, out_features=num_classes + ) + return model diff --git a/src/models/senet.py b/src/models/senet.py index 2e053df..16468e3 100644 --- a/src/models/senet.py +++ b/src/models/senet.py @@ -237,9 +237,9 @@ def cell_senet(model_name='se_resnext50_32x4d', num_classes=1108, n_channels=6, model_state_dict = torch.load(weight)['model_state_dict'] model.load_state_dict(model_state_dict) print(f"\n\n******************************* Loaded checkpoint {weight}") - in_features = model._classifier.in_features - model._classifier = nn.Linear( - in_features=in_features, out_features=num_classes - ) + in_features = model._classifier.in_features + model._classifier = nn.Linear( + in_features=in_features, out_features=num_classes + ) return model