Skip to content

Commit

Permalink
Update plot_figure.py
Browse files Browse the repository at this point in the history
  • Loading branch information
KaijieMo1 committed Dec 21, 2020
1 parent 00d6a2c commit 666432a
Showing 1 changed file with 14 additions and 13 deletions.
27 changes: 14 additions & 13 deletions plot/plot_figure.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,6 @@ def plot_mosaic(config, mask, slice_dim=2, colormap=None, vspace=2, hspace=2, co
origin_image= origin_image[0: shape_[0],0: shape_[1],0:shape_[2]]
mask = mask[0: shape_[0], 0: shape_[1], 0:shape_[2]]



if slice_dim == 2:
h, w, slices = mask_shape[0], mask_shape[1], mask_shape[2]
elif slice_dim == 1:
Expand All @@ -166,12 +164,14 @@ def plot_mosaic(config, mask, slice_dim=2, colormap=None, vspace=2, hspace=2, co
indice_ = tuple(indice)

color_image = colormap[mask[indice_].astype(int) % num_category].astype('uint8')

if rotate_k:
color_image= np.rot90(color_image, k=rotate_k)

if flip_axis:
color_image=np.flip(color_image,axis=flip_axis)


im = Image.fromarray(color_image)
im = im.convert("RGBA")

Expand All @@ -185,18 +185,17 @@ def plot_mosaic(config, mask, slice_dim=2, colormap=None, vspace=2, hspace=2, co
if flip_axis:
origin_image_slice = np.flip(origin_image_slice, axis=flip_axis)

origin_image_slice = Image.fromarray(origin_image_slice).convert("RGBA")


origin_image_slice = Image.fromarray(origin_image_slice).convert("RGBA")

im = Image.blend(im, origin_image_slice, alpha=alpha_origin)

# Draw slice index on the left top of the image
# Draw slice index on the left top of the image
draw = ImageDraw.Draw(im)
# In Linux, "arial.ttf" should be changed
#font = ImageFont.truetype("arial.ttf", 14)
#draw.text((7, 7), str(slice_index+1), font=font)
#figure.paste(im, (col_index * (w + vspace), row_index * (h + hspace)))
# In Linux, "arial.ttf" should be changed
font = ImageFont.truetype("arial.ttf", 14)
draw.text((7, 7), str(slice_index+1), font=font)
figure.paste(im, (col_index * (w + vspace), row_index * (h + hspace)))

dir_figures = config['result_rootdir'] + '/' + config['model'] + '/figures/plot_mosaic/' + dataset + '/' + name_ID
if client_save_rootdir is not None:
Expand Down Expand Up @@ -537,7 +536,6 @@ def plot_combine(config,image, heatmap, alpha=0.6, display=False, save_path=None

plt.rcParams["font.family"] = "Times New Roman"
if image.shape!=heatmap.shape:
print(image.shape,heatmap.shape)
image = cv.resize(image, dsize=heatmap.shape[::-1], interpolation=cv.INTER_CUBIC)

aspect = 0.1
Expand All @@ -550,9 +548,8 @@ def plot_combine(config,image, heatmap, alpha=0.6, display=False, save_path=None
image = np.rot90(image, k=1, axes=(1, 0))
heatmap = np.rot90(heatmap, k=1, axes=(1, 0))


# Discrete color scheme
cMap = ListedColormap([[0.1,0.2,0.3], [0.1,0.3,0.4],[0.3,0.3,0.3], [0.6,0.4,0.4],[0.4,0.5,0.8], [0.7,0.2,0.5],])
cMap = ListedColormap(config['colormap'])
# Display
fig, ax = plt.subplots()
image = ax.pcolormesh(image)
Expand All @@ -561,7 +558,11 @@ def plot_combine(config,image, heatmap, alpha=0.6, display=False, save_path=None
cbar = plt.colorbar(heatmap)

landmark_dict = {0: 'Wrists', 1: 'Shoulders', 2: 'Liver_dome', 3: 'Hips', 4: 'Heels', 5: 'Below'}
cbar.ax.set_ylabel('Class',fontsize=20)
if 'plot_fontsize'in config:
font_size=config['plot_fontsize']
else:
font_size=20
cbar.ax.set_ylabel('Class',fontsize=font_size)
cbar.ax.get_yaxis().set_ticks([])
for j, lab in enumerate(landmark_dict):
#cbar.ax.text(1.7, (2 * j +2) / 2.4-0.4, landmark_dict[j],rotation=90, ha='center', va='center',fontsize=18) #rotation=90 ,
Expand Down

0 comments on commit 666432a

Please sign in to comment.