Skip to content

Commit

Permalink
Land "Make ceil,floor,round,trunc handle integers" (pytorch#85144)
Browse files Browse the repository at this point in the history
PR to land pytorch#78480, as Rohit does
not work in the PyTorch project anymore
Pull Request resolved: pytorch#85144
Approved by: https://github.com/ngimel, https://github.com/mruberry
  • Loading branch information
lezcano authored and pytorchmergebot committed Sep 21, 2022
1 parent 6f2b390 commit 2a88f1b
Show file tree
Hide file tree
Showing 7 changed files with 111 additions and 28 deletions.
2 changes: 1 addition & 1 deletion .github/ci_commit_pins/xla.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
307af4313d2b0b0236618ef837959a41068cc272
revert-4019-revert-3913-ceil_forr_round_trunc_int
19 changes: 15 additions & 4 deletions aten/src/ATen/native/UnaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,14 +157,28 @@ TORCH_IMPL_FUNC(func_out) (const Tensor& self, const Tensor& result) { \
func_stub(device_type(), *this); \
}

// This macro is as optional as the one above. torch.(ceil|floor|round|trunc) are no-ops for integers
// See gh-70918
#define CREATE_UNARY_TORCH_IMPL_INTEGER_NO_OP_FUNC(func_out, func_stub) \
TORCH_IMPL_FUNC(func_out) (const Tensor& self, const Tensor& result) { \
if (c10::isIntegralType(self.scalar_type(), /*includeBool=*/false)) { \
result.copy_(self); \
} else { \
func_stub(device_type(), *this); \
} \
}
CREATE_UNARY_TORCH_IMPL_INTEGER_NO_OP_FUNC(ceil_out, ceil_stub)
CREATE_UNARY_TORCH_IMPL_INTEGER_NO_OP_FUNC(floor_out, floor_stub)
CREATE_UNARY_TORCH_IMPL_INTEGER_NO_OP_FUNC(round_out, round_stub)
CREATE_UNARY_TORCH_IMPL_INTEGER_NO_OP_FUNC(trunc_out, trunc_stub)

CREATE_UNARY_TORCH_IMPL_FUNC(acos_out, acos_stub)
CREATE_UNARY_TORCH_IMPL_FUNC(acosh_out, acosh_stub)
CREATE_UNARY_TORCH_IMPL_FUNC(asin_out, asin_stub)
CREATE_UNARY_TORCH_IMPL_FUNC(asinh_out, asinh_stub)
CREATE_UNARY_TORCH_IMPL_FUNC(atan_out, atan_stub)
CREATE_UNARY_TORCH_IMPL_FUNC(atanh_out, atanh_stub)
CREATE_UNARY_TORCH_IMPL_FUNC(bitwise_not_out, bitwise_not_stub)
CREATE_UNARY_TORCH_IMPL_FUNC(ceil_out, ceil_stub)
CREATE_UNARY_TORCH_IMPL_FUNC(cos_out, cos_stub)
CREATE_UNARY_TORCH_IMPL_FUNC(cosh_out, cosh_stub)
CREATE_UNARY_TORCH_IMPL_FUNC(digamma_out, digamma_stub)
Expand All @@ -174,7 +188,6 @@ CREATE_UNARY_TORCH_IMPL_FUNC(erfinv_out, erfinv_stub)
CREATE_UNARY_TORCH_IMPL_FUNC(exp_out, exp_stub)
CREATE_UNARY_TORCH_IMPL_FUNC(exp2_out, exp2_stub)
CREATE_UNARY_TORCH_IMPL_FUNC(expm1_out, expm1_stub)
CREATE_UNARY_TORCH_IMPL_FUNC(floor_out, floor_stub)
CREATE_UNARY_TORCH_IMPL_FUNC(frac_out, frac_stub)
CREATE_UNARY_TORCH_IMPL_FUNC(i0_out, i0_stub)
CREATE_UNARY_TORCH_IMPL_FUNC(lgamma_out, lgamma_stub)
Expand All @@ -184,7 +197,6 @@ CREATE_UNARY_TORCH_IMPL_FUNC(log1p_out, log1p_stub)
CREATE_UNARY_TORCH_IMPL_FUNC(log2_out, log2_stub)
CREATE_UNARY_TORCH_IMPL_FUNC(neg_out, neg_stub)
CREATE_UNARY_TORCH_IMPL_FUNC(reciprocal_out, reciprocal_stub)
CREATE_UNARY_TORCH_IMPL_FUNC(round_out, round_stub)
CREATE_UNARY_TORCH_IMPL_FUNC(rsqrt_out, rsqrt_stub)
CREATE_UNARY_TORCH_IMPL_FUNC(sigmoid_out, sigmoid_stub)
CREATE_UNARY_TORCH_IMPL_FUNC(sign_out, sign_stub)
Expand All @@ -201,7 +213,6 @@ CREATE_UNARY_TORCH_IMPL_FUNC(special_log_ndtr_out, special_log_ndtr_stub)
CREATE_UNARY_TORCH_IMPL_FUNC(sqrt_out, sqrt_stub)
CREATE_UNARY_TORCH_IMPL_FUNC(tan_out, tan_stub)
CREATE_UNARY_TORCH_IMPL_FUNC(tanh_out, tanh_stub)
CREATE_UNARY_TORCH_IMPL_FUNC(trunc_out, trunc_stub)
CREATE_UNARY_TORCH_IMPL_FUNC(special_airy_ai_out, special_airy_ai_stub)
CREATE_UNARY_TORCH_IMPL_FUNC(special_bessel_j0_out, special_bessel_j0_stub)
CREATE_UNARY_TORCH_IMPL_FUNC(special_bessel_j1_out, special_bessel_j1_stub)
Expand Down
11 changes: 7 additions & 4 deletions test/test_jit_fuser_te.py
Original file line number Diff line number Diff line change
Expand Up @@ -1393,10 +1393,13 @@ def apply(fn):
torch.sqrt,
torch.rsqrt,
torch.abs,
torch.ceil,
torch.floor,
torch.round,
torch.trunc,
# TODO broken on int8 since
# https://github.com/pytorch/pytorch/pull/85144
# RuntimeError: Invalid integral op_type: 23
# torch.ceil,
# torch.floor,
# torch.round,
# torch.trunc,
torch.frac,
# TODO: broken on ROCm?
# F.hardshrink,
Expand Down
5 changes: 3 additions & 2 deletions test/test_type_promotion.py
Original file line number Diff line number Diff line change
Expand Up @@ -1063,7 +1063,7 @@ def test_unary_op_out_casting(self, device, dtypes):
out = torch.empty(0, dtype=dtypes[1], device=device)

ops = (torch.neg, torch.floor, torch.ceil)
float_only_ops = {torch.floor, torch.ceil}
float_and_int_only_ops = {torch.floor, torch.ceil}
real_only_ops = {torch.floor, torch.ceil}
for op in ops:
if dtypes[0] is not dtypes[1]:
Expand All @@ -1073,8 +1073,9 @@ def test_unary_op_out_casting(self, device, dtypes):
with self.assertRaises(RuntimeError):
op(t, out=out)
elif (
op in float_only_ops
op in float_and_int_only_ops
and (not dtypes[0].is_floating_point and not dtypes[0].is_complex)
and (not (dtypes[0] == torch.int64 and dtypes[1] == torch.int64))
and device != "meta"
):
with self.assertRaises(RuntimeError):
Expand Down
2 changes: 1 addition & 1 deletion torch/_refs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,7 +814,7 @@ def tanh(a):
return prims.tanh(a)


@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
def trunc(a):
return prims.trunc(a)

Expand Down
12 changes: 12 additions & 0 deletions torch/_torch_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2327,6 +2327,9 @@ def merge_dicts(*dicts):
Returns a new tensor with the ceil of the elements of :attr:`input`,
the smallest integer greater than or equal to each element.
For integer inputs, follows the array-api convention of returning a
copy of the input tensor.
.. math::
\text{out}_{i} = \left\lceil \text{input}_{i} \right\rceil
"""
Expand Down Expand Up @@ -4164,6 +4167,9 @@ def merge_dicts(*dicts):
Returns a new tensor with the floor of the elements of :attr:`input`,
the largest integer less than or equal to each element.
For integer inputs, follows the array-api convention of returning a
copy of the input tensor.
.. math::
\text{out}_{i} = \left\lfloor \text{input}_{i} \right\rfloor
"""
Expand Down Expand Up @@ -9631,6 +9637,9 @@ def merge_dicts(*dicts):
Rounds elements of :attr:`input` to the nearest integer.
For integer inputs, follows the array-api convention of returning a
copy of the input tensor.
.. note::
This function implements the "round half to even" to
break ties when a number is equidistant from two
Expand Down Expand Up @@ -12010,6 +12019,9 @@ def merge_dicts(*dicts):
Returns a new tensor with the truncated integer values of
the elements of :attr:`input`.
For integer inputs, follows the array-api convention of returning a
copy of the input tensor.
Args:
{input}
Expand Down
88 changes: 72 additions & 16 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -7523,8 +7523,8 @@ def sample_inputs_max_unpool_grad(op_info, device, dtype, requires_grad, **kwarg

ForeachFuncInfo(
'ceil',
dtypes=floating_types_and(torch.bfloat16),
dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
dtypes=all_types_and(torch.bfloat16),
dtypesIfCUDA=all_types_and(torch.half, torch.bfloat16),
),

ForeachFuncInfo(
Expand All @@ -7547,8 +7547,8 @@ def sample_inputs_max_unpool_grad(op_info, device, dtype, requires_grad, **kwarg

ForeachFuncInfo(
'floor',
dtypes=floating_types_and(torch.bfloat16),
dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
dtypes=all_types_and(torch.bfloat16),
dtypesIfCUDA=all_types_and(torch.half, torch.bfloat16),
),

ForeachFuncInfo(
Expand All @@ -7559,8 +7559,8 @@ def sample_inputs_max_unpool_grad(op_info, device, dtype, requires_grad, **kwarg

ForeachFuncInfo(
'round',
dtypes=floating_types_and(torch.bfloat16),
dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
dtypes=all_types_and(torch.bfloat16),
dtypesIfCUDA=all_types_and(torch.half, torch.bfloat16),
),

ForeachFuncInfo(
Expand All @@ -7583,8 +7583,8 @@ def sample_inputs_max_unpool_grad(op_info, device, dtype, requires_grad, **kwarg

ForeachFuncInfo(
'trunc',
dtypes=floating_types_and(torch.bfloat16),
dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
dtypes=all_types_and(torch.bfloat16),
dtypesIfCUDA=all_types_and(torch.half, torch.bfloat16),
),

ForeachFuncInfo(
Expand Down Expand Up @@ -8755,10 +8755,21 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
sample_inputs_func=sample_inputs_cdist),
UnaryUfuncInfo('ceil',
ref=np.ceil,
dtypes=floating_types_and(torch.bfloat16),
dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
dtypes=all_types_and(torch.bfloat16),
dtypesIfCUDA=all_types_and(torch.half, torch.bfloat16),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
skips=(
DecorateInfo(unittest.expectedFailure,
'TestNNCOpInfo',
'test_nnc_correctness',
dtypes=tuple(t for t in integral_types() if t != torch.uint8)),
DecorateInfo(unittest.expectedFailure,
'TestCudaFuserOpInfo',
'test_nvfuser_correctness',
dtypes=(torch.int32, torch.int64),
active_if=not TEST_WITH_ROCM),
),
supports_sparse=True,
supports_sparse_csr=True,
supports_sparse_csc=True,
Expand Down Expand Up @@ -9432,10 +9443,21 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
)),
UnaryUfuncInfo('floor',
ref=np.floor,
dtypes=floating_types_and(torch.bfloat16),
dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
dtypes=all_types_and(torch.bfloat16),
dtypesIfCUDA=all_types_and(torch.half, torch.bfloat16),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
skips=(
DecorateInfo(unittest.expectedFailure,
'TestNNCOpInfo',
'test_nnc_correctness',
dtypes=tuple(t for t in integral_types() if t != torch.uint8)),
DecorateInfo(unittest.expectedFailure,
'TestCudaFuserOpInfo',
'test_nvfuser_correctness',
dtypes=(torch.int32, torch.int64),
active_if=not TEST_WITH_ROCM),
),
supports_sparse=True,
supports_sparse_csr=True,
supports_sparse_csc=True,
Expand Down Expand Up @@ -12407,10 +12429,21 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
UnaryUfuncInfo('round',
ref=np.round,
aliases=('special.round',),
dtypes=floating_types_and(torch.bfloat16),
dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
dtypes=all_types_and(torch.bfloat16),
dtypesIfCUDA=all_types_and(torch.half, torch.bfloat16),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
skips=(
DecorateInfo(unittest.expectedFailure,
'TestNNCOpInfo',
'test_nnc_correctness',
dtypes=tuple(t for t in integral_types() if t != torch.uint8)),
DecorateInfo(unittest.expectedFailure,
'TestCudaFuserOpInfo',
'test_nvfuser_correctness',
dtypes=(torch.int32, torch.int64),
active_if=not TEST_WITH_ROCM),
),
supports_sparse=True,
supports_sparse_csr=True,
supports_sparse_csc=True,
Expand Down Expand Up @@ -12936,11 +12969,22 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
UnaryUfuncInfo('trunc',
aliases=('fix', ),
ref=np.trunc,
dtypes=floating_types_and(torch.bfloat16),
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
dtypes=all_types_and(torch.bfloat16),
dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
supports_sparse=True,
skips=(
DecorateInfo(unittest.expectedFailure,
'TestNNCOpInfo',
'test_nnc_correctness',
dtypes=tuple(t for t in integral_types() if t != torch.uint8)),
DecorateInfo(unittest.expectedFailure,
'TestCudaFuserOpInfo',
'test_nvfuser_correctness',
dtypes=(torch.int32, torch.int64),
active_if=not TEST_WITH_ROCM),
),
supports_sparse_csr=True,
supports_sparse_csc=True,
supports_sparse_bsr=True,
Expand Down Expand Up @@ -16380,6 +16424,9 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
ElementwiseUnaryPythonRefInfo(
"_refs.ceil",
torch_opinfo_name="ceil",
# Fails on int32
# https://github.com/pytorch/pytorch/issues/85258
supports_nvfuser=False,
),
ElementwiseUnaryPythonRefInfo(
"_refs.conj_physical",
Expand Down Expand Up @@ -16434,6 +16481,9 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
ElementwiseUnaryPythonRefInfo(
"_refs.floor",
torch_opinfo_name="floor",
# Fails on int32
# https://github.com/pytorch/pytorch/issues/85258
supports_nvfuser=False,
),
ElementwiseUnaryPythonRefInfo(
"_refs.frac",
Expand Down Expand Up @@ -16554,6 +16604,9 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
ElementwiseUnaryPythonRefInfo(
"_refs.round",
torch_opinfo_name="round",
# Fails on int32
# https://github.com/pytorch/pytorch/issues/85258
supports_nvfuser=False,
),
ElementwiseUnaryPythonRefInfo(
"_refs.rsqrt",
Expand Down Expand Up @@ -16621,6 +16674,9 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
ElementwiseUnaryPythonRefInfo(
"_refs.trunc",
torch_opinfo_name="trunc",
# Fails on int32
# https://github.com/pytorch/pytorch/issues/85258
supports_nvfuser=False,
),
#
# Elementwise Unary Special OpInfos
Expand Down

0 comments on commit 2a88f1b

Please sign in to comment.