Skip to content

Commit

Permalink
baseline with se_resnext50_32x4d
Browse files Browse the repository at this point in the history
  • Loading branch information
ngxbac committed Jul 3, 2019
1 parent 9733f52 commit 27a7a96
Show file tree
Hide file tree
Showing 6 changed files with 347 additions and 19 deletions.
2 changes: 1 addition & 1 deletion bin/train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ export CUDA_VISIBLE_DEVICES=2,3
RUN_CONFIG=config.yml


LOGDIR=/raid/bac/kaggle/logs/recursion_cell/test/
LOGDIR=/raid/bac/kaggle/logs/recursion_cell/se_resnext50_32x4d/
catalyst-dl run \
--config=./configs/${RUN_CONFIG} \
--logdir=$LOGDIR \
Expand Down
4 changes: 2 additions & 2 deletions configs/config.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
model_params:
model: cell_densenet
model_name: densenet121
model: cell_senet
model_name: se_resnext50_32x4d
n_channels: 6
num_classes: 1108

Expand Down
7 changes: 4 additions & 3 deletions src/augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
def train_aug(image_size=224):
return Compose([
RandomCrop(image_size, image_size),
Flip(),
RandomRotate90(),
HorizontalFlip(),
# Normalize(),
], p=1)


def valid_aug(image_size=224):
return Compose([
Resize(image_size, image_size),
CenterCrop(image_size, image_size),
# Normalize(),
], p=1)
81 changes: 81 additions & 0 deletions src/make_submission.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import pandas as pd
import numpy as np

import torch
import torch.nn.functional as Ftorch
from torch.utils.data import DataLoader
import os
import glob
import click
from tqdm import *

from models import *
from augmentation import *
from dataset import RecursionCellularSite


device = torch.device('cuda')


def predict(model, loader):
model.eval()
preds = []
with torch.no_grad():
for dct in tqdm(loader, total=len(loader)):
images = dct['images'].to(device)
pred = model(images)
pred = Ftorch.softmax(pred)
pred = pred.detach().cpu().numpy()
preds.append(pred)

preds = np.concatenate(preds, axis=0)
return preds


def predict_all():
test_csv = '/raid/data/kaggle/recursion-cellular-image-classification/test.csv'
log_dir = "/raid/bac/kaggle/logs/recursion_cell/se_resnext50_32x4d/"
root = "/raid/data/kaggle/recursion-cellular-image-classification/"
site = 1
channels = [1,2,3,4,5,6]

model = cell_senet(
model_name="se_resnext50_32x4d",
num_classes=1108,
n_channels=6
)

checkpoint = f"{log_dir}/checkpoints/best.pth"
checkpoint = torch.load(checkpoint)
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(device)

# Dataset
dataset = RecursionCellularSite(
csv_file=test_csv,
root=root,
transform=valid_aug(320),
mode='test',
site=site,
channels=channels
)

loader = DataLoader(
dataset=dataset,
batch_size=128,
shuffle=False,
num_workers=4,
)

pred = predict(model, loader)

all_preds = np.argmax(pred, axis=1)
df = pd.read_csv(test_csv)
submission = df.copy()
submission['sirna'] = all_preds.astype(int)
os.makedirs("submission", exist_ok=True)
submission.to_csv('./submission/submission_se_resnext50_32x4d.csv', index=False, columns=['id_code', 'sirna'])


if __name__ == '__main__':
predict_all()
23 changes: 10 additions & 13 deletions src/models/senet.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,16 @@ def cell_senet(model_name='se_resnext50', num_classes=1108, n_channels=6):
num_classes=num_classes,
pretrained=True
)
conv1 = model.layer0.conv1
model.layer0.conv1 = nn.Conv2d(in_channels=n_channels,
out_channels=conv1.out_channels,
kernel_size=conv1.kernel_size,
stride=conv1.stride,
padding=conv1.padding,
bias=conv1.bias)
print(model)
conv1 = model._features[0].conv1
model._features[0].conv1 = nn.Conv2d(in_channels=n_channels,
out_channels=conv1.out_channels,
kernel_size=conv1.kernel_size,
stride=conv1.stride,
padding=conv1.padding,
bias=conv1.bias)

# copy pretrained weights
model.layer0.conv1.weight.data[:, :3, :, :] = conv1.weight.data
model.layer0.conv1.weight.data[:, 3:n_channels, :, :] = conv1.weight.data[:, :int(n_channels-3), :, :]

model.avgpool = nn.AdaptiveAvgPool2d(1)
in_features = model.last_linear.in_features
model.last_linear = nn.Linear(in_features, num_classes)
model._features[0].conv1.weight.data[:,:3,:,:] = conv1.weight.data
model._features[0].conv1.weight.data[:,3:n_channels,:,:] = conv1.weight.data[:,:int(n_channels-3),:,:]
return model
249 changes: 249 additions & 0 deletions src/rxrxio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os

import numpy as np
from skimage.io import imread
import pandas as pd

import tensorflow as tf

DEFAULT_BASE_PATH = 'gs://rxrx1-us-central1'
DEFAULT_METADATA_BASE_PATH = os.path.join(DEFAULT_BASE_PATH, 'metadata')
DEFAULT_IMAGES_BASE_PATH = os.path.join(DEFAULT_BASE_PATH, 'images')
DEFAULT_CHANNELS = (1, 2, 3, 4, 5, 6)
RGB_MAP = {
1: {
'rgb': np.array([19, 0, 249]),
'range': [0, 51]
},
2: {
'rgb': np.array([42, 255, 31]),
'range': [0, 107]
},
3: {
'rgb': np.array([255, 0, 25]),
'range': [0, 64]
},
4: {
'rgb': np.array([45, 255, 252]),
'range': [0, 191]
},
5: {
'rgb': np.array([250, 0, 253]),
'range': [0, 89]
},
6: {
'rgb': np.array([254, 255, 40]),
'range': [0, 191]
}
}


def load_image(image_path):
with tf.io.gfile.GFile(image_path, 'rb') as f:
return imread(f, format='png')


def load_images_as_tensor(image_paths, dtype=np.uint8):
n_channels = len(image_paths)

data = np.ndarray(shape=(512, 512, n_channels), dtype=dtype)

for ix, img_path in enumerate(image_paths):
data[:, :, ix] = load_image(img_path)

return data


def convert_tensor_to_rgb(t, channels=DEFAULT_CHANNELS, vmax=255, rgb_map=RGB_MAP):
"""
Converts and returns the image data as RGB image
Parameters
----------
t : np.ndarray
original image data
channels : list of int
channels to include
vmax : int
the max value used for scaling
rgb_map : dict
the color mapping for each channel
See rxrx.io.RGB_MAP to see what the defaults are.
Returns
-------
np.ndarray the image data of the site as RGB channels
"""
colored_channels = []
for i, channel in enumerate(channels):
x = (t[:, :, i] / vmax) / \
((rgb_map[channel]['range'][1] - rgb_map[channel]['range'][0]) / 255) + \
rgb_map[channel]['range'][0] / 255
x = np.where(x > 1., 1., x)
x_rgb = np.array(
np.outer(x, rgb_map[channel]['rgb']).reshape(512, 512, 3),
dtype=int)
colored_channels.append(x_rgb)
im = np.array(np.array(colored_channels).sum(axis=0), dtype=int)
im = np.where(im > 255, 255, im)
return im


def image_path(dataset,
experiment,
plate,
address,
site,
channel,
base_path=DEFAULT_IMAGES_BASE_PATH):
"""
Returns the path of a channel image.
Parameters
----------
dataset : str
what subset of the data: train, test
experiment : str
experiment name
plate : int
plate number
address : str
plate address
site : int
site number
channel : int
channel number
base_path : str
the base path of the raw images
Returns
-------
str the path of image
"""
return os.path.join(base_path, dataset, experiment, "Plate{}".format(plate),
"{}_s{}_w{}.png".format(address, site, channel))


def load_site(dataset,
experiment,
plate,
well,
site,
channels=DEFAULT_CHANNELS,
base_path=DEFAULT_IMAGES_BASE_PATH):
"""
Returns the image data of a site
Parameters
----------
dataset : str
what subset of the data: train, test
experiment : str
experiment name
plate : int
plate number
address : str
plate address
site : int
site number
channels : list of int
channels to include
base_path : str
the base path of the raw images
Returns
-------
np.ndarray the image data of the site
"""
channel_paths = [
image_path(
dataset, experiment, plate, well, site, c, base_path=base_path)
for c in channels
]
return load_images_as_tensor(channel_paths)


def load_site_as_rgb(dataset,
experiment,
plate,
well,
site,
channels=DEFAULT_CHANNELS,
base_path=DEFAULT_IMAGES_BASE_PATH,
rgb_map=RGB_MAP):
"""
Loads and returns the image data as RGB image
Parameters
----------
dataset : str
what subset of the data: train, test
experiment : str
experiment name
plate : int
plate number
address : str
plate address
site : int
site number
channels : list of int
channels to include
base_path : str
the base path of the raw images
rgb_map : dict
the color mapping for each channel
See rxrx.io.RGB_MAP to see what the defaults are.
Returns
-------
np.ndarray the image data of the site as RGB channels
"""
x = load_site(dataset, experiment, plate, well, site, channels, base_path)
return convert_tensor_to_rgb(x, channels, rgb_map=rgb_map)


def _tf_read_csv(path):
with tf.io.gfile.GFile(path, 'rb') as f:
return pd.read_csv(f)


def _load_dataset(base_path, dataset, include_controls=True):
df = _tf_read_csv(os.path.join(base_path, dataset + '.csv'))
if include_controls:
controls = _tf_read_csv(
os.path.join(base_path, dataset + '_controls.csv'))
df['well_type'] = 'treatment'
df = pd.concat([controls, df], sort=True)
df['cell_type'] = df.experiment.str.split("-").apply(lambda a: a[0])
df['dataset'] = dataset
dfs = []
for site in (1, 2):
df = df.copy()
df['site'] = site
dfs.append(df)
res = pd.concat(dfs).sort_values(
by=['id_code', 'site']).set_index('id_code')
return res


def combine_metadata(base_path=DEFAULT_METADATA_BASE_PATH,
include_controls=True):
"""
Combines all metadata files into a single dataframe and
expands it to include sites, not just wells.
Note, that the dtype of sirna is a float due to the missing
test values but it should be treated as an int.
Parameters
----------
base_path : str
where the metadata files from Kaggle live
include_controls : bool
indicate if you want the controls included in the dataframe
Returns
-------
pandas.DataFrame the combined metadata
"""
df = pd.concat(
[
_load_dataset(
base_path, dataset, include_controls=include_controls)
for dataset in ['test', 'train']
],
sort=True)
return df

0 comments on commit 27a7a96

Please sign in to comment.