Skip to content

Commit

Permalink
sink transposes in einsum (#105)
Browse files Browse the repository at this point in the history
  • Loading branch information
Pangoraw authored Jul 27, 2024
1 parent f91f838 commit eb57fe4
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 2 deletions.
92 changes: 90 additions & 2 deletions src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3604,6 +3604,93 @@ struct TransposeDotReorder
}
};

// transpose(einsum) -> einsum
struct TransposeEinsum : public OpRewritePattern<mlir::stablehlo::TransposeOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(mlir::stablehlo::TransposeOp transpose,
PatternRewriter &rewriter) const final {
auto operand = transpose.getOperand();
auto einsum = operand.getDefiningOp<mlir::stablehlo::EinsumOp>();
if (!einsum || !llvm::hasSingleElement(operand.getUsers()))
return failure();

auto einsumConfig = einsum.getEinsumConfig();
auto arrowPos = einsumConfig.find("->");

if (arrowPos == StringRef::npos)
return failure();

auto permutation = transpose.getPermutation();

if (einsumConfig.size() - (arrowPos + 2) < permutation.size())
return failure();

auto newEinsumConfig = std::string(einsumConfig.str());
for (int i = 0; i < permutation.size(); ++i) {
newEinsumConfig[arrowPos + 2 + i] =
einsumConfig[arrowPos + 2 + permutation[i]];
}

rewriter.modifyOpInPlace(einsum, [&einsum, &transpose, newEinsumConfig] {
einsum.setEinsumConfig(
StringAttr::get(einsum.getContext(), newEinsumConfig));
einsum.getResult().setType(transpose.getType());
});
rewriter.replaceAllUsesWith(transpose.getResult(), einsum.getResult());

return success();
}
};

// einsum(transpose(x), transpose(y)) -> einsum(x, y)
struct EinsumTranspose : public OpRewritePattern<mlir::stablehlo::EinsumOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(mlir::stablehlo::EinsumOp einsum,
PatternRewriter &rewriter) const final {
llvm::StringRef einsumConfig = einsum.getEinsumConfig();

auto lhs_trans =
einsum.getLhs().getDefiningOp<mlir::stablehlo::TransposeOp>();
auto rhs_trans =
einsum.getRhs().getDefiningOp<mlir::stablehlo::TransposeOp>();
if (!lhs_trans && !rhs_trans)
return failure();

size_t commaPos = einsumConfig.find(",");
size_t arrowPos = einsumConfig.find("->");
if (commaPos != einsum.getLhs().getType().getRank() ||
einsumConfig.size() - commaPos < einsum.getRhs().getType().getRank() ||
(arrowPos != StringRef::npos &&
arrowPos - commaPos < einsum.getRhs().getType().getRank()))
return failure();

auto newEinsumConfig = std::string(einsumConfig.str());

if (lhs_trans) {
for (int i = 0; i < commaPos; ++i) {
newEinsumConfig[i] = einsumConfig[lhs_trans.getPermutation()[i]];
}
}

if (rhs_trans) {
int64_t rhsRank = einsum.getRhs().getType().getRank();
for (int i = 0; i < rhsRank; ++i) {
newEinsumConfig[commaPos + 1 + i] =
einsumConfig[commaPos + 1 + rhs_trans.getPermutation()[i]];
}
}

rewriter.replaceOpWithNewOp<mlir::stablehlo::EinsumOp>(
einsum, einsum.getType(),
lhs_trans ? lhs_trans.getOperand() : einsum.getLhs(),
rhs_trans ? rhs_trans.getOperand() : einsum.getRhs(),
StringAttr::get(einsum.getContext(), newEinsumConfig));
return success();
}
};

struct DotTranspose : public OpRewritePattern<mlir::stablehlo::DotGeneralOp> {
using OpRewritePattern::OpRewritePattern;

Expand Down Expand Up @@ -6218,8 +6305,9 @@ struct EnzymeHLOOptPass : public EnzymeHLOOptPassBase<EnzymeHLOOptPass> {
BinopPadToConcat<stablehlo::MulOp>, ConcatPad>(context);

if (passses & 512)
patterns.add<TransposeDotReorder, DotTranspose, ConvertConvertFloat,
ConcatToPad, ConcatAppendingReshape, ReshapeIota>(context);
patterns.add<TransposeDotReorder, DotTranspose, EinsumTranspose,
TransposeEinsum, ConvertConvertFloat, ConcatToPad,
ConcatAppendingReshape, ReshapeIota>(context);

if (passses & 1024)
patterns.add<FullReduceReshapeOrTranspose>(context);
Expand Down
8 changes: 8 additions & 0 deletions src/enzyme_ad/jax/TransformOps/TransformOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,14 @@ def ApplyTransposeDotReorderPatterns : EnzymeHLOPatternOp<
"transpose_dot_reorder"> {
let patterns = ["TransposeDotReorder"];
}
def ApplyTransposeEinsumPatterns : EnzymeHLOPatternOp<
"transpose_einsum"> {
let patterns = ["TransposeEinsum"];
}
def ApplyEinsumTransposePatterns : EnzymeHLOPatternOp<
"einsum_transpose"> {
let patterns = ["EinsumTranspose"];
}
def ApplyDotTransposePatterns : EnzymeHLOPatternOp<
"dot_transpose"> {
let patterns = ["DotTranspose"];
Expand Down
16 changes: 16 additions & 0 deletions test/lit_tests/einsumtranspose.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// RUN: enzymexlamlir-opt --pass-pipeline="builtin.module(enzyme-hlo-opt{max_constant_expansion=1})" %s | FileCheck %s

module {
func.func @einsum(%arg0: tensor<2x3xf32>, %arg1: tensor<4x3x5xf32>) -> tensor<5x4x2xf32> {
%0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<2x3xf32>) -> tensor<3x2xf32>
%1 = stablehlo.transpose %arg1, dims = [2, 0, 1] : (tensor<4x3x5xf32>) -> tensor<5x4x3xf32>
%2 = stablehlo.einsum %0, %1, config = "ba,dbc->cad" : (tensor<3x2xf32>, tensor<5x4x3xf32>) -> tensor<4x2x5xf32>
%3 = stablehlo.transpose %2, dims = [2, 0, 1] : (tensor<4x2x5xf32>) -> tensor<5x4x2xf32>
func.return %3 : tensor<5x4x2xf32>
}
}

// CHECK: func.func @einsum(%arg0: tensor<2x3xf32>, %arg1: tensor<4x3x5xf32>) -> tensor<5x4x2xf32> {
// CHECK-NEXT: %0 = stablehlo.einsum %arg0, %arg1, config = "ab,cdb->dca" : (tensor<2x3xf32>, tensor<4x3x5xf32>) -> tensor<5x4x2xf32>
// CHECK-NEXT: return %0 : tensor<5x4x2xf32>
// CHECK-NEXT: }

0 comments on commit eb57fe4

Please sign in to comment.