Skip to content

Commit

Permalink
- Smooth label loss
Browse files Browse the repository at this point in the history
- normalize per image
  • Loading branch information
ngxbac committed Jul 5, 2019
1 parent 4321488 commit d2d09b8
Show file tree
Hide file tree
Showing 4 changed files with 187 additions and 2 deletions.
9 changes: 8 additions & 1 deletion src/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,14 @@
from .experiment import Experiment
from .runner import ModelRunner as Runner
from models import *
from losses import *
from callbacks import *


registry.Model(cell_resnet)
registry.Model(cell_senet)
registry.Model(cell_densenet)
registry.Model(cell_densenet)

registry.Callback(LabelSmoothCriterionCallback)

registry.Criterion(LabelSmoothingCrossEntropy)
62 changes: 62 additions & 0 deletions src/callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from catalyst.dl.core import Callback, RunnerState
import torch.nn as nn


class LabelSmoothCriterionCallback(Callback):
def __init__(
self,
input_key: str = "targets",
output_key: str = "logits",
prefix: str = "loss",
criterion_key: str = None,
loss_key: str = None,
multiplier: float = 1.0
):
self.input_key = input_key
self.output_key = output_key
self.prefix = prefix
self.criterion_key = criterion_key
self.loss_key = loss_key
self.multiplier = multiplier

def _add_loss_to_state(self, state: RunnerState, loss):
if self.loss_key is None:
if state.loss is not None:
if isinstance(state.loss, list):
state.loss.append(loss)
else:
state.loss = [state.loss, loss]
else:
state.loss = loss
else:
if state.loss is not None:
assert isinstance(state.loss, dict)
state.loss[self.loss_key] = loss
else:
state.loss = {self.loss_key: loss}

def _compute_loss(self, state: RunnerState, criterion):
loss = criterion(
state.output[self.output_key],
state.input[self.input_key]
)
return loss

def on_stage_start(self, state: RunnerState):
assert state.criterion is not None

def on_batch_end(self, state: RunnerState):
if state.loader_name.startswith("train"):
criterion = state.get_key(
key="criterion", inner_key=self.criterion_key
)
else:
criterion = nn.CrossEntropyLoss()

loss = self._compute_loss(state, criterion) * self.multiplier

state.metrics.add_batch_value(metrics_dict={
self.prefix: loss.item(),
})

self._add_loss_to_state(state, loss)
92 changes: 91 additions & 1 deletion src/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,43 @@ def image_path(dataset,
"{}_s{}_w{}.png".format(address, site, channel))


def image_stats(pixel_stat,
experiment,
plate,
address,
site,
channel):
"""
Returns the path of a channel image.
Parameters
----------
dataset : str
what subset of the data: train, test
experiment : str
experiment name
plate : int
plate number
address : str
plate address
site : int
site number
channel : int
channel number
base_path : str
the base path of the raw images
Returns
-------
str the path of image
"""

channel_stat = pixel_stat[(pixel_stat.experiment == experiment)
& (pixel_stat.plate == plate)
& (pixel_stat.well == address)
& (pixel_stat.site == site)
& (pixel_stat.channel == channel)]

return channel_stat["mean"].values[0], channel_stat["std"].values[0]

# def load_image(image_path):
# with tf.io.gfile.GFile(image_path, 'rb') as f:
# return imread(f, format='png')
Expand Down Expand Up @@ -123,6 +160,21 @@ def convert_tensor_to_rgb(t, channels=DEFAULT_CHANNELS, vmax=255, rgb_map=RGB_MA
return im


def normalize(img, mean, std, max_pixel_value=255.0):
mean = np.array(mean, dtype=np.float32)
mean *= max_pixel_value

std = np.array(std, dtype=np.float32)
std *= max_pixel_value

denominator = np.reciprocal(std, dtype=np.float32)

img = img.astype(np.float32)
img -= mean
img *= denominator
return img


class RecursionCellularSite(Dataset):

def __init__(self,
Expand All @@ -134,6 +186,33 @@ def __init__(self,
channels=[1, 2, 3, 4, 5, 6],
):
df = pd.read_csv(csv_file, nrows=None)
self.pixel_stat = pd.read_csv(os.path.join(root, "pixel_stats.csv"))
self.stat_dict = {}
for experiment, plate, well, site, channel, mean, std in zip(self.pixel_stat.experiment,
self.pixel_stat.plate,
self.pixel_stat.well,
self.pixel_stat.site,
self.pixel_stat.channel,
self.pixel_stat["mean"],
self.pixel_stat["std"]):
if not experiment in self.stat_dict:
self.stat_dict[experiment] = {}

if not plate in self.stat_dict[experiment]:
self.stat_dict[experiment][plate] = {}

if not well in self.stat_dict[experiment][plate]:
self.stat_dict[experiment][plate][well] = {}

if not site in self.stat_dict[experiment][plate][well]:
self.stat_dict[experiment][plate][well][site] = {}

if not channel in self.stat_dict[experiment][plate][well][site]:
self.stat_dict[experiment][plate][well][channel] = {}

self.stat_dict[experiment][plate][well][channel]["mean"] = mean / 255
self.stat_dict[experiment][plate][well][channel]["std"] = std / 255


self.transform = transform
self.mode = mode
Expand Down Expand Up @@ -172,11 +251,22 @@ def __getitem__(self, idx):
) for channel in self.channels
]

std_arr = []
mean_arr = []

for channel in self.channels:
mean = self.stat_dict[experiment][plate][well][channel]["mean"]
std = self.stat_dict[experiment][plate][well][channel]["std"]
std_arr.append(std)
mean_arr.append(mean)

image = load_images_as_tensor(channel_paths, dtype=np.float32)
# image = convert_tensor_to_rgb(image)
image = image / 255
# image = image / 255
if self.transform:
image = self.transform(image=image)['image']

image = normalize(image, std=std_arr, mean=mean_arr, max_pixel_value=255)
image = np.transpose(image, (2, 0, 1)).astype(np.float32)

if self.mode == 'train':
Expand Down
26 changes: 26 additions & 0 deletions src/losses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import torch
import torch.nn as nn
import torch.nn.functional as F


class LabelSmoothingCrossEntropy(nn.Module):
"""
NLL loss with label smoothing.
"""
def __init__(self, smoothing=0.1):
"""
Constructor for the LabelSmoothing module.
:param smoothing: label smoothing factor
"""
super(LabelSmoothingCrossEntropy, self).__init__()
assert smoothing < 1.0
self.smoothing = smoothing
self.confidence = 1. - smoothing

def forward(self, x, target):
logprobs = F.log_softmax(x, dim=-1)
nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
nll_loss = nll_loss.squeeze(1)
smooth_loss = -logprobs.mean(dim=-1)
loss = self.confidence * nll_loss + self.smoothing * smooth_loss
return loss.mean()

0 comments on commit d2d09b8

Please sign in to comment.