From eb57fe4ade733bc25b2d3b8948af1cedfb1ee31f Mon Sep 17 00:00:00 2001 From: Paul Berg Date: Sat, 27 Jul 2024 19:54:10 +0200 Subject: [PATCH] sink transposes in einsum (#105) --- src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp | 92 ++++++++++++++++++- .../jax/TransformOps/TransformOps.td | 8 ++ test/lit_tests/einsumtranspose.mlir | 16 ++++ 3 files changed, 114 insertions(+), 2 deletions(-) create mode 100644 test/lit_tests/einsumtranspose.mlir diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index 5eb3e141..560e0a1b 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -3604,6 +3604,93 @@ struct TransposeDotReorder } }; +// transpose(einsum) -> einsum +struct TransposeEinsum : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mlir::stablehlo::TransposeOp transpose, + PatternRewriter &rewriter) const final { + auto operand = transpose.getOperand(); + auto einsum = operand.getDefiningOp(); + if (!einsum || !llvm::hasSingleElement(operand.getUsers())) + return failure(); + + auto einsumConfig = einsum.getEinsumConfig(); + auto arrowPos = einsumConfig.find("->"); + + if (arrowPos == StringRef::npos) + return failure(); + + auto permutation = transpose.getPermutation(); + + if (einsumConfig.size() - (arrowPos + 2) < permutation.size()) + return failure(); + + auto newEinsumConfig = std::string(einsumConfig.str()); + for (int i = 0; i < permutation.size(); ++i) { + newEinsumConfig[arrowPos + 2 + i] = + einsumConfig[arrowPos + 2 + permutation[i]]; + } + + rewriter.modifyOpInPlace(einsum, [&einsum, &transpose, newEinsumConfig] { + einsum.setEinsumConfig( + StringAttr::get(einsum.getContext(), newEinsumConfig)); + einsum.getResult().setType(transpose.getType()); + }); + rewriter.replaceAllUsesWith(transpose.getResult(), einsum.getResult()); + + return success(); + } +}; + +// einsum(transpose(x), transpose(y)) -> einsum(x, y) +struct EinsumTranspose : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mlir::stablehlo::EinsumOp einsum, + PatternRewriter &rewriter) const final { + llvm::StringRef einsumConfig = einsum.getEinsumConfig(); + + auto lhs_trans = + einsum.getLhs().getDefiningOp(); + auto rhs_trans = + einsum.getRhs().getDefiningOp(); + if (!lhs_trans && !rhs_trans) + return failure(); + + size_t commaPos = einsumConfig.find(","); + size_t arrowPos = einsumConfig.find("->"); + if (commaPos != einsum.getLhs().getType().getRank() || + einsumConfig.size() - commaPos < einsum.getRhs().getType().getRank() || + (arrowPos != StringRef::npos && + arrowPos - commaPos < einsum.getRhs().getType().getRank())) + return failure(); + + auto newEinsumConfig = std::string(einsumConfig.str()); + + if (lhs_trans) { + for (int i = 0; i < commaPos; ++i) { + newEinsumConfig[i] = einsumConfig[lhs_trans.getPermutation()[i]]; + } + } + + if (rhs_trans) { + int64_t rhsRank = einsum.getRhs().getType().getRank(); + for (int i = 0; i < rhsRank; ++i) { + newEinsumConfig[commaPos + 1 + i] = + einsumConfig[commaPos + 1 + rhs_trans.getPermutation()[i]]; + } + } + + rewriter.replaceOpWithNewOp( + einsum, einsum.getType(), + lhs_trans ? lhs_trans.getOperand() : einsum.getLhs(), + rhs_trans ? rhs_trans.getOperand() : einsum.getRhs(), + StringAttr::get(einsum.getContext(), newEinsumConfig)); + return success(); + } +}; + struct DotTranspose : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -6218,8 +6305,9 @@ struct EnzymeHLOOptPass : public EnzymeHLOOptPassBase { BinopPadToConcat, ConcatPad>(context); if (passses & 512) - patterns.add(context); + patterns.add(context); if (passses & 1024) patterns.add(context); diff --git a/src/enzyme_ad/jax/TransformOps/TransformOps.td b/src/enzyme_ad/jax/TransformOps/TransformOps.td index acc7f87e..c8d64236 100644 --- a/src/enzyme_ad/jax/TransformOps/TransformOps.td +++ b/src/enzyme_ad/jax/TransformOps/TransformOps.td @@ -211,6 +211,14 @@ def ApplyTransposeDotReorderPatterns : EnzymeHLOPatternOp< "transpose_dot_reorder"> { let patterns = ["TransposeDotReorder"]; } +def ApplyTransposeEinsumPatterns : EnzymeHLOPatternOp< + "transpose_einsum"> { + let patterns = ["TransposeEinsum"]; +} +def ApplyEinsumTransposePatterns : EnzymeHLOPatternOp< + "einsum_transpose"> { + let patterns = ["EinsumTranspose"]; +} def ApplyDotTransposePatterns : EnzymeHLOPatternOp< "dot_transpose"> { let patterns = ["DotTranspose"]; diff --git a/test/lit_tests/einsumtranspose.mlir b/test/lit_tests/einsumtranspose.mlir new file mode 100644 index 00000000..8020fe48 --- /dev/null +++ b/test/lit_tests/einsumtranspose.mlir @@ -0,0 +1,16 @@ +// RUN: enzymexlamlir-opt --pass-pipeline="builtin.module(enzyme-hlo-opt{max_constant_expansion=1})" %s | FileCheck %s + +module { + func.func @einsum(%arg0: tensor<2x3xf32>, %arg1: tensor<4x3x5xf32>) -> tensor<5x4x2xf32> { + %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<2x3xf32>) -> tensor<3x2xf32> + %1 = stablehlo.transpose %arg1, dims = [2, 0, 1] : (tensor<4x3x5xf32>) -> tensor<5x4x3xf32> + %2 = stablehlo.einsum %0, %1, config = "ba,dbc->cad" : (tensor<3x2xf32>, tensor<5x4x3xf32>) -> tensor<4x2x5xf32> + %3 = stablehlo.transpose %2, dims = [2, 0, 1] : (tensor<4x2x5xf32>) -> tensor<5x4x2xf32> + func.return %3 : tensor<5x4x2xf32> + } +} + +// CHECK: func.func @einsum(%arg0: tensor<2x3xf32>, %arg1: tensor<4x3x5xf32>) -> tensor<5x4x2xf32> { +// CHECK-NEXT: %0 = stablehlo.einsum %arg0, %arg1, config = "ab,cdb->dca" : (tensor<2x3xf32>, tensor<4x3x5xf32>) -> tensor<5x4x2xf32> +// CHECK-NEXT: return %0 : tensor<5x4x2xf32> +// CHECK-NEXT: }