Skip to content

Commit

Permalink
add placeholder for simnet-tf-freeze model
Browse files Browse the repository at this point in the history
  • Loading branch information
yinweichong committed Jul 10, 2019
1 parent c7fe5ec commit b0edc58
Showing 1 changed file with 24 additions and 8 deletions.
32 changes: 24 additions & 8 deletions tools/simnet/train/tf/tf_simnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def train(conf_dict):
datafeed = datafeeds.TFPairwisePaddingData(conf_dict)
input_l, input_r, neg_input = datafeed.ops()
pos_score = net.predict(input_l, input_r)
output_prob = tf.identity(pos_score, name="output_preb")
output_prob = tf.identity(pos_score, name="output_prob")
neg_score = net.predict(input_l, neg_input)
loss_layer = utility.import_object(
conf_dict["loss_py"], conf_dict["loss_class"])(conf_dict)
Expand Down Expand Up @@ -105,13 +105,29 @@ def freeze(conf_dict):
"""
model_path = conf_dict["save_path"]
freeze_path = conf_dict["freeze_path"]
saver = tf.train.import_meta_graph(model_path + '.meta')
with tf.Session() as sess:
saver.restore(sess, model_path)
var_graph_def = tf.get_default_graph().as_graph_def()
const_graph_def = graph_util.convert_variables_to_constants(sess, var_graph_def, ["output_prob"])
with tf.gfile.GFile(freeze_path, "wb") as f:
f.write(const_graph_def.SerializeToString())
training_mode = conf_dict["training_mode"]

graph = tf.Graph()
with graph.as_default():
net = utility.import_object(
conf_dict["net_py"], conf_dict["net_class"])(conf_dict)
test_l = dict([(u, tf.placeholder(tf.int32, [None, v], name=u))
for (u, v) in dict(conf_dict["left_slots"]).iteritems()])
test_r = dict([(u, tf.placeholder(tf.int32, [None, v], name=u))
for (u, v) in dict(conf_dict["right_slots"]).iteritems()])
pred = net.predict(test_l, test_r)
if training_mode == "pointwise":
output_prob = tf.nn.softmax(pred, -1, name="output_prob")
elif training_mode == "pairwise":
output_prob = tf.identity(pred, name="output_prob")

restore_saver = tf.train.Saver()
with tf.Session(graph=graph) as sess:
sess.run(tf.global_variables_initializer())
restore_saver.restore(sess, model_path)
output_graph_def = tf.graph_util.\
convert_variables_to_constants(sess, sess.graph_def, ["output_prob"])
tf.train.write_graph(output_graph_def, '.', freeze_path, as_text=False)

def convert(conf_dict):
"""
Expand Down

0 comments on commit b0edc58

Please sign in to comment.