-
Notifications
You must be signed in to change notification settings - Fork 4
/
demo.py
123 lines (101 loc) · 4.86 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import os
import torch
from collections import OrderedDict
from PIL import Image
import numpy as np
from matplotlib import pyplot as plt
from torch.nn.functional import upsample
import networks.deeplab_resnet as resnet
from mypath import Path
from dataloaders import helpers as helpers
modelName = 'dextr_pascal-sbd'
pad = 50
thres = 0.8
gpu_id = 0
device = torch.device("cuda:"+str(gpu_id) if torch.cuda.is_available() else "cpu")
#img2txt
CHAR_LIST = " ;:,^`'.*~oahkbdpqwmZO0QLCJUYXzcvunxrjft/\|()1{}[]?-_+\<>i!lI$@B%8&WM#"
num_chars = len(CHAR_LIST)
num_cols = 50
# Create the network and load the weights
net = resnet.resnet101(1, nInputChannels=4, classifier='psp')
print("Initializing weights from: {}".format(os.path.join(Path.models_dir(), modelName + '.pth')))
state_dict_checkpoint = torch.load(os.path.join(Path.models_dir(), modelName + '.pth'),
map_location=lambda storage, loc: storage)
# Remove the prefix .module from the model when it is trained using DataParallel
if 'module.' in list(state_dict_checkpoint.keys())[0]:
new_state_dict = OrderedDict()
for k, v in state_dict_checkpoint.items():
name = k[7:] # remove `module.` from multi-gpu training
new_state_dict[name] = v
else:
new_state_dict = state_dict_checkpoint
net.load_state_dict(new_state_dict)
net.eval()
net.to(device)
# Read image and click the points
image = np.array(Image.open('ims/dog-cat.jpg'))
plt.ion()
plt.axis('off')
plt.imshow(image)
plt.title('Click the four extreme points of the objects\nHit enter when done (do not close the window)')
results = []
idx = 0
with torch.no_grad():
while 1:
extreme_points_ori = np.array(plt.ginput(4, timeout=0)).astype(np.int)
# Crop image to the bounding box from the extreme points and resize
bbox = helpers.get_bbox(image, points=extreme_points_ori, pad=pad, zero_pad=True)
crop_image = helpers.crop_from_bbox(image, bbox, zero_pad=True)
resize_image = helpers.fixed_resize(crop_image, (512, 512)).astype(np.float32)
# Generate extreme point heat map normalized to image values
extreme_points = extreme_points_ori - [np.min(extreme_points_ori[:, 0]), np.min(extreme_points_ori[:, 1])] + [pad,
pad]
extreme_points = (512 * extreme_points * [1 / crop_image.shape[1], 1 / crop_image.shape[0]]).astype(np.int)
extreme_heatmap = helpers.make_gt(resize_image, extreme_points, sigma=10)
extreme_heatmap = helpers.cstm_normalize(extreme_heatmap, 255)
# Concatenate inputs and convert to tensor
input_dextr = np.concatenate((resize_image, extreme_heatmap[:, :, np.newaxis]), axis=2)
inputs = torch.from_numpy(input_dextr.transpose((2, 0, 1))[np.newaxis, ...])
# Run a forward pass
inputs = inputs.to(device)
outputs = net.forward(inputs)
outputs = upsample(outputs, size=(512, 512), mode='bilinear', align_corners=True)
outputs = outputs.to(torch.device('cpu'))
pred = np.transpose(outputs.data.numpy()[0, ...], (1, 2, 0))
pred = 1 / (1 + np.exp(-pred))
pred = np.squeeze(pred)
result = helpers.crop2fullmask(pred, bbox, im_size=image.shape[:2], zero_pad=True, relax=pad) > thres
#save cut image and image2txt file
#cut_mask = result.astype(int)
height, width,_ = image.shape
out_img = np.zeros((height, width,3),np.uint8)
out_img[result]=image[result]
#print(out_img)
out_gray=np.dot(out_img[...,:3], [0.299, 0.587, 0.114]).astype(np.uint8)
cell_width = width / num_cols
cell_height = 2 * cell_width
num_rows = int(height / cell_height)
if num_cols > width or num_rows > height:
print("Too many columns or rows. Use default setting")
cell_width = 6
cell_height = 12
num_cols = int(width / cell_width)
num_rows = int(height / cell_height)
idx = idx + 1
output_file = open('out%d.txt'%idx, 'w')
for i in range(num_rows):
for j in range(num_cols):
output_file.write(
CHAR_LIST[min(int(np.mean(out_gray[int(i * cell_height):min(int((i + 1) * cell_height), height),
int(j * cell_width):min(int((j + 1) * cell_width),
width)]) * num_chars / 255), num_chars - 1)])
output_file.write("\n")
output_file.close()
im = Image.fromarray(out_img)
#im = Image.fromarray(out_gray)
im.save("out%d.png"%idx)
results.append(result)
# Plot the results
plt.imshow(helpers.overlay_masks(image / 255, results))
plt.plot(extreme_points_ori[:, 0], extreme_points_ori[:, 1], 'gx')