forked from naisy/realtime_object_detection
-
Notifications
You must be signed in to change notification settings - Fork 0
/
load_graph_deeplab_v3.py
82 lines (73 loc) · 2.64 KB
/
load_graph_deeplab_v3.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import tensorflow as tf
from tensorflow.core.framework import graph_pb2
import copy
class LoadFrozenGraph():
"""
LOAD FROZEN GRAPH
"""
def __init__(self, cfg):
self.cfg = cfg
return
def load_graph(self):
print('Building Graph')
return self.load_frozen_graph_without_split()
def print_graph(self, graph):
"""
PRINT GRAPH OPERATIONS
"""
print("{:-^32}".format(" operations in graph "))
for op in graph.get_operations():
print(op.name,op.outputs)
return
def print_graph_def(self, graph_def):
"""
PRINT GRAPHDEF NODE NAMES
"""
print("{:-^32}".format(" nodes in graph_def "))
for node in graph_def.node:
print(node)
return
def print_graph_operation_by_name(self, graph, name):
"""
PRINT GRAPH OPERATION DETAILS
"""
op = graph.get_operation_by_name(name=name)
print("{:-^32}".format(" operations in graph "))
print("{:-^32}\n{}".format(" op ", op))
print("{:-^32}\n{}".format(" op.name ", op.name))
print("{:-^32}\n{}".format(" op.outputs ", op.outputs))
print("{:-^32}\n{}".format(" op.inputs ", op.inputs))
print("{:-^32}\n{}".format(" op.device ", op.device))
print("{:-^32}\n{}".format(" op.graph ", op.graph))
print("{:-^32}\n{}".format(" op.values ", op.values()))
print("{:-^32}\n{}".format(" op.op_def ", op.op_def))
print("{:-^32}\n{}".format(" op.colocation_groups ", op.colocation_groups))
print("{:-^32}\n{}".format(" op.get_attr ", op.get_attr("T")))
i = 0
for output in op.outputs:
op_tensor = output
tensor_shape = op_tensor.get_shape().as_list()
print("{:-^32}\n{}".format(" outputs["+str(i)+"] shape ", tensor_shape))
i += 1
return
def load_frozen_graph_without_split(self):
"""
Load frozen_graph.
"""
model_path = self.cfg['model_path']
tf.reset_default_graph()
graph_def = tf.GraphDef()
with tf.gfile.GFile(model_path, 'rb') as fid:
serialized_graph = fid.read()
graph_def.ParseFromString(serialized_graph)
# force CPU device placement for NMS ops
for node in graph_def.node:
if 'BatchMultiClassNonMaxSuppression_1' in node.name:
node.device = '/device:CPU:0'
else:
node.device = '/device:GPU:0'
tf.import_graph_def(graph_def, name='')
"""
return
"""
return tf.get_default_graph()