Skip to content

Commit

Permalink
[dtensor][6/N] change to a better/safer op registration (pytorch#90735)
Browse files Browse the repository at this point in the history
This PR changes the op registration to a better mechanism, now
we require the directly overload registration instead of the op
key str, this have several benefits:
1. We ensure that the op registration registers the correct op, which
  means it would be faild if the op registration become wrong (this PR
  already fixing several op registration errors as we use direct
  OpOverload registration
2. If the overload name get changed/deleted, we immediately know it at
  the source code compilation level, which is safer
3. This also keep it consistents with the op registration mechanism with
  other tensor subclasses within PyTorch

Differential Revision: [D42876250](https://our.internmc.facebook.com/intern/diff/D42876250)
Pull Request resolved: pytorch#90735
Approved by: https://github.com/XilunWu, https://github.com/fduwjj
  • Loading branch information
wanchaol authored and pytorchmergebot committed Feb 1, 2023
1 parent 42633cf commit 60e503d
Show file tree
Hide file tree
Showing 8 changed files with 425 additions and 429 deletions.
2 changes: 0 additions & 2 deletions test/distributed/_tensor/test_dtensor_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,6 @@ def wrapped(fn):
xfail("combinations"),
xfail("complex"),
xfail("constant_pad_nd"),
xfail("copysign"),
xfail("corrcoef"),
xfail("count_nonzero"),
xfail("cov"),
Expand Down Expand Up @@ -401,7 +400,6 @@ def wrapped(fn):
xfail("put"),
xfail("qr"),
xfail("quantile"),
xfail("rad2deg"),
xfail("rand_like"),
xfail("randint_like"),
xfail("randint"),
Expand Down
56 changes: 23 additions & 33 deletions torch/distributed/_tensor/ops/math_ops.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
from typing import cast, Optional, Sequence

import torch

from torch.distributed._tensor.op_schema import OpSchema, OutputSharding
from torch.distributed._tensor.ops.common_rules import pointwise_rule, reduction_rule
from torch.distributed._tensor.ops.utils import (
Expand All @@ -11,6 +13,9 @@
from torch.distributed._tensor.placement_types import DTensorSpec


aten = torch.ops.aten


def _infer_reduction_dims(dims_arg: object, ndim: int) -> Optional[Sequence[int]]:
if dims_arg is None:
return None
Expand All @@ -22,11 +27,17 @@ def _infer_reduction_dims(dims_arg: object, ndim: int) -> Optional[Sequence[int]
return dims


@register_prop_rule("aten.all.default")
@register_prop_rule(aten.all.default)
def default_reduction_rule(op_schema: OpSchema) -> OutputSharding:
return reduction_rule(op_schema, reduction_linear=True)


@register_prop_rule(
[
aten.sum.default,
aten.sum.dim_IntList,
]
)
def sum_rule(op_schema: OpSchema) -> OutputSharding:
args_schema = op_schema.args_schema
input_spec = cast(DTensorSpec, args_schema[0])
Expand All @@ -40,15 +51,7 @@ def sum_rule(op_schema: OpSchema) -> OutputSharding:
)


sum_ops = [
"aten.sum.default",
"aten.sum.dim_IntList",
]
for sum_op in sum_ops:
register_prop_rule(sum_op)(sum_rule)


@register_prop_rule("aten._softmax.default")
@register_prop_rule(aten._softmax.default)
def softmax_rule(op_schema: OpSchema) -> OutputSharding:
input_spec, softmax_dim, _ = op_schema.args_schema
input_spec = cast(DTensorSpec, input_spec)
Expand All @@ -59,7 +62,7 @@ def softmax_rule(op_schema: OpSchema) -> OutputSharding:
return OutputSharding(input_spec)


@register_prop_rule("aten._softmax_backward_data.default")
@register_prop_rule(aten._softmax_backward_data.default)
def softmax_bwd_rule(op_schema: OpSchema) -> OutputSharding:
grad_out_spec, out_spec, softmax_dim, _ = op_schema.args_schema
grad_out_spec = cast(DTensorSpec, grad_out_spec)
Expand All @@ -74,6 +77,7 @@ def softmax_bwd_rule(op_schema: OpSchema) -> OutputSharding:
return pointwise_rule(op_schema)


@register_prop_rule([aten.mean.default, aten.mean.dim, aten.mean.out])
def mean_rule(op_schema: OpSchema) -> OutputSharding:
args_schema = op_schema.args_schema
input_spec = cast(DTensorSpec, args_schema[0])
Expand All @@ -88,16 +92,13 @@ def mean_rule(op_schema: OpSchema) -> OutputSharding:
)


mean_ops = [
"aten.mean.default",
"aten.mean.dim",
"aten.mean.out",
]

for mean_op in mean_ops:
register_prop_rule(mean_op)(mean_rule)


@register_prop_rule(
[
aten.var.default,
aten.var.dim,
aten.var.out,
]
)
def var_rule(op_schema: OpSchema) -> OutputSharding:
args_schema = op_schema.args_schema
input_spec = cast(DTensorSpec, args_schema[0])
Expand All @@ -114,18 +115,7 @@ def var_rule(op_schema: OpSchema) -> OutputSharding:
)


var_ops = [
"aten.var.default",
"aten.var.dim",
"aten.var.out",
]

for var_op in var_ops:
register_prop_rule(var_op)(var_rule)


@register_prop_rule("aten.var.correction")
@register_prop_rule("aten.var.correction_out")
@register_prop_rule([aten.var.correction, aten.var.correction_out])
def var_correction_rule(op_schema: OpSchema) -> OutputSharding:
args_schema = op_schema.args_schema
input_spec = cast(DTensorSpec, args_schema[0])
Expand Down
15 changes: 10 additions & 5 deletions torch/distributed/_tensor/ops/matrix_ops.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
# implement matrix related ops for distributed tensor

import torch

from torch.distributed._tensor.op_schema import OpSchema, OutputSharding
from torch.distributed._tensor.ops.common_rules import einop_rule, pointwise_rule
from torch.distributed._tensor.ops.utils import register_prop_rule

aten = torch.ops.aten


def _update_schema_suggestion_for_addmm(
output_sharding: OutputSharding,
Expand Down Expand Up @@ -41,12 +46,12 @@ def _update_schema_suggestion_for_addmm(
return output_sharding


@register_prop_rule("aten.mm.default")
@register_prop_rule(aten.mm.default)
def mm_rules(op_schema: OpSchema) -> OutputSharding:
return einop_rule("mk,kn->mn", op_schema, linearity=False)


@register_prop_rule("aten.addmm.default")
@register_prop_rule(aten.addmm.default)
def addmm_rules(op_schema: OpSchema) -> OutputSharding:
input_spec, mat1_spec, mat2_spec = op_schema.args_spec
mm_out_sharding = mm_rules(
Expand Down Expand Up @@ -80,17 +85,17 @@ def addmm_rules(op_schema: OpSchema) -> OutputSharding:
return output_sharding


@register_prop_rule("aten.t.default")
@register_prop_rule(aten.t.default)
def transpose_rule(op_schema: OpSchema) -> OutputSharding:
return einop_rule("ij->ji", op_schema, linearity=True)


@register_prop_rule("aten.bmm.default")
@register_prop_rule(aten.bmm.default)
def bmm_rules(op_schema: OpSchema) -> OutputSharding:
return einop_rule("bmk,bkn->bmn", op_schema, linearity=False)


@register_prop_rule("aten.baddbmm.default")
@register_prop_rule(aten.baddbmm.default)
def baddbmm_rules(op_schema: OpSchema) -> OutputSharding:
input_spec, mat1_spec, mat2_spec = op_schema.args_spec
bmm_output_sharding = bmm_rules(
Expand Down
Loading

0 comments on commit 60e503d

Please sign in to comment.