diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index 8aa3b4af..9001be49 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -46,6 +46,94 @@ template Attribute makeAttr(mlir::Type elemType, T val) { namespace { +class ReshapeDimMapping { +public: + void addMapping(int64_t left, int64_t right) { + mapping.push_back(std::make_pair(left, right)); + } + + SmallVector getMappingFromResultDim(int64_t dim) const { + SmallVector result; + for (auto &[left, right] : mapping) { + if (left == dim) + result.push_back(right); + } + return result; + } + + SmallVector getMappingFromOperandDim(int64_t dim) const { + SmallVector result; + for (auto &[left, right] : mapping) { + if (right == dim) + result.push_back(left); + } + return result; + } + + bool isOnlySplitting() const { + llvm::SmallDenseSet keys; + for (auto &[left, right] : mapping) { + if (!std::get<1>(keys.insert(left))) + return false; + } + return true; + } + + void dump() const { + for (auto &[left, right] : mapping) { + llvm::outs() << left << " -> " << right << "\n"; + } + } + +private: + // Left is result dim, right is operand dim. + SmallVector> mapping; +}; + +// Analyze if a reshape is clearly merging or splitting dimensions. +std::optional +tryFindReshapeDimMapping(stablehlo::ReshapeOp op) { + ReshapeDimMapping mapping; + int64_t lhsPos = 0; + int64_t rhsPos = 0; + auto rhsShape = op.getOperand().getType().cast().getShape(); + auto lhsShape = op.getResult().getType().cast().getShape(); + while (lhsPos < lhsShape.size() && rhsPos < rhsShape.size()) { + if (lhsShape[lhsPos] == rhsShape[rhsPos]) { + // Nice 1-to-1 mapping. + mapping.addMapping(lhsPos, rhsPos); + } else if (lhsShape[lhsPos] < rhsShape[rhsPos]) { + // Potential many-to-one mapping. + int64_t product = lhsShape[lhsPos]; + mapping.addMapping(lhsPos, rhsPos); + while (product < rhsShape[rhsPos]) { + if (++lhsPos >= lhsShape.size()) + break; + product *= lhsShape[lhsPos]; + mapping.addMapping(lhsPos, rhsPos); + } + if (product != rhsShape[rhsPos]) + return std::nullopt; + } else { + // Potential one-to-many mapping. + assert(lhsShape[lhsPos] > rhsShape[rhsPos]); + int64_t product = rhsShape[rhsPos]; + mapping.addMapping(lhsPos, rhsPos); + while (product < lhsShape[lhsPos]) { + if (++rhsPos >= rhsShape.size()) + break; + product *= rhsShape[rhsPos]; + mapping.addMapping(lhsPos, rhsPos); + } + if (product != lhsShape[rhsPos]) + return std::nullopt; + } + ++lhsPos; + ++rhsPos; + }; + return mapping; +} + struct NoopSlice final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -3025,6 +3113,71 @@ struct PadDotGeneral : public OpRewritePattern { } }; +struct ReshapeToSlice : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(stablehlo::SliceOp op, + PatternRewriter &rewriter) const final { + auto reshape = op.getOperand().getDefiningOp(); + if (!reshape) { + return rewriter.notifyMatchFailure(op, "defining op is not a reshape"); + } + std::optional mapping = + tryFindReshapeDimMapping(reshape); + if (!mapping) { + return rewriter.notifyMatchFailure( + reshape, "reshape is not clearly merging or splitting dimensions"); + } + if (!mapping->isOnlySplitting()) { + // TODO: it may still be possible to handle this depending on the slice + // configuration. + return rewriter.notifyMatchFailure(reshape, + "reshape is merging dimensions"); + } + + auto sliceOperandType = op.getOperand().getType().cast(); + SmallVector notSlicedDims; + notSlicedDims.reserve(sliceOperandType.getRank()); + for (auto [start, limit, stride, dim] : + llvm::zip(op.getStartIndices(), op.getLimitIndices(), op.getStrides(), + sliceOperandType.getShape())) { + notSlicedDims.push_back(start == 0 && limit == dim && stride == 1); + } + + auto reshapeOperandType = reshape.getOperand().getType().cast(); + SmallVector starts, limits, strides; + for (auto [i, dim] : llvm::enumerate(reshapeOperandType.getShape())) { + SmallVector resultDims = mapping->getMappingFromOperandDim(i); + if (llvm::hasSingleElement(resultDims)) { + // Keep existing. + starts.push_back(op.getStartIndices()[resultDims[0]]); + limits.push_back(op.getLimitIndices()[resultDims[0]]); + strides.push_back(op.getStrides()[resultDims[0]]); + continue; + } + + if (!llvm::all_of(resultDims, + [&](int64_t dim) { return notSlicedDims[dim]; })) { + return rewriter.notifyMatchFailure(reshape, + "split dimension is also sliced"); + } + + // It's a full slice of the original dimension. + starts.push_back(0); + limits.push_back(reshapeOperandType.getDimSize(i)); + strides.push_back(1); + } + + auto newSlice = rewriter.create( + op->getLoc(), reshape.getOperand(), starts, limits, strides); + auto newReshape = rewriter.create( + reshape->getLoc(), op.getResult().getType(), newSlice.getResult()); + rewriter.replaceOp(op, newReshape); + + return success(); + } +}; + struct EnzymeHLOOptPass : public EnzymeHLOOptPassBase { void runOnOperation() override { @@ -3052,7 +3205,8 @@ struct EnzymeHLOOptPass : public EnzymeHLOOptPassBase { BinBroadcastSplat, BinBroadcastSplat, BinBroadcastSplat, TransposeTranspose, - TransposeConvert, BroadcastReduce, PadDotGeneral>(context); + TransposeConvert, BroadcastReduce, PadDotGeneral, ReshapeToSlice>( + context); patterns.add(max_constant_expansion, context); if (all_finite) diff --git a/test/lit_tests/reshapeslice.mlir b/test/lit_tests/reshapeslice.mlir new file mode 100644 index 00000000..23223da0 --- /dev/null +++ b/test/lit_tests/reshapeslice.mlir @@ -0,0 +1,14 @@ +// RUN: enzymexlamlir-opt --enzyme-hlo-opt %s | FileCheck %s + +// CHECK-LABEL: @reshape_slice +// CHECK-SAME: %[[ARG0:.+]]: tensor< +// CHECK: %[[S0:.+]] = stablehlo.slice %[[ARG0]] [0:1, 1:3, 0:1024, 0:4] +// CHECK: stablehlo.reshape %[[S0]] : (tensor<1x2x1024x4xbf16>) -> tensor<1x2x1024x1x4xbf16> +// CHECK: %[[S1:.+]] = stablehlo.slice %arg0 [0:1, 2:3, 1024:2048, 0:4] +// CHECK: stablehlo.reshape %[[S1]] : (tensor<1x1x1024x4xbf16>) -> tensor<1x1x1024x1x4xbf16> +func.func @reshape_slice(%7: tensor<1x3x2048x4xbf16>) -> (tensor<1x2x1024x1x4xbf16>, tensor<1x1x1024x1x4xbf16>) { + %8 = stablehlo.reshape %7 : (tensor<1x3x2048x4xbf16>) -> tensor<1x3x2048x1x4xbf16> + %9 = stablehlo.slice %8 [0:1, 1:3, 0:1024, 0:1, 0:4] : (tensor<1x3x2048x1x4xbf16>) -> tensor<1x2x1024x1x4xbf16> + %10 = stablehlo.slice %8 [0:1, 2:3, 1024:2048, 0:1, 0:4] : (tensor<1x3x2048x1x4xbf16>) -> tensor<1x1x1024x1x4xbf16> + return %9, %10 : tensor<1x2x1024x1x4xbf16>, tensor<1x1x1024x1x4xbf16> +}