Skip to content

Commit

Permalink
Simplify data loader
Browse files Browse the repository at this point in the history
  • Loading branch information
Dawars committed May 11, 2023
1 parent 6984eb4 commit 70f34ed
Showing 1 changed file with 10 additions and 66 deletions.
76 changes: 10 additions & 66 deletions data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import glob
from pathlib import Path

import imageio
import torch
from skimage import io, transform, color
import numpy as np
Expand All @@ -13,6 +14,7 @@
from torchvision import transforms, utils
from PIL import Image

imageio.plugins.freeimage.download()

# ==========================dataset load==========================
class RescaleT(object):
Expand Down Expand Up @@ -293,17 +295,6 @@ def __getitem__(self, idx):
else:
label_3 = io.imread(self.label_name_list[idx])

label = np.zeros(label_3.shape[0:2])
if (3 == len(label_3.shape)):
label = label_3[:, :, 0]
elif (2 == len(label_3.shape)):
label = label_3

if (3 == len(image.shape) and 2 == len(label.shape)):
label = label[:, :, np.newaxis]
elif (2 == len(image.shape) and 2 == len(label.shape)):
image = image[:, :, np.newaxis]
label = label[:, :, np.newaxis]

sample = {'imidx': imidx, 'image': image, 'label': label}

Expand All @@ -326,27 +317,10 @@ def __len__(self):
return len(self.image_name_list)

def __getitem__(self, idx):

image = io.imread(self.image_name_list[idx])
imname = self.image_name_list[idx]
imidx = np.array([idx])
image = np.atleast_3d(np.array(Image.open(self.image_name_list[idx])))
label = np.atleast_3d(np.array(Image.open(self.label_name_list[idx])))

if (0 == len(self.label_name_list)):
label_3 = np.zeros(image.shape)
else:
label_3 = io.imread(self.label_name_list[idx])

label = np.zeros(label_3.shape[0:2])
if (3 == len(label_3.shape)):
label = label_3[:, :, 0]
elif (2 == len(label_3.shape)):
label = label_3

if (3 == len(image.shape) and 2 == len(label.shape)):
label = label[:, :, np.newaxis]
elif (2 == len(image.shape) and 2 == len(label.shape)):
image = image[:, :, np.newaxis]
label = label[:, :, np.newaxis]
sample = {'imidx': imidx, 'image': image, 'label': 255 - label}

if self.transform:
Expand All @@ -368,27 +342,10 @@ def __len__(self):
return len(self.image_name_list)

def __getitem__(self, idx):

image = io.imread(self.image_name_list[idx])
imname = self.image_name_list[idx]
imidx = np.array([idx])
image = np.atleast_3d(np.array(Image.open(self.image_name_list[idx])))
label = np.atleast_3d(np.array(Image.open(self.label_name_list[idx])))

if (0 == len(self.label_name_list)):
label_3 = np.zeros(image.shape)
else:
label_3 = io.imread(self.label_name_list[idx])

label = np.zeros(label_3.shape[0:2])
if (3 == len(label_3.shape)):
label = label_3[:, :, 0]
elif (2 == len(label_3.shape)):
label = label_3

if (3 == len(image.shape) and 2 == len(label.shape)):
label = label[:, :, np.newaxis]
elif (2 == len(image.shape) and 2 == len(label.shape)):
image = image[:, :, np.newaxis]
label = label[:, :, np.newaxis]
sample = {'imidx': imidx, 'image': image, 'label': label}

if self.transform:
Expand All @@ -410,25 +367,12 @@ def __len__(self):
return len(self.image_name_list)

def __getitem__(self, idx):

image = io.imread(self.image_name_list[idx])
imname = self.image_name_list[idx]
imidx = np.array([idx])
image = np.atleast_3d(np.array(Image.open(self.image_name_list[idx]).convert("RGB")))
label = imageio.v3.imread(self.label_name_list[idx])

label_3 = io.imread(self.label_name_list[idx])

label = np.zeros(label_3.shape[0:2])
if (3 == len(label_3.shape)):
label = label_3[:, :, 0]
elif (2 == len(label_3.shape)):
label = label_3

if (3 == len(image.shape) and 2 == len(label.shape)):
label = label[:, :, np.newaxis]
elif (2 == len(image.shape) and 2 == len(label.shape)):
image = image[:, :, np.newaxis]
label = label[:, :, np.newaxis]
sample = {'imidx': imidx, 'image': image[..., :3], 'label': (label != 0.0).astype(int) * 255}
label = (label != 0.0).astype(int)[..., :1] * 255
sample = {'imidx': imidx, 'image': image[..., :3], 'label': label}

if self.transform:
sample = self.transform(sample)
Expand Down

0 comments on commit 70f34ed

Please sign in to comment.