Skip to content

Commit

Permalink
[Prim][PIR]Add composite rule of tile for both static and dynamic sha…
Browse files Browse the repository at this point in the history
…pe (PaddlePaddle#61571)

* support dynamic shape of tile and full_like

* fix confict bug

* add squeeze and unsqueeze as primitive op

* fix dy shape

* fix code

* fix code

* debug

* debug info

* fix code

* fix bug

* remove unused code

* skip rank4 case

* fix check op
  • Loading branch information
cyber-pioneer committed Feb 20, 2024
1 parent 5b8f331 commit d7a32fa
Show file tree
Hide file tree
Showing 5 changed files with 178 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
"squeeze",
"stack",
"unsqueeze",
"tile",
]

# come into effect in generated file op_decomp.cc
Expand Down Expand Up @@ -77,6 +78,7 @@
"squeeze",
"stack",
"unsqueeze",
"tile",
]


Expand Down
13 changes: 12 additions & 1 deletion paddle/fluid/primitive/base/decomp_trans.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ std::unordered_set<std::string> decomp_op_contain_none = {"pd_op.squeeze",
"pd_op.flatten",
"pd_op.batch_norm",
"pd_op.batch_norm_"};
//
std::unordered_set<std::string> dynamic_shape_blacklist = {"pd_op.squeeze",
"pd_op.unsqueeze"};

static bool find_value(const std::vector<int64_t>& vec, int64_t value) {
if (std::find(vec.begin(), vec.end(), value) != vec.end()) {
Expand All @@ -48,6 +51,9 @@ static bool find_value(const std::vector<int64_t>& vec, int64_t value) {
}

static const phi::DDim& GetValueDims(pir::Value value) {
if (!value.type()) {
PADDLE_THROW(phi::errors::InvalidArgument("The type of value is nullptr."));
}
if (value.type().isa<DenseTensorType>()) {
return value.type().dyn_cast<DenseTensorType>().dims();
} else if (value.type().isa<SelectedRowsType>()) {
Expand Down Expand Up @@ -101,7 +107,7 @@ bool DecompProgram::check_decomp_dynamic_shape(pir::Operation* op) {
// check if initialized in case of optional input.
if (!paddle::dialect::IsEmptyValue(value)) {
pir::Operation* prev_op = value.defining_op();
if (prev_op->name() == "builtin.combine") {
if (prev_op && prev_op->name() == "builtin.combine") {
for (pir::OpOperand& sub_item : prev_op->operands()) {
if (check_dynamic_shape(sub_item, *op)) {
return true;
Expand Down Expand Up @@ -336,6 +342,11 @@ void DecompProgram::decomp_block(
check_decomp_dynamic_shape(op)) {
enable_prim = false;
}
if (enable_prim && check_decomp_dynamic_shape(op) &&
dynamic_shape_blacklist.find(op->name()) !=
dynamic_shape_blacklist.end()) {
enable_prim = false;
}
if (enable_prim) {
VLOG(4) << "[Prim] decomp op name " << op->name();
check_decomp_dynamic_shape(op);
Expand Down
93 changes: 93 additions & 0 deletions paddle/fluid/primitive/composite/composite.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#pragma once

#include <numeric>
#include "paddle/fluid/primitive/primitive/primitive.h"
#include "paddle/fluid/primitive/type/lazy_tensor.h"
#include "paddle/fluid/primitive/utils/utils.h"
Expand All @@ -25,6 +26,11 @@ namespace details {
// empty_shape means x.shape=[]
static std::vector<int64_t> empty_shape;

template <typename T>
static Tensor get_slice(const Tensor& x, int64_t idx) {
return slice<T>(x, {0}, {idx}, {idx + 1}, {1}, {});
}

template <typename T>
Tensor any_decomp(const Tensor& x, const IntArray& axis, bool keepdim) {
auto org_dtype = x.dtype();
Expand Down Expand Up @@ -825,6 +831,93 @@ std::tuple<Tensor, Tensor, Tensor> group_norm_decomp(
return std::make_tuple(out, mean_out, var_out);
}

template <typename T>
Tensor tile_decomp(const Tensor& x, const IntArray& repeat_times) {
// x.shape = [3,4] repeat_time=(a,b,c)
// shape1 = [1,3,4]
// shape2 = [1,1,1,3,1,4]
// shape3 = [a,1,b,3,c,4]
// shape4 = shape1 -> [a, b*3, c*4]
// t1 = x.reshape(shape1)
// t2 = t1.reshape(shape2)
// t3 = t2.expand(shape3)
// res = t3.reshape(t3)
std::vector<int64_t> repeat_times_ = repeat_times.GetData();
std::vector<int64_t> shape1 = common::vectorize<int64_t>(x.dims());
auto diff = int64_t(repeat_times_.size()) - int64_t(shape1.size());
Tensor t1;
if (find_value(shape1, -1)) {
size_t repeat_time_length = repeat_times_.size();
std::vector<int64_t> unsqueeze_idx2;
if (diff > 0) {
std::vector<int64_t> unsqueeze_idx1(diff);
std::iota(unsqueeze_idx1.begin(), unsqueeze_idx1.end(), 0);
t1 = unsqueeze<T>(x, unsqueeze_idx1);
} else {
t1 = x;
}
auto length2 = t1.dims().size();
for (size_t i = 0; i < repeat_times_.size(); i++) {
unsqueeze_idx2.push_back(length2 - repeat_times_.size() + i * 2);
}

Tensor t2 = unsqueeze<T>(t1, unsqueeze_idx2);
std::vector<int64_t> ref_shape(t2.dims().size(), 1);
for (size_t i = 0; i < unsqueeze_idx2.size(); i++) {
ref_shape[unsqueeze_idx2[i]] = repeat_times_[i];
}
Tensor ref_t = full<T>(ref_shape, 1.0, t2.dtype());
Tensor t3 = t2 * ref_t;
Tensor origin_shape_t = shape<T>(t1);
std::vector<int64_t> t1_shape = common::vectorize<int64_t>(t1.dims());
std::vector<Tensor> res_s;
for (int64_t i = int64_t(length2) - 1; i >= 0; i--) {
auto relative_idx =
int64_t(repeat_time_length) - 1 - int64_t(length2 - i - 1);

if (relative_idx >= 0) {
res_s.insert(
res_s.begin(),
get_slice<T>(origin_shape_t, i) * repeat_times_[relative_idx]);
} else {
res_s.insert(res_s.begin(), get_slice<T>(origin_shape_t, i));
}
}
Tensor s4 = concat<T>(res_s, 0);
return backend::reshape_with_tensor<T>(t3, s4);

} else {
if (diff > 0) {
for (int64_t i = 0; i < diff; i++) {
shape1.insert(shape1.begin(), 1);
}
}

auto length = int64_t(shape1.size());
std::vector<int64_t> shape2 = shape1;
std::vector<int64_t> shape3 = shape1;
std::vector<int64_t> final_shape = shape1;
auto r_length = repeat_times_.size();
for (size_t j = 0; j < repeat_times_.size(); j++) {
int64_t i = int64_t(j);

shape2.insert(shape2.begin() + (length - 1 - i), 1);
shape3.insert(shape3.begin() + (length - 1 - i),
repeat_times_[r_length - i - 1]);

final_shape[length - i - 1] =
final_shape[length - i - 1] * repeat_times_[r_length - i - 1];
}

t1 = reshape<T>(x, shape1);

auto t2 = reshape<T>(t1, shape2);
auto t3 = t2.expand(shape3);
auto res = reshape<T>(t3, final_shape);
return res;
}
}

template <typename T>
Tensor square_decomp(const Tensor& x) {
auto org_dtype = x.dtype();
Expand Down
25 changes: 22 additions & 3 deletions test/legacy_test/test_tile_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ def init_data(self):
self.repeat_times = [2]

def test_check_output(self):
self.check_output(check_cinn=self.check_cinn, check_pir=True)
self.check_output(
check_cinn=self.check_cinn, check_pir=True, check_prim_pir=True
)

def test_check_grad(self):
self.check_grad(
Expand Down Expand Up @@ -144,6 +146,18 @@ def init_data(self):
def if_enable_cinn(self):
self.check_cinn = True

def test_check_output(self):
# todo: enable check_prim_pir
self.check_output(check_cinn=self.check_cinn, check_pir=True)

def test_check_grad(self):
self.check_grad(
['X'],
'Out',
check_prim=True,
check_pir=True,
)


# Situation 2: repeat_times is a list (with tensor)
# CINN not support repeat_times is a tensor now
Expand Down Expand Up @@ -269,7 +283,9 @@ def init_data(self):
self.repeat_times = [2, 1, 4]

def test_check_output(self):
self.check_output(check_cinn=self.check_cinn, check_pir=True)
self.check_output(
check_cinn=self.check_cinn, check_pir=True, check_prim_pir=True
)

def test_check_grad(self):
self.check_grad(
Expand Down Expand Up @@ -307,7 +323,10 @@ def if_enable_cinn(self):
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(
place, check_cinn=self.check_cinn, check_pir=True
place,
check_cinn=self.check_cinn,
check_pir=True,
check_prim_pir=True,
)

def init_data(self):
Expand Down
60 changes: 49 additions & 11 deletions test/prim/pir_prim/test_prim_sub_graph_dynamic_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,23 @@ def stack_net(x):
return paddle.stack([x, y], axis=0)


def tile_net1(x):
y = paddle.tile(x, repeat_times=[2, 5])
return y


def tile_net2(x):
y = paddle.tile(x, repeat_times=[3, 2, 5])
return y


class TestPrimOne(unittest.TestCase):
def setUp(self):
np.random.seed(2023)
self.dtype = "float32"
self.shape_x = [1, 300, 4096]
self.x = np.random.random(self.shape_x).astype(self.dtype)
self.x_shape = [1, 300, 4096]
self.init_x_shape = [None, None, 4096]
self.x = np.random.random(self.x_shape).astype(self.dtype)
self.net = log_softmax_net
self.necessary_ops = "pd_op.log_softmax"
self.enable_cinn = False
Expand All @@ -89,7 +100,7 @@ def base_net(self, flag=None):
self.net,
use_cinn=self.enable_cinn,
input_spec=[
InputSpec(shape=[None, None, 4096], dtype='float32'),
InputSpec(shape=self.init_x_shape, dtype='float32'),
],
)
fn.eval()
Expand Down Expand Up @@ -119,8 +130,9 @@ class TestPrimOne2(TestPrimOne):
def setUp(self):
np.random.seed(2023)
self.dtype = "bool"
self.shape_x = [1, 300, 4096]
self.x = np.random.random(self.shape_x).astype(self.dtype)
self.x_shape = [1, 300, 4096]
self.init_x_shape = [None, None, 4096]
self.x = np.random.random(self.x_shape).astype(self.dtype)
self.net = any_net
self.necessary_ops = "pd_op.any"
self.enable_cinn = False
Expand All @@ -131,8 +143,8 @@ def setUp(self):
# def setUp(self):
# np.random.seed(2023)
# self.dtype = "int"
# self.shape_x = [1, 300, 4096]
# self.x = np.random.randint(0, 10, size=self.shape_x)
# self.x_shape = [1, 300, 4096]
# self.x = np.random.randint(0, 10, size=self.x_shape)
# self.net = embedding_net
# self.necessary_ops = "pd_op.embedding"
# self.enable_cinn = False
Expand All @@ -142,8 +154,9 @@ class TestPrimOne3(TestPrimOne):
def setUp(self):
np.random.seed(2023)
self.dtype = "float32"
self.shape_x = [1, 300, 4096]
self.x = np.random.random(self.shape_x).astype(self.dtype)
self.x_shape = [1, 300, 4096]
self.init_x_shape = [None, None, 4096]
self.x = np.random.random(self.x_shape).astype(self.dtype)
self.net = full_like_net
self.necessary_ops = "pd_op.full_like"
self.enable_cinn = False
Expand All @@ -153,12 +166,37 @@ class TestPrimOne4(TestPrimOne):
def setUp(self):
np.random.seed(2023)
self.dtype = "float32"
self.shape_x = [1, 300, 4096]
self.x = np.random.random(self.shape_x).astype(self.dtype)
self.x_shape = [1, 300, 4096]
self.init_x_shape = [None, None, 4096]
self.x = np.random.random(self.x_shape).astype(self.dtype)
self.net = stack_net
self.necessary_ops = "pd_op.stack"
self.enable_cinn = False


class TestPrimOne5(TestPrimOne):
def setUp(self):
np.random.seed(2023)
self.dtype = "float32"
self.x_shape = [1, 300, 4096]
self.init_x_shape = [None, None, 4096]
self.x = np.random.random(self.x_shape).astype(self.dtype)
self.net = tile_net1
self.necessary_ops = "pd_op.tile"
self.enable_cinn = False


class TestPrimOne6(TestPrimOne):
def setUp(self):
np.random.seed(2023)
self.dtype = "float32"
self.x_shape = [300, 4096]
self.init_x_shape = [None, 4096]
self.x = np.random.random(self.x_shape).astype(self.dtype)
self.net = tile_net2
self.necessary_ops = "pd_op.tile"
self.enable_cinn = False


if __name__ == "__main__":
unittest.main()

0 comments on commit d7a32fa

Please sign in to comment.