Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DimExpr] SymbolicDimOp->symbol::DimExpr #60734

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
[DimExpr] SymbolicDimOp->symbol::DimExpr
  • Loading branch information
jiahy0825 committed Jan 10, 2024
commit 6bb890164e01654b34aee1bbe1b999992bf427d7
74 changes: 46 additions & 28 deletions paddle/cinn/adt/graph_symbolic_dim_infer_ctx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "paddle/cinn/adt/print.h"
#include "paddle/cinn/adt/symbolic_dim.h"
#include "paddle/cinn/adt/unique_id.h"
#include "paddle/cinn/common/dim_expr_simplify.h"
#include "paddle/cinn/hlir/framework/pir/group.h"
#include "paddle/cinn/hlir/framework/pir/utils.h"
#include "paddle/cinn/runtime/flags.h"
Expand Down Expand Up @@ -300,12 +301,15 @@ void GenerateProductEqualConstraints(const ::pir::Value& lhs_tensor,
}
}

std::vector<::pir::shape::SymbolicDimOp> CreateSymbolicDimsFromValue(
std::vector<symbol::DimExpr> CreateSymbolicDimsFromValue(
const ::pir::Value& tensor,
const std::shared_ptr<::pir::ShapeConstraintIRAnalysis>& shape_analysis) {
CHECK_NOTNULL(shape_analysis.get());
std::vector<::pir::shape::SymbolicDimOp> dims =
shape_analysis->GetOrCreateSymbolicDimsForRankedValue(tensor);
const auto& shape_or_data = shape_analysis->GetShapeOrDataForValue(
const_cast<::pir::Value*>(&tensor));
const auto& dims = shape_or_data.data().has_value()
? shape_or_data.data().value()
: shape_or_data.shape();
CHECK_EQ(dims.size(),
hlir::framework::pir::CompatibleInfo::ValueShape(tensor).size());
return dims;
Expand All @@ -331,19 +335,17 @@ std::string ToTxtString(const DimVar& dim_var) {
dim_var.variant());
}

void GenerateDimEqualConstraints(
const std::vector<::pir::shape::SymbolicDimOp>& lhs_dims,
const std::vector<::pir::shape::SymbolicDimOp>& rhs_dims,
const ::pir::Value& lhs_tensor,
const ::pir::Value& rhs_tensor,
const ::pir::SymbolicDimMgr* symbolic_dim_mgr,
DimFunctions* ret) {
void GenerateDimEqualConstraints(const std::vector<symbol::DimExpr>& lhs_dims,
const std::vector<symbol::DimExpr>& rhs_dims,
const ::pir::Value& lhs_tensor,
const ::pir::Value& rhs_tensor,
const ::pir::SymbolicDimMgr* symbolic_dim_mgr,
DimFunctions* ret) {
VisitEachIdxPairOfTwoVectors(
lhs_dims, rhs_dims, [&](std::size_t lhs_idx, std::size_t rhs_idx) {
const ::pir::shape::SymbolicDimOp& lhs_dim = lhs_dims.at(lhs_idx);
const ::pir::shape::SymbolicDimOp& rhs_dim = rhs_dims.at(rhs_idx);
if (const_cast<::pir::SymbolicDimMgr*>(symbolic_dim_mgr)
->IsSymbolicDimEqual(lhs_dim, rhs_dim)) {
const symbol::DimExpr& lhs_dim = lhs_dims.at(lhs_idx);
const symbol::DimExpr& rhs_dim = rhs_dims.at(rhs_idx);
if (lhs_dim == rhs_dim) {
ShapeDialectTensorDim lhs_adt_dim{lhs_tensor, lhs_idx};
ShapeDialectTensorDim rhs_adt_dim{rhs_tensor, rhs_idx};
VLOG(4) << "Dim Equal: " << ToTxtString(lhs_adt_dim)
Expand All @@ -358,14 +360,30 @@ void GenerateDimEqualConstraints(
});
}

bool IsSymbolicDimProductEqual(const std::vector<symbol::DimExpr>& lhs,
const std::vector<symbol::DimExpr>& rhs) {
const auto& MakeListDimExpr = [](const std::vector<symbol::DimExpr>& exprs)
-> symbol::List<symbol::DimExpr> {
symbol::List<symbol::DimExpr> ret{};
for (const auto& expr : exprs) {
ret->emplace_back(expr);
}
return ret;
};
symbol::DimExpr lhs_expr{symbol::Mul<symbol::DimExpr>{MakeListDimExpr(lhs)}};
symbol::DimExpr rhs_expr{symbol::Mul<symbol::DimExpr>{MakeListDimExpr(rhs)}};
return cinn::common::SimplifyDimExpr(lhs_expr) ==
cinn::common::SimplifyDimExpr(rhs_expr);
}

void BuildTensorShapeDialectConstraints(
const ::pir::Value& lhs_tensor,
const ::pir::Value& rhs_tensor,
const std::shared_ptr<::pir::ShapeConstraintIRAnalysis>& shape_analysis,
DimFunctions* ret) {
std::vector<::pir::shape::SymbolicDimOp> lhs_dims =
std::vector<symbol::DimExpr> lhs_dims =
CreateSymbolicDimsFromValue(lhs_tensor, shape_analysis);
std::vector<::pir::shape::SymbolicDimOp> rhs_dims =
std::vector<symbol::DimExpr> rhs_dims =
CreateSymbolicDimsFromValue(lhs_tensor, shape_analysis);

GenerateDimEqualConstraints(lhs_dims,
Expand All @@ -375,9 +393,7 @@ void BuildTensorShapeDialectConstraints(
&shape_analysis->symbolicDimMgr(),
ret);

if (shape_analysis->symbolicDimMgr().IsSymbolicDimProductEqual(
::pir::SymbolicDimProduct{lhs_dims},
::pir::SymbolicDimProduct{rhs_dims})) {
if (IsSymbolicDimProductEqual(lhs_dims, rhs_dims)) {
GenerateProductEqualConstraints(lhs_tensor, rhs_tensor, ret);
}
}
Expand Down Expand Up @@ -439,14 +455,17 @@ void VisitEachTensor(const List<::pir::Value>& tensors, const DoEachT& DoEach) {
}
}

::pir::shape::SymbolicDimOp GetSymbolicDimOp4TensorDim(
symbol::DimExpr GetSymbolicDimOp4TensorDim(
const ShapeDialectTensorDim& tensor_dim,
const std::shared_ptr<::pir::ShapeConstraintIRAnalysis>& shape_analysis) {
const auto& [tensor, axis] = tensor_dim;
const auto& symbolic_dim_ops =
shape_analysis->GetOrCreateSymbolicDimsForRankedValue(tensor);
CHECK_LT(axis, symbolic_dim_ops.size());
return symbolic_dim_ops.at(axis);
auto [tensor, axis] = tensor_dim;
const auto& symbolic_shape_or_data =
shape_analysis->GetShapeOrDataForValue(&tensor);
const auto& dim_exprs = symbolic_shape_or_data.data().has_value()
? symbolic_shape_or_data.data().value()
: symbolic_shape_or_data.shape();
CHECK_LT(axis, dim_exprs.size());
return dim_exprs.at(axis);
}

SymbolicDim GetOrNewSymbolicDim(
Expand All @@ -459,8 +478,7 @@ SymbolicDim GetOrNewSymbolicDim(
for (const auto& [tensor_dim, symbolic_dim] : tensor_dim2symbolic_Dim) {
const auto& cur_symblic_dim_op =
GetSymbolicDimOp4TensorDim(tensor_dim, shape_analysis);
if (shape_analysis->symbolicDimMgr().IsSymbolicDimEqual(
target_symbolic_dim_op, cur_symblic_dim_op)) {
if (target_symbolic_dim_op == cur_symblic_dim_op) {
return symbolic_dim;
}
}
Expand All @@ -470,7 +488,7 @@ SymbolicDim GetOrNewSymbolicDim(
std::unordered_map<DimVar, const DimExpr> MakeEquationStartExpr(
const cinn::hlir::framework::pir::Group* group,
const std::shared_ptr<::pir::ShapeConstraintIRAnalysis>& shape_analysis,
std::unordered_map<SymbolicDim, ::pir::shape::SymbolicDimOp>*
std::unordered_map<SymbolicDim, symbol::DimExpr>*
map_expr_symbolic2dialect_symbolic) {
std::unordered_map<DimVar, const DimExpr> ret{};
std::unordered_set<std::string> output_names = GetAllOutputNames(group->ops);
Expand Down
5 changes: 3 additions & 2 deletions paddle/cinn/adt/graph_symbolic_dim_infer_ctx.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "paddle/cinn/utils/type_defs.h"
#include "paddle/pir/core/value.h"
#include "paddle/pir/dialect/shape/ir/shape_op.h"
#include "paddle/pir/dialect/shape/utils/dim_expr.h"

namespace pir {
class Operation;
Expand Down Expand Up @@ -54,7 +55,7 @@ class GraphSymbolicDimInferCtx {
return iter->second;
}

const std::unordered_map<SymbolicDim, ::pir::shape::SymbolicDimOp>&
const std::unordered_map<SymbolicDim, ::symbol::DimExpr>&
map_expr_symbolic2dialect_symbolic() const {
return map_expr_symbolic2dialect_symbolic_;
}
Expand All @@ -65,7 +66,7 @@ class GraphSymbolicDimInferCtx {
const cinn::hlir::framework::pir::Group* group_;
std::unordered_map<::pir::Value, std::vector<std::optional<DimExpr>>>
tensor2dim_exprs_;
std::unordered_map<SymbolicDim, ::pir::shape::SymbolicDimOp>
std::unordered_map<SymbolicDim, ::symbol::DimExpr>
map_expr_symbolic2dialect_symbolic_;
};

Expand Down
11 changes: 5 additions & 6 deletions paddle/cinn/adt/map_expr_ctx.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,9 @@ class MapExprCtx final {
MapExprCtx(const MapExprCtx&) = delete;
MapExprCtx(MapExprCtx&&) = delete;

explicit MapExprCtx(
const MapExpr& map_expr,
const std::unordered_map<SymbolicDim, ::pir::shape::SymbolicDimOp>&
map_expr_symbolic2dialect_symbolic)
explicit MapExprCtx(const MapExpr& map_expr,
const std::unordered_map<SymbolicDim, symbol::DimExpr>&
map_expr_symbolic2dialect_symbolic)
: map_expr_(map_expr),
map_expr_symbolic2dialect_symbolic_(
map_expr_symbolic2dialect_symbolic) {}
Expand All @@ -54,15 +53,15 @@ class MapExprCtx final {
return node2lowered_funcs_;
}

const std::unordered_map<SymbolicDim, ::pir::shape::SymbolicDimOp>&
const std::unordered_map<SymbolicDim, symbol::DimExpr>&
map_expr_symbolic2dialect_symbolic() const {
return map_expr_symbolic2dialect_symbolic_;
}

private:
const MapExpr map_expr_;
Node2LoweredFuncs node2lowered_funcs_;
std::unordered_map<SymbolicDim, ::pir::shape::SymbolicDimOp>
std::unordered_map<SymbolicDim, symbol::DimExpr>
map_expr_symbolic2dialect_symbolic_;
};

Expand Down
7 changes: 4 additions & 3 deletions paddle/cinn/hlir/pe/map_expr_to_ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class MapExprToIrTranslator {
explicit MapExprToIrTranslator(
const MapExpr& map_expr,
const Node2LoweredFuncs& node2lowered_funcs,
const std::unordered_map<SymbolicDim, ::pir::shape::SymbolicDimOp>&
const std::unordered_map<SymbolicDim, symbol::DimExpr>&
map_expr_symbolic2dialect_symbolic,
const cinn::common::Target& target)
: map_expr_(map_expr),
Expand Down Expand Up @@ -790,8 +790,9 @@ class MapExprToIrTranslator {

ir::Expr TranslateDimExprImpl(const SymbolicDim& dim_expr) const {
CHECK_GT(map_expr_symbolic2dialect_symbolic_.count(dim_expr), 0);
CHECK(map_expr_symbolic2dialect_symbolic_.at(dim_expr).Has<std::string>());
return ir::Var{
map_expr_symbolic2dialect_symbolic_.at(dim_expr).GetSymName()};
map_expr_symbolic2dialect_symbolic_.at(dim_expr).Get<std::string>()};
}

ir::Expr TranslateDimExprImpl(const Negative<DimExpr>& dim_expr) const {
Expand Down Expand Up @@ -873,7 +874,7 @@ class MapExprToIrTranslator {
const cinn::common::Target target_;
TensorIteratorExpr4TensorT TensorIteratorExpr4Tensor;
LoopDescriptor4LoopIteratorT LoopDescriptor4LoopIterator;
std::unordered_map<SymbolicDim, ::pir::shape::SymbolicDimOp>
std::unordered_map<SymbolicDim, symbol::DimExpr>
map_expr_symbolic2dialect_symbolic_;
};

Expand Down