Skip to content

Commit

Permalink
support fused vjp gen in pir (PaddlePaddle#60893)
Browse files Browse the repository at this point in the history
  • Loading branch information
Charles-hit committed Jan 25, 2024
1 parent 89d19f5 commit 7faf94c
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,4 @@

vjp_interface_black_list = [
'silu_grad',
'fused_dropout_add',
'fused_rotary_position_embedding',
'fused_bias_dropout_residual_layer_norm',
'fused_dot_product_attention',
'max_pool2d_v2',
]
11 changes: 9 additions & 2 deletions paddle/fluid/primitive/codegen/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ set(update_fwd_pd_op_path
set(rev_pd_op_path
${PADDLE_BINARY_DIR}/paddle/fluid/pir/dialect/operator/ir/generated/ops_backward.parsed.yaml
)
set(fused_op_path
${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/fused_ops.parsed.yaml
)
set(fused_rev_op_path
${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/fused_backward.parsed.yaml
)
set(prim_path "${PADDLE_SOURCE_DIR}/paddle/fluid/primitive/primitive.yaml")
set(templates_dir
"${PADDLE_SOURCE_DIR}/paddle/fluid/primitive/codegen/templates/")
Expand All @@ -25,8 +31,9 @@ execute_process(
${PYTHON_EXECUTABLE} ${scripts} --fwd_path ${fwd_path} --rev_path
${rev_path} --fwd_pd_op_path ${fwd_pd_op_path} --update_fwd_pd_op_path
${update_fwd_pd_op_path} --rev_pd_op_path ${rev_pd_op_path} --prim_path
${prim_path} --templates_dir ${templates_dir} --compat_path ${compat_path}
--destination_dir ${destination_dir}
${prim_path} --fused_op_path ${fused_op_path} --fused_rev_op_path
${fused_rev_op_path} --templates_dir ${templates_dir} --compat_path
${compat_path} --destination_dir ${destination_dir}
RESULT_VARIABLE _result)
if(${_result})
message(
Expand Down
26 changes: 24 additions & 2 deletions paddle/fluid/primitive/codegen/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,8 @@ def gen(
fwd_pd_op_path: pathlib.Path,
update_fwd_pd_op_path: pathlib.Path,
rev_pd_op_path: pathlib.Path,
fused_op_path: pathlib.Path,
fused_rev_path: pathlib.Path,
templates_dir: pathlib.Path,
destination_dir: pathlib.Path,
):
Expand All @@ -368,6 +370,8 @@ def gen(
fwd_pd_op_path (pathlib.Path): The YAML file path of the ir forward API.
update_fwd_pd_op_path (pathlib.Path): The YAML file path of the ir update_ops.
rev_pd_op_path (pathlib.Path): The YAML file path of the ir backward API.
fused_op_path (pathlib.Path): The YAML file path of the fused API.
fused_rev_path (pathlib.Path): The YAML file path of the fused backward API.
templates_dir (pathlib.Path): The directory of the templates.
destination_dir (pathlib.Path): The Directory of the generated file.
Expand All @@ -382,6 +386,8 @@ def gen(
ir_fwds,
ir_revs,
ir_update_fwds,
fused_fwds,
fused_revs,
) = (
load(prim_path),
load(fwd_path),
Expand All @@ -390,13 +396,17 @@ def gen(
load(fwd_pd_op_path),
load(rev_pd_op_path),
load(update_fwd_pd_op_path),
load(fused_op_path),
load(fused_rev_path),
)
filter_compat_info(compats)

fwd_apis = fwds + ir_fwds + ir_update_fwds
fwd_apis = fwds + ir_fwds + ir_update_fwds + fused_fwds

apis = [{**api, **{'is_fwd': True}} for api in fwd_apis]
apis = apis + [{**api, **{'is_fwd': False}} for api in revs + ir_revs]
apis = apis + [
{**api, **{'is_fwd': False}} for api in revs + ir_revs + fused_revs
]
apis = [
{**api, **{'is_prim': True}}
if api['name'] in prims
Expand Down Expand Up @@ -452,6 +462,16 @@ def gen(
type=str,
help='The ir backward ops parsed yaml file.',
)
parser.add_argument(
'--fused_op_path',
type=str,
help='The parsed fused forward ops yaml file.',
)
parser.add_argument(
'--fused_rev_op_path',
type=str,
help='The parsed fused backward ops yaml file.',
)
parser.add_argument(
'--templates_dir',
type=str,
Expand All @@ -472,6 +492,8 @@ def gen(
pathlib.Path(args.fwd_pd_op_path),
pathlib.Path(args.update_fwd_pd_op_path),
pathlib.Path(args.rev_pd_op_path),
pathlib.Path(args.fused_op_path),
pathlib.Path(args.fused_rev_op_path),
pathlib.Path(args.templates_dir),
pathlib.Path(args.destination_dir),
)

0 comments on commit 7faf94c

Please sign in to comment.