forked from PaddlePaddle/Paddle
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request PaddlePaddle#12887 from chenwhql/sequence_enumerat…
…e_op Feat: add sequence enumerate op
- Loading branch information
Showing
7 changed files
with
397 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/fluid/operators/sequence_enumerate_op.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
class SequenceEnumerateOp : public framework::OperatorWithKernel { | ||
public: | ||
using framework::OperatorWithKernel::OperatorWithKernel; | ||
|
||
void InferShape(framework::InferShapeContext* ctx) const override { | ||
PADDLE_ENFORCE( | ||
ctx->HasInput("X"), | ||
"Input(X) of SequecceEnumerate operator should not be null."); | ||
PADDLE_ENFORCE( | ||
ctx->HasOutput("Out"), | ||
"Output(X) of SequenceEnumerate operator should not be null."); | ||
|
||
const auto x_dims = ctx->GetInputDim("X"); | ||
PADDLE_ENFORCE_EQ( | ||
x_dims.size(), 2UL, | ||
"Input(X) of SequenceEnumerate operator's rank should be 2."); | ||
PADDLE_ENFORCE_EQ( | ||
x_dims[1], 1UL, | ||
"Input(X) of SequenceEnumerate operator's 2nd dimension should be 1."); | ||
|
||
const auto win_size = ctx->Attrs().Get<int>("win_size"); | ||
ctx->SetOutputDim("Out", {x_dims[0], win_size}); | ||
ctx->ShareLoD("X", "Out"); | ||
} | ||
}; | ||
|
||
class SequenceEnumerateOpMaker : public framework::OpProtoAndCheckerMaker { | ||
public: | ||
void Make() override { | ||
AddInput("X", | ||
"(2-D LoDTensor with the 2nd dimension equal to 1) " | ||
"Input LoDTensor of SequenceEnumerate operator."); | ||
AddOutput("Out", | ||
"(2-D LoDTensor with the 2nd dimension equal to win_size) " | ||
"Output LoDTensor of SequenceEnumerate operator."); | ||
AddAttr<int>("win_size", "(int) The enumerate sequence window size.") | ||
.AddCustomChecker([](const int& win_size) { | ||
PADDLE_ENFORCE(win_size >= 2, | ||
"The window size should be not less than 2."); | ||
}); | ||
AddAttr<int>("pad_value", "(int) The enumerate sequence padding value.") | ||
.SetDefault(0); | ||
AddComment(R"DOC( | ||
Sequence Enumerate Operator. | ||
Generate a new sequence for the input index sequence, which enumerates all the | ||
sub-sequences with length `win_size` of the input. | ||
The enumerated sequence has the same 1st dimension with variable `input`, and | ||
the 2nd dimension is `win_size`, padded by `pad_value` if necessary in generation. | ||
Examples: | ||
Case 1: | ||
Input: | ||
X.lod = [[0, 3, 5]] | ||
X.data = [[1], [2], [3], [4], [5]] | ||
X.dims = [5, 1] | ||
Attrs: | ||
win_size = 2 | ||
pad_value = 0 | ||
Output: | ||
Out.lod = [[0, 3, 5]] | ||
Out.data = [[1, 2], [2, 3], [3, 0], [4, 5], [5, 0]] | ||
Out.dims = [5, 2] | ||
)DOC"); | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
REGISTER_OP_WITHOUT_GRADIENT(sequence_enumerate, ops::SequenceEnumerateOp, | ||
ops::SequenceEnumerateOpMaker); | ||
REGISTER_OP_CPU_KERNEL( | ||
sequence_enumerate, | ||
ops::SequenceEnumerateKernel<paddle::platform::CPUDeviceContext, int32_t>, | ||
ops::SequenceEnumerateKernel<paddle::platform::CPUDeviceContext, int64_t>); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include <thrust/device_vector.h> | ||
#include <thrust/host_vector.h> | ||
#include "paddle/fluid/operators/sequence_enumerate_op.h" | ||
#include "paddle/fluid/platform/cuda_primitives.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
using platform::PADDLE_CUDA_NUM_THREADS; | ||
using LoDTensor = framework::LoDTensor; | ||
|
||
template <typename T> | ||
__global__ void CalcOutPut(const T* in_data, const size_t* in_lod, | ||
const size_t lod_len, const int64_t win_size, | ||
const int64_t pad_value, T* out_data) { | ||
int index = blockIdx.x * blockDim.x + threadIdx.x; | ||
if (index < in_lod[lod_len - 1]) { | ||
int end_idx = 0; | ||
// Get LoD interval of index | ||
for (int i = 1; i < lod_len; ++i) { | ||
if (index < in_lod[i]) { | ||
end_idx = in_lod[i]; | ||
break; | ||
} | ||
} | ||
for (size_t i = 0; i < win_size; ++i) { | ||
int word_pos = index + i; | ||
out_data[index * win_size + i] = | ||
word_pos < end_idx ? in_data[word_pos] : pad_value; | ||
} | ||
} | ||
} | ||
|
||
template <typename T> | ||
class SequenceEnumerateOpCUDAKernel : public framework::OpKernel<T> { | ||
public: | ||
void Compute(const framework::ExecutionContext& context) const override { | ||
auto* in = context.Input<LoDTensor>("X"); | ||
auto* out = context.Output<LoDTensor>("Out"); | ||
int win_size = context.Attr<int>("win_size"); | ||
int pad_value = context.Attr<int>("pad_value"); | ||
|
||
auto in_dims = in->dims(); | ||
auto in_lod = in->lod(); | ||
|
||
PADDLE_ENFORCE_EQ( | ||
static_cast<uint64_t>(in_dims[0]), in_lod[0].back(), | ||
"The actual input data's size mismatched with LoD information."); | ||
|
||
/* Generate enumerate sequence set */ | ||
auto stream = context.cuda_device_context().stream(); | ||
auto lod0 = in_lod[0]; | ||
auto in_len = in->numel(); | ||
auto in_data = in->data<T>(); | ||
auto out_data = out->mutable_data<T>(context.GetPlace()); | ||
// Copy LoD to GPU | ||
const size_t* dev_in_lod_ptr = lod0.CUDAData(context.GetPlace()); | ||
// Calc output tensor | ||
CalcOutPut<<<(in_len - 1) / PADDLE_CUDA_NUM_THREADS + 1, | ||
PADDLE_CUDA_NUM_THREADS, 0, stream>>>( | ||
in_data, dev_in_lod_ptr, lod0.size(), win_size, pad_value, out_data); | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
REGISTER_OP_CUDA_KERNEL( | ||
sequence_enumerate, | ||
paddle::operators::SequenceEnumerateOpCUDAKernel<int32_t>, | ||
paddle::operators::SequenceEnumerateOpCUDAKernel<int64_t>); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#pragma once | ||
|
||
#include "paddle/fluid/framework/op_registry.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
using LoDTensor = framework::LoDTensor; | ||
|
||
template <typename DeviceContext, typename T> | ||
class SequenceEnumerateKernel : public framework::OpKernel<T> { | ||
public: | ||
void Compute(const framework::ExecutionContext& context) const override { | ||
auto* in = context.Input<LoDTensor>("X"); | ||
auto* out = context.Output<LoDTensor>("Out"); | ||
int win_size = context.Attr<int>("win_size"); | ||
int pad_value = context.Attr<int>("pad_value"); | ||
|
||
auto in_dims = in->dims(); | ||
auto in_lod = in->lod(); | ||
|
||
PADDLE_ENFORCE_EQ( | ||
static_cast<uint64_t>(in_dims[0]), in_lod[0].back(), | ||
"The actual input data's size mismatched with LoD information."); | ||
|
||
// Generate enumerate sequence set | ||
auto lod0 = in_lod[0]; | ||
auto in_data = in->data<T>(); | ||
auto out_data = out->mutable_data<T>(context.GetPlace()); | ||
for (size_t i = 0; i < lod0.size() - 1; ++i) { | ||
for (size_t idx = lod0[i]; idx < lod0[i + 1]; ++idx) { | ||
for (int word_idx = 0; word_idx < win_size; ++word_idx) { | ||
size_t word_pos = idx + word_idx; | ||
out_data[win_size * idx + word_idx] = | ||
word_pos < lod0[i + 1] ? in_data[word_pos] : pad_value; | ||
} | ||
} | ||
} | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.