Skip to content

Commit

Permalink
Add the paddle fluid config for mm_dnn
Browse files Browse the repository at this point in the history
  • Loading branch information
Yibing Liu committed Sep 14, 2018
1 parent 61d68e9 commit a14d526
Show file tree
Hide file tree
Showing 5 changed files with 208 additions and 10 deletions.
40 changes: 40 additions & 0 deletions tools/simnet/train/paddle/examples/mmdnn-pointwise.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
{
"net": {
"module_name": "mm_dnn",
"class_name": "MMDNN",
"embedding_dim": 128,
"num_filters": 256,
"lstm_dim": 128,
"hidden_size": 128,
"window_size_left": 3,
"window_size_right": 3,
"dpool_size_left": 2,
"dpool_size_right": 2
},
"loss": {
"module_name": "softmax_cross_entropy_loss",
"class_name": "SoftmaxCrossEntropyLoss"
},
"optimizer": {
"class_name": "AdamOptimizer",
"learning_rate": 0.001,
"beta1": 0.9,
"beta2": 0.999,
"epsilon": 1e-08
},
"use_cuda": 1,
"dict_size": 3,
"max_len_left": 32,
"max_len_right": 32,
"n_class": 2,
"task_mode": "pointwise",
"match_mask" : 1,
"train_file_path": "data/train_pointwise_data",
"test_file_path": "data/test_pointwise_data",
"result_file_path": "result_mm_dnn_pointwise",
"epoch_num": 1,
"model_path": "models/mm_dnn_pointwise",
"use_epoch": 0,
"batch_size": 64,
"num_threads": 6
}
17 changes: 17 additions & 0 deletions tools/simnet/train/paddle/layers/paddle_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,23 @@ def ops(self, input, label):
loss = fluid.layers.cross_entropy(input=input, label=label)
return loss

class SoftmaxWithCrossEntropyLayer(object):
"""
Softmax with Cross Entropy Calculate Layer
"""
def __init__(self, name="softmax_with_cross_entropy"):
"""
initialize
"""
pass

def ops(self, input, label):
"""
operation
"""
loss = fluid.layers.softmax_with_cross_entropy(logits=input, label=label)
return loss


class CosSimLayer(object):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

class SoftmaxCrossEntropyLoss(object):
"""
Softmax Cross Entropy Loss Calculate
Softmax with Cross Entropy Loss Calculate
"""
def __init__(self, conf_dict):
"""
Expand All @@ -31,7 +31,8 @@ def compute(self, input, label):
"""
compute loss
"""
cross_entropy = layers.CrossEntropyLayer()
softmax_with_cross_entropy = layers.SoftmaxWithCrossEntropyLayer()
reduce_mean = layers.ReduceMeanLayer()
loss = reduce_mean.ops(cross_entropy.ops(input, label))
return loss
cost = softmax_with_cross_entropy.ops(input, label)
avg_cost = reduce_mean.ops(cost)
return avg_cost
129 changes: 129 additions & 0 deletions tools/simnet/train/paddle/nets/mm_dnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import numpy as np
import paddle.fluid as fluid


class MMDNN(object):
def __init__(self, config):
self.vocab_size = int(config['dict_size'])
self.emb_size = int(config['net']['embedding_dim'])
self.lstm_dim = int(config['net']['lstm_dim'])
self.kernel_size = int(config['net']['num_filters'])
self.win_size1 = int(config['net']['window_size_left'])
self.win_size2 = int(config['net']['window_size_right'])
self.dpool_size1 = int(config['net']['dpool_size_left'])
self.dpool_size2 = int(config['net']['dpool_size_right'])
self.hidden_size = int(config['net']['hidden_size'])
self.seq_len1 = int(config['max_len_left'])
self.seq_len2 = int(config['max_len_right'])
self.task_mode = config['task_mode']

if 'match_mask' in config and config['match_mask'] != 0:
self.match_mask = True
else:
self.match_mask = False

if self.task_mode == "pointwise":
self.n_class = int(config['n_class'])
self.out_size = self.n_class
elif self.task_mode == "pairwise":
self.out_size = 1
else:
logging.error("training mode not supported")

def bi_dynamic_lstm(self, input, hidden_size):
fw_in_proj = fluid.layers.fc(input=input,
size=4 * hidden_size,
bias_attr=False)
forward, _ = fluid.layers.dynamic_lstm(
input=fw_in_proj, size=4 * hidden_size, is_reverse=False)

rv_in_proj = fluid.layers.fc(input=input,
size=4 * hidden_size,
bias_attr=False)
reverse, _ = fluid.layers.dynamic_lstm(
input=rv_in_proj, size=4 * hidden_size, is_reverse=True)
return [forward, reverse]

def conv_pool_relu_layer(self, input, mask=None):
# data format NCHW
emb_expanded = fluid.layers.unsqueeze(input=input, axes=[1])
# same padding
conv = fluid.layers.conv2d(
input=emb_expanded,
num_filters=self.kernel_size,
stride=1,
padding=16,
filter_size=[self.seq_len1, self.seq_len2],
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(0.1)))

if mask is not None:
cross_mask = fluid.layers.stack(x=[mask] * self.kernel_size, axis=1)
conv = cross_mask * conv + (1 - cross_mask) * (-2**32 + 1)
#valid padding
pool = fluid.layers.pool2d(
input=conv,
pool_size=[
self.seq_len1 / self.dpool_size1,
self.seq_len2 / self.dpool_size2
],
pool_stride=[
self.seq_len1 / self.dpool_size1,
self.seq_len2 / self.dpool_size2
],
pool_type="max", )

relu = fluid.layers.relu(pool)
return relu

def get_cross_mask(self, left_lens, right_lens):
mask1 = fluid.layers.sequence_mask(
x=left_lens, dtype='float32', maxlen=self.seq_len1 + 1)
mask2 = fluid.layers.sequence_mask(
x=right_lens, dtype='float32', maxlen=self.seq_len2 + 1)

mask1 = fluid.layers.transpose(x=mask1, perm=[0, 2, 1])
cross_mask = fluid.layers.matmul(x=mask1, y=mask2)
return cross_mask

def predict(self, left, right):
left_emb = fluid.layers.embedding(
input=left,
size=[self.vocab_size, self.emb_size],
is_sparse=True,
param_attr=fluid.ParamAttr(name="word_embedding"))
right_emb = fluid.layers.embedding(
input=right,
size=[self.vocab_size, self.emb_size],
is_sparse=True,
param_attr=fluid.ParamAttr(name="word_embedding"))

bi_left_outputs = self.bi_dynamic_lstm(
input=left_emb, hidden_size=self.lstm_dim)
left_seq_encoder = fluid.layers.concat(input=bi_left_outputs, axis=1)

bi_right_outputs = self.bi_dynamic_lstm(
input=right_emb, hidden_size=self.lstm_dim)
right_seq_encoder = fluid.layers.concat(input=bi_right_outputs, axis=1)

pad_value = fluid.layers.assign(input=np.array([0]).astype("float32"))
left_seq_encoder, left_lens = fluid.layers.sequence_pad(
x=left_seq_encoder, pad_value=pad_value, maxlen=self.seq_len1)
right_seq_encoder, right_lens = fluid.layers.sequence_pad(
x=right_seq_encoder, pad_value=pad_value, maxlen=self.seq_len2)

cross = fluid.layers.matmul(
left_seq_encoder, right_seq_encoder, transpose_y=True)
if self.match_mask:
cross_mask = self.get_cross_mask(left_lens, right_lens)
else:
cross_mask = None

conv_pool_relu = self.conv_pool_relu_layer(input=cross, mask=cross_mask)
relu_hid1 = fluid.layers.fc(input=conv_pool_relu,
act="tanh",
size=self.hidden_size)

pred = fluid.layers.fc(input=relu_hid1, size=self.out_size)

return left_seq_encoder, pred
23 changes: 17 additions & 6 deletions tools/simnet/train/paddle/paddle_simnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,13 @@ def train(conf_dict):
# Load Optimization method
optimizer = utils.import_class(
"optimizers", "paddle_optimizers", conf_dict["optimizer"]["class_name"])(conf_dict)

# Get service
place = fluid.core.CPUPlace()
if "use_cuda" in conf_dict and conf_dict["use_cuda"] == 1:
place = fluid.core.CUDAPlace(0)
else:
place = fluid.core.CPUPlace()

if conf_dict["task_mode"] == "pairwise":
# Build network
left = data.ops(name="left", shape=[1], dtype="int64", lod_level=1)
Expand All @@ -73,6 +77,7 @@ def train(conf_dict):
label = data.ops(name="label", shape=[1], dtype="int64", lod_level=0)
left_feat, pred = net.predict(left, right)
avg_cost = loss.compute(pred, label)
avg_cost.persistable = True
# Get Feeder and Reader
feeder = fluid.DataFeeder(place=place, feed_list=[
left.name, right.name, label.name])
Expand All @@ -87,7 +92,8 @@ def train(conf_dict):
executor.run(fluid.default_startup_program())
# Get and run executor
parallel_executor = fluid.ParallelExecutor(
use_cuda=False, loss_name=avg_cost.name,
use_cuda="use_cuda" in conf_dict and conf_dict["use_cuda"] == 1,
loss_name=avg_cost.name,
main_program=fluid.default_main_program())
# Get device number
device_count = parallel_executor.device_count
Expand All @@ -104,8 +110,9 @@ def train(conf_dict):
continue
avg_loss = parallel_executor.run(
[avg_cost.name], feed=feeder.feed(data))
print("epoch: %d, iter: %d, loss: %f" %
(epoch_id, iter, np.mean(avg_loss[0])))
if iter % 100 == 0:
print("epoch: %d, iter: %d, loss: %f" %
(epoch_id, iter, np.mean(avg_loss[0])))
losses.append(np.mean(avg_loss[0]))
end_time = time.time()
print("epoch: %d, loss: %f, used time: %d sec" %
Expand Down Expand Up @@ -134,7 +141,10 @@ def predict(conf_dict):
model_save_dir = conf_dict["model_path"]
model_path = os.path.join(model_save_dir, str(conf_dict["use_epoch"]))
# Get device
place = fluid.core.CPUPlace()
if "use_cuda" in conf_dict and conf_dict["use_cuda"] == 1:
place = fluid.core.CUDAPlace(0)
else:
place = fluid.core.CPUPlace()
# Get executor
executor = fluid.Executor(place=place)
# Load model
Expand Down Expand Up @@ -174,6 +184,7 @@ def predict(conf_dict):
"--conf_file_path", default="examples/cnn_pointwise.json", help="config file path")
args = parser.parse_args()
conf_dict = utils.parse_json(args.conf_file_path)
print(conf_dict)
if args.task_type == "train":
train(conf_dict)
else:
Expand Down

0 comments on commit a14d526

Please sign in to comment.