Skip to content

Commit

Permalink
Fix RMSNorm symbolic schedule (PaddlePaddle#61083)
Browse files Browse the repository at this point in the history
* [Dynamic Shape] Split ReduceMean to ReduceSum+Scale

* fix int32 bug

* fix kernel output args

* fix symbolic schedule on rms_norm

* cinn(dim_expr): adapt static pir to dim expr api

---------

Co-authored-by: jiahongyu <jiahongyu@baidu.com>
Co-authored-by: 6clc <chaoliu.lc@foxmail.com>
  • Loading branch information
3 people committed Jan 26, 2024
1 parent 1532abc commit 9c0bf71
Show file tree
Hide file tree
Showing 11 changed files with 66 additions and 21 deletions.
3 changes: 2 additions & 1 deletion paddle/cinn/ast_gen_ius/ast_gen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "paddle/cinn/optim/replace_var_with_expr.h"

PD_DECLARE_bool(cinn_new_group_scheduler);
PD_DECLARE_bool(cinn_bucket_compile);

namespace cinn {
namespace ast_gen_ius {
Expand Down Expand Up @@ -184,7 +185,7 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) {
// Put the two parts together
ir::Expr body = ir::Block::Make({init_body, reduce_body});
for (int i = static_cast<int>(axis_len) - 1; i >= 0; --i) {
if (shape[i] == Expr(1)) {
if (!FLAGS_cinn_bucket_compile && shape[i] == Expr(1)) {
continue;
}
ir::Var loop_var = axis[i];
Expand Down
4 changes: 2 additions & 2 deletions paddle/cinn/common/cas.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ struct CasInterval {
if (expr_l.is_constant() && expr_r.is_constant()) {
CHECK(expr_l->type().is_integer());
CHECK(expr_r->type().is_integer());
l = expr_l.as_int32();
r = expr_r.as_int32();
l = expr_l.as_int64();
r = expr_r.as_int64();
return;
}
e_l = expr_l;
Expand Down
34 changes: 32 additions & 2 deletions paddle/cinn/hlir/framework/pir/op_lowering_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ BucketLoweredFuncsWrapper OpLowererImpl::BucketLower(const GroupPtr& group,
std::vector<ir::Argument> group_func_args;
std::vector<ir::LoweredFunc> funcs = PostProcess(group,
tensor_map,
apply_op_schedule,
apply_group_schedule,
{scheduled_func_bodies},
&group_func_arg_tensors_copy,
&group_func_args);
Expand Down Expand Up @@ -737,6 +737,31 @@ ir::Tensor OpLowererImpl::GetTensor(const GroupPtr& group,
}
}

ir::Tensor OpLowererImpl::GetTensorSymbolic(const GroupPtr& group,
const ::pir::Value& value) {
auto type_info = value.type().dyn_cast<paddle::dialect::DenseTensorType>();
auto dtype = type_info.dtype();
std::string input_id = ValueName(value);
auto ForEachDimExpr = [&](const auto& DoEach) {
if (!group->value_to_shape_or_data_exprs.empty()) {
const auto& sym_vec = group->GetShapeOrDataExprs(value).shape();
for (const auto& dim_expr : sym_vec) {
DoEach(dim_expr);
}
} else {
auto in_shape = ::common::vectorize<int64_t>(type_info.dims());
for (int64_t dim : in_shape) {
DoEach(::symbol::DimExpr{dim});
}
}
};
std::vector<ir::Dim> sym_shape;
ForEachDimExpr(
[&](const auto& sym) { sym_shape.emplace_back(input_id, sym); });
return lang::CreatePlaceHolder(
sym_shape, CompatibleInfo::ConvertIRType(dtype), input_id);
}

std::vector<ir::Tensor> OpLowererImpl::CollectInputTensor(
const GroupPtr& group,
const ::pir::Operation* op,
Expand All @@ -745,7 +770,12 @@ std::vector<ir::Tensor> OpLowererImpl::CollectInputTensor(
std::vector<ir::Tensor> tensors;
for (auto in_value : CompatibleInfo::RealOperandSources(*op)) {
VLOG(4) << "input tensor name: " << ValueName(in_value);
ir::Tensor tensor = GetTensor(group, in_value);
ir::Tensor tensor;
if (FLAGS_cinn_bucket_compile) {
tensor = GetTensorSymbolic(group, in_value);
} else {
tensor = GetTensor(group, in_value);
}
VLOG(4) << "shape: " << tensor->shape;
VLOG(4) << "sym_shape: " << tensor->sym_shape;

Expand Down
2 changes: 2 additions & 0 deletions paddle/cinn/hlir/framework/pir/op_lowering_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,8 @@ class OpLowererImpl : public OpLowererImplBase<GroupPtr> {
std::unordered_map<::pir::Value, ir::Tensor>* tensor_map);

ir::Tensor GetTensor(const GroupPtr& group, const ::pir::Value& value);
ir::Tensor GetTensorSymbolic(const GroupPtr& group,
const ::pir::Value& value);

void CollectOutputInfo(::pir::Operation* op,
std::vector<Type>* out_types,
Expand Down
17 changes: 13 additions & 4 deletions paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ namespace cinn {
namespace ir {

void DynamicShapeGroupScheduler::Init() {
VLOG(4) << "=============================Start group "
"schedule==============================";
VLOG(4) << "original group func body: \n"
<< ir_sch_->GetModule().GetExprs()[0];
InitBuckets();
tactics_.emplace_back(new AlignIterSpaceTactic());
tactics_.emplace_back(new ComputeInlineTactic());
Expand Down Expand Up @@ -154,6 +158,7 @@ DynamicShapeGroupScheduler::GetIRs() {

IterativeSpaceInfo DynamicShapeGroupScheduler::ConstructIterSpaceInfo(
ScheduleBlockNode* node) {
VLOG(5) << "global master: " << node->id();
IterativeSpaceInfo info;
std::vector<int> sp_iter_indices;
std::vector<int> rb_iter_indices;
Expand Down Expand Up @@ -199,12 +204,16 @@ IterativeSpaceInfo DynamicShapeGroupScheduler::ConstructIterSpaceInfo(
CHECK_NOTNULL(index.as_var());
ir::Var iter_var = index.as_var_ref();
ir::Expr iter_value = iter_var2value.at(iter_var);
CHECK_NOTNULL(iter_value.as_var());
CHECK(iter_value.as_var() || iter_value.is_constant());
ir::For* for_node;
for (ir::Expr& loop : loops) {
if (loop.As<ir::For>()->loop_var == iter_value.as_var_ref()) {
for_node = loop.As<ir::For>();
if (iter_value.as_var()) {
for (ir::Expr& loop : loops) {
if (loop.As<ir::For>()->loop_var == iter_value.as_var_ref()) {
for_node = loop.As<ir::For>();
}
}
} else if (iter_value.is_constant()) {
for_node = loops.at(loop_idx).As<ir::For>();
}
CHECK_NOTNULL(for_node);
bool is_reduce_iter_var = reduce_iter_vars.count(iter_var) > 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,8 @@ void ArrangeStorageTactic::Apply(ir::IRSchedule* sch,
LOG(FATAL) << "Fusion requires synchronization across blocks, but "
"currently we do not support it.";
break;
} else {
LOG(FATAL) << "dead code";
}
}

Expand Down
3 changes: 0 additions & 3 deletions paddle/cinn/ir/group_schedule/tactic/compute_inline_tactic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,6 @@ void ComputeInlineTactic::Init(ScheduleContext* context) {

void ComputeInlineTactic::Apply(ir::IRSchedule* sch,
const std::string& block_id) {
VLOG(5) << "[Start DoComputeInline] func body: "
<< sch->GetModule().GetExprs().front();

// TODO(LiuYang): Compute of ops will be rewrited so that we
// don't use it in dynamic group_schedule rules temporarily.
// if (IsProhibitScheduleExternCallBlock(node->Block())) {
Expand Down
3 changes: 2 additions & 1 deletion paddle/cinn/ir/ir_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "paddle/cinn/ir/ir_visitor.h"
#include "paddle/cinn/ir/module.h"
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/utils/error.h"

namespace cinn {
namespace ir {
Expand Down Expand Up @@ -115,7 +116,7 @@ int16_t Expr::as_int16() const {
return As<IntImm>()->value;
}
int32_t Expr::as_int32() const {
CHECK(type().is_int(32));
CHECK(type().is_int(32)) << utils::enforce::GetCurrentTraceBackString();
return As<IntImm>()->value;
}
int64_t Expr::as_int64() const {
Expand Down
13 changes: 7 additions & 6 deletions paddle/cinn/ir/schedule/impl/compute_location.cc
Original file line number Diff line number Diff line change
Expand Up @@ -152,13 +152,14 @@ void DyScheduleImpl::SimpleComputeAt(const Expr& block, const Expr& loop) {
// collect if
auto if_checker = [](const Expr* x) { return x->As<ir::IfThenElse>(); };
auto if_set = ir::ir_utils::CollectIRNodesWithoutTensor(body, if_checker);
auto checker = [block_name](const Expr* x) {
return x->As<ir::ScheduleBlockRealize>() &&
x->As<ir::ScheduleBlockRealize>()
->schedule_block.As<ScheduleBlock>()
->name == block_name;
};
for (auto if_expr : if_set) {
auto checker = [block_name](const Expr* x) {
return x->As<ir::ScheduleBlockRealize>() &&
x->As<ir::ScheduleBlockRealize>()
->schedule_block.As<ScheduleBlock>()
->name == block_name;
};
if (Contains(result, if_expr)) continue;
if (ir::ir_utils::CollectIRNodesWithoutTensor(if_expr, checker, true)
.size() > 0) {
result =
Expand Down
4 changes: 3 additions & 1 deletion paddle/cinn/ir/tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,9 @@ std::vector<Expr> _Tensor_::domain_with_reduce_axis() const {
if (reduce_axis.empty()) return domain;
auto res = domain;
for (const Var &axis : reduce_axis) {
CHECK(axis->upper_bound.type().is_int(32)) << axis->upper_bound;
CHECK(axis->upper_bound.type().is_int(32) ||
axis->upper_bound.type().is_int(64))
<< axis->upper_bound;
res.push_back(axis->upper_bound);
}
return res;
Expand Down
2 changes: 1 addition & 1 deletion test/ir/pir/cinn/test_rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(self):
self.variance_epsilon = 1e-6

def forward(self, hidden_states):
variance = hidden_states.pow(2).mean(-1, keepdim=True)
variance = hidden_states.pow(2).sum(-1, keepdim=True) / 768
hidden_states = (
paddle.rsqrt(variance + self.variance_epsilon) * hidden_states
)
Expand Down

0 comments on commit 9c0bf71

Please sign in to comment.