Skip to content

Commit

Permalink
[MLIR][Linalg] Add pass to convert linalg.generic back to named ops (#…
Browse files Browse the repository at this point in the history
…95656)

Add a new mlir-opt  pass `--linalg-specialize-generic-ops` which lifts generic,
where possible, to linalg named ops.
Much like `-linalg-generalize-named-ops` lowers named ops to linalg.generic .
Also add patterns to recognize contractions which can be specialized from 
linalg.generic to named op: `linalg.{batch_}?matmul{_transpose_(a|b)}?`
  • Loading branch information
javedabsar1 authored Jun 30, 2024
1 parent 69d3793 commit 3efac5c
Show file tree
Hide file tree
Showing 9 changed files with 585 additions and 2 deletions.
5 changes: 5 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,11 @@ def LinalgGeneralizeNamedOpsPass : Pass<"linalg-generalize-named-ops"> {
let dependentDialects = ["linalg::LinalgDialect"];
}

def LinalgSpecializeGenericOpsPass : Pass<"linalg-specialize-generic-ops"> {
let summary = "Convert generic ops back to named ops";
let dependentDialects = ["linalg::LinalgDialect"];
}

def LinalgDetensorizePass : InterfacePass<"linalg-detensorize", "FunctionOpInterface"> {
let summary = "Detensorize linalg ops";
let dependentDialects = [];
Expand Down
23 changes: 23 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -1416,6 +1416,20 @@ struct LinalgGeneralizationPattern
}
};

struct LinalgSpecializationPattern : public OpRewritePattern<GenericOp> {
using OpRewritePattern<GenericOp>::OpRewritePattern;

FailureOr<GenericOp>
returningMatchAndRewrite(GenericOp op, PatternRewriter &rewriter) const {
return specializeGenericOp(rewriter, op);
}

LogicalResult matchAndRewrite(GenericOp op,
PatternRewriter &rewriter) const override {
return returningMatchAndRewrite(op, rewriter);
}
};

/// Vectorization pattern for memref::CopyOp.
struct CopyVectorizationPattern : public OpRewritePattern<memref::CopyOp> {
using OpRewritePattern<memref::CopyOp>::OpRewritePattern;
Expand Down Expand Up @@ -1567,6 +1581,15 @@ void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns);
/// linalg.generic ops.
void populateLinalgNamedOpsGeneralizationPatterns(RewritePatternSet &patterns);

/// Populates `patterns` with patterns to convert linalg.generic ops to named
/// ops where possible. A linalg.generic can represent wide range and complex
/// computations for which equivalent linalg named op may not exist e.g.
/// linalg.generic that takes a tensor and computes a polynomial such as:
/// p(x) = an*x^n + ... + a1x + a0
/// There is no equivalent named op to convert to. Many such cases exist.
void populateLinalgGenericOpsSpecializationPatterns(
RewritePatternSet &patterns);

/// Linalg decompose convolutions patterns

/// Populates patterns to decompose high-D convolution ops into low-D ones.
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,9 @@ std::optional<Value> linalg::isaFillOpInterface(GenericOp genericOp) {
static bool
isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp genericOp,
unsigned arity) {
// Check all loops are parallel, and have only tensor semantics.
// Check all loops are parallel.
if (genericOp.getNumParallelLoops() != genericOp.getNumLoops() ||
genericOp.getNumLoops() < 1 || !genericOp.hasPureTensorSemantics())
genericOp.getNumLoops() < 1)
return false;

// Check there are arity-inputs, 1-output and all are identity-maps.
Expand Down
229 changes: 229 additions & 0 deletions mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,22 @@
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/TypeID.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/Debug.h"

namespace mlir {
#define GEN_PASS_DEF_LINALGSPECIALIZEGENERICOPSPASS
#include "mlir/Dialect/Linalg/Passes.h.inc"
} // namespace mlir

#define DEBUG_TYPE "linalg-specialization"

#define REPLACE_BINARY_OP(NEWOP, OPERANDS_SWAP) \
Expand Down Expand Up @@ -58,6 +68,197 @@ static bool areBinOpsSwapped(GenericOp genericOp) {
return swapped;
}

//===----------------------------------------------------------------------===//
// Specialize linalg generic to matmul variants.
//===----------------------------------------------------------------------===//
/// Identifies linalg.generic that is essentially named op of the form:
// ` linalg.{batch_}?matmul{_transpose_a | _transpose_b}? `
//
// It is possible that a linalg.generic may be implementing a matmul but not
// in a straight-forward way e.g. below is matrix multiply over some slice
// ```
// %0 = linalg.generic {
// indexing_maps = [affine_map<(d0, d1, d2) -> (3, d1, d0)>,
// affine_map<(d0, d1, d2) -> (d0, 5, d2)>,
// affine_map<(d0, d1, d2) -> (d2, d1, 13)>],
// iterator_types = ["parallel", "parallel", "parallel"]}
// ins(%A, %B : tensor<20x20x20xf32>, tensor<20x20x20xf32>)
// outs(%C : tensor<20x20x20xf32>) {
// ^bb0(%a: f32, %b: f32, %c : f32):
// %mul = arith.mulf %a, %b : f32
// %add = arith.addf %mul, %c : f32
// linalg.yield %add : f32
// } -> tensor<20x20x20xf32>
// ```
// It is not possible to represent above as named op.
// e.g. linalg.batch_matmul(%A, %B : tensor<20x20x20xf32>, ...) is
// not the same as linalg.generic above.
namespace {
enum class IndexMatchResult {
Match = 0, // identity map.
Transposed, // transposed map.
Mismatch // none of the above.
};

// Checks whether the input Affine `map` contains two consecutive dims that
// can be interpreted as accessing a 2D matrix. It is assumed that the row
// column dimension are adjacent axis (in this order) and start at
// `rowDimIdx` in the input map.
//
// e.g. consider A matrix in `C[M,N] = A[M,K] * B[K,N]`. We will check
// whether the map of A is identity (match), transposed, or something
// completely different (mis-match). Similar for B and C.
static IndexMatchResult matchOperandMap(AffineMap map, unsigned rowDimIdx,
unsigned expectedPosOfRowDim,
unsigned expectedPosOfColDim) {
// Get the matrix multiply indices. They are past the batch indices.
auto exprOfRowDim = map.getResults()[rowDimIdx];
auto exprOfColDim = map.getResults()[rowDimIdx + 1];

// They should be pure dimension ids.
if (exprOfRowDim.getKind() != AffineExprKind::DimId ||
exprOfColDim.getKind() != AffineExprKind::DimId)
return IndexMatchResult::Mismatch;

auto posRowDim = cast<AffineDimExpr>(exprOfRowDim).getPosition();
auto posColDim = cast<AffineDimExpr>(exprOfColDim).getPosition();

if (expectedPosOfRowDim == posRowDim && expectedPosOfColDim == posColDim)
return IndexMatchResult::Match;

if (expectedPosOfRowDim == posColDim && expectedPosOfColDim == posRowDim)
return IndexMatchResult::Transposed;

return IndexMatchResult::Mismatch;
}

// Replaces genericOp with `NamedOpTy` op, supplied as a template arg.
// All the variants expressed as pseudo regular expression:
// `linalg.{batch_}?matmul{_transpose_a | _transpose_b}?`
// have same number of ins/out, so its easy to stamp different versions.
template <typename NamedOpTy>
static LinalgOp replaceWithMatmulVariant(RewriterBase &rewriter, GenericOp op) {
LinalgOp namedOp = rewriter.replaceOpWithNewOp<NamedOpTy>(
op, ValueRange{op.getDpsInputs()[0], op.getDpsInputs()[1]},
ValueRange{op.getDpsInits()[0]});
return namedOp;
}

// Converts linalg.generic to named linalg.*matmul* where possible.
static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
GenericOp genericOp) {
if (genericOp.getNumDpsInputs() != 2 || genericOp.getNumDpsInits() != 1)
return failure();

// Early exit if not projected permutations.
auto mapRange = genericOp.getIndexingMapsArray();
if (llvm::any_of(mapRange,
[](AffineMap m) { return !m.isProjectedPermutation(); }))
return failure();

// Linalg generic contraction can be across multiple axis e.g.
// ```
// linalg.generic
// {indexing_maps = [affine_map<(m, n, k1, k2) -> (m, k1, k2)>,
// affine_map<(m, n, k1, k2) -> (k2, k1, n)>,
// affine_map<(m, n, k1, k2) -> (m, n)>],
// iterator_types = ["parallel", "parallel",
// "reduction", "reduction"]}
// ins(%A, %B : tensor<10x20x30xf32>, tensor<30x20x40xf32>)
// outs(%C : tensor<10x40xf32>) {
// ^bb0(%a: f32, %b: f32, %c: f32):
// %1 = arith.mulf %a, %b : f32
// %2 = arith.addf %c, %1 : f32
// linalg.yield %2 : f32
// } -> tensor<10x40xf32>
// ```
// In above contraction, there are two reduction dimensions {k1, k2}
// and although a valid linalg contraction, it is not a named-op
// matrix multiply kind. Therefore, reject multi-dim reduction.
auto res = inferContractionDims(genericOp);
if (!succeeded(res))
return failure();
auto dims = *res;
if (dims.m.size() != 1 || dims.n.size() != 1 || dims.k.size() != 1)
return failure();

if (!mlir::linalg::detail::isContractionBody(
*genericOp.getBlock(), [](Operation *first, Operation *second) {
if ((isa<arith::MulFOp>(first) && isa<arith::AddFOp>(second)) ||
(isa<arith::MulIOp>(first) && isa<arith::AddIOp>(second)) ||
(isa<complex::MulOp>(first) && isa<complex::AddOp>(second)))
return true;
return false;
}))
return failure();

// Check rank of operands
auto indexingMaps = genericOp.getIndexingMapsArray();
if (llvm::any_of(indexingMaps, [&dims](AffineMap m) {
return m.getResults().size() !=
dims.batch.size() + 2 /* any two of {m,n,k} */;
}))
return failure();

auto numOfBatchDims = dims.batch.size();
if (indexingMaps[0].getNumDims() != numOfBatchDims + 3)
return failure();

if (numOfBatchDims) {
// Each operand in a linalg generic contraction could express different
// permutations for its batch dimension. But for named op it must be
// identity since separate maps are not specified.
if (llvm::any_of(indexingMaps, [numOfBatchDims](AffineMap m) {
for (unsigned i = 0; i < numOfBatchDims; ++i) {
auto expr = m.getResults()[i];
if (expr.getKind() != AffineExprKind::DimId ||
cast<AffineDimExpr>(expr).getPosition() != i)
return true;
}
return false;
}))
return failure();
}

auto a =
matchOperandMap(indexingMaps[0], numOfBatchDims, dims.m[0], dims.k[0]);
auto b =
matchOperandMap(indexingMaps[1], numOfBatchDims, dims.k[0], dims.n[0]);
auto c =
matchOperandMap(indexingMaps[2], numOfBatchDims, dims.m[0], dims.n[0]);

if (llvm::any_of(ArrayRef<IndexMatchResult>{a, b, c}, [](IndexMatchResult r) {
return r == IndexMatchResult::Mismatch;
}))
return failure();

if (c != IndexMatchResult::Match ||
(a == IndexMatchResult::Transposed && b == IndexMatchResult::Transposed))
return failure();

/// Codegen the different matmul variants.
if (numOfBatchDims) {
if (a == IndexMatchResult::Transposed)
return replaceWithMatmulVariant<BatchMatmulTransposeAOp>(rewriter,
genericOp);
if (b == IndexMatchResult::Transposed)
return replaceWithMatmulVariant<BatchMatmulTransposeBOp>(rewriter,
genericOp);
return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp);
}

if (a == IndexMatchResult::Transposed)
return replaceWithMatmulVariant<MatmulTransposeAOp>(rewriter, genericOp);
if (b == IndexMatchResult::Transposed)
return replaceWithMatmulVariant<MatmulTransposeBOp>(rewriter, genericOp);
return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
}

} // namespace

//===----------------------------------------------------------------------===//
// Categorize linalg generic to named op where possible.
//===----------------------------------------------------------------------===//
FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
GenericOp genericOp) {
if (isaCopyOpInterface(genericOp)) {
Expand Down Expand Up @@ -100,5 +301,33 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
return namedOp;
}
}

if (isaContractionOpInterface(genericOp)) {
return specializeLinalgContractions(rewriter, genericOp);
}
return failure();
}

namespace {
struct LinalgSpecializeGenericOpsPass
: public impl::LinalgSpecializeGenericOpsPassBase<
LinalgSpecializeGenericOpsPass> {

using impl::LinalgSpecializeGenericOpsPassBase<
LinalgSpecializeGenericOpsPass>::LinalgSpecializeGenericOpsPassBase;
void runOnOperation() override;
};
} // namespace

void LinalgSpecializeGenericOpsPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
populateLinalgGenericOpsSpecializationPatterns(patterns);

if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
signalPassFailure();
}

void mlir::linalg::populateLinalgGenericOpsSpecializationPatterns(
RewritePatternSet &patterns) {
patterns.add<LinalgSpecializationPattern>(patterns.getContext());
}
52 changes: 52 additions & 0 deletions mlir/test/Dialect/Linalg/roundtrip-linalg-named-ops.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// The following test examples of linalg named ops lowered to linalg.generic and then
// lifted back up to named op.
// RUN: mlir-opt %s -linalg-generalize-named-ops | mlir-opt --linalg-specialize-generic-ops | FileCheck %s

func.func @unary_exp(%A: memref<7x14x21xf32>, %Out: memref<7x14x21xf32>) {
linalg.exp ins(%A : memref<7x14x21xf32>) outs(%Out : memref<7x14x21xf32>)
return
}

// CHECK-LABEL: unary_exp
// CHECK-SAME: %[[A:.+]]: memref<7x14x21xf32>, %[[Out:.+]]: memref<7x14x21xf32>)
// CHECK-NOT: linalg.generic
// CHECK: linalg.exp ins(%[[A]] : memref<7x14x21xf32>) outs(%[[Out]] : memref<7x14x21xf32>)

// -----

func.func @binary_add(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %Out: tensor<?x?xf32>) -> tensor<?x?xf32> {
%0 = linalg.add ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>) outs(%Out : tensor<?x?xf32>) -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}

// CHECK-LABEL: binary_add
// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK-NOT: linalg.generic
// CHECK: linalg.add ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>

// -----

func.func @matmul(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %Out: tensor<?x?xf32>) -> tensor<?x?xf32> {
%0 = linalg.matmul ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>) outs(%Out : tensor<?x?xf32>) -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}

// CHECK-LABEL: @matmul
// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK-NOT: linalg.generic
// CHECK: linalg.matmul ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>

// -----

func.func @mixed_named_ops(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
%C: tensor<?x?xf32>, %Out: tensor<?x?xf32>) -> tensor<?x?xf32> {
%AB = linalg.matmul ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>) outs(%Out : tensor<?x?xf32>) -> tensor<?x?xf32>
%1 = linalg.add ins(%AB, %C : tensor<?x?xf32>, tensor<?x?xf32>) outs(%Out : tensor<?x?xf32>) -> tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}

// CHECK-LABEL: @mixed_named_ops
// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[C:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK-NOT: linalg.generic
// CHECK: %[[AB:.+]] = linalg.matmul ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: linalg.add ins(%[[AB]], %[[C]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
Loading

0 comments on commit 3efac5c

Please sign in to comment.