Skip to content

Commit

Permalink
Added custom toimage function and updated code
Browse files Browse the repository at this point in the history
  • Loading branch information
sukkritsharmaofficial committed Apr 14, 2020
1 parent 1b7f656 commit ae258d8
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 4 deletions.
9 changes: 5 additions & 4 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch.optim as optim

from model import SeeInDark
from toimage import toimage

input_dir = './dataset/Sony/short/'
gt_dir = './dataset/Sony/long/'
Expand Down Expand Up @@ -88,9 +89,9 @@ def pack_raw(raw):
origin_full = scale_full
scale_full = scale_full*np.mean(gt_full)/np.mean(scale_full) # scale the low-light image to the same mean of the groundtruth

scipy.misc.toimage(origin_full*255, high=255, low=0, cmin=0, cmax=255).save(result_dir + '%5d_00_%d_ori.png'%(test_id,ratio))
scipy.misc.toimage(output*255, high=255, low=0, cmin=0, cmax=255).save(result_dir + '%5d_00_%d_out.png'%(test_id,ratio))
scipy.misc.toimage(scale_full*255, high=255, low=0, cmin=0, cmax=255).save(result_dir + '%5d_00_%d_scale.png'%(test_id,ratio))
scipy.misc.toimage(gt_full*255, high=255, low=0, cmin=0, cmax=255).save(result_dir + '%5d_00_%d_gt.png'%(test_id,ratio))
toimage(origin_full*255, high=255, low=0, cmin=0, cmax=255).save(result_dir + '%5d_00_%d_ori.png'%(test_id,ratio))
toimage(output*255, high=255, low=0, cmin=0, cmax=255).save(result_dir + '%5d_00_%d_out.png'%(test_id,ratio))
toimage(scale_full*255, high=255, low=0, cmin=0, cmax=255).save(result_dir + '%5d_00_%d_scale.png'%(test_id,ratio))
toimage(gt_full*255, high=255, low=0, cmin=0, cmax=255).save(result_dir + '%5d_00_%d_gt.png'%(test_id,ratio))


126 changes: 126 additions & 0 deletions toimage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import numpy as np
from PIL import Image


_errstr = "Mode is unknown or incompatible with input array shape."


def bytescale(data, cmin=None, cmax=None, high=255, low=0):
if data.dtype == np.uint8:
return data

if high > 255:
raise ValueError("`high` should be less than or equal to 255.")
if low < 0:
raise ValueError("`low` should be greater than or equal to 0.")
if high < low:
raise ValueError("`high` should be greater than or equal to `low`.")

if cmin is None:
cmin = data.min()
if cmax is None:
cmax = data.max()

cscale = cmax - cmin
if cscale < 0:
raise ValueError("`cmax` should be larger than `cmin`.")
elif cscale == 0:
cscale = 1

scale = float(high - low) / cscale
bytedata = (data - cmin) * scale + low
return (bytedata.clip(low, high) + 0.5).astype(np.uint8)


def toimage(arr, high=255, low=0, cmin=None, cmax=None, pal=None,
mode=None, channel_axis=None):
data = np.asarray(arr)
if np.iscomplexobj(data):
raise ValueError("Cannot convert a complex-valued array.")
shape = list(data.shape)
valid = len(shape) == 2 or ((len(shape) == 3) and
((3 in shape) or (4 in shape)))
if not valid:
raise ValueError("'arr' does not have a suitable array shape for "
"any mode.")
if len(shape) == 2:
shape = (shape[1], shape[0]) # columns show up first
if mode == 'F':
data32 = data.astype(np.float32)
image = Image.frombytes(mode, shape, data32.tostring())
return image
if mode in [None, 'L', 'P']:
bytedata = bytescale(data, high=high, low=low,
cmin=cmin, cmax=cmax)
image = Image.frombytes('L', shape, bytedata.tostring())
if pal is not None:
image.putpalette(np.asarray(pal, dtype=np.uint8).tostring())
# Becomes a mode='P' automagically.
elif mode == 'P': # default gray-scale
pal = (np.arange(0, 256, 1, dtype=np.uint8)[:, np.newaxis] *
np.ones((3,), dtype=np.uint8)[np.newaxis, :])
image.putpalette(np.asarray(pal, dtype=np.uint8).tostring())
return image
if mode == '1': # high input gives threshold for 1
bytedata = (data > high)
image = Image.frombytes('1', shape, bytedata.tostring())
return image
if cmin is None:
cmin = np.amin(np.ravel(data))
if cmax is None:
cmax = np.amax(np.ravel(data))
data = (data*1.0 - cmin)*(high - low)/(cmax - cmin) + low
if mode == 'I':
data32 = data.astype(np.uint32)
image = Image.frombytes(mode, shape, data32.tostring())
else:
raise ValueError(_errstr)
return image

# if here then 3-d array with a 3 or a 4 in the shape length.
# Check for 3 in datacube shape --- 'RGB' or 'YCbCr'
if channel_axis is None:
if (3 in shape):
ca = np.flatnonzero(np.asarray(shape) == 3)[0]
else:
ca = np.flatnonzero(np.asarray(shape) == 4)
if len(ca):
ca = ca[0]
else:
raise ValueError("Could not find channel dimension.")
else:
ca = channel_axis

numch = shape[ca]
if numch not in [3, 4]:
raise ValueError("Channel axis dimension is not valid.")

bytedata = bytescale(data, high=high, low=low, cmin=cmin, cmax=cmax)
if ca == 2:
strdata = bytedata.tostring()
shape = (shape[1], shape[0])
elif ca == 1:
strdata = np.transpose(bytedata, (0, 2, 1)).tostring()
shape = (shape[2], shape[0])
elif ca == 0:
strdata = np.transpose(bytedata, (1, 2, 0)).tostring()
shape = (shape[2], shape[1])
if mode is None:
if numch == 3:
mode = 'RGB'
else:
mode = 'RGBA'

if mode not in ['RGB', 'RGBA', 'YCbCr', 'CMYK']:
raise ValueError(_errstr)

if mode in ['RGB', 'YCbCr']:
if numch != 3:
raise ValueError("Invalid array shape for mode.")
if mode in ['RGBA', 'CMYK']:
if numch != 4:
raise ValueError("Invalid array shape for mode.")

# Here we know data and mode is correct
image = Image.frombytes(mode, shape, strdata)
return image

0 comments on commit ae258d8

Please sign in to comment.