-
Notifications
You must be signed in to change notification settings - Fork 19
/
data_utils.py
117 lines (92 loc) · 4.82 KB
/
data_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
from os import listdir
from os.path import join
import torch
from PIL import Image
from torch.utils.data.dataset import Dataset
from torchvision.transforms import Compose, RandomCrop, ToTensor, ToPILImage, CenterCrop, Resize,Grayscale
from imagecrop import FusionRandomCrop
from torchvision.transforms import functional as F
def is_image_file(filename):
return any(filename.endswith(extension) for extension in ['.tif','.bmp','.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG'])
def calculate_valid_crop_size(crop_size, upscale_factor):
return crop_size - (crop_size % upscale_factor)
def train_hr_transform(crop_size):
return Compose([
FusionRandomCrop(crop_size),
])
def train_vis_ir_transform():
return Compose([
Grayscale(num_output_channels=3),
ToTensor(),
])
def train_lr_transform(crop_size, upscale_factor):
return Compose([
ToPILImage(),
Resize(crop_size // upscale_factor, interpolation=Image.BICUBIC),
ToTensor()
])
def display_transform():
return Compose([
ToPILImage(),
Resize(400),
CenterCrop(400),
Grayscale(),
ToTensor()
])
class TrainDatasetFromFolder(Dataset):
def __init__(self, dataset_dir, crop_size, upscale_factor):
super(TrainDatasetFromFolder, self).__init__()
self.visible_image_filenames = [join(dataset_dir+'vi/', x) for x in listdir(dataset_dir+'vi/') if is_image_file(x)]
self.infrared_image_filenames = [x.replace('vi/V','ir/I') for x in self.visible_image_filenames]
crop_size = calculate_valid_crop_size(crop_size, upscale_factor)
self.hr_transform = train_hr_transform(crop_size)
self.vis_ir_transform = train_vis_ir_transform()
self.lr_transform = train_lr_transform(crop_size, upscale_factor)
def __getitem__(self, index):
visible_image = Image.open(self.visible_image_filenames[index])
infrared_image = Image.open(self.infrared_image_filenames[index])
crop_size = self.hr_transform(visible_image)#, infrared_image)
visible_image, infrared_image = F.crop(visible_image,crop_size[0],crop_size[1],crop_size[2],crop_size[3])\
,F.crop(infrared_image, crop_size[0],crop_size[1],crop_size[2],crop_size[3])
visible_image = self.vis_ir_transform(visible_image)
infrared_image = self.vis_ir_transform(infrared_image)
data = torch.cat((self.lr_transform(infrared_image)[0].unsqueeze(0),self.lr_transform(visible_image)[0].unsqueeze(0)))
return data, infrared_image, visible_image
def __len__(self):
return len(self.visible_image_filenames)
class ValDatasetFromFolder(Dataset):
def __init__(self, dataset_dir, upscale_factor):
super(ValDatasetFromFolder, self).__init__()
self.upscale_factor = upscale_factor
self.visible_image_filenames = [join(dataset_dir+'vi/', x) for x in listdir(dataset_dir+'vi/') if is_image_file(x)]
self.infrared_image_filenames = [x.replace('vi/V','ir/I') for x in self.visible_image_filenames]
def __getitem__(self, index):
visible_image = Image.open(self.visible_image_filenames[index])
infrared_image = Image.open(self.infrared_image_filenames[index])
w, h = visible_image.size
crop_size = calculate_valid_crop_size(min(w, h), self.upscale_factor)
visible_image1 = CenterCrop(crop_size)(visible_image)
infrared_image1 = CenterCrop(crop_size)(infrared_image)
visible_image1 = ToTensor()(Grayscale(num_output_channels=3)(visible_image1))
infrared_image1 = ToTensor()(Grayscale(num_output_channels=3)(infrared_image1))
data = torch.cat((infrared_image1[0].unsqueeze(0),visible_image1[0].unsqueeze(0)))
return data, infrared_image1, visible_image1
def __len__(self):
return len(self.infrared_image_filenames)
class TestDatasetFromFolder(Dataset):
def __init__(self, dataset_dir, upscale_factor):
super(TestDatasetFromFolder, self).__init__()
self.upscale_factor = upscale_factor
imagelist = listdir(dataset_dir+'tmp/vi/')
imagelist.sort()
self.visible_image_filenames = [join(dataset_dir+'tmp/vi/', x) for x in imagelist if is_image_file(x)]
self.infrared_image_filenames = [x.replace('tmp/vi/V_','tmp/ir/I_') for x in self.visible_image_filenames]
def __getitem__(self, index):
visible_image = Image.open(self.visible_image_filenames[index])
infrared_image = Image.open(self.infrared_image_filenames[index])
visible_image = ToTensor()(Grayscale(num_output_channels=3)(visible_image))
infrared_image = ToTensor()(Grayscale(num_output_channels=3)(infrared_image))
data = torch.cat((infrared_image[0].unsqueeze(0),visible_image[0].unsqueeze(0)))
return data, infrared_image, visible_image
def __len__(self):
return len(self.infrared_image_filenames)