Skip to content

Commit

Permalink
pir support fused_conv3d (PaddlePaddle#61160)
Browse files Browse the repository at this point in the history
  • Loading branch information
wanghuancoder committed Jan 25, 2024
1 parent 712b21d commit 98c5d5a
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 1 deletion.
10 changes: 10 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/onednn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,16 @@
data_type : input
optional : bias, residual_param

- op : fused_conv3d
args : (Tensor input, Tensor filter, Tensor bias, Tensor residual_param, int[] strides={1, 1}, int[] paddings={0, 0}, str padding_algorithm="EXPLICIT", int[] dilations={1, 1}, int groups=1, str data_format="NCHW", str mkldnn_data_type="float32", str fuse_activation="", bool fuse_residual_connection=false, bool force_fp32_output=false)
output : Tensor(output)
infer_meta :
func : FusedConvInferMeta
kernel :
func : fused_conv3d
data_type : input
optional : bias, residual_param

- op : quantize
args : (Tensor input, bool is_negative_input=false, float scale=1.0, float shift=0.0, str output_format="NHWC", bool bfloat16=false)
output : Tensor(output)
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/ops_onednn_extra.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@
extra_args : float fuse_alpha = 0.0, float fuse_beta = 0.0, float scale_in=1.0, float scale_out=1.0, float scale_in_eltwise=1.0, float[] scale_weights={1.0f}
data_format_tensors : input

- op : fused_conv3d
extra_args : float fuse_alpha = 0.0, float fuse_beta = 0.0, float scale_in=1.0, float scale_out=1.0, float scale_in_eltwise=1.0, float[] scale_weights={1.0f}
data_format_tensors : input

- op : lrn
extra_args : bool is_test=false
data_format_tensors : x
Expand Down
9 changes: 8 additions & 1 deletion paddle/phi/api/yaml/op_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1336,8 +1336,15 @@
fused_conv2d_add_act : GetConvExpectedKernelType

- op : fused_conv3d
inputs :
{input : Input, filter : Filter, bias : Bias, residual_param : ResidualData}
outputs :
{output : Output}
attrs :
{scale_in : Scale_in, scale_out : Scale_out, scale_in_eltwise : Scale_in_eltwise, scale_weights : Scale_weights}
extra :
attrs : [bool use_mkldnn = true, str mkldnn_data_type = "float32"]
attrs : [bool use_cudnn = false, float fuse_alpha = 0.0f, float fuse_beta = 0.0f, float Scale_in = 1.0f,
float Scale_out = 1.0f, float Scale_in_eltwise = 1.0f, 'float[] Scale_weights = {1.0f}', bool use_mkldnn = true, str mkldnn_data_type = "float32"]

- op : fused_embedding_eltwise_layernorm
inputs :
Expand Down

0 comments on commit 98c5d5a

Please sign in to comment.