Skip to content

Commit

Permalink
run through inplace KV cache compilation pass pipeline in torch_blade (
Browse files Browse the repository at this point in the history
…alibaba#1225)

run throigh kv cache compilation with re-inplace pass and dynamic_update_slice codegen
  • Loading branch information
Yancey1989 committed Jul 26, 2023
1 parent fcae1c4 commit 738fa33
Show file tree
Hide file tree
Showing 13 changed files with 267 additions and 62 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -173,15 +173,11 @@ const std::unordered_set<std::string> &GetTorchMlirWhiteList() {

std::call_once(white, [&]() {
auto list = StrSplit(env::ReadStringFromEnvVar("TORCH_MHLO_OP_WHITE_LIST", ""), ';');
for (auto s : list) {
white_list.insert(s);
}
for (auto s : list) white_list.insert(s);
});
std::call_once(black, [&]() {
auto list = StrSplit(env::ReadStringFromEnvVar("TORCH_MHLO_OP_BLACK_LIST", ""), ';');
for (auto s : list) {
white_list.erase(s);
}
for (auto s : list) white_list.erase(s);
});

std::ostringstream ostr;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,7 @@ at::List<at::Tensor> RalContext::PreProcessInputs(
at::List<at::Tensor> contiguous_inputs;
for (at::Tensor inp_tensor : inputs) {
// make sure the input is in contiguous layout
auto contiguous_tensor = inp_tensor.contiguous();
contiguous_inputs.push_back(contiguous_tensor);
contiguous_inputs.push_back(inp_tensor.contiguous());
}
return contiguous_inputs;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1466,6 +1466,29 @@ LogicalResult ConvertAtenOp<AtenConstantPadNdOp>::matchAndRewrite(
} // namespace

namespace {

Value getNormalizedDimSizeInternal(
PatternRewriter& rewriter,
Operation* op,
Value index,
Value dimSize) {
auto loc = op->getLoc();
Value zero = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(rewriter.getI32Type(), 0));

// To normalize index into range [-dimSize, dimSize]
// index = min(max(-dimSize, index), dimSize)
auto negDimSize = rewriter.create<arith::SubIOp>(loc, zero, dimSize);
index = rewriter.create<arith::MaxSIOp>(loc, negDimSize, index);
index = rewriter.create<arith::MinSIOp>(loc, dimSize, index);

auto dimSizePlusIndex = rewriter.create<arith::AddIOp>(loc, dimSize, index);
auto indexPositive = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sge, index, zero);
// get positive index: (index >=0) ? index: index + dimSize
return rewriter.create<arith::SelectOp>(
loc, indexPositive, index, dimSizePlusIndex);
}
template <>
LogicalResult ConvertAtenOp<AtenSliceScatterOp>::matchAndRewrite(
AtenSliceScatterOp op,
Expand All @@ -1479,23 +1502,27 @@ LogicalResult ConvertAtenOp<AtenSliceScatterOp>::matchAndRewrite(
rewriter, op, adaptor.getSelf(), kMhloDimSizeBits);
auto inputShape = selfTy.getShape();

int64_t start, end, dim;
if (!matchPattern(op.getStart(), m_TorchConstantInt(&start)))
return rewriter.notifyMatchFailure(
op, "only constant start is currently supported");
if (!matchPattern(op.getEnd(), m_TorchConstantInt(&end)))
return rewriter.notifyMatchFailure(
op, "only constant end is currently supported");
int64_t dim;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
return rewriter.notifyMatchFailure(
op, "only constant dim is currently supported");

Value dimSize = rewriter.create<arith::IndexCastOp>(
op.getLoc(),
rewriter.getI32Type(),
rewriter.create<tensor::DimOp>(op.getLoc(), self, dim));
auto startIndexValue = rewriter.create<arith::TruncIOp>(
op.getLoc(), rewriter.getI32Type(), adaptor.getStart());

SmallVector<Value> start_indices(selfTy.getRank());
for (auto i = 0; i < inputShape.size(); ++i) {
if (i == dim) {
start = toPositiveDim(start, inputShape[i]);
start_indices[i] = rewriter.create<mhlo::ConstantOp>(
auto normalizedStartIndex =
getNormalizedDimSizeInternal(rewriter, op, startIndexValue, dimSize);
start_indices[i] = rewriter.create<tensor::FromElementsOp>(
op.getLoc(),
rewriter.getIntegerAttr(rewriter.getIntegerType(32), start));
RankedTensorType::get({}, rewriter.getI32Type()),
normalizedStartIndex);
} else {
start_indices[i] = rewriter.create<mhlo::ConstantOp>(
op.getLoc(), rewriter.getIntegerAttr(rewriter.getIntegerType(32), 0));
Expand Down Expand Up @@ -1531,6 +1558,9 @@ LogicalResult ConvertAtenOp<OverwriteTensorContentsOp>::matchAndRewrite(
auto operands = op.getOperands();
auto value = operands[0];
auto overwriten = operands[1];
// torch-mlir lowering aten::copy to broadcast_to operator, in this case,
// we can skip this operator and use the input directly.
value = backtraceOperand<AtenBroadcastToOp>(value);
overwriten = backtraceOperand<CopyToNonValueTensorOp>(overwriten);
overwriten = rewriter.create<ToBuiltinTensorOp>(loc, overwriten);
value = rewriter.create<ToBuiltinTensorOp>(loc, value);
Expand Down
39 changes: 23 additions & 16 deletions pytorch_blade/tests/disc/ops/test_input_mutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,35 +11,42 @@

import os
import torch
import torch.nn as nn
from typing import List, Optional, Tuple
from torch import Tensor
import torch_blade
import unittest
from tests.disc.testing_base import skipTorchLE
import torch_blade.clustering.support_fusion_group as fusion
from tests.disc.testing_base import DiscTestCase
from tests.disc.testing_base import skipTorchLE, DiscTestCase

class KVCacheModule(nn.Module):
def forward(self, k_cache: Tensor, k: Tensor, step : Tensor):
k_cache[..., step - k.shape[-2]: step , :].add_(k)
return k_cache

class TestInputMutation(DiscTestCase):
def setUp(self):
super().setUp()
os.environ["TORCH_MHLO_OP_WHITE_LIST"] = "aten::copy_;aten::add;aten::slice_scatter;"
os.environ["TORCH_BLADE_EXPERIMENTAL_MERGE_HORIZONTAL_GROUPS"] = "true"

def tearDown(self):
del os.environ["TORCH_MHLO_OP_WHITE_LIST"]

@skipTorchLE("2.0.0")
def test_inplace_kv_cache(self):
def func(k_cache: Tensor, k: Tensor) -> Tensor:
k_cache[...,k.shape[-2] :, :].add_(k)
return k_cache

with fusion.min_group_nodes(1):
opt_func = torch.compile(backend='aot_disc')(func)
add = torch.zeros(2, 8, 32, 2, device=self.device)
value = torch.ones(2, 8, 1, 2, device=self.device)
actual = opt_func(add.clone(), value.clone())
expect = func(add.clone(), value.clone())
self.assertTrue(torch.allclose(actual.cpu(), expect.cpu()))
del os.environ["TORCH_BLADE_EXPERIMENTAL_MERGE_HORIZONTAL_GROUPS"]

@skipTorchLE("1.10.0")
def test_inplace_kv(self):
k_cache = torch.zeros(2, 32, 8, device=self.device)
k = torch.ones(2, 1, 8, device=self.device)

m = KVCacheModule()
m.train(False)
step = torch.tensor(1)
opt_func = torch_blade.optimize(m, allow_tracing=True, model_inputs=(k_cache.clone(), k.clone(), step))
expect = m(k_cache.clone(), k.clone(), step)
actual = opt_func(k_cache.clone(), k.clone(), step)
self.assertTrue(torch.allclose(expect.cpu(), actual.cpu()))


if __name__ == "__main__":
unittest.main()
105 changes: 103 additions & 2 deletions pytorch_blade/torch_blade/pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,6 @@ def _optimize_common(c_module):
logger.error("If do quantization, the model must in eval mode ")
if cfg.enable_int8:
_jit_pass_quantization_preprocess(c_module)

if not is_training:
# optimization passes only work in eval mode
if cfg.freeze_module:
Expand All @@ -382,6 +381,8 @@ def _optimize_common(c_module):
# we can't do the type promotion correctly.
_jit_pass_replace_inplace_name(graph)
torch._C._jit_pass_remove_mutation(graph)
_jit_pass_reinplace(graph)


# TODO: if dynamic rank exists, this pass maybe leads to error
if IGNORE_DYNAMIC_RANK:
Expand Down Expand Up @@ -419,12 +420,112 @@ def _replace_inplace_name(graph):
new_op.addInput(inp)
graph.appendNode(new_op)
new_op.moveBefore(node)
new_op.output().setType(value.type())
new_op.output().setType(node.output().type())
value.replaceAllUsesWith(new_op.output())
node.destroy()

_replace_inplace_name(graph)


def _jit_pass_reinplace(graph):
"""
before:
%slice = slice(%arg0, %start, %end, %step)
%slice.1 = slice(%slice, %start.1, %end.1, %step)
%output = add_(%slice.1, %arg1)
after:
%slice = slice(%arg0, %start, %end, %step)
%slice.1 = slice(%slice, %start.1, %end.1, %step)
%output = add(%slice.1, %arg1)
%slice.2 = slice(%arg0, %start, %end, %step)
%slice_scatter.1 = slice_scatter(%slice.2, %output, %start, %end, %step)
%slice_scatter.2 = slice_scatter(%slice_scatter.1, %output, %start.1, %end.1, %step)
%copy = copy_(%arg0, %slice_scatter.2, %false)
"""
def _collect_all_inplace_nodes(block):
all_nodes = []
for node in block.nodes():
if "inplace" in node.kind():
inp_value = next(node.inputs())
if inp_value.node().kind() == "aten::slice":
all_nodes.append(node)
for node in block.nodes():
for inner_blk in node.blocks():
all_nodes += _collect_all_inplace_nodes(inner_blk)
return all_nodes

inplace_nodes = _collect_all_inplace_nodes(graph)
cst_false = graph.create("prim::Constant")
cst_false.i_("value", 0)
cst_false.output().setType(torch._C.BoolType.get())
graph.appendNode(cst_false)
for node in inplace_nodes:
# create a outplace op
new_op = graph.create(node.kind().rstrip("_inplace_"))
value = node.output()
for inp in node.inputs():
new_op.addInput(inp)
graph.appendNode(new_op)
new_op.moveBefore(node)
new_op.output().setType(value.type())
value.replaceAllUsesWith(new_op.output())

# find the first slice op
slice_ops = []
prv_node = None
inp_v = next(node.inputs())
cur_node = inp_v.node()
first_slice = cur_node
while cur_node.kind() == "aten::slice":
slice_ops.append(cur_node)
first_slice = cur_node
cur_node = next(cur_node.inputs()).node()

# create aten::slice
slice_0 = graph.create("aten::slice")
for inp in first_slice.inputs():
slice_0.addInput(inp)
slice_0.output().setType(first_slice.output().type())
graph.appendNode(slice_0)
slice_0.moveAfter(new_op)

# create aten::slice_scatter
inp1_v = slice_0.output()
inp2_v = new_op.output()
prev_node = slice_0
for op_idx, op in enumerate(slice_ops):
slice_scatter = graph.create("aten::slice_scatter")
input_list = [v for v in op.inputs()]
if op_idx == 0:
slice_scatter.addInput(slice_0.output())
slice_scatter.addInput(new_op.output())
else:
slice_scatter.addInput(input_list[0])
slice_scatter.addInput(prev_node.output())
for idx, inp in enumerate(op.inputs()):
if idx < 1:
continue
slice_scatter.addInput(inp)
slice_scatter.output().setType(next(slice_scatter.inputs()).type())
out = slice_scatter.output()
graph.appendNode(slice_scatter)
slice_scatter.moveAfter(prev_node)
prev_node = slice_scatter

# create aten::copy_
copy_op = graph.create("aten::copy_")
copy_op.addInput(list(slice_scatter.inputs())[0])
copy_op.addInput(slice_scatter.output())
copy_op.addInput(cst_false.output())
copy_op.output().setType(list(slice_scatter.inputs())[0].type())
graph.appendNode(copy_op)
if list(graph.return_node().inputs())[0].node().kind() == "prim::TupleConstruct":
copy_op.moveBefore(list(graph.return_node().inputs())[0].node())
list(copy_op.inputs())[0].replaceAllUsesAfterNodeWith(copy_op, slice_scatter.output())
node.destroy()

def _jit_pass_hack_cpu_device(graph):
cfg = Config.get_current_context_or_new()
if not cfg.enable_force_to_cuda:
Expand Down
4 changes: 4 additions & 0 deletions tao_compiler/mlir/disc/transforms/codegen_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ Value emitNumElementsComputation(OpBuilder& b, Location loc, Operation* op) {
// only const rank is supported for now
assert(op->getDialect()->getNamespace() == "lmhlo");
int num_operands = op->getNumOperands();
if (isa<lmhlo::DynamicUpdateSliceOp>(op) &&
op->getOperand(0) == op->getOperand(num_operands - 1)) {
return emitNumElementsComputation(b, loc, op->getOperand(1));
}
Value result_memref = op->getOperand(num_operands - 1);
return emitNumElementsComputation(b, loc, result_memref);
}
Expand Down
12 changes: 11 additions & 1 deletion tao_compiler/mlir/disc/transforms/disc_lhlo_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,15 @@ namespace {
template <typename T>
using BaseOpConversion = OpConversionPattern<T>;

template <typename T>
Value backtraceOperand(Value operand) {
auto op = operand.getDefiningOp();
if (op && mlir::isa<T>(op)) {
return op->getOperand(0);
}
return operand;
}

struct LhloArgsMutationOpRewriter
: public OpRewritePattern<lmhlo_disc::ArgsMutationOp> {
explicit LhloArgsMutationOpRewriter(MLIRContext* context)
Expand All @@ -65,7 +74,8 @@ struct LhloArgsMutationOpRewriter
PatternRewriter& rewriter) const override {
auto op = lhloOp.getOperation();
auto operands = op->getOperands();
operands[0].replaceAllUsesWith(operands[1]);
Value value = backtraceOperand<memref::ReinterpretCastOp>(operands[0]);
value.replaceAllUsesWith(operands[1]);
rewriter.eraseOp(op);
return success();
}
Expand Down
12 changes: 10 additions & 2 deletions tao_compiler/mlir/disc/transforms/fusion_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ limitations under the License.
#include "mlir/disc/transforms/disc_shape_optimization_utils.h"
#include "mlir/disc/transforms/lhlo_elemental_utils.h"
#include "mlir/disc/transforms/placement_utils.h"
#include "tensorflow/tsl/platform/str_util.h"
#include "utils/placement_utils.h"

// This file implements some helper functions and classes used to do fusion
Expand Down Expand Up @@ -554,7 +555,8 @@ bool isFusible(Operation* op) {
lmhlo::ReverseOp,
lmhlo::SelectOp,
lmhlo::SliceOp,
lmhlo::TransposeOp
lmhlo::TransposeOp,
lmhlo::DynamicUpdateSliceOp
>(op);
// clang-format on
}
Expand Down Expand Up @@ -1481,7 +1483,13 @@ Value BaseGpuFusionStrategy::getEffectiveShape(FusionPattern& target, Value v) {
Operation* result_op = target.findLastWriter(v);
assert(result_op);
// effective shape of reduce op is its operand's shape.
return isa<lmhlo::ReduceOp>(result_op) ? result_op->getOperand(0) : v;
if (isa<lmhlo::ReduceOp>(result_op)) {
return result_op->getOperand(0);
} else if (isa<lmhlo::DynamicUpdateSliceOp>(result_op) &&
isInplaceOperator(result_op)) {
return result_op->getOperand(1);
}
return v;
}

bool BaseGpuFusionStrategy::tryFuse(ShapeAnalysis& shapeAnalysis,
Expand Down
2 changes: 2 additions & 0 deletions tao_compiler/mlir/disc/transforms/fusion_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ DenseSet<Operation*> getValueUsersInFusionLike(Value memref, Operation* op);

bool isOnGpu(Operation* op);

bool isInplaceOperator(Operation* op);

// Attributes used to annotate the fusion type, fusion name and tags.
constexpr const char* kDiscFusionTypeAttrName = "disc.fusion_type";
constexpr StringRef kFusionOpNameAttr = "disc.fusion.name";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -288,8 +288,8 @@ LogicalResult InputInlineFusionPattern::inlineFuseLhloOp(
lower_config) ||
miscFuseHelper<SliceOp>(b, user, producer, load_op, load_ops,
lower_config) ||
// miscFuseHelper<DynamicUpdateSliceOp>(b, user, producer, load_op,
// load_ops, lower_config) ||
miscFuseHelper<DynamicUpdateSliceOp>(b, user, producer, load_op, load_ops,
lower_config) ||
miscFuseHelper<TransposeOp>(b, user, producer, load_op, load_ops,
lower_config)) {
return success();
Expand Down
Loading

0 comments on commit 738fa33

Please sign in to comment.