From d2e5395b97a08efd1fbc826704c983b9391c7661 Mon Sep 17 00:00:00 2001 From: chenweihang Date: Fri, 17 Aug 2018 03:18:27 +0000 Subject: [PATCH 1/6] feat: add sequence enumerate op --- .../fluid/operators/sequence_enumerate_op.cc | 101 ++++++++++++++++++ .../fluid/operators/sequence_enumerate_op.cu | 75 +++++++++++++ .../fluid/operators/sequence_enumerate_op.h | 56 ++++++++++ .../unittests/test_sequence_enumerate_op.py | 79 ++++++++++++++ 4 files changed, 311 insertions(+) create mode 100644 paddle/fluid/operators/sequence_enumerate_op.cc create mode 100644 paddle/fluid/operators/sequence_enumerate_op.cu create mode 100644 paddle/fluid/operators/sequence_enumerate_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_sequence_enumerate_op.py diff --git a/paddle/fluid/operators/sequence_enumerate_op.cc b/paddle/fluid/operators/sequence_enumerate_op.cc new file mode 100644 index 0000000000000..0d9fdf7d5cab1 --- /dev/null +++ b/paddle/fluid/operators/sequence_enumerate_op.cc @@ -0,0 +1,101 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/sequence_enumerate_op.h" + +namespace paddle { +namespace operators { + +class SequenceEnumerateOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE( + ctx->HasInput("X"), + "Input(X) of SequecceEnumerate operator should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("Out"), + "Output(X) of SequenceEnumerate operator should not be null."); + + const auto x_dims = ctx->GetInputDim("X"); + PADDLE_ENFORCE( + x_dims.size() == 2 && x_dims[1] == 1, + "Input(X) of SequenceEnumerate operator should be a 2-D LoDTensor " + "with the 2nd dimension equal to 1."); + + const auto win_size = ctx->Attrs().Get("win_size"); + PADDLE_ENFORCE(win_size <= x_dims[0], + "The enumerate window size should be less than or equal to " + "input sequence length."); + ctx->SetOutputDim("Out", {x_dims[0], win_size}); + ctx->ShareLoD("X", "Out"); + } +}; + +class SequenceEnumerateOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "(2-D LoDTensor with the 2nd dimension equal to 1) " + "Input LoDTensor of SequenceEnumerate operator."); + AddOutput("Out", + "(2-D LoDTensor with the 2nd dimension equal to 1) " + "Output LoDTensor of SequenceEnumerate operator."); + AddAttr("win_size", "(int) The enumerate sequence window size.") + .AddCustomChecker([](const int& win_size) { + PADDLE_ENFORCE(win_size >= 2, + "The window size should be greater than 2."); + }); + AddAttr("pad_value", "(int) The enumerate sequence padding value.") + .SetDefault(0); + AddComment(R"DOC( +Sequence Enumerate Operator. + +Sequence enumerate operator generate a new LoDTensor +with the same 1st dimension length as the original LoDTensor, +and with the 2nd dimension equal to the input window length, +the new sub-sequence on 2nd dimension is enumerated one by one on the original sequence. +The values of the last insufficient part areall filled with the input pad_value. + +Examples: +Case 1: + Input: + X.lod = [[0, 3, 5]] + X.data = [1, 2, 3, 4, 5] + X.dims = [5, 1] + Attrs: + win_size = 2 + pad_value = 0 + Output: + Out.lod = [[0, 3, 5]] + Out.data = [[1, 2], [2, 3], [3, 4], [4, 5], [0, 0]] + Out.dims = [5, 2] + + Currently, only 1-level LoDTensor is supported. + +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(sequence_enumerate, ops::SequenceEnumerateOp, + ops::SequenceEnumerateOpMaker); +REGISTER_OP_CPU_KERNEL( + sequence_enumerate, + ops::SequenceEnumerateKernel, + ops::SequenceEnumerateKernel); diff --git a/paddle/fluid/operators/sequence_enumerate_op.cu b/paddle/fluid/operators/sequence_enumerate_op.cu new file mode 100644 index 0000000000000..2e2356e7eca30 --- /dev/null +++ b/paddle/fluid/operators/sequence_enumerate_op.cu @@ -0,0 +1,75 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "paddle/fluid/operators/sequence_enumerate_op.h" +#include "paddle/fluid/platform/cuda_primitives.h" + +namespace paddle { +namespace operators { +using platform::PADDLE_CUDA_NUM_THREADS; +using LoDTensor = framework::LoDTensor; + +template +__global__ void CalcOutPut(const T* in_data, const int64_t in_len, + const int64_t win_size, const int64_t pad_value, + T* out_data) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < in_len) { + for (size_t i = 0; i < win_size; ++i) { + int word_pos = index + i; + out_data[index * win_size + i] = + word_pos < in_len ? in_data[word_pos] : pad_value; + } + } +} + +template +class SequenceEnumerateOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* in = context.Input("X"); + auto* out = context.Output("Out"); + int win_size = context.Attr("win_size"); + int pad_value = context.Attr("pad_value"); + + auto in_dims = in->dims(); + auto in_lod = in->lod(); + + PADDLE_ENFORCE_EQ(in_lod.size(), 1UL, + "Only support one level sequence now."); + PADDLE_ENFORCE_EQ( + static_cast(in_dims[0]), in_lod[0].back(), + "The actual input data's size mismatched with LoD information."); + + /* Generate enumerate sequence set */ + auto stream = context.cuda_device_context().stream(); + auto in_len = in->numel(); + auto in_data = in->data(); + auto out_data = out->mutable_data(context.GetPlace()); + // Calc output tensor + CalcOutPut<<<(in_len - 1) / PADDLE_CUDA_NUM_THREADS + 1, + PADDLE_CUDA_NUM_THREADS, 0, stream>>>( + in_data, in_len, win_size, pad_value, out_data); + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OP_CUDA_KERNEL( + sequence_enumerate, + paddle::operators::SequenceEnumerateOpCUDAKernel, + paddle::operators::SequenceEnumerateOpCUDAKernel); diff --git a/paddle/fluid/operators/sequence_enumerate_op.h b/paddle/fluid/operators/sequence_enumerate_op.h new file mode 100644 index 0000000000000..8e9549508e971 --- /dev/null +++ b/paddle/fluid/operators/sequence_enumerate_op.h @@ -0,0 +1,56 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { +using LoDTensor = framework::LoDTensor; + +template +class SequenceEnumerateKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* in = context.Input("X"); + auto* out = context.Output("Out"); + int win_size = context.Attr("win_size"); + int pad_value = context.Attr("pad_value"); + + auto in_dims = in->dims(); + auto in_lod = in->lod(); + + PADDLE_ENFORCE_EQ(in_lod.size(), 1UL, + "Only support one level sequence now."); + PADDLE_ENFORCE_EQ( + static_cast(in_dims[0]), in_lod[0].back(), + "The actual input data's size mismatched with LoD information."); + + // Generate enumerate sequence set + auto seq_length = in_dims[0]; + auto in_data = in->data(); + auto out_data = out->mutable_data(context.GetPlace()); + for (int idx = 0; idx < seq_length; ++idx) { + for (int word_idx = 0; word_idx < win_size; ++word_idx) { + int word_pos = idx + word_idx; + out_data[win_size * idx + word_idx] = + word_pos < seq_length ? in_data[word_pos] : pad_value; + } + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_sequence_enumerate_op.py b/python/paddle/fluid/tests/unittests/test_sequence_enumerate_op.py new file mode 100644 index 0000000000000..18d91728fb950 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_sequence_enumerate_op.py @@ -0,0 +1,79 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest + + +def sequence_enumerate(input_seq, lod0, win_size, pad_value): + out_seq = [] + for idx in range(0, len(input_seq)): + single_seq = [] + for word_idx in range(win_size): + word_pos = idx + word_idx + dat = input_seq[word_pos] if word_pos < len(input_seq) \ + else pad_value + single_seq.append(dat) + out_seq.append(single_seq) + return out_seq + + +class TestSequenceEnumerateOp(OpTest): + def setUp(self): + self.op_type = "sequence_enumerate" + self.init_test_case() + self.inputs = {'X': (self.in_seq, self.lod)} + self.attrs = {'win_size': self.win_size, 'pad_value': self.pad_value} + self.outputs = {'Out': (self.out_seq, self.lod)} + + def test_check_output(self): + self.check_output() + + def init_test_case(self): + self.in_seq = np.random.randint(0, 10, (30, 1)).astype("int32") + self.lod = [[9, 4, 11, 6]] + self.win_size = 2 + self.pad_value = 0 + out_seq = sequence_enumerate(self.in_seq, self.lod[0], self.win_size, + self.pad_value) + self.out_seq = np.array(out_seq).astype("int32") + + +class TesSequenceEnumerateOpInt64(TestSequenceEnumerateOp): + def init_test_case(self): + self.in_seq = np.random.randint(0, 10, (30, 1)).astype("int64") + self.lod = [[9, 4, 11, 6]] + self.win_size = 2 + self.pad_value = 0 + out_seq = sequence_enumerate(self.in_seq, self.lod[0], self.win_size, + self.pad_value) + self.out_seq = np.array(out_seq).astype("int64") + + +class TestSequenceEnumerateOpMaxWinSize(TestSequenceEnumerateOp): + def init_test_case(self): + self.in_seq = np.random.randint(0, 10, (30, 1)).astype("int32") + self.lod = [[9, 4, 11, 6]] + self.win_size = 30 + self.pad_value = 0 + out_seq = sequence_enumerate(self.in_seq, self.lod[0], self.win_size, + self.pad_value) + self.out_seq = np.array(out_seq).astype("int32") + + +if __name__ == "__main__": + unittest.main() From 219a2369da4c80318e75020acdcafb2971398143 Mon Sep 17 00:00:00 2001 From: chenweihang Date: Thu, 23 Aug 2018 16:09:45 +0000 Subject: [PATCH 2/6] feat: wrap sequence enumerate op --- paddle/fluid/API.spec | 1 + .../fluid/operators/sequence_enumerate_op.cc | 20 +-- .../fluid/operators/sequence_enumerate_op.cu | 2 - .../fluid/operators/sequence_enumerate_op.h | 2 - python/paddle/fluid/layers/nn.py | 136 ++++++++---------- .../fluid/tests/unittests/test_layers.py | 9 ++ 6 files changed, 85 insertions(+), 85 deletions(-) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index e963902a50200..c2a08d2e53c3f 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -161,6 +161,7 @@ paddle.fluid.layers.crop ArgSpec(args=['x', 'shape', 'offsets', 'name'], varargs paddle.fluid.layers.rank_loss ArgSpec(args=['label', 'left', 'right', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.prelu ArgSpec(args=['x', 'mode', 'param_attr', 'name'], varargs=None, keywords=None, defaults=(None, None)) paddle.fluid.layers.flatten ArgSpec(args=['x', 'axis', 'name'], varargs=None, keywords=None, defaults=(1, None)) +paddle.fluid.layers.sequence_enumerate ArgSpec(args=['input', 'win_size', 'pad_value', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)) paddle.fluid.layers.open_recordio_file ArgSpec(args=['filename', 'shapes', 'lod_levels', 'dtypes', 'pass_num', 'for_parallel'], varargs=None, keywords=None, defaults=(1, True)) paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None)) diff --git a/paddle/fluid/operators/sequence_enumerate_op.cc b/paddle/fluid/operators/sequence_enumerate_op.cc index 0d9fdf7d5cab1..cacbb0977714b 100644 --- a/paddle/fluid/operators/sequence_enumerate_op.cc +++ b/paddle/fluid/operators/sequence_enumerate_op.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/operators/sequence_enumerate_op.h" +#include namespace paddle { namespace operators { @@ -30,16 +31,21 @@ class SequenceEnumerateOp : public framework::OperatorWithKernel { "Output(X) of SequenceEnumerate operator should not be null."); const auto x_dims = ctx->GetInputDim("X"); - PADDLE_ENFORCE( - x_dims.size() == 2 && x_dims[1] == 1, - "Input(X) of SequenceEnumerate operator should be a 2-D LoDTensor " - "with the 2nd dimension equal to 1."); + PADDLE_ENFORCE_EQ( + x_dims.size(), 2UL, + "Input(X) of SequenceEnumerate operator's rank should be 2."); const auto win_size = ctx->Attrs().Get("win_size"); - PADDLE_ENFORCE(win_size <= x_dims[0], + // TODO(chenweihang): unittest doesn't has batch size, but test_layers has + auto first_dim = x_dims[0] == -1 ? x_dims[1] : x_dims[0]; + PADDLE_ENFORCE(win_size <= first_dim, "The enumerate window size should be less than or equal to " "input sequence length."); - ctx->SetOutputDim("Out", {x_dims[0], win_size}); + + std::vector out_shape(x_dims.size() + 1, 0); + for (int i = 0; i < x_dims.size(); ++i) out_shape.emplace_back(x_dims[i]); + out_shape.emplace_back(win_size); + ctx->SetOutputDim("Out", framework::make_ddim(out_shape)); ctx->ShareLoD("X", "Out"); } }; @@ -83,8 +89,6 @@ Case 1: Out.data = [[1, 2], [2, 3], [3, 4], [4, 5], [0, 0]] Out.dims = [5, 2] - Currently, only 1-level LoDTensor is supported. - )DOC"); } }; diff --git a/paddle/fluid/operators/sequence_enumerate_op.cu b/paddle/fluid/operators/sequence_enumerate_op.cu index 2e2356e7eca30..e680174a2cf08 100644 --- a/paddle/fluid/operators/sequence_enumerate_op.cu +++ b/paddle/fluid/operators/sequence_enumerate_op.cu @@ -48,8 +48,6 @@ class SequenceEnumerateOpCUDAKernel : public framework::OpKernel { auto in_dims = in->dims(); auto in_lod = in->lod(); - PADDLE_ENFORCE_EQ(in_lod.size(), 1UL, - "Only support one level sequence now."); PADDLE_ENFORCE_EQ( static_cast(in_dims[0]), in_lod[0].back(), "The actual input data's size mismatched with LoD information."); diff --git a/paddle/fluid/operators/sequence_enumerate_op.h b/paddle/fluid/operators/sequence_enumerate_op.h index 8e9549508e971..8a30003b164b1 100644 --- a/paddle/fluid/operators/sequence_enumerate_op.h +++ b/paddle/fluid/operators/sequence_enumerate_op.h @@ -32,8 +32,6 @@ class SequenceEnumerateKernel : public framework::OpKernel { auto in_dims = in->dims(); auto in_lod = in->lod(); - PADDLE_ENFORCE_EQ(in_lod.size(), 1UL, - "Only support one level sequence now."); PADDLE_ENFORCE_EQ( static_cast(in_dims[0]), in_lod[0].back(), "The actual input data's size mismatched with LoD information."); diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index bd2b950cffe64..d4efb682d95d8 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -29,79 +29,22 @@ from functools import reduce __all__ = [ - 'fc', - 'embedding', - 'dynamic_lstm', - 'dynamic_lstmp', - 'dynamic_gru', - 'gru_unit', - 'linear_chain_crf', - 'crf_decoding', - 'cos_sim', - 'cross_entropy', - 'square_error_cost', - 'chunk_eval', - 'sequence_conv', - 'conv2d', - 'conv3d', - 'sequence_pool', - 'sequence_softmax', - 'softmax', - 'pool2d', - 'pool3d', - 'batch_norm', - 'beam_search_decode', - 'conv2d_transpose', - 'conv3d_transpose', - 'sequence_expand', - 'lstm_unit', - 'reduce_sum', - 'reduce_mean', - 'reduce_max', - 'reduce_min', - 'reduce_prod', - 'sequence_first_step', - 'sequence_last_step', - 'dropout', - 'split', - 'ctc_greedy_decoder', - 'edit_distance', - 'l2_normalize', - 'matmul', - 'topk', - 'warpctc', - 'sequence_reshape', - 'transpose', - 'im2sequence', - 'nce', - 'hsigmoid', - 'beam_search', - 'row_conv', - 'multiplex', - 'layer_norm', - 'softmax_with_cross_entropy', - 'smooth_l1', - 'one_hot', - 'autoincreased_step_counter', - 'reshape', - 'lod_reset', - 'lrn', - 'pad', - 'label_smooth', - 'roi_pool', - 'dice_loss', - 'image_resize', - 'image_resize_short', - 'resize_bilinear', - 'gather', - 'random_crop', - 'mean_iou', - 'relu', - 'log', - 'crop', - 'rank_loss', - 'prelu', - 'flatten', + 'fc', 'embedding', 'dynamic_lstm', 'dynamic_lstmp', 'dynamic_gru', + 'gru_unit', 'linear_chain_crf', 'crf_decoding', 'cos_sim', 'cross_entropy', + 'square_error_cost', 'chunk_eval', 'sequence_conv', 'conv2d', 'conv3d', + 'sequence_pool', 'sequence_softmax', 'softmax', 'pool2d', 'pool3d', + 'batch_norm', 'beam_search_decode', 'conv2d_transpose', 'conv3d_transpose', + 'sequence_expand', 'lstm_unit', 'reduce_sum', 'reduce_mean', 'reduce_max', + 'reduce_min', 'reduce_prod', 'sequence_first_step', 'sequence_last_step', + 'dropout', 'split', 'ctc_greedy_decoder', 'edit_distance', 'l2_normalize', + 'matmul', 'topk', 'warpctc', 'sequence_reshape', 'transpose', 'im2sequence', + 'nce', 'hsigmoid', 'beam_search', 'row_conv', 'multiplex', 'layer_norm', + 'softmax_with_cross_entropy', 'smooth_l1', 'one_hot', + 'autoincreased_step_counter', 'reshape', 'lod_reset', 'lrn', 'pad', + 'label_smooth', 'roi_pool', 'dice_loss', 'image_resize', + 'image_resize_short', 'resize_bilinear', 'gather', 'random_crop', + 'mean_iou', 'relu', 'log', 'crop', 'rank_loss', 'prelu', 'flatten', + 'sequence_enumerate' ] @@ -5475,3 +5418,50 @@ def flatten(x, axis=1, name=None): outputs={'Out': out}, attrs={"axis": axis}) return out + + +def sequence_enumerate(input, win_size, pad_value, name=None): + """ + Generate a new LoDTensor + with the same 1st dimension length as the original LoDTensor, + and with the 2nd dimension equal to the input window length, + the new sub-sequence on 2nd dimension is enumerated one by one on the original sequence. + The values of the last insufficient part areall filled with the input pad_value. + + Examples: + Case 1: + Input: + X.lod = [[0, 3, 5]] + X.data = [1, 2, 3, 4, 5] + X.dims = [5, 1] + Attrs: + win_size = 2 + pad_value = 0 + Output: + Out.lod = [[0, 3, 5]] + Out.data = [[1, 2], [2, 3], [3, 4], [4, 5], [0, 0]] + Out.dims = [5, 2] + + Args: + input (Variable): The input variable which is a LoDTensor + win_size (int): The enumerate sequence window size. + pad_value (int): The enumerate sequence padding value. + + Returns: + Variable: The enumerate sequence variable which is a LoDTensor. + + Examples: + .. code-block:: python + + x = fluid.layers.data(shape[30, 1], dtype='int32', lod_level=1) + out = fluid.layers.sequence_enumerate(input=x, win_size=3, pad_value=0) + """ + helper = LayerHelper('sequence_enumerate', **locals()) + out = helper.create_tmp_variable(helper.input_dtype()) + helper.append_op( + type='sequence_enumerate', + inputs={'X': input}, + outputs={'Out': out}, + attrs={'win_size': win_size, + 'pad_value': pad_value}) + return out diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index e833a7db482db..4994e11d1fb4f 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -500,6 +500,15 @@ def test_prelu(self): self.assertIsNotNone(out) print(str(program)) + def test_sequence_enumerate(self): + program = Program() + with program_guard(program): + x = layers.data( + name="input", shape=[30], dtype='int32', lod_level=1) + out = layers.sequence_enumerate(input=x, win_size=2, pad_value=0) + self.assertIsNotNone(out) + print(str(program)) + if __name__ == '__main__': unittest.main() From 733ea0d29bac96924e62c5714fa57ce07e2ff220 Mon Sep 17 00:00:00 2001 From: chenweihang Date: Fri, 24 Aug 2018 05:51:56 +0000 Subject: [PATCH 3/6] adjust infershape details --- paddle/fluid/operators/sequence_enumerate_op.cc | 15 ++++----------- python/paddle/fluid/layers/nn.py | 4 ++-- .../paddle/fluid/tests/unittests/test_layers.py | 4 +--- .../tests/unittests/test_sequence_enumerate_op.py | 11 ++++------- 4 files changed, 11 insertions(+), 23 deletions(-) diff --git a/paddle/fluid/operators/sequence_enumerate_op.cc b/paddle/fluid/operators/sequence_enumerate_op.cc index cacbb0977714b..b8c8daf3f394b 100644 --- a/paddle/fluid/operators/sequence_enumerate_op.cc +++ b/paddle/fluid/operators/sequence_enumerate_op.cc @@ -13,7 +13,6 @@ // limitations under the License. #include "paddle/fluid/operators/sequence_enumerate_op.h" -#include namespace paddle { namespace operators { @@ -34,18 +33,12 @@ class SequenceEnumerateOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ( x_dims.size(), 2UL, "Input(X) of SequenceEnumerate operator's rank should be 2."); + PADDLE_ENFORCE_EQ( + x_dims[1], 1UL, + "Input(X) of SequenceEnumerate operator's 2nd dimension should be 1."); const auto win_size = ctx->Attrs().Get("win_size"); - // TODO(chenweihang): unittest doesn't has batch size, but test_layers has - auto first_dim = x_dims[0] == -1 ? x_dims[1] : x_dims[0]; - PADDLE_ENFORCE(win_size <= first_dim, - "The enumerate window size should be less than or equal to " - "input sequence length."); - - std::vector out_shape(x_dims.size() + 1, 0); - for (int i = 0; i < x_dims.size(); ++i) out_shape.emplace_back(x_dims[i]); - out_shape.emplace_back(win_size); - ctx->SetOutputDim("Out", framework::make_ddim(out_shape)); + ctx->SetOutputDim("Out", {x_dims[0], win_size}); ctx->ShareLoD("X", "Out"); } }; diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index c8e4c99f9eaf4..9411256c74fe4 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -5563,7 +5563,7 @@ def sequence_enumerate(input, win_size, pad_value, name=None): out = fluid.layers.sequence_enumerate(input=x, win_size=3, pad_value=0) """ helper = LayerHelper('sequence_enumerate', **locals()) - out = helper.create_tmp_variable(helper.input_dtype()) + out = helper.create_tmp_variable(helper.input_dtype(), stop_gradient=True) helper.append_op( type='sequence_enumerate', inputs={'X': input}, @@ -5571,7 +5571,7 @@ def sequence_enumerate(input, win_size, pad_value, name=None): attrs={'win_size': win_size, 'pad_value': pad_value}) - + def stack(x, axis=0): helper = LayerHelper('stack', **locals()) axis = 0 if axis is None else axis diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index c45ccee4bd16e..351bcf790bf15 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -522,10 +522,8 @@ def test_prelu(self): def test_sequence_enumerate(self): program = Program() with program_guard(program): - x = layers.data( - name="input", shape=[30], dtype='int32', lod_level=1) + x = layers.data(name="input", shape=[1], dtype='int32', lod_level=1) out = layers.sequence_enumerate(input=x, win_size=2, pad_value=0) - self.assertIsNotNone(out) print(str(program)) diff --git a/python/paddle/fluid/tests/unittests/test_sequence_enumerate_op.py b/python/paddle/fluid/tests/unittests/test_sequence_enumerate_op.py index 18d91728fb950..41624da51266a 100644 --- a/python/paddle/fluid/tests/unittests/test_sequence_enumerate_op.py +++ b/python/paddle/fluid/tests/unittests/test_sequence_enumerate_op.py @@ -19,7 +19,7 @@ from op_test import OpTest -def sequence_enumerate(input_seq, lod0, win_size, pad_value): +def sequence_enumerate(input_seq, win_size, pad_value): out_seq = [] for idx in range(0, len(input_seq)): single_seq = [] @@ -48,8 +48,7 @@ def init_test_case(self): self.lod = [[9, 4, 11, 6]] self.win_size = 2 self.pad_value = 0 - out_seq = sequence_enumerate(self.in_seq, self.lod[0], self.win_size, - self.pad_value) + out_seq = sequence_enumerate(self.in_seq, self.win_size, self.pad_value) self.out_seq = np.array(out_seq).astype("int32") @@ -59,8 +58,7 @@ def init_test_case(self): self.lod = [[9, 4, 11, 6]] self.win_size = 2 self.pad_value = 0 - out_seq = sequence_enumerate(self.in_seq, self.lod[0], self.win_size, - self.pad_value) + out_seq = sequence_enumerate(self.in_seq, self.win_size, self.pad_value) self.out_seq = np.array(out_seq).astype("int64") @@ -70,8 +68,7 @@ def init_test_case(self): self.lod = [[9, 4, 11, 6]] self.win_size = 30 self.pad_value = 0 - out_seq = sequence_enumerate(self.in_seq, self.lod[0], self.win_size, - self.pad_value) + out_seq = sequence_enumerate(self.in_seq, self.win_size, self.pad_value) self.out_seq = np.array(out_seq).astype("int32") From 0c4697f8cdf7372a14c426901be828db9c42212f Mon Sep 17 00:00:00 2001 From: chenweihang Date: Mon, 27 Aug 2018 06:49:58 +0000 Subject: [PATCH 4/6] fix: change to enumerate by sentence --- .../fluid/operators/sequence_enumerate_op.cc | 4 +-- .../fluid/operators/sequence_enumerate_op.cu | 23 +++++++++++---- .../fluid/operators/sequence_enumerate_op.h | 14 +++++---- python/paddle/fluid/layers/nn.py | 6 ++-- .../unittests/test_sequence_enumerate_op.py | 29 ++++++++++++------- 5 files changed, 48 insertions(+), 28 deletions(-) diff --git a/paddle/fluid/operators/sequence_enumerate_op.cc b/paddle/fluid/operators/sequence_enumerate_op.cc index b8c8daf3f394b..307a4ece923cd 100644 --- a/paddle/fluid/operators/sequence_enumerate_op.cc +++ b/paddle/fluid/operators/sequence_enumerate_op.cc @@ -72,14 +72,14 @@ The values of the last insufficient part areall filled with the input pad_value. Case 1: Input: X.lod = [[0, 3, 5]] - X.data = [1, 2, 3, 4, 5] + X.data = [[1], [2], [3], [4], [5]] X.dims = [5, 1] Attrs: win_size = 2 pad_value = 0 Output: Out.lod = [[0, 3, 5]] - Out.data = [[1, 2], [2, 3], [3, 4], [4, 5], [0, 0]] + Out.data = [[1, 2], [2, 3], [3, 0], [4, 5], [5, 0]] Out.dims = [5, 2] )DOC"); diff --git a/paddle/fluid/operators/sequence_enumerate_op.cu b/paddle/fluid/operators/sequence_enumerate_op.cu index e680174a2cf08..bdc9a615aa9a1 100644 --- a/paddle/fluid/operators/sequence_enumerate_op.cu +++ b/paddle/fluid/operators/sequence_enumerate_op.cu @@ -23,15 +23,23 @@ using platform::PADDLE_CUDA_NUM_THREADS; using LoDTensor = framework::LoDTensor; template -__global__ void CalcOutPut(const T* in_data, const int64_t in_len, - const int64_t win_size, const int64_t pad_value, - T* out_data) { +__global__ void CalcOutPut(const T* in_data, const size_t* in_lod, + const size_t lod_len, const int64_t win_size, + const int64_t pad_value, T* out_data) { int index = blockIdx.x * blockDim.x + threadIdx.x; - if (index < in_len) { + if (index < in_lod[lod_len - 1]) { + int end_idx = 0; + // Get LoD interval of index + for (int i = 1; i < lod_len; ++i) { + if (index < in_lod[i]) { + end_idx = in_lod[i]; + break; + } + } for (size_t i = 0; i < win_size; ++i) { int word_pos = index + i; out_data[index * win_size + i] = - word_pos < in_len ? in_data[word_pos] : pad_value; + word_pos < end_idx ? in_data[word_pos] : pad_value; } } } @@ -54,13 +62,16 @@ class SequenceEnumerateOpCUDAKernel : public framework::OpKernel { /* Generate enumerate sequence set */ auto stream = context.cuda_device_context().stream(); + auto lod0 = in_lod[0]; auto in_len = in->numel(); auto in_data = in->data(); auto out_data = out->mutable_data(context.GetPlace()); + // Copy LoD to GPU + const size_t* dev_in_lod_ptr = lod0.CUDAData(context.GetPlace()); // Calc output tensor CalcOutPut<<<(in_len - 1) / PADDLE_CUDA_NUM_THREADS + 1, PADDLE_CUDA_NUM_THREADS, 0, stream>>>( - in_data, in_len, win_size, pad_value, out_data); + in_data, dev_in_lod_ptr, lod0.size(), win_size, pad_value, out_data); } }; diff --git a/paddle/fluid/operators/sequence_enumerate_op.h b/paddle/fluid/operators/sequence_enumerate_op.h index 8a30003b164b1..dc18d9b207130 100644 --- a/paddle/fluid/operators/sequence_enumerate_op.h +++ b/paddle/fluid/operators/sequence_enumerate_op.h @@ -37,14 +37,16 @@ class SequenceEnumerateKernel : public framework::OpKernel { "The actual input data's size mismatched with LoD information."); // Generate enumerate sequence set - auto seq_length = in_dims[0]; + auto lod0 = in_lod[0]; auto in_data = in->data(); auto out_data = out->mutable_data(context.GetPlace()); - for (int idx = 0; idx < seq_length; ++idx) { - for (int word_idx = 0; word_idx < win_size; ++word_idx) { - int word_pos = idx + word_idx; - out_data[win_size * idx + word_idx] = - word_pos < seq_length ? in_data[word_pos] : pad_value; + for (size_t i = 0; i < lod0.size() - 1; ++i) { + for (size_t idx = lod0[i]; idx < lod0[i + 1]; ++idx) { + for (int word_idx = 0; word_idx < win_size; ++word_idx) { + size_t word_pos = idx + word_idx; + out_data[win_size * idx + word_idx] = + word_pos < lod0[i + 1] ? in_data[word_pos] : pad_value; + } } } } diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index ecb651d6eb0f4..14ebe22b62048 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -5534,14 +5534,14 @@ def sequence_enumerate(input, win_size, pad_value, name=None): Case 1: Input: X.lod = [[0, 3, 5]] - X.data = [1, 2, 3, 4, 5] + X.data = [[1], [2], [3], [4], [5]] X.dims = [5, 1] Attrs: win_size = 2 pad_value = 0 Output: Out.lod = [[0, 3, 5]] - Out.data = [[1, 2], [2, 3], [3, 4], [4, 5], [0, 0]] + Out.data = [[1, 2], [2, 3], [3, 0], [4, 5], [5, 0]] Out.dims = [5, 2] Args: @@ -5567,7 +5567,7 @@ def sequence_enumerate(input, win_size, pad_value, name=None): attrs={'win_size': win_size, 'pad_value': pad_value}) - + def sequence_mask(x, maxlen=None, dtype='int64', name=None): """ **SequenceMask Layer** diff --git a/python/paddle/fluid/tests/unittests/test_sequence_enumerate_op.py b/python/paddle/fluid/tests/unittests/test_sequence_enumerate_op.py index 41624da51266a..f2e5844c7f7ba 100644 --- a/python/paddle/fluid/tests/unittests/test_sequence_enumerate_op.py +++ b/python/paddle/fluid/tests/unittests/test_sequence_enumerate_op.py @@ -19,16 +19,20 @@ from op_test import OpTest -def sequence_enumerate(input_seq, win_size, pad_value): +def sequence_enumerate(input_seq, in_lod, win_size, pad_value): + lod0 = [0] + for i in range(0, len(in_lod[0])): + lod0.append(lod0[i] + in_lod[0][i]) out_seq = [] - for idx in range(0, len(input_seq)): - single_seq = [] - for word_idx in range(win_size): - word_pos = idx + word_idx - dat = input_seq[word_pos] if word_pos < len(input_seq) \ + for i in range(0, len(lod0) - 1): + for idx in range(lod0[i], lod0[i + 1]): + single_seq = [] + for word_idx in range(win_size): + word_pos = idx + word_idx + dat = input_seq[word_pos] if word_pos < lod0[i+1] \ else pad_value - single_seq.append(dat) - out_seq.append(single_seq) + single_seq.append(dat) + out_seq.append(single_seq) return out_seq @@ -48,7 +52,8 @@ def init_test_case(self): self.lod = [[9, 4, 11, 6]] self.win_size = 2 self.pad_value = 0 - out_seq = sequence_enumerate(self.in_seq, self.win_size, self.pad_value) + out_seq = sequence_enumerate(self.in_seq, self.lod, self.win_size, + self.pad_value) self.out_seq = np.array(out_seq).astype("int32") @@ -58,7 +63,8 @@ def init_test_case(self): self.lod = [[9, 4, 11, 6]] self.win_size = 2 self.pad_value = 0 - out_seq = sequence_enumerate(self.in_seq, self.win_size, self.pad_value) + out_seq = sequence_enumerate(self.in_seq, self.lod, self.win_size, + self.pad_value) self.out_seq = np.array(out_seq).astype("int64") @@ -68,7 +74,8 @@ def init_test_case(self): self.lod = [[9, 4, 11, 6]] self.win_size = 30 self.pad_value = 0 - out_seq = sequence_enumerate(self.in_seq, self.win_size, self.pad_value) + out_seq = sequence_enumerate(self.in_seq, self.lod, self.win_size, + self.pad_value) self.out_seq = np.array(out_seq).astype("int32") From 0b7d82befba589da0a00168625f6d16c84c87290 Mon Sep 17 00:00:00 2001 From: chenweihang Date: Fri, 31 Aug 2018 12:48:32 +0000 Subject: [PATCH 5/6] doc: refine English description --- .../fluid/operators/sequence_enumerate_op.cc | 15 ++++++------- python/paddle/fluid/layers/nn.py | 17 +++++++------- .../unittests/test_sequence_enumerate_op.py | 22 +++++++++++++++++++ 3 files changed, 37 insertions(+), 17 deletions(-) diff --git a/paddle/fluid/operators/sequence_enumerate_op.cc b/paddle/fluid/operators/sequence_enumerate_op.cc index 307a4ece923cd..ae0241bd77a7d 100644 --- a/paddle/fluid/operators/sequence_enumerate_op.cc +++ b/paddle/fluid/operators/sequence_enumerate_op.cc @@ -50,24 +50,23 @@ class SequenceEnumerateOpMaker : public framework::OpProtoAndCheckerMaker { "(2-D LoDTensor with the 2nd dimension equal to 1) " "Input LoDTensor of SequenceEnumerate operator."); AddOutput("Out", - "(2-D LoDTensor with the 2nd dimension equal to 1) " + "(2-D LoDTensor with the 2nd dimension equal to win_size) " "Output LoDTensor of SequenceEnumerate operator."); AddAttr("win_size", "(int) The enumerate sequence window size.") .AddCustomChecker([](const int& win_size) { PADDLE_ENFORCE(win_size >= 2, - "The window size should be greater than 2."); + "The window size should be not less than 2."); }); AddAttr("pad_value", "(int) The enumerate sequence padding value.") .SetDefault(0); AddComment(R"DOC( Sequence Enumerate Operator. -Sequence enumerate operator generate a new LoDTensor -with the same 1st dimension length as the original LoDTensor, -and with the 2nd dimension equal to the input window length, -the new sub-sequence on 2nd dimension is enumerated one by one on the original sequence. -The values of the last insufficient part areall filled with the input pad_value. - +Generate a new sequence for the input index sequence, which enumerates all the +sub-sequences with length win_size of the input. +The enumerated sequence has the same 1st dimension with variable input, and +the 2nd dimension is win_size, padded by pad_value if necessary in generation. + Examples: Case 1: Input: diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 14ebe22b62048..ffb9858108a68 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -5522,13 +5522,12 @@ def flatten(x, axis=1, name=None): return out -def sequence_enumerate(input, win_size, pad_value, name=None): +def sequence_enumerate(input, win_size, pad_value=0, name=None): """ - Generate a new LoDTensor - with the same 1st dimension length as the original LoDTensor, - and with the 2nd dimension equal to the input window length, - the new sub-sequence on 2nd dimension is enumerated one by one on the original sequence. - The values of the last insufficient part areall filled with the input pad_value. + Generate a new sequence for the input index sequence, which enumerates all the + sub-sequences with length win_size of the input. + The enumerated sequence has the same 1st dimension with variable input, and + the 2nd dimension is win_size, padded by pad_value if necessary in generation. Examples: Case 1: @@ -5545,9 +5544,9 @@ def sequence_enumerate(input, win_size, pad_value, name=None): Out.dims = [5, 2] Args: - input (Variable): The input variable which is a LoDTensor - win_size (int): The enumerate sequence window size. - pad_value (int): The enumerate sequence padding value. + input (Variable): The input variable which is a index sequence. + win_size (int): The window size for enumerating all sub-sequences. + pad_value (int): The padding value, default 0. Returns: Variable: The enumerate sequence variable which is a LoDTensor. diff --git a/python/paddle/fluid/tests/unittests/test_sequence_enumerate_op.py b/python/paddle/fluid/tests/unittests/test_sequence_enumerate_op.py index f2e5844c7f7ba..9814ec0a15e18 100644 --- a/python/paddle/fluid/tests/unittests/test_sequence_enumerate_op.py +++ b/python/paddle/fluid/tests/unittests/test_sequence_enumerate_op.py @@ -68,6 +68,17 @@ def init_test_case(self): self.out_seq = np.array(out_seq).astype("int64") +class TestSequenceEnumerateOpLargeWinSize(TestSequenceEnumerateOp): + def init_test_case(self): + self.in_seq = np.random.randint(0, 10, (30, 1)).astype("int32") + self.lod = [[9, 4, 11, 6]] + self.win_size = 5 + self.pad_value = 0 + out_seq = sequence_enumerate(self.in_seq, self.lod, self.win_size, + self.pad_value) + self.out_seq = np.array(out_seq).astype("int32") + + class TestSequenceEnumerateOpMaxWinSize(TestSequenceEnumerateOp): def init_test_case(self): self.in_seq = np.random.randint(0, 10, (30, 1)).astype("int32") @@ -79,5 +90,16 @@ def init_test_case(self): self.out_seq = np.array(out_seq).astype("int32") +class TestSequenceEnumerateOpLargePadValue(TestSequenceEnumerateOp): + def init_test_case(self): + self.in_seq = np.random.randint(0, 10, (30, 1)).astype("int32") + self.lod = [[9, 4, 11, 6]] + self.win_size = 5 + self.pad_value = 5 + out_seq = sequence_enumerate(self.in_seq, self.lod, self.win_size, + self.pad_value) + self.out_seq = np.array(out_seq).astype("int32") + + if __name__ == "__main__": unittest.main() From 7ddbbcb0b59846e5eb0e6d9b8d9d022379f2cb6e Mon Sep 17 00:00:00 2001 From: chenweihang Date: Sat, 1 Sep 2018 14:16:47 +0000 Subject: [PATCH 6/6] doc: refine API and doc --- paddle/fluid/API.spec | 2 +- paddle/fluid/operators/sequence_enumerate_op.cc | 6 +++--- python/paddle/fluid/layers/nn.py | 6 +++--- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 601fa9cc0abaf..515ace3640489 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -164,10 +164,10 @@ paddle.fluid.layers.crop ArgSpec(args=['x', 'shape', 'offsets', 'name'], varargs paddle.fluid.layers.rank_loss ArgSpec(args=['label', 'left', 'right', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.prelu ArgSpec(args=['x', 'mode', 'param_attr', 'name'], varargs=None, keywords=None, defaults=(None, None)) paddle.fluid.layers.flatten ArgSpec(args=['x', 'axis', 'name'], varargs=None, keywords=None, defaults=(1, None)) -paddle.fluid.layers.sequence_enumerate ArgSpec(args=['input', 'win_size', 'pad_value', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.sequence_mask ArgSpec(args=['x', 'maxlen', 'dtype', 'name'], varargs=None, keywords=None, defaults=(None, 'int64', None)) paddle.fluid.layers.stack ArgSpec(args=['x', 'axis'], varargs=None, keywords=None, defaults=(0,)) paddle.fluid.layers.unstack ArgSpec(args=['x', 'axis', 'num'], varargs=None, keywords=None, defaults=(0, None)) +paddle.fluid.layers.sequence_enumerate ArgSpec(args=['input', 'win_size', 'pad_value', 'name'], varargs=None, keywords=None, defaults=(0, None)) paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)) paddle.fluid.layers.open_recordio_file ArgSpec(args=['filename', 'shapes', 'lod_levels', 'dtypes', 'pass_num', 'for_parallel'], varargs=None, keywords=None, defaults=(1, True)) paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None)) diff --git a/paddle/fluid/operators/sequence_enumerate_op.cc b/paddle/fluid/operators/sequence_enumerate_op.cc index ae0241bd77a7d..58e48c228bb34 100644 --- a/paddle/fluid/operators/sequence_enumerate_op.cc +++ b/paddle/fluid/operators/sequence_enumerate_op.cc @@ -63,9 +63,9 @@ class SequenceEnumerateOpMaker : public framework::OpProtoAndCheckerMaker { Sequence Enumerate Operator. Generate a new sequence for the input index sequence, which enumerates all the -sub-sequences with length win_size of the input. -The enumerated sequence has the same 1st dimension with variable input, and -the 2nd dimension is win_size, padded by pad_value if necessary in generation. +sub-sequences with length `win_size` of the input. +The enumerated sequence has the same 1st dimension with variable `input`, and +the 2nd dimension is `win_size`, padded by `pad_value` if necessary in generation. Examples: Case 1: diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 5642c8c8ef99c..9862a3fe1b483 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -5653,9 +5653,9 @@ def flatten(x, axis=1, name=None): def sequence_enumerate(input, win_size, pad_value=0, name=None): """ Generate a new sequence for the input index sequence, which enumerates all the - sub-sequences with length win_size of the input. - The enumerated sequence has the same 1st dimension with variable input, and - the 2nd dimension is win_size, padded by pad_value if necessary in generation. + sub-sequences with length `win_size` of the input. + The enumerated sequence has the same 1st dimension with variable `input`, and + the 2nd dimension is `win_size`, padded by `pad_value` if necessary in generation. Examples: Case 1: