Skip to content

Commit

Permalink
Improve style
Browse files Browse the repository at this point in the history
  • Loading branch information
fepegar committed Oct 30, 2018
1 parent bbbd67c commit 455bf75
Show file tree
Hide file tree
Showing 4 changed files with 333 additions and 306 deletions.
6 changes: 4 additions & 2 deletions demos/PyTorchNiftyNet/libs/dataset_niftynet.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import torch
from torch.utils.data import Dataset


class DatasetNiftySampler(Dataset):
def __init__(self, sampler):
super(DatasetNiftySampler, self).__init__()
self.sampler = sampler

def __getitem__(self, index):
data = self.sampler(idx=index)
return torch.from_numpy(data['image'][..., 0, :]).float(),\
torch.from_numpy(data['label'][..., 0, :]).float()
image = torch.from_numpy(data['image'][..., 0, :]).float()
label = torch.from_numpy(data['label'][..., 0, :]).float()
return image, label

def __len__(self):
return len(self.sampler.reader.output_list)
18 changes: 4 additions & 14 deletions demos/PyTorchNiftyNet/libs/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,28 +10,18 @@ def __init__(self):
def forward(self, output, label):
probs = output.view(-1)
mask = label.view(-1)
#eps = 0.00000001
smooth = 1

intersection = torch.sum(probs * mask)

den1 = torch.sum(probs)
den2 = torch.sum(mask)

#soft_dice = ((2 * intersection) / (den1 + den2 + eps))

soft_dice = ((2 * intersection) + smooth) / (den1 + den2 + smooth)

soft_dice = (2 * intersection + smooth) / (den1 + den2 + smooth)
return -soft_dice


def dice(input, target):

eps = 0.00000001

def dice(input, target):
epsilon = 1e-8
iflat = input.view(-1)
tflat = target.view(-1)
intersection = (iflat * tflat).sum()

return ((2. * intersection) /
(iflat.sum() + tflat.sum() + eps))
return 2 * intersection / (iflat.sum() + tflat.sum() + epsilon)
Loading

0 comments on commit 455bf75

Please sign in to comment.