Skip to content

Commit

Permalink
Change EmbedLayerNormalization mask index output to optional (#15526)
Browse files Browse the repository at this point in the history
### Description
This PR changes an EmbedLayerNormalization node's mask index output to
be an optional output if a mask input is not provided.



### Motivation and Context
The documentation for EmbedLayerNormalization states 
```
The last input mask is optional. If mask is provided, mask index (that is position of first 0 in mask, or number of words) will be calculated.
```
However, if the mask input is not provided, the mask index output is
still calculated and required.
  • Loading branch information
kunal-vaishnavi committed Apr 27, 2023
1 parent d471432 commit 39d6d70
Show file tree
Hide file tree
Showing 11 changed files with 116 additions and 31 deletions.
6 changes: 4 additions & 2 deletions docs/ContribOperators.md
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -1527,6 +1527,8 @@ This version of the operator has been available since version 1 of the 'com.micr
<dl>
<dt><tt>epsilon</tt> : float</dt>
<dd>The epsilon value to use to avoid division by zero.</dd>
<dt><tt>mask_index_type</tt> : int</dt>
<dd>The mask index tensor type for shape inference (0: None, 1: 1D mask_index)</dd>
</dl>

#### Inputs (7 - 9)
Expand All @@ -1552,12 +1554,12 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>2D position ids with shape (batch_size, sequence_length) or (1, sequence_length)</dd>
</dl>

#### Outputs (2 - 3)
#### Outputs (1 - 3)

<dl>
<dt><tt>output</tt> : T</dt>
<dd>3D output tensor with shape (batch_size, sequence_length, hidden_size)</dd>
<dt><tt>mask_index</tt> : T1</dt>
<dt><tt>mask_index</tt> (optional) : T1</dt>
<dd>1D mask_index tensor with shape (batch_size)</dd>
<dt><tt>embedding_sum</tt> (optional) : T</dt>
<dd>sum of word_embedding and position_embedding without layer normalization</dd>
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/contrib_ops/cpu/bert/embed_layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ Status EmbedLayerNorm<T>::Compute(OpKernelContext* context) const {
}

// Calculate mask
if (nullptr != mask) {
if (nullptr != mask && nullptr != mask_index) {
const int32_t* mask_data = mask->Data<int32_t>();
int32_t* mask_index_data = mask_index->MutableData<int32_t>();
for (int b = 0; b < batch_size; b++) {
Expand All @@ -162,7 +162,7 @@ Status EmbedLayerNorm<T>::Compute(OpKernelContext* context) const {
}
mask_index_data[b] = cur_sum;
}
} else {
} else if (mask_index != nullptr) {
memset(mask_index->MutableData<int32_t>(), 0, batch_size * sizeof(int32_t));
}

Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/contrib_ops/cpu/quantization/qembed_layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ Status ComputeInternal(OpKernelContext* context, float epsilon) {
}

// Calculate mask
if (nullptr != mask) {
if (nullptr != mask && nullptr != mask_index) {
const int32_t* mask_data = mask->Data<int32_t>();
int32_t* mask_index_data = mask_index->MutableData<int32_t>();
for (int b = 0; b < batch_size; b++) {
Expand All @@ -180,7 +180,7 @@ Status ComputeInternal(OpKernelContext* context, float epsilon) {
}
mask_index_data[b] = cur_sum;
}
} else {
} else if (mask_index != nullptr) {
memset(mask_index->MutableData<int32_t>(), 0, batch_size * sizeof(int32_t));
}
return Status::OK();
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cuda/bert/embed_layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ Status EmbedLayerNorm<T>::ComputeInternal(OpKernelContext* context) const {
return LaunchEmbedLayerNormKernel(
Stream(context),
output->MutableData<T>(),
mask_index->MutableData<int32_t>(),
nullptr == mask_index ? nullptr : mask_index->MutableData<int32_t>(),
input_ids->Data<int32_t>(),
nullptr == segment_ids ? nullptr : segment_ids->Data<int32_t>(),
nullptr == mask ? nullptr : mask->Data<int32_t>(),
Expand Down
13 changes: 8 additions & 5 deletions onnxruntime/contrib_ops/cuda/bert/embed_layer_norm_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -212,11 +212,14 @@ Status LaunchEmbedLayerNormKernel(
void* embedding_sum,
const int* position_ids,
const bool broadcast_position_ids) {
if (nullptr == input_mask) {
CUDA_RETURN_IF_ERROR(cudaMemsetAsync(mask_index, 0, sizeof(int) * batch_size, stream));
} else {
ORT_RETURN_IF_ERROR(
ComputeMaskIndex(stream, sequence_length, batch_size, input_mask, static_cast<int*>(mask_index)));

if (mask_index != nullptr) {
if (nullptr == input_mask) {
CUDA_RETURN_IF_ERROR(cudaMemsetAsync(mask_index, 0, sizeof(int) * batch_size, stream));
} else {
ORT_RETURN_IF_ERROR(
ComputeMaskIndex(stream, sequence_length, batch_size, input_mask, static_cast<int*>(mask_index)));
}
}

if (element_size == 2) {
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/core/graph/contrib_ops/bert_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -855,6 +855,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
OpSchema()
.SetDoc(EmbedLayerNormalization_ver1_doc)
.Attr("epsilon", "The epsilon value to use to avoid division by zero.", AttributeProto::FLOAT, kDefaultEmbedLayerNormEpsilon)
.Attr("mask_index_type", "The mask index tensor type for shape inference (0: None, 1: 1D mask_index)", AttributeProto::INT, OPTIONAL_VALUE)
.Input(0, "input_ids", "2D words IDs with shape (batch_size, sequence_length)", "T1")
.Input(1, "segment_ids", "2D segment IDs with shape (batch_size, sequence_length)", "T1", OpSchema::Optional)
.Input(2, "word_embedding", "2D with shape (,hidden_size)", "T")
Expand All @@ -865,7 +866,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
.Input(7, "mask", "2D attention mask with shape (batch_size, sequence_length)", "T1", OpSchema::Optional)
.Input(8, "position_ids", "2D position ids with shape (batch_size, sequence_length) or (1, sequence_length)", "T1", OpSchema::Optional)
.Output(0, "output", "3D output tensor with shape (batch_size, sequence_length, hidden_size)", "T")
.Output(1, "mask_index", "1D mask_index tensor with shape (batch_size)", "T1")
.Output(1, "mask_index", "1D mask_index tensor with shape (batch_size)", "T1", OpSchema::Optional)
.Output(2, "embedding_sum", "sum of word_embedding and position_embedding without layer normalization", "T", OpSchema::Optional)
.TypeConstraint("T1", {"tensor(int32)"}, "Constrain input and output integer tensors types")
.TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output float tensors types.")
Expand Down
16 changes: 11 additions & 5 deletions onnxruntime/core/graph/contrib_ops/shape_inference_functions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,16 @@

#include "core/graph/contrib_ops/shape_inference_functions.h"
#include <onnx/defs/shape_inference.h>
#include <iostream>

namespace onnxruntime {
namespace contrib {
void EmbedLayerNormalizationShapeInference(::ONNX_NAMESPACE::InferenceContext& ctx) {
propagateElemTypeFromInputToOutput(ctx, 2, 0);
propagateElemTypeFromInputToOutput(ctx, 0, 1);
auto mask_index_type = getAttribute(ctx, "mask_index_type", 1);
if (mask_index_type > 0) {
propagateElemTypeFromInputToOutput(ctx, 0, 1);
}
if (!hasInputShape(ctx, 0)) {
// TODO(kreeger): In this case update the output to (?, ?, hidden_size).
return;
Expand Down Expand Up @@ -97,11 +101,13 @@ void EmbedLayerNormalizationShapeInference(::ONNX_NAMESPACE::InferenceContext& c
updateOutputShape(ctx, 0, output_shape);

// mask_index shape is (batch_size)
ONNX_NAMESPACE::TensorShapeProto mask_index_shape;
*mask_index_shape.add_dim() = input_ids_dims[0];
updateOutputShape(ctx, 1, mask_index_shape);
if (mask_index_type > 0) {
ONNX_NAMESPACE::TensorShapeProto mask_index_shape;
*mask_index_shape.add_dim() = input_ids_dims[0];
updateOutputShape(ctx, 1, mask_index_shape);
}

if (ctx.getNumOutputs() > 2) {
if (ctx.getNumOutputs() == 3 || (ctx.getNumOutputs() == 2 && mask_index_type == 0)) {
updateOutputShape(ctx, 2, output_shape);
propagateElemTypeFromInputToOutput(ctx, 0, 2);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ class DmlOperatorEmbedLayerNormalization : public DmlOperator
maskIndexOutputEdge.FromNodeOutputIndex = 0;
outputEdges.push_back(std::move(maskIndexOutputEdge));
}
else
else if (maskIndexDesc.Desc)
{
// Insert the edge feeding into the MaskIndex output
DML_OUTPUT_GRAPH_EDGE_DESC maskIndexOutputEdge = {};
Expand Down
7 changes: 4 additions & 3 deletions onnxruntime/python/tools/symbolic_shape_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2262,9 +2262,10 @@ def _infer_EmbedLayerNormalization(self, node): # noqa: N802
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], word_embedding_dtype, output_shape))

mask_index_shape = [input_ids_shape[0]]
vi = self.known_vi_[node.output[1]]
vi.CopyFrom(helper.make_tensor_value_info(node.output[1], onnx.TensorProto.INT32, mask_index_shape))
if len(node.output) > 1 and node.output[1]:
mask_index_shape = [input_ids_shape[0]]
vi = self.known_vi_[node.output[1]]
vi.CopyFrom(helper.make_tensor_value_info(node.output[1], onnx.TensorProto.INT32, mask_index_shape))

if len(node.output) > 2:
# Optional output of add before layer normalization is done
Expand Down
18 changes: 15 additions & 3 deletions onnxruntime/test/contrib_ops/embed_layer_norm_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ static void RunTest(const embedlayernorm::OpData& data,
ToFloat16(data.beta_data),
/*is_initializer=*/true);
tester.AddAttribute("epsilon", data.epsilon);
if (data.has_mask) {
if (data.has_mask && data.mask_data.size()) {
tester.AddInput<int32_t>("mask", mask_dims, data.mask_data);
}
tester.AddOutput<MLFloat16>("output", output_dims, ToFloat16(data.output_data));
Expand All @@ -117,12 +117,17 @@ static void RunTest(const embedlayernorm::OpData& data,
tester.AddInput<float>("gamma", gamma_dims, data.gamma_data, /*is_initializer=*/true);
tester.AddInput<float>("beta", beta_dims, data.beta_data, /*is_initializer=*/true);
tester.AddAttribute("epsilon", data.epsilon);
if (data.has_mask) {
if (data.has_mask && data.mask_data.size()) {
tester.AddInput<int32_t>("mask", mask_dims, data.mask_data);
}
tester.AddOutput<float>("output", output_dims, data.output_data);
}
tester.AddOutput<int32_t>("mask_index", mask_index_dims, data.mask_index_data);
tester.AddAttribute("mask_index_type", static_cast<int64_t>(data.mask_index_type));
if (data.mask_index_data.size()) {
tester.AddOutput<int32_t>("mask_index", mask_index_dims, data.mask_index_data);
} else {
tester.AddOptionalOutputEdge<int32_t>();
}
if (sum_output) {
std::vector<int64_t> embedding_sum_output_dims = output_dims;
if (use_float16) {
Expand Down Expand Up @@ -188,6 +193,13 @@ TEST(EmbedLayerNormTest, EmbedLayerNormBatch1_EmbeddingSum) {
TEST(EmbedLayerNormTest, EmbedLayerNormBatch1_EmbeddingSum_Float16) {
RunTest(embedlayernorm::EmbedLayerNormBatch1_EmbeddingSum(), true, true);
}

TEST(EmbedLayerNormTest, EmbedLayerNormBatch1_EmbeddingSum_NoMaskIndex) {
RunTest(embedlayernorm::EmbedLayerNormBatch1_EmbeddingSum_NoMaskIndex(),
/* use_float16 = */ false,
/* sum_output = */ true);
}

TEST(EmbedLayerNormTest, EmbedLayerNormBatch2) {
RunTest(embedlayernorm::EmbedLayerNormBatch2());
}
Expand Down
72 changes: 66 additions & 6 deletions onnxruntime/test/contrib_ops/embed_layer_norm_test_vectors.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,12 @@ class OpData {
const std::vector<float>& output_data,
const std::vector<int32_t>& mask_index_data,
float epsilon = kEpsilon,
int mask_index_type = 1,
bool has_mask = true,
bool has_segment = true,
const std::vector<float>& embedding_sum_data = {},
const std::vector<int32_t>& position_ids_data = {})
: batch_size(batch_size), sequence_size(sequence_size), hidden_size(hidden_size), input_ids_data(input_ids_data), segment_ids_data(segment_ids_data), mask_data(mask_data), word_embedding_data(word_embedding_data), position_embedding_data(position_embedding_data), segment_embedding_data(segment_embedding_data), gamma_data(gamma_data), beta_data(beta_data), output_data(output_data), mask_index_data(mask_index_data), epsilon(epsilon), has_mask(has_mask), has_segment(has_segment), embedding_sum_data(embedding_sum_data), position_ids_data(position_ids_data) {}
: batch_size(batch_size), sequence_size(sequence_size), hidden_size(hidden_size), input_ids_data(input_ids_data), segment_ids_data(segment_ids_data), mask_data(mask_data), word_embedding_data(word_embedding_data), position_embedding_data(position_embedding_data), segment_embedding_data(segment_embedding_data), gamma_data(gamma_data), beta_data(beta_data), output_data(output_data), mask_index_data(mask_index_data), epsilon(epsilon), mask_index_type(mask_index_type), has_mask(has_mask), has_segment(has_segment), embedding_sum_data(embedding_sum_data), position_ids_data(position_ids_data) {}

const int batch_size;
const int sequence_size;
Expand All @@ -51,6 +52,7 @@ class OpData {
const std::vector<float> output_data;
const std::vector<int32_t> mask_index_data;
const float epsilon;
const int mask_index_type;
const bool has_mask = true;
const bool has_segment = true;
const std::vector<float> embedding_sum_data;
Expand Down Expand Up @@ -110,6 +112,7 @@ inline OpData EmbedLayerNormBatch2(bool has_mask = true) {
int batch_size = 3;
int sequence_size = 2;
int hidden_size = 4;
int mask_index_type = 1;

std::vector<int32_t> input_ids_data = {
1, 3,
Expand Down Expand Up @@ -169,7 +172,7 @@ inline OpData EmbedLayerNormBatch2(bool has_mask = true) {

return OpData(batch_size, sequence_size, hidden_size, input_ids_data, segment_ids_data,
mask_data, word_embedding_data, position_embedding_data, segment_embedding_data,
gamma_data, beta_data, output_data, mask_index_data, kEpsilon, has_mask);
gamma_data, beta_data, output_data, mask_index_data, kEpsilon, mask_index_type, has_mask);
}

inline OpData EmbedLayerNormLargeBatchSmallHiddenSize() {
Expand Down Expand Up @@ -245,6 +248,7 @@ inline OpData EmbedLayerNormBatch_Distill() {
int batch_size = 3;
int sequence_size = 2;
int hidden_size = 4;
int mask_index_type = 1;

std::vector<int32_t> input_ids_data = {
1, 3,
Expand Down Expand Up @@ -292,7 +296,7 @@ inline OpData EmbedLayerNormBatch_Distill() {

return OpData(batch_size, sequence_size, hidden_size, input_ids_data, segment_ids_data,
mask_data, word_embedding_data, position_embedding_data, segment_embedding_data,
gamma_data, beta_data, output_data, mask_index_data, kEpsilon,
gamma_data, beta_data, output_data, mask_index_data, kEpsilon, mask_index_type,
/*has_mask=*/true,
/*has_segment=*/false);
}
Expand All @@ -301,6 +305,7 @@ inline OpData EmbedLayerNormBatch1_PositionIds(bool diff_order = false) {
int batch_size = 1;
int sequence_size = 2;
int hidden_size = 4;
int mask_index_type = 1;

std::vector<int32_t> input_ids_data = {
1, 3};
Expand Down Expand Up @@ -356,7 +361,7 @@ inline OpData EmbedLayerNormBatch1_PositionIds(bool diff_order = false) {

return OpData(batch_size, sequence_size, hidden_size, input_ids_data, segment_ids_data,
mask_data, word_embedding_data, position_embedding_data, segment_embedding_data,
gamma_data, beta_data, output_data, mask_index_data, kEpsilon,
gamma_data, beta_data, output_data, mask_index_data, kEpsilon, mask_index_type,
/*has_mask=*/true,
/*has_segment=*/false,
embedding_sum_output_data,
Expand All @@ -367,6 +372,7 @@ inline OpData EmbedLayerNormBatch3_PositionIds_BroadCast() {
int batch_size = 3;
int sequence_size = 2;
int hidden_size = 4;
int mask_index_type = 1;

std::vector<int32_t> input_ids_data = {
1, 3, 1, 3, 1, 3};
Expand Down Expand Up @@ -416,7 +422,7 @@ inline OpData EmbedLayerNormBatch3_PositionIds_BroadCast() {

return OpData(batch_size, sequence_size, hidden_size, input_ids_data, segment_ids_data,
mask_data, word_embedding_data, position_embedding_data, segment_embedding_data,
gamma_data, beta_data, output_data, mask_index_data, kEpsilon,
gamma_data, beta_data, output_data, mask_index_data, kEpsilon, mask_index_type,
/*has_mask=*/true,
/*has_segment=*/false,
embedding_sum_output_data,
Expand All @@ -427,6 +433,7 @@ inline OpData EmbedLayerNormBatch1_EmbeddingSum() {
int batch_size = 1;
int sequence_size = 2;
int hidden_size = 4;
int mask_index_type = 1;

std::vector<int32_t> input_ids_data = {
1, 3};
Expand Down Expand Up @@ -470,11 +477,64 @@ inline OpData EmbedLayerNormBatch1_EmbeddingSum() {

return OpData(batch_size, sequence_size, hidden_size, input_ids_data, segment_ids_data,
mask_data, word_embedding_data, position_embedding_data, segment_embedding_data,
gamma_data, beta_data, output_data, mask_index_data, kEpsilon,
gamma_data, beta_data, output_data, mask_index_data, kEpsilon, mask_index_type,
/*has_mask=*/true,
/*has_segment=*/false,
embedding_sum_data);
}

inline OpData EmbedLayerNormBatch1_EmbeddingSum_NoMaskIndex() {
int batch_size = 1;
int sequence_size = 2;
int hidden_size = 4;
int mask_index_type = 0;

std::vector<int32_t> input_ids_data = {
1, 3};

std::vector<int32_t> segment_ids_data = {};

std::vector<int32_t> mask_data = {};

std::vector<float> word_embedding_data = {
0.2f, 0.1f, 0.4f, -0.6f,
0.3f, 0.2f, 0.5f, 0.6f,
0.6f, 0.7f, 0.0f, -0.1f,
0.8f, 0.6f, 0.9f, 1.2f,
0.1f, 0.3f, 0.5f, 0.9f,
1.0f, -2.0f, 1.1f, 0.8f};

std::vector<float> position_embedding_data = {
0.1f, 0.1f, 0.4f, 0.6f,
0.6f, 0.0f, 0.8f, 0.6f,
0.3f, 0.9f, -2.0f, 0.8f};

std::vector<float> segment_embedding_data = {};

std::vector<float> gamma_data = {
0.25f, 0.15f, 0.45f, -0.66f};

std::vector<float> beta_data = {
0.6f, 0.2f, 0.5f, -0.6f};

std::vector<float> output_data = {
0.39587587118148804, 0.03670068085193634, 0.7449488639831543, -1.4981462955474854,
0.61326867341995239, -0.046796366572380066, 0.81048583984375, -1.1954958438873291};

std::vector<int32_t> mask_index_data = {};

std::vector<float> embedding_sum_data = {
0.40000000596046448, 0.30000001192092896, 0.89999997615814209, 1.2000000476837158,
1.4000000953674316, 0.60000002384185791, 1.7000000476837158, 1.8000000715255737};

return OpData(batch_size, sequence_size, hidden_size, input_ids_data, segment_ids_data,
mask_data, word_embedding_data, position_embedding_data, segment_embedding_data,
gamma_data, beta_data, output_data, mask_index_data, kEpsilon, mask_index_type,
/*has_mask=*/true,
/*has_segment=*/false,
embedding_sum_data);
}

} // namespace embedlayernorm
} // namespace test
} // namespace onnxruntime

0 comments on commit 39d6d70

Please sign in to comment.