Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use FusedMatMul When Transpose is Between First Dim and Contiguous Batch Dims #9734

Merged
merged 6 commits into from
Dec 27, 2021

Conversation

centwang
Copy link
Contributor

Current FusedMatMul can support only Transpose on last 2 dims. When the 2-D arrays for MatMul is the 1st and last dims, and the batch dims are contiguous in the original tensor, we can also use GemmStridedBatched to calculate without doing the Transpose. The perm pattern in the Transpose is like [1,2,0,3] or [1,2,3,0]. This PR is to support these cases using FusedMatMul.

For perf comparison using a module with Add+EinSum("ks,ksm->sm")+MSELOss, K = 16, S = 7840, M = 2048, before the changes, it's ~7ms for each step, after the changes, it's ~4.5ms for each step, which has similar perf as PyTorch.

Using ULR-XL (16 layers) for perf test, before the changes, the execution graph has 195 Transpose nodes, 16 MatMul nodes and 306 FusedMatMul nodes. After the changes the numbers are: 131 Transpose nodes and 322 FusedMatMul nodes. From nvvp profiling, for each step, the execution time reduces from ~913ms to ~882ms, which have ~4%. The gain is from the reduce of Transpose compute, and the new fused FusedMatMul nodes use GemmStridedBatched, which has comparible perf as original MatMul node.s

@centwang centwang added the training issues related to ONNX Runtime training; typically submitted using template label Nov 11, 2021
@pengwa
Copy link
Contributor

pengwa commented Dec 13, 2021

Current FusedMatMul can support only Transpose on last 2 dims. When the 2-D arrays for MatMul is the 1st and last dims, and the batch dims are contiguous in the original tensor, we can also use GemmStridedBatched to calculate without doing the Transpose. The perm pattern in the Transpose is like [1,2,0,3] or [1,2,3,0]. This PR is to support these cases using FusedMatMul.

For perf comparison using a module with Add+EinSum("ks,ksm->sm")+MSELOss, K = 16, S = 7840, M = 2048, before the changes, it's ~7ms for each step, after the changes, it's ~4.5ms for each step, which has similar perf as PyTorch.

Using ULR-XL (16 layers) for perf test, before the changes, the execution graph has 195 Transpose nodes, 16 MatMul nodes and 306 FusedMatMul nodes. After the changes the numbers are: 131 Transpose nodes and 322 FusedMatMul nodes. From nvvp profiling, for each step, the execution time reduces from ~913ms to ~882ms, which have ~4%. The gain is from the reduce of Transpose compute, and the new fused FusedMatMul nodes use GemmStridedBatched, which has comparible perf as original MatMul node.s

This is a nice change!! Some site notes put here FYI. It is found APEX and other libs I investigated last week also do this trick, the trick is applied to the models having self-attention's input having shape [seq, batch, num_head, head_dim]. We would remove at least two transposes + a scaling multiple (sqrt(num_head)) for the BERT large case. @iK1D @SherlockNoMad

static Node* GetTransposeNodeFromOutput(Graph& graph, NodeArg& node_arg) {
// is_trans is whether to transpose the 2 dims used to MatMul.
// is_trans_batch is whether to transpose 1st dim and batch dims (dim-1 to dim-rank-2).
// For example:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it would be nice if we can give a more descriptive comments covering what exact cases we target to fuse.

An example FYI

/* Here we check input and mask dimensions are as expected:

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and we need a definition for the 'batch' here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is better to use a different word than "batch" because it is used with respect to training batch. May be something like "range" may be okay.

Copy link
Contributor

@satyajandhyala satyajandhyala Dec 17, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

may be "circular permutation" is more clear.
1->0, 2->1, ..,r->r-1, 0->r.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CUDA's APIs (GemmBatched, GemmStridedBatched) use the same name. Our MatMul code also calls them batches. I think we still call batch here, but add more comments to explain.

static Node* GetTransposeNodeFromOutput(Graph& graph, NodeArg& node_arg) {
// is_trans is whether to transpose the 2 dims used to MatMul.
// is_trans_batch is whether to transpose 1st dim and batch dims (dim-1 to dim-rank-2).
// For example:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and we need a definition for the 'batch' here

}

if (!is_trans_on_last_two_dims) {
return nullptr;
// Transpose node can be fused to MatMul when the batch dimensions have same order before and after transpose.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: change to "the batch dims keep same relative orders before and after transpose"?

Copy link
Contributor

@satyajandhyala satyajandhyala Dec 16, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Introducing the notion of "circular permutation" is really helpful to understand the code here.

// is_trans is whether to transpose the 2 dims used to MatMul.
// is_trans_batch is whether to transpose 1st dim and batch dims (dim-1 to dim-rank-2).
// For example:
// is_trans=False, is_trans_batch=False: [0,1,2,3]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should not do the fusion for [0.1.2.3,..] case, right

onnxruntime/core/providers/cpu/math/matmul.h Show resolved Hide resolved
onnxruntime/core/providers/cpu/math/matmul.h Show resolved Hide resolved
left_ld_factor_ = right_ld_factor_ = 1;

if (trans_batch_a || trans_batch_b) {
ORT_ENFORCE(left_num_dims > 2 && left_num_dims == right_num_dims, "Two input should have same rank and rank >= 3 if transBatchA or transBatchB is true");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
ORT_ENFORCE(left_num_dims > 2 && left_num_dims == right_num_dims, "Two input should have same rank and rank >= 3 if transBatchA or transBatchB is true");
ORT_ENFORCE(left_num_dims > 2 && left_num_dims == right_num_dims, "Two inputs should have same rank and rank >= 3 if transBatchA or transBatchB is true");

onnxruntime/core/providers/cpu/math/matmul_helper.h Outdated Show resolved Hide resolved
onnxruntime/core/providers/cpu/math/matmul_helper.h Outdated Show resolved Hide resolved
@pengwa
Copy link
Contributor

pengwa commented Dec 17, 2021

The change looks great overall! There are few things, need your help for confirmation:

  1. In your measured case, how many transpose get eliminated per layer? from 195 to 131 for 16 layers, so it is 4 transpose per layer?
  2. have we ever covered the backward pass for the fused matmul?

@centwang
Copy link
Contributor Author

The change looks great overall! There are few things, need your help for confirmation:

  1. In your measured case, how many transpose get eliminated per layer? from 195 to 131 for 16 layers, so it is 4 transpose per layer?
  2. have we ever covered the backward pass for the fused matmul?

I didn't check the big graph carefully, but from the number yes it's 4 for each layer. From the code the fusion is added for both training and inference transformer list, so ideally it backward is also covered. But we build the gradient graph after the training transformers, and use FusedMatMul instead of MatMul in backward graph, so I think it's rare to have such case in backward we can fuse.

@satyajandhyala
Copy link
Contributor

satyajandhyala commented Dec 17, 2021

This is good.
Going by the definition at https://mathworld.wolfram.com/CyclicPermutation.html. This is applying cyclic permutation to left by 1. Is it possible to extend the idea further, generalize for any number r less than dimensions?

@centwang
Copy link
Contributor Author

This is good. Going by the definition at https://mathworld.wolfram.com/CyclicPermutation.html. This is applying cyclic permutation to left by 1. Is it possible to extend the idea further, generalize for any number r less than dimensions?

I don't quite get the idea. Could you please give some example? i.e., what's the 'perm' attribute for the Transpose nodes.

@satyajandhyala
Copy link
Contributor

satyajandhyala commented Dec 20, 2021

This is good. Going by the definition at https://mathworld.wolfram.com/CyclicPermutation.html. This is applying cyclic permutation to left by 1. Is it possible to extend the idea further, generalize for any number r less than dimensions?

I don't quite get the idea. Could you please give some example? i.e., what's the 'perm' attribute for the Transpose nodes.

This change supports [1,2,0,3] or [1,2,3,0]. In future, not in this PR, could we consider permutations like, [2, 0, 1, 3] or [2,3,0,1] also?

@centwang
Copy link
Contributor Author

This is good. Going by the definition at https://mathworld.wolfram.com/CyclicPermutation.html. This is applying cyclic permutation to left by 1. Is it possible to extend the idea further, generalize for any number r less than dimensions?

I don't quite get the idea. Could you please give some example? i.e., what's the 'perm' attribute for the Transpose nodes.

This change supports [1,2,0,3] or [1,2,3,0]. In future, not in this PR, could we consider permutations like, [2, 0, 1, 3] or [2,3,0,1] also?

I have below comments in the code to explain which cases we can fuse. For [2,0,1,3] or [2,3,0,1], it's not possible to get the strideA, strideB, lda, ldb for the parameters of GemmStridedBatched, so we cannot fuse such cases.

// Transpose node can be fused to MatMul when the batch dims keep same relative orders before and after transpose.
// But if they are not contiguous, after the fusion, we can only use GemmBatched instead of GemmStridedBatched,
// which may have perf issue. To keep it simple, we will fuse only when batch dimensions are contiguous.

Copy link
Contributor

@pengwa pengwa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the late response! LGTM!! :)

onnxruntime/core/providers/cpu/math/matmul.h Show resolved Hide resolved
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
training issues related to ONNX Runtime training; typically submitted using template
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants