Skip to content

Commit

Permalink
[DimExpr] SymbolicDimOp->symbol::DimExpr (#60734)
Browse files Browse the repository at this point in the history
* [DimExpr] SymbolicDimOp->symbol::DimExpr

* Change interface
  • Loading branch information
jiahy0825 committed Jan 16, 2024
1 parent 5599dbd commit d54c834
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 39 deletions.
73 changes: 45 additions & 28 deletions paddle/cinn/adt/graph_symbolic_dim_infer_ctx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "paddle/cinn/adt/print.h"
#include "paddle/cinn/adt/symbolic_dim.h"
#include "paddle/cinn/adt/unique_id.h"
#include "paddle/cinn/common/dim_expr_simplify.h"
#include "paddle/cinn/hlir/framework/pir/group.h"
#include "paddle/cinn/hlir/framework/pir/utils.h"
#include "paddle/cinn/runtime/flags.h"
Expand Down Expand Up @@ -300,12 +301,14 @@ void GenerateProductEqualConstraints(const ::pir::Value& lhs_tensor,
}
}

std::vector<::pir::shape::SymbolicDimOp> CreateSymbolicDimsFromValue(
std::vector<symbol::DimExpr> CreateSymbolicDimsFromValue(
const ::pir::Value& tensor,
const std::shared_ptr<::pir::ShapeConstraintIRAnalysis>& shape_analysis) {
CHECK_NOTNULL(shape_analysis.get());
std::vector<::pir::shape::SymbolicDimOp> dims =
shape_analysis->GetOrCreateSymbolicDimsForRankedValue(tensor);
const auto& shape_or_data = shape_analysis->GetShapeOrDataForValue(tensor);
const auto& dims = shape_or_data.data().has_value()
? shape_or_data.data().value()
: shape_or_data.shape();
CHECK_EQ(dims.size(),
hlir::framework::pir::CompatibleInfo::ValueShape(tensor).size());
return dims;
Expand All @@ -331,19 +334,17 @@ std::string ToTxtString(const DimVar& dim_var) {
dim_var.variant());
}

void GenerateDimEqualConstraints(
const std::vector<::pir::shape::SymbolicDimOp>& lhs_dims,
const std::vector<::pir::shape::SymbolicDimOp>& rhs_dims,
const ::pir::Value& lhs_tensor,
const ::pir::Value& rhs_tensor,
const ::pir::SymbolicDimMgr* symbolic_dim_mgr,
DimFunctions* ret) {
void GenerateDimEqualConstraints(const std::vector<symbol::DimExpr>& lhs_dims,
const std::vector<symbol::DimExpr>& rhs_dims,
const ::pir::Value& lhs_tensor,
const ::pir::Value& rhs_tensor,
const ::pir::SymbolicDimMgr* symbolic_dim_mgr,
DimFunctions* ret) {
VisitEachIdxPairOfTwoVectors(
lhs_dims, rhs_dims, [&](std::size_t lhs_idx, std::size_t rhs_idx) {
const ::pir::shape::SymbolicDimOp& lhs_dim = lhs_dims.at(lhs_idx);
const ::pir::shape::SymbolicDimOp& rhs_dim = rhs_dims.at(rhs_idx);
if (const_cast<::pir::SymbolicDimMgr*>(symbolic_dim_mgr)
->IsSymbolicDimEqual(lhs_dim, rhs_dim)) {
const symbol::DimExpr& lhs_dim = lhs_dims.at(lhs_idx);
const symbol::DimExpr& rhs_dim = rhs_dims.at(rhs_idx);
if (lhs_dim == rhs_dim) {
ShapeDialectTensorDim lhs_adt_dim{lhs_tensor, lhs_idx};
ShapeDialectTensorDim rhs_adt_dim{rhs_tensor, rhs_idx};
VLOG(4) << "Dim Equal: " << ToTxtString(lhs_adt_dim)
Expand All @@ -358,14 +359,30 @@ void GenerateDimEqualConstraints(
});
}

bool IsSymbolicDimProductEqual(const std::vector<symbol::DimExpr>& lhs,
const std::vector<symbol::DimExpr>& rhs) {
const auto& MakeListDimExpr = [](const std::vector<symbol::DimExpr>& exprs)
-> symbol::List<symbol::DimExpr> {
symbol::List<symbol::DimExpr> ret{};
for (const auto& expr : exprs) {
ret->emplace_back(expr);
}
return ret;
};
symbol::DimExpr lhs_expr{symbol::Mul<symbol::DimExpr>{MakeListDimExpr(lhs)}};
symbol::DimExpr rhs_expr{symbol::Mul<symbol::DimExpr>{MakeListDimExpr(rhs)}};
return cinn::common::SimplifyDimExpr(lhs_expr) ==
cinn::common::SimplifyDimExpr(rhs_expr);
}

void BuildTensorShapeDialectConstraints(
const ::pir::Value& lhs_tensor,
const ::pir::Value& rhs_tensor,
const std::shared_ptr<::pir::ShapeConstraintIRAnalysis>& shape_analysis,
DimFunctions* ret) {
std::vector<::pir::shape::SymbolicDimOp> lhs_dims =
std::vector<symbol::DimExpr> lhs_dims =
CreateSymbolicDimsFromValue(lhs_tensor, shape_analysis);
std::vector<::pir::shape::SymbolicDimOp> rhs_dims =
std::vector<symbol::DimExpr> rhs_dims =
CreateSymbolicDimsFromValue(lhs_tensor, shape_analysis);

GenerateDimEqualConstraints(lhs_dims,
Expand All @@ -375,9 +392,7 @@ void BuildTensorShapeDialectConstraints(
&shape_analysis->symbolicDimMgr(),
ret);

if (shape_analysis->symbolicDimMgr().IsSymbolicDimProductEqual(
::pir::SymbolicDimProduct{lhs_dims},
::pir::SymbolicDimProduct{rhs_dims})) {
if (IsSymbolicDimProductEqual(lhs_dims, rhs_dims)) {
GenerateProductEqualConstraints(lhs_tensor, rhs_tensor, ret);
}
}
Expand Down Expand Up @@ -439,14 +454,17 @@ void VisitEachTensor(const List<::pir::Value>& tensors, const DoEachT& DoEach) {
}
}

::pir::shape::SymbolicDimOp GetSymbolicDimOp4TensorDim(
symbol::DimExpr GetSymbolicDimOp4TensorDim(
const ShapeDialectTensorDim& tensor_dim,
const std::shared_ptr<::pir::ShapeConstraintIRAnalysis>& shape_analysis) {
const auto& [tensor, axis] = tensor_dim;
const auto& symbolic_dim_ops =
shape_analysis->GetOrCreateSymbolicDimsForRankedValue(tensor);
CHECK_LT(axis, symbolic_dim_ops.size());
return symbolic_dim_ops.at(axis);
auto [tensor, axis] = tensor_dim;
const auto& symbolic_shape_or_data =
shape_analysis->GetShapeOrDataForValue(tensor);
const auto& dim_exprs = symbolic_shape_or_data.data().has_value()
? symbolic_shape_or_data.data().value()
: symbolic_shape_or_data.shape();
CHECK_LT(axis, dim_exprs.size());
return dim_exprs.at(axis);
}

SymbolicDim GetOrNewSymbolicDim(
Expand All @@ -459,8 +477,7 @@ SymbolicDim GetOrNewSymbolicDim(
for (const auto& [tensor_dim, symbolic_dim] : tensor_dim2symbolic_Dim) {
const auto& cur_symblic_dim_op =
GetSymbolicDimOp4TensorDim(tensor_dim, shape_analysis);
if (shape_analysis->symbolicDimMgr().IsSymbolicDimEqual(
target_symbolic_dim_op, cur_symblic_dim_op)) {
if (target_symbolic_dim_op == cur_symblic_dim_op) {
return symbolic_dim;
}
}
Expand All @@ -470,7 +487,7 @@ SymbolicDim GetOrNewSymbolicDim(
std::unordered_map<DimVar, const DimExpr> MakeEquationStartExpr(
const cinn::hlir::framework::pir::Group* group,
const std::shared_ptr<::pir::ShapeConstraintIRAnalysis>& shape_analysis,
std::unordered_map<SymbolicDim, ::pir::shape::SymbolicDimOp>*
std::unordered_map<SymbolicDim, symbol::DimExpr>*
map_expr_symbolic2dialect_symbolic) {
std::unordered_map<DimVar, const DimExpr> ret{};
std::unordered_set<std::string> output_names = GetAllOutputNames(group->ops);
Expand Down
5 changes: 3 additions & 2 deletions paddle/cinn/adt/graph_symbolic_dim_infer_ctx.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "paddle/cinn/utils/type_defs.h"
#include "paddle/pir/core/value.h"
#include "paddle/pir/dialect/shape/ir/shape_op.h"
#include "paddle/pir/dialect/shape/utils/dim_expr.h"

namespace pir {
class Operation;
Expand Down Expand Up @@ -54,7 +55,7 @@ class GraphSymbolicDimInferCtx {
return iter->second;
}

const std::unordered_map<SymbolicDim, ::pir::shape::SymbolicDimOp>&
const std::unordered_map<SymbolicDim, ::symbol::DimExpr>&
map_expr_symbolic2dialect_symbolic() const {
return map_expr_symbolic2dialect_symbolic_;
}
Expand All @@ -65,7 +66,7 @@ class GraphSymbolicDimInferCtx {
const cinn::hlir::framework::pir::Group* group_;
std::unordered_map<::pir::Value, std::vector<std::optional<DimExpr>>>
tensor2dim_exprs_;
std::unordered_map<SymbolicDim, ::pir::shape::SymbolicDimOp>
std::unordered_map<SymbolicDim, ::symbol::DimExpr>
map_expr_symbolic2dialect_symbolic_;
};

Expand Down
11 changes: 5 additions & 6 deletions paddle/cinn/adt/map_expr_ctx.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,9 @@ class MapExprCtx final {
MapExprCtx(const MapExprCtx&) = delete;
MapExprCtx(MapExprCtx&&) = delete;

explicit MapExprCtx(
const MapExpr& map_expr,
const std::unordered_map<SymbolicDim, ::pir::shape::SymbolicDimOp>&
map_expr_symbolic2dialect_symbolic)
explicit MapExprCtx(const MapExpr& map_expr,
const std::unordered_map<SymbolicDim, symbol::DimExpr>&
map_expr_symbolic2dialect_symbolic)
: map_expr_(map_expr),
map_expr_symbolic2dialect_symbolic_(
map_expr_symbolic2dialect_symbolic) {}
Expand All @@ -54,15 +53,15 @@ class MapExprCtx final {
return node2lowered_funcs_;
}

const std::unordered_map<SymbolicDim, ::pir::shape::SymbolicDimOp>&
const std::unordered_map<SymbolicDim, symbol::DimExpr>&
map_expr_symbolic2dialect_symbolic() const {
return map_expr_symbolic2dialect_symbolic_;
}

private:
const MapExpr map_expr_;
Node2LoweredFuncs node2lowered_funcs_;
std::unordered_map<SymbolicDim, ::pir::shape::SymbolicDimOp>
std::unordered_map<SymbolicDim, symbol::DimExpr>
map_expr_symbolic2dialect_symbolic_;
};

Expand Down
7 changes: 4 additions & 3 deletions paddle/cinn/hlir/pe/map_expr_to_ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class MapExprToIrTranslator {
explicit MapExprToIrTranslator(
const MapExpr& map_expr,
const Node2LoweredFuncs& node2lowered_funcs,
const std::unordered_map<SymbolicDim, ::pir::shape::SymbolicDimOp>&
const std::unordered_map<SymbolicDim, symbol::DimExpr>&
map_expr_symbolic2dialect_symbolic,
const cinn::common::Target& target)
: map_expr_(map_expr),
Expand Down Expand Up @@ -790,8 +790,9 @@ class MapExprToIrTranslator {

ir::Expr TranslateDimExprImpl(const SymbolicDim& dim_expr) const {
CHECK_GT(map_expr_symbolic2dialect_symbolic_.count(dim_expr), 0);
CHECK(map_expr_symbolic2dialect_symbolic_.at(dim_expr).Has<std::string>());
return ir::Var{
map_expr_symbolic2dialect_symbolic_.at(dim_expr).GetSymName()};
map_expr_symbolic2dialect_symbolic_.at(dim_expr).Get<std::string>()};
}

ir::Expr TranslateDimExprImpl(const Negative<DimExpr>& dim_expr) const {
Expand Down Expand Up @@ -873,7 +874,7 @@ class MapExprToIrTranslator {
const cinn::common::Target target_;
TensorIteratorExpr4TensorT TensorIteratorExpr4Tensor;
LoopDescriptor4LoopIteratorT LoopDescriptor4LoopIterator;
std::unordered_map<SymbolicDim, ::pir::shape::SymbolicDimOp>
std::unordered_map<SymbolicDim, symbol::DimExpr>
map_expr_symbolic2dialect_symbolic_;
};

Expand Down

0 comments on commit d54c834

Please sign in to comment.