Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deepsupervision #1

Merged
merged 7 commits into from
Aug 1, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
DS for SENET
  • Loading branch information
ngxbac committed Jul 29, 2019
commit 40393bf020381be63fee28e677a0ff00bde149fd
2 changes: 1 addition & 1 deletion bin/train_ds.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ RUN_CONFIG=config_ds.yml

for channels in [1,2,3,4,5]; do
for fold in 0; do
LOGDIR=/raid/bac/kaggle/logs/recursion_cell/test/190729/fold_$fold/DSInceptionV3/
LOGDIR=/raid/bac/kaggle/logs/recursion_cell/test/190729/fold_$fold/DSSENet/
catalyst-dl run \
--config=./configs/${RUN_CONFIG} \
--logdir=$LOGDIR \
Expand Down
14 changes: 9 additions & 5 deletions configs/config_ds.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
model_params:
model: DSInceptionV3
model: DSSENet
pretrained: True
n_channels: 5
num_classes: 1108
Expand Down Expand Up @@ -52,13 +52,15 @@ stages:
callbacks_params: &callback_params
loss:
callback: DSCriterionCallback
loss_weights: [0.001, 0.005, 0.01, 0.02, 0.02, 0.1, 1.0]
# loss_weights: [0.001, 0.005, 0.01, 0.02, 0.02, 0.1, 1.0] # For DS InceptionV3
loss_weights: [0.02, 0.02, 0.1, 1.0]
optimizer:
callback: OptimizerCallback
accumulation_steps: 2
accuracy:
callback: DSAccuracyCallback
logit_names: ["m2", "m4", "m6", "m8", "m9", "m10", "final"]
# logit_names: ["m2", "m4", "m6", "m8", "m9", "m10", "final"] # For DS InceptionV3
logit_names: ["m1", "m2", "m3", "final"]
scheduler:
callback: SchedulerCallback
reduce_metric: *reduce_metric
Expand All @@ -85,13 +87,15 @@ stages:
callbacks_params:
loss:
callback: DSCriterionCallback
loss_weights: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
# loss_weights: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
loss_weights: [0.02, 0.02, 0.1, 1.0]
optimizer:
callback: OptimizerCallback
accumulation_steps: 2
accuracy:
callback: DSAccuracyCallback
logit_names: ["m2", "m4", "m6", "m8", "m9", "m10", "final"]
# logit_names: ["m2", "m4", "m6", "m8", "m9", "m10", "final"]
logit_names: ["m1", "m2", "m3", "final"]
scheduler:
callback: SchedulerCallback
reduce_metric: *reduce_metric
Expand Down
1 change: 1 addition & 0 deletions src/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
registry.Model(InceptionV3TIMM)
registry.Model(GluonResnetTIMM)
registry.Model(DSInceptionV3)
registry.Model(DSSENet)

# Register callbacks
registry.Callback(LabelSmoothCriterionCallback)
Expand Down
82 changes: 81 additions & 1 deletion src/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,21 @@ def predict(model, loader):
return preds


def predict_ds(model, loader):
model.eval()
preds = []
with torch.no_grad():
for dct in tqdm(loader, total=len(loader)):
images = dct['images'].to(device)
pred = model(images)
pred = [p.detach().cpu().numpy() for p in pred]
preds.append(pred)

preds = np.concatenate(preds, axis=1)
print(preds.shape)
return preds


def predict_all():
# test_csv = '/raid/data/kaggle/recursion-cellular-image-classification/test.csv'
test_csv = './csv/valid_0.csv'
Expand Down Expand Up @@ -101,5 +116,70 @@ def predict_all():
np.save(f"./prediction/fold_0/{model_name}_{channel_str}_valid.npy", preds)


def predict_deepsupervision():
test_csv = '/raid/data/kaggle/recursion-cellular-image-classification/test.csv'
# test_csv = './csv/valid_0.csv'
model_name = 'DSInceptionV3'

for channel_str in [
"[1,2,3,4,5]",
]:

log_dir = f"/raid/bac/kaggle/logs/recursion_cell/test/190729/fold_0/{model_name}/"
root = "/raid/data/kaggle/recursion-cellular-image-classification/"
sites = [1]
channels = [int(i) for i in channel_str[1:-1].split(',')]

# log_dir = log_dir.replace('[', '[[]')
# log_dir = log_dir.replace(']', '[]]')

ckp = os.path.join(log_dir, "checkpoints/stage1.50.pth")
model = DSInceptionV3(
num_classes=1108,
n_channels=len(channels) * len(sites)
)

checkpoint = torch.load(ckp)
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(device)
# model = nn.DataParallel(model)

print("*" * 50)
print(f"checkpoint: {ckp}")
print(f"Channel: {channel_str}")
preds = []
for site in [1, 2]:
# Dataset
dataset = RecursionCellularSite(
csv_file=test_csv,
root=root,
transform=valid_aug(512),
mode='test',
sites=[site],
channels=channels
)

loader = DataLoader(
dataset=dataset,
batch_size=128,
shuffle=False,
num_workers=8,
)

pred = predict_ds(model, loader)
preds.append(pred)

preds = np.asarray(preds)#.mean(axis=0)
print(preds.shape)
# all_preds = np.argmax(preds, axis=1)
df = pd.read_csv(test_csv)
submission = df.copy()
# submission['sirna'] = all_preds.astype(int)
os.makedirs("./prediction/DS/", exist_ok=True)
# submission.to_csv(f'./prediction/DS/{model_name}_test.csv', index=False, columns=['id_code', 'sirna'])
np.save(f"./prediction/DS/{model_name}_test.npy", preds)


if __name__ == '__main__':
predict_all()
# predict_all()
predict_deepsupervision()
2 changes: 1 addition & 1 deletion src/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
from .efficientnet import EfficientNet
from .inceptionv3 import InceptionV3TIMM
from .gluon_resnet import GluonResnetTIMM
from .deepsupervision import DSInceptionV3
from .deepsupervision import DSInceptionV3, DSSENet
3 changes: 2 additions & 1 deletion src/models/deepsupervision/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .inception_v3 import DSInceptionV3
from .inception_v3 import DSInceptionV3
from .senet import DSSENet
110 changes: 110 additions & 0 deletions src/models/deepsupervision/senet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
from catalyst.contrib.modules.common import Flatten
from catalyst.contrib.modules.pooling import GlobalConcatPool2d
from cnn_finetune import make_model


class DSSENet(nn.Module):
def __init__(
self,
model_name='se_resnext50_32x4d',
num_classes=6,
pretrained=True,
n_channels=4,

):
super(DSSENet, self).__init__()
self.model = make_model(
model_name=model_name,
num_classes=num_classes,
pretrained=True,
dropout_p=0.3
)
# print(self.model)
conv1 = self.model._features[0].conv1
self.model._features[0].conv1 = nn.Conv2d(in_channels=n_channels,
out_channels=conv1.out_channels,
kernel_size=conv1.kernel_size,
stride=conv1.stride,
padding=conv1.padding,
bias=conv1.bias)

# copy pretrained weights
self.model._features[0].conv1.weight.data[:,:3,:,:] = conv1.weight.data
self.model._features[0].conv1.weight.data[:,3:n_channels,:,:] = conv1.weight.data[:,:int(n_channels-3),:,:]

self.deepsuper_1 = nn.Sequential(
GlobalConcatPool2d(),
Flatten(),
nn.BatchNorm1d(256 * 2),
nn.Linear(256 * 2, num_classes)
)

self.deepsuper_2 = nn.Sequential(
GlobalConcatPool2d(),
Flatten(),
nn.BatchNorm1d(512 * 2),
nn.Linear(512 * 2, num_classes)
)

self.deepsuper_3 = nn.Sequential(
GlobalConcatPool2d(),
Flatten(),
nn.BatchNorm1d(1024 * 2),
nn.Linear(1024 * 2, num_classes)
)

# WARNING: should adapt the Linear layer to be suitable for each image size !!!
self.fc = nn.Sequential(
nn.Conv2d(in_channels=2048, out_channels=128, kernel_size=(1, 1)),
nn.ReLU(),
Flatten(),
nn.Linear(32768, 1024), # Take care here: 3200 for 224x224, 25088 for 512x512
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(1024, num_classes)
)

self.is_infer = False

def freeze_base(self):
# pass
for param in self.model.parameters():
param.requires_grad = False

def unfreeze_base(self):
# pass
for param in self.model.parameters():
param.requires_grad = True

def forward(self, x):
x = self.model._features[0](x)
x = self.model._features[1](x)
x_1 = self.deepsuper_1(x)
x = self.model._features[2](x)
x_2 = self.deepsuper_2(x)
x = self.model._features[3](x)
x_3 = self.deepsuper_3(x)
x = self.model._features[4](x)
x_final = self.fc(x)

return x_1, x_2, x_3, x_final

def freeze(self):
# Freeze all the backbone
for param in self.model.parameters():
param.requires_grad = False

def unfreeze(self):
# Unfreeze all the backbone
for param in self.model.parameters():
param.requires_grad = True


if __name__ == '__main__':
x = torch.zeros((2, 4, 512, 512))
model = DSSENet()
out = model(x)