diff --git a/paddle/cinn/adt/graph_symbolic_dim_infer_ctx.cc b/paddle/cinn/adt/graph_symbolic_dim_infer_ctx.cc index c0c189ac8ec93..c1451c68f553c 100644 --- a/paddle/cinn/adt/graph_symbolic_dim_infer_ctx.cc +++ b/paddle/cinn/adt/graph_symbolic_dim_infer_ctx.cc @@ -22,6 +22,7 @@ #include "paddle/cinn/adt/print.h" #include "paddle/cinn/adt/symbolic_dim.h" #include "paddle/cinn/adt/unique_id.h" +#include "paddle/cinn/common/dim_expr_simplify.h" #include "paddle/cinn/hlir/framework/pir/group.h" #include "paddle/cinn/hlir/framework/pir/utils.h" #include "paddle/cinn/runtime/flags.h" @@ -300,12 +301,14 @@ void GenerateProductEqualConstraints(const ::pir::Value& lhs_tensor, } } -std::vector<::pir::shape::SymbolicDimOp> CreateSymbolicDimsFromValue( +std::vector CreateSymbolicDimsFromValue( const ::pir::Value& tensor, const std::shared_ptr<::pir::ShapeConstraintIRAnalysis>& shape_analysis) { CHECK_NOTNULL(shape_analysis.get()); - std::vector<::pir::shape::SymbolicDimOp> dims = - shape_analysis->GetOrCreateSymbolicDimsForRankedValue(tensor); + const auto& shape_or_data = shape_analysis->GetShapeOrDataForValue(tensor); + const auto& dims = shape_or_data.data().has_value() + ? shape_or_data.data().value() + : shape_or_data.shape(); CHECK_EQ(dims.size(), hlir::framework::pir::CompatibleInfo::ValueShape(tensor).size()); return dims; @@ -331,19 +334,17 @@ std::string ToTxtString(const DimVar& dim_var) { dim_var.variant()); } -void GenerateDimEqualConstraints( - const std::vector<::pir::shape::SymbolicDimOp>& lhs_dims, - const std::vector<::pir::shape::SymbolicDimOp>& rhs_dims, - const ::pir::Value& lhs_tensor, - const ::pir::Value& rhs_tensor, - const ::pir::SymbolicDimMgr* symbolic_dim_mgr, - DimFunctions* ret) { +void GenerateDimEqualConstraints(const std::vector& lhs_dims, + const std::vector& rhs_dims, + const ::pir::Value& lhs_tensor, + const ::pir::Value& rhs_tensor, + const ::pir::SymbolicDimMgr* symbolic_dim_mgr, + DimFunctions* ret) { VisitEachIdxPairOfTwoVectors( lhs_dims, rhs_dims, [&](std::size_t lhs_idx, std::size_t rhs_idx) { - const ::pir::shape::SymbolicDimOp& lhs_dim = lhs_dims.at(lhs_idx); - const ::pir::shape::SymbolicDimOp& rhs_dim = rhs_dims.at(rhs_idx); - if (const_cast<::pir::SymbolicDimMgr*>(symbolic_dim_mgr) - ->IsSymbolicDimEqual(lhs_dim, rhs_dim)) { + const symbol::DimExpr& lhs_dim = lhs_dims.at(lhs_idx); + const symbol::DimExpr& rhs_dim = rhs_dims.at(rhs_idx); + if (lhs_dim == rhs_dim) { ShapeDialectTensorDim lhs_adt_dim{lhs_tensor, lhs_idx}; ShapeDialectTensorDim rhs_adt_dim{rhs_tensor, rhs_idx}; VLOG(4) << "Dim Equal: " << ToTxtString(lhs_adt_dim) @@ -358,14 +359,30 @@ void GenerateDimEqualConstraints( }); } +bool IsSymbolicDimProductEqual(const std::vector& lhs, + const std::vector& rhs) { + const auto& MakeListDimExpr = [](const std::vector& exprs) + -> symbol::List { + symbol::List ret{}; + for (const auto& expr : exprs) { + ret->emplace_back(expr); + } + return ret; + }; + symbol::DimExpr lhs_expr{symbol::Mul{MakeListDimExpr(lhs)}}; + symbol::DimExpr rhs_expr{symbol::Mul{MakeListDimExpr(rhs)}}; + return cinn::common::SimplifyDimExpr(lhs_expr) == + cinn::common::SimplifyDimExpr(rhs_expr); +} + void BuildTensorShapeDialectConstraints( const ::pir::Value& lhs_tensor, const ::pir::Value& rhs_tensor, const std::shared_ptr<::pir::ShapeConstraintIRAnalysis>& shape_analysis, DimFunctions* ret) { - std::vector<::pir::shape::SymbolicDimOp> lhs_dims = + std::vector lhs_dims = CreateSymbolicDimsFromValue(lhs_tensor, shape_analysis); - std::vector<::pir::shape::SymbolicDimOp> rhs_dims = + std::vector rhs_dims = CreateSymbolicDimsFromValue(lhs_tensor, shape_analysis); GenerateDimEqualConstraints(lhs_dims, @@ -375,9 +392,7 @@ void BuildTensorShapeDialectConstraints( &shape_analysis->symbolicDimMgr(), ret); - if (shape_analysis->symbolicDimMgr().IsSymbolicDimProductEqual( - ::pir::SymbolicDimProduct{lhs_dims}, - ::pir::SymbolicDimProduct{rhs_dims})) { + if (IsSymbolicDimProductEqual(lhs_dims, rhs_dims)) { GenerateProductEqualConstraints(lhs_tensor, rhs_tensor, ret); } } @@ -439,14 +454,17 @@ void VisitEachTensor(const List<::pir::Value>& tensors, const DoEachT& DoEach) { } } -::pir::shape::SymbolicDimOp GetSymbolicDimOp4TensorDim( +symbol::DimExpr GetSymbolicDimOp4TensorDim( const ShapeDialectTensorDim& tensor_dim, const std::shared_ptr<::pir::ShapeConstraintIRAnalysis>& shape_analysis) { - const auto& [tensor, axis] = tensor_dim; - const auto& symbolic_dim_ops = - shape_analysis->GetOrCreateSymbolicDimsForRankedValue(tensor); - CHECK_LT(axis, symbolic_dim_ops.size()); - return symbolic_dim_ops.at(axis); + auto [tensor, axis] = tensor_dim; + const auto& symbolic_shape_or_data = + shape_analysis->GetShapeOrDataForValue(tensor); + const auto& dim_exprs = symbolic_shape_or_data.data().has_value() + ? symbolic_shape_or_data.data().value() + : symbolic_shape_or_data.shape(); + CHECK_LT(axis, dim_exprs.size()); + return dim_exprs.at(axis); } SymbolicDim GetOrNewSymbolicDim( @@ -459,8 +477,7 @@ SymbolicDim GetOrNewSymbolicDim( for (const auto& [tensor_dim, symbolic_dim] : tensor_dim2symbolic_Dim) { const auto& cur_symblic_dim_op = GetSymbolicDimOp4TensorDim(tensor_dim, shape_analysis); - if (shape_analysis->symbolicDimMgr().IsSymbolicDimEqual( - target_symbolic_dim_op, cur_symblic_dim_op)) { + if (target_symbolic_dim_op == cur_symblic_dim_op) { return symbolic_dim; } } @@ -470,7 +487,7 @@ SymbolicDim GetOrNewSymbolicDim( std::unordered_map MakeEquationStartExpr( const cinn::hlir::framework::pir::Group* group, const std::shared_ptr<::pir::ShapeConstraintIRAnalysis>& shape_analysis, - std::unordered_map* + std::unordered_map* map_expr_symbolic2dialect_symbolic) { std::unordered_map ret{}; std::unordered_set output_names = GetAllOutputNames(group->ops); diff --git a/paddle/cinn/adt/graph_symbolic_dim_infer_ctx.h b/paddle/cinn/adt/graph_symbolic_dim_infer_ctx.h index fe3bfa3ffc9ff..08948296525bc 100644 --- a/paddle/cinn/adt/graph_symbolic_dim_infer_ctx.h +++ b/paddle/cinn/adt/graph_symbolic_dim_infer_ctx.h @@ -22,6 +22,7 @@ #include "paddle/cinn/utils/type_defs.h" #include "paddle/pir/core/value.h" #include "paddle/pir/dialect/shape/ir/shape_op.h" +#include "paddle/pir/dialect/shape/utils/dim_expr.h" namespace pir { class Operation; @@ -54,7 +55,7 @@ class GraphSymbolicDimInferCtx { return iter->second; } - const std::unordered_map& + const std::unordered_map& map_expr_symbolic2dialect_symbolic() const { return map_expr_symbolic2dialect_symbolic_; } @@ -65,7 +66,7 @@ class GraphSymbolicDimInferCtx { const cinn::hlir::framework::pir::Group* group_; std::unordered_map<::pir::Value, std::vector>> tensor2dim_exprs_; - std::unordered_map + std::unordered_map map_expr_symbolic2dialect_symbolic_; }; diff --git a/paddle/cinn/adt/map_expr_ctx.h b/paddle/cinn/adt/map_expr_ctx.h index 8fa36b0cdfdb4..b12c473b65b8f 100644 --- a/paddle/cinn/adt/map_expr_ctx.h +++ b/paddle/cinn/adt/map_expr_ctx.h @@ -33,10 +33,9 @@ class MapExprCtx final { MapExprCtx(const MapExprCtx&) = delete; MapExprCtx(MapExprCtx&&) = delete; - explicit MapExprCtx( - const MapExpr& map_expr, - const std::unordered_map& - map_expr_symbolic2dialect_symbolic) + explicit MapExprCtx(const MapExpr& map_expr, + const std::unordered_map& + map_expr_symbolic2dialect_symbolic) : map_expr_(map_expr), map_expr_symbolic2dialect_symbolic_( map_expr_symbolic2dialect_symbolic) {} @@ -54,7 +53,7 @@ class MapExprCtx final { return node2lowered_funcs_; } - const std::unordered_map& + const std::unordered_map& map_expr_symbolic2dialect_symbolic() const { return map_expr_symbolic2dialect_symbolic_; } @@ -62,7 +61,7 @@ class MapExprCtx final { private: const MapExpr map_expr_; Node2LoweredFuncs node2lowered_funcs_; - std::unordered_map + std::unordered_map map_expr_symbolic2dialect_symbolic_; }; diff --git a/paddle/cinn/hlir/pe/map_expr_to_ir.cc b/paddle/cinn/hlir/pe/map_expr_to_ir.cc index ad3b3469a604e..05bb3674db091 100644 --- a/paddle/cinn/hlir/pe/map_expr_to_ir.cc +++ b/paddle/cinn/hlir/pe/map_expr_to_ir.cc @@ -55,7 +55,7 @@ class MapExprToIrTranslator { explicit MapExprToIrTranslator( const MapExpr& map_expr, const Node2LoweredFuncs& node2lowered_funcs, - const std::unordered_map& + const std::unordered_map& map_expr_symbolic2dialect_symbolic, const cinn::common::Target& target) : map_expr_(map_expr), @@ -790,8 +790,9 @@ class MapExprToIrTranslator { ir::Expr TranslateDimExprImpl(const SymbolicDim& dim_expr) const { CHECK_GT(map_expr_symbolic2dialect_symbolic_.count(dim_expr), 0); + CHECK(map_expr_symbolic2dialect_symbolic_.at(dim_expr).Has()); return ir::Var{ - map_expr_symbolic2dialect_symbolic_.at(dim_expr).GetSymName()}; + map_expr_symbolic2dialect_symbolic_.at(dim_expr).Get()}; } ir::Expr TranslateDimExprImpl(const Negative& dim_expr) const { @@ -873,7 +874,7 @@ class MapExprToIrTranslator { const cinn::common::Target target_; TensorIteratorExpr4TensorT TensorIteratorExpr4Tensor; LoopDescriptor4LoopIteratorT LoopDescriptor4LoopIterator; - std::unordered_map + std::unordered_map map_expr_symbolic2dialect_symbolic_; };