Skip to content

Commit

Permalink
Initial Commit
Browse files Browse the repository at this point in the history
  • Loading branch information
sukkritsharmaofficial committed Apr 14, 2020
1 parent 649b42a commit 45c06e8
Show file tree
Hide file tree
Showing 13 changed files with 263 additions and 0 deletions.
Binary file added assets/output1.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added dataset/long/00001_00_10s.ARW
Binary file not shown.
Binary file added dataset/long/20120_00_30s.ARW
Binary file not shown.
Binary file added dataset/short/00001_00_0.1s.ARW
Binary file not shown.
Binary file added dataset/short/20120_00_0.1s.ARW
Binary file not shown.
Binary file added saved_model/checkpoint_sony_e4000.pth
Binary file not shown.
1 change: 1 addition & 0 deletions saved_model/model-here
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
put your model in this folder
96 changes: 96 additions & 0 deletions test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import os,time,scipy.io

import numpy as np
import rawpy
import glob

import torch
import torch.nn as nn
import torch.optim as optim

from model import SeeInDark

input_dir = './dataset/Sony/short/'
gt_dir = './dataset/Sony/long/'
m_path = './saved_model/'
m_name = 'checkpoint_sony_e4000.pth'
result_dir = './test_result_Sony/'

device = torch.device('cpu')
#get test IDs
test_fns = glob.glob(gt_dir + '*.ARW')
test_ids = []
for i in range(len(test_fns)):
_, test_fn = os.path.split(test_fns[i])
test_ids.append(int(test_fn[0:5]))



def pack_raw(raw):
#pack Bayer image to 4 channels
im = np.maximum(raw - 512,0)/ (16383 - 512) #subtract the black level

im = np.expand_dims(im,axis=2)
img_shape = im.shape
H = img_shape[0]
W = img_shape[1]

out = np.concatenate((im[0:H:2,0:W:2,:],
im[0:H:2,1:W:2,:],
im[1:H:2,1:W:2,:],
im[1:H:2,0:W:2,:]), axis=2)
return out



model = SeeInDark()
model.load_state_dict(torch.load( m_path + m_name ,map_location={'cuda:1':'cuda:0'}))
model = model.to(device)
if not os.path.isdir(result_dir):
os.makedirs(result_dir)

for test_id in test_ids:
#test the first image in each sequence
in_files = glob.glob(input_dir + '%05d_00*.ARW'%test_id)
for k in range(len(in_files)):
in_path = in_files[k]
_, in_fn = os.path.split(in_path)
print(in_fn)
gt_files = glob.glob(gt_dir + '%05d_00*.ARW'%test_id)
gt_path = gt_files[0]
_, gt_fn = os.path.split(gt_path)
in_exposure = float(in_fn[9:-5])
gt_exposure = float(gt_fn[9:-5])
ratio = min(gt_exposure/in_exposure,300)

raw = rawpy.imread(in_path)
im = raw.raw_image_visible.astype(np.float32)
input_full = np.expand_dims(pack_raw(im),axis=0) *ratio

im = raw.postprocess(use_camera_wb=True, half_size=False, no_auto_bright=True, output_bps=16)
scale_full = np.expand_dims(np.float32(im/65535.0),axis = 0)

gt_raw = rawpy.imread(gt_path)
im = gt_raw.postprocess(use_camera_wb=True, half_size=False, no_auto_bright=True, output_bps=16)
gt_full = np.expand_dims(np.float32(im/65535.0),axis = 0)

input_full = np.minimum(input_full,1.0)

in_img = torch.from_numpy(input_full).permute(0,3,1,2).to(device)
out_img = model(in_img)
output = out_img.permute(0, 2, 3, 1).cpu().data.numpy()

output = np.minimum(np.maximum(output,0),1)

output = output[0,:,:,:]
gt_full = gt_full[0,:,:,:]
scale_full = scale_full[0,:,:,:]
origin_full = scale_full
scale_full = scale_full*np.mean(gt_full)/np.mean(scale_full) # scale the low-light image to the same mean of the groundtruth

scipy.misc.toimage(origin_full*255, high=255, low=0, cmin=0, cmax=255).save(result_dir + '%5d_00_%d_ori.png'%(test_id,ratio))
scipy.misc.toimage(output*255, high=255, low=0, cmin=0, cmax=255).save(result_dir + '%5d_00_%d_out.png'%(test_id,ratio))
scipy.misc.toimage(scale_full*255, high=255, low=0, cmin=0, cmax=255).save(result_dir + '%5d_00_%d_scale.png'%(test_id,ratio))
scipy.misc.toimage(gt_full*255, high=255, low=0, cmin=0, cmax=255).save(result_dir + '%5d_00_%d_gt.png'%(test_id,ratio))


Binary file added test_result_Sony/_ 1_00_100_gt.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added test_result_Sony/_ 1_00_100_ori.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added test_result_Sony/_ 1_00_100_out.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added test_result_Sony/_ 1_00_100_scale.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
166 changes: 166 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
import os,time,scipy.io

import numpy as np
import rawpy
import glob

import torch
import torch.nn as nn
import torch.optim as optim

from model import SeeInDark

input_dir = './dataset/Sony/short/'
gt_dir = './dataset/Sony/long/'
result_dir = './result_Sony/'
model_dir = './saved_model/'

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
#get train and test IDs
train_fns = glob.glob(gt_dir + '0*.ARW')
train_ids = []
for i in range(len(train_fns)):
_, train_fn = os.path.split(train_fns[i])
train_ids.append(int(train_fn[0:5]))

test_fns = glob.glob(gt_dir + '/1*.ARW')
test_ids = []
for i in range(len(test_fns)):
_, test_fn = os.path.split(test_fns[i])
test_ids.append(int(test_fn[0:5]))



ps = 512 #patch size for training
save_freq = 100

DEBUG = 0
if DEBUG == 1:
save_freq = 100
train_ids = train_ids[0:5]
test_ids = test_ids[0:5]

def pack_raw(raw):
#pack Bayer image to 4 channels
im = raw.raw_image_visible.astype(np.float32)
im = np.maximum(im - 512,0)/ (16383 - 512) #subtract the black level

im = np.expand_dims(im,axis=2)
img_shape = im.shape
H = img_shape[0]
W = img_shape[1]

out = np.concatenate((im[0:H:2,0:W:2,:],
im[0:H:2,1:W:2,:],
im[1:H:2,1:W:2,:],
im[1:H:2,0:W:2,:]), axis=2)
return out

def reduce_mean(out_im, gt_im):
return torch.abs(out_im - gt_im).mean()


#Raw data takes long time to load. Keep them in memory after loaded.
gt_images=[None]*6000
input_images = {}
input_images['300'] = [None]*len(train_ids)
input_images['250'] = [None]*len(train_ids)
input_images['100'] = [None]*len(train_ids)

g_loss = np.zeros((5000,1))



allfolders = glob.glob('./result/*0')
lastepoch = 0
for folder in allfolders:
lastepoch = np.maximum(lastepoch, int(folder[-4:]))

learning_rate = 1e-4
model = SeeInDark().to(device)
model._initialize_weights()
opt = optim.Adam(model.parameters(), lr = learning_rate)
for epoch in range(lastepoch,4001):
if os.path.isdir("result/%04d"%epoch):
continue
cnt=0
if epoch > 2000:
for g in opt.param_groups:
g['lr'] = 1e-5


for ind in np.random.permutation(len(train_ids)):
# get the path from image id
train_id = train_ids[ind]
in_files = glob.glob(input_dir + '%05d_00*.ARW'%train_id)
in_path = in_files[np.random.random_integers(0,len(in_files)-1)]
_, in_fn = os.path.split(in_path)

gt_files = glob.glob(gt_dir + '%05d_00*.ARW'%train_id)
gt_path = gt_files[0]
_, gt_fn = os.path.split(gt_path)
in_exposure = float(in_fn[9:-5])
gt_exposure = float(gt_fn[9:-5])
ratio = min(gt_exposure/in_exposure,300)

st=time.time()
cnt+=1

if input_images[str(ratio)[0:3]][ind] is None:
raw = rawpy.imread(in_path)
input_images[str(ratio)[0:3]][ind] = np.expand_dims(pack_raw(raw),axis=0) *ratio

gt_raw = rawpy.imread(gt_path)
im = gt_raw.postprocess(use_camera_wb=True, half_size=False, no_auto_bright=True, output_bps=16)
gt_images[ind] = np.expand_dims(np.float32(im/65535.0),axis = 0)


#crop
H = input_images[str(ratio)[0:3]][ind].shape[1]
W = input_images[str(ratio)[0:3]][ind].shape[2]

xx = np.random.randint(0,W-ps)
yy = np.random.randint(0,H-ps)
input_patch = input_images[str(ratio)[0:3]][ind][:,yy:yy+ps,xx:xx+ps,:]
gt_patch = gt_images[ind][:,yy*2:yy*2+ps*2,xx*2:xx*2+ps*2,:]


if np.random.randint(2,size=1)[0] == 1: # random flip
input_patch = np.flip(input_patch, axis=1)
gt_patch = np.flip(gt_patch, axis=1)
if np.random.randint(2,size=1)[0] == 1:
input_patch = np.flip(input_patch, axis=0)
gt_patch = np.flip(gt_patch, axis=0)
if np.random.randint(2,size=1)[0] == 1: # random transpose
input_patch = np.transpose(input_patch, (0,2,1,3))
gt_patch = np.transpose(gt_patch, (0,2,1,3))


input_patch = np.minimum(input_patch,1.0)
gt_patch = np.maximum(gt_patch, 0.0)

in_img = torch.from_numpy(input_patch).permute(0,3,1,2).to(device)
gt_img = torch.from_numpy(gt_patch).permute(0,3,1,2).to(device)


model.zero_grad()
out_img = model(in_img)

loss = reduce_mean(out_img, gt_img)
loss.backward()

opt.step()
g_loss[ind]=loss.data

#print("%d %d Loss=%.3f Time=%.3f"%(epoch,cnt,np.mean(g_loss[np.where(g_loss)]),time.time()-st))

if epoch%save_freq==0:
if not os.path.isdir(result_dir + '%04d'%epoch):
os.makedirs(result_dir + '%04d'%epoch)
output = out_img.permute(0, 2, 3, 1).cpu().data.numpy()
output = np.minimum(np.maximum(output,0),1)

temp = np.concatenate((gt_patch[0,:,:,:], output[0,:,:,:]),axis=1)
scipy.misc.toimage(temp*255, high=255, low=0, cmin=0, cmax=255).save(result_dir + '%04d/%05d_00_train_%d.jpg'%(epoch,train_id,ratio))
torch.save(model.state_dict(), model_dir+'checkpoint_sony_e%04d.pth'%epoch)

0 comments on commit 45c06e8

Please sign in to comment.