-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added custom toimage function and updated code
- Loading branch information
1 parent
1b7f656
commit ae258d8
Showing
2 changed files
with
131 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
import numpy as np | ||
from PIL import Image | ||
|
||
|
||
_errstr = "Mode is unknown or incompatible with input array shape." | ||
|
||
|
||
def bytescale(data, cmin=None, cmax=None, high=255, low=0): | ||
if data.dtype == np.uint8: | ||
return data | ||
|
||
if high > 255: | ||
raise ValueError("`high` should be less than or equal to 255.") | ||
if low < 0: | ||
raise ValueError("`low` should be greater than or equal to 0.") | ||
if high < low: | ||
raise ValueError("`high` should be greater than or equal to `low`.") | ||
|
||
if cmin is None: | ||
cmin = data.min() | ||
if cmax is None: | ||
cmax = data.max() | ||
|
||
cscale = cmax - cmin | ||
if cscale < 0: | ||
raise ValueError("`cmax` should be larger than `cmin`.") | ||
elif cscale == 0: | ||
cscale = 1 | ||
|
||
scale = float(high - low) / cscale | ||
bytedata = (data - cmin) * scale + low | ||
return (bytedata.clip(low, high) + 0.5).astype(np.uint8) | ||
|
||
|
||
def toimage(arr, high=255, low=0, cmin=None, cmax=None, pal=None, | ||
mode=None, channel_axis=None): | ||
data = np.asarray(arr) | ||
if np.iscomplexobj(data): | ||
raise ValueError("Cannot convert a complex-valued array.") | ||
shape = list(data.shape) | ||
valid = len(shape) == 2 or ((len(shape) == 3) and | ||
((3 in shape) or (4 in shape))) | ||
if not valid: | ||
raise ValueError("'arr' does not have a suitable array shape for " | ||
"any mode.") | ||
if len(shape) == 2: | ||
shape = (shape[1], shape[0]) # columns show up first | ||
if mode == 'F': | ||
data32 = data.astype(np.float32) | ||
image = Image.frombytes(mode, shape, data32.tostring()) | ||
return image | ||
if mode in [None, 'L', 'P']: | ||
bytedata = bytescale(data, high=high, low=low, | ||
cmin=cmin, cmax=cmax) | ||
image = Image.frombytes('L', shape, bytedata.tostring()) | ||
if pal is not None: | ||
image.putpalette(np.asarray(pal, dtype=np.uint8).tostring()) | ||
# Becomes a mode='P' automagically. | ||
elif mode == 'P': # default gray-scale | ||
pal = (np.arange(0, 256, 1, dtype=np.uint8)[:, np.newaxis] * | ||
np.ones((3,), dtype=np.uint8)[np.newaxis, :]) | ||
image.putpalette(np.asarray(pal, dtype=np.uint8).tostring()) | ||
return image | ||
if mode == '1': # high input gives threshold for 1 | ||
bytedata = (data > high) | ||
image = Image.frombytes('1', shape, bytedata.tostring()) | ||
return image | ||
if cmin is None: | ||
cmin = np.amin(np.ravel(data)) | ||
if cmax is None: | ||
cmax = np.amax(np.ravel(data)) | ||
data = (data*1.0 - cmin)*(high - low)/(cmax - cmin) + low | ||
if mode == 'I': | ||
data32 = data.astype(np.uint32) | ||
image = Image.frombytes(mode, shape, data32.tostring()) | ||
else: | ||
raise ValueError(_errstr) | ||
return image | ||
|
||
# if here then 3-d array with a 3 or a 4 in the shape length. | ||
# Check for 3 in datacube shape --- 'RGB' or 'YCbCr' | ||
if channel_axis is None: | ||
if (3 in shape): | ||
ca = np.flatnonzero(np.asarray(shape) == 3)[0] | ||
else: | ||
ca = np.flatnonzero(np.asarray(shape) == 4) | ||
if len(ca): | ||
ca = ca[0] | ||
else: | ||
raise ValueError("Could not find channel dimension.") | ||
else: | ||
ca = channel_axis | ||
|
||
numch = shape[ca] | ||
if numch not in [3, 4]: | ||
raise ValueError("Channel axis dimension is not valid.") | ||
|
||
bytedata = bytescale(data, high=high, low=low, cmin=cmin, cmax=cmax) | ||
if ca == 2: | ||
strdata = bytedata.tostring() | ||
shape = (shape[1], shape[0]) | ||
elif ca == 1: | ||
strdata = np.transpose(bytedata, (0, 2, 1)).tostring() | ||
shape = (shape[2], shape[0]) | ||
elif ca == 0: | ||
strdata = np.transpose(bytedata, (1, 2, 0)).tostring() | ||
shape = (shape[2], shape[1]) | ||
if mode is None: | ||
if numch == 3: | ||
mode = 'RGB' | ||
else: | ||
mode = 'RGBA' | ||
|
||
if mode not in ['RGB', 'RGBA', 'YCbCr', 'CMYK']: | ||
raise ValueError(_errstr) | ||
|
||
if mode in ['RGB', 'YCbCr']: | ||
if numch != 3: | ||
raise ValueError("Invalid array shape for mode.") | ||
if mode in ['RGBA', 'CMYK']: | ||
if numch != 4: | ||
raise ValueError("Invalid array shape for mode.") | ||
|
||
# Here we know data and mode is correct | ||
image = Image.frombytes(mode, shape, strdata) | ||
return image |