-
Notifications
You must be signed in to change notification settings - Fork 11.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[MLIR][Linalg] Add pass to convert linalg.generic back to named ops (#…
…95656) Add a new mlir-opt pass `--linalg-specialize-generic-ops` which lifts generic, where possible, to linalg named ops. Much like `-linalg-generalize-named-ops` lowers named ops to linalg.generic . Also add patterns to recognize contractions which can be specialized from linalg.generic to named op: `linalg.{batch_}?matmul{_transpose_(a|b)}?`
- Loading branch information
1 parent
69d3793
commit 3efac5c
Showing
9 changed files
with
585 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
// The following test examples of linalg named ops lowered to linalg.generic and then | ||
// lifted back up to named op. | ||
// RUN: mlir-opt %s -linalg-generalize-named-ops | mlir-opt --linalg-specialize-generic-ops | FileCheck %s | ||
|
||
func.func @unary_exp(%A: memref<7x14x21xf32>, %Out: memref<7x14x21xf32>) { | ||
linalg.exp ins(%A : memref<7x14x21xf32>) outs(%Out : memref<7x14x21xf32>) | ||
return | ||
} | ||
|
||
// CHECK-LABEL: unary_exp | ||
// CHECK-SAME: %[[A:.+]]: memref<7x14x21xf32>, %[[Out:.+]]: memref<7x14x21xf32>) | ||
// CHECK-NOT: linalg.generic | ||
// CHECK: linalg.exp ins(%[[A]] : memref<7x14x21xf32>) outs(%[[Out]] : memref<7x14x21xf32>) | ||
|
||
// ----- | ||
|
||
func.func @binary_add(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %Out: tensor<?x?xf32>) -> tensor<?x?xf32> { | ||
%0 = linalg.add ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>) outs(%Out : tensor<?x?xf32>) -> tensor<?x?xf32> | ||
return %0 : tensor<?x?xf32> | ||
} | ||
|
||
// CHECK-LABEL: binary_add | ||
// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32> | ||
// CHECK-NOT: linalg.generic | ||
// CHECK: linalg.add ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32> | ||
|
||
// ----- | ||
|
||
func.func @matmul(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %Out: tensor<?x?xf32>) -> tensor<?x?xf32> { | ||
%0 = linalg.matmul ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>) outs(%Out : tensor<?x?xf32>) -> tensor<?x?xf32> | ||
return %0 : tensor<?x?xf32> | ||
} | ||
|
||
// CHECK-LABEL: @matmul | ||
// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32> | ||
// CHECK-NOT: linalg.generic | ||
// CHECK: linalg.matmul ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32> | ||
|
||
// ----- | ||
|
||
func.func @mixed_named_ops(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, | ||
%C: tensor<?x?xf32>, %Out: tensor<?x?xf32>) -> tensor<?x?xf32> { | ||
%AB = linalg.matmul ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>) outs(%Out : tensor<?x?xf32>) -> tensor<?x?xf32> | ||
%1 = linalg.add ins(%AB, %C : tensor<?x?xf32>, tensor<?x?xf32>) outs(%Out : tensor<?x?xf32>) -> tensor<?x?xf32> | ||
return %1 : tensor<?x?xf32> | ||
} | ||
|
||
// CHECK-LABEL: @mixed_named_ops | ||
// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[C:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32> | ||
// CHECK-NOT: linalg.generic | ||
// CHECK: %[[AB:.+]] = linalg.matmul ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32> | ||
// CHECK: linalg.add ins(%[[AB]], %[[C]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32> |
Oops, something went wrong.