Skip to content

Commit

Permalink
port copysign to structured kernel (pytorch#55040)
Browse files Browse the repository at this point in the history
Summary:
Related pytorch#54945

This PR ports `copysign` to structured, and the `copysign.Scalar` overloads are re-dispatched to the structured kernel.

Pull Request resolved: pytorch#55040

Reviewed By: glaringlee

Differential Revision: D27465501

Pulled By: ezyang

fbshipit-source-id: 5cbabfeaaaa7ca184ae0b701b9692a918a90b117
  • Loading branch information
RockingJavaBean authored and facebook-github-bot committed Apr 1, 2021
1 parent 8b02d12 commit b880854
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 30 deletions.
36 changes: 19 additions & 17 deletions aten/src/ATen/native/BinaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ TORCH_META_FUNC2(div, Tensor_mode) (const Tensor& self, const Tensor& other, std
}
}

TORCH_META_FUNC2(copysign, Tensor) (
const Tensor& self, const Tensor& other
) {
build_binary_float_op(maybe_get_output(), self, other);
}

} // namespace meta


Expand Down Expand Up @@ -188,29 +194,25 @@ Tensor& add_relu_(Tensor& self, const Tensor& other, const Scalar& alpha) {
return add_relu_impl(self, self, other, alpha);
}

Tensor& copysign_out(const Tensor& self, const Tensor& other, Tensor& result) {
auto iter = TensorIterator::binary_float_op(result, self, other);
copysign_stub(iter.device_type(), iter);
return result;
}

Tensor copysign(const Tensor& self, const Tensor& other) {
Tensor result;
auto iter = TensorIterator::binary_float_op(result, self, other);
copysign_stub(iter.device_type(), iter);
return iter.output();
}

Tensor& copysign_(Tensor& self, const Tensor& other) {
return native::copysign_out(self, other, self);
TORCH_IMPL_FUNC(copysign_out) (
const Tensor& self, const Tensor& other, const Tensor& result
) {
copysign_stub(device_type(), *this);
}

Tensor copysign(const Tensor& self, const Scalar& other) {
return native::copysign(self, wrapped_scalar_tensor(other));
// redispatch!
return at::copysign(self, wrapped_scalar_tensor(other));
}

Tensor& copysign_(Tensor& self, const Scalar& other) {
return native::copysign_(self, wrapped_scalar_tensor(other));
// redispatch!
return self.copysign_(wrapped_scalar_tensor(other));
}

Tensor& copysign_out(const Tensor& self, const Scalar& other, Tensor& result) {
// redispatch!
return at::copysign_out(result, self, wrapped_scalar_tensor(other));
}

// WARNING: There doesn't appear to be any testing for this function
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/BinaryOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ DECLARE_DISPATCH(binary_fn, igamma_stub);
DECLARE_DISPATCH(binary_fn, igammac_stub);
DECLARE_DISPATCH(binary_fn, nextafter_stub);
DECLARE_DISPATCH(binary_fn, heaviside_stub);
DECLARE_DISPATCH(binary_fn, copysign_stub);
DECLARE_DISPATCH(structured_binary_fn, copysign_stub);
DECLARE_DISPATCH(binary_fn, xlogy_stub);

}} // namespace at::native
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cpu/BinaryOpsKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -932,7 +932,7 @@ void heaviside_kernel(TensorIterator& iter) {
});
}

void copysign_kernel(TensorIterator& iter) {
void copysign_kernel(TensorIteratorBase& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "copysign_cpu", [&]() {
cpu_kernel(iter, [](scalar_t a, scalar_t b) -> scalar_t {
return copysign(a, b);
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/CopysignKernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

namespace at { namespace native {

void copysign_kernel_cuda(TensorIterator& iter) {
void copysign_kernel_cuda(TensorIteratorBase& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "copysign_cuda", [&]() {
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
return c10::cuda::compat::copysign(a, b);
Expand Down
24 changes: 14 additions & 10 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -750,29 +750,33 @@
dispatch:
CPU, CUDA: bitwise_not_out

- func: copysign.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
structured: True
structured_inherits: TensorIteratorBase
dispatch:
CPU, CUDA: copysign_out

- func: copysign.Tensor(Tensor self, Tensor other) -> Tensor
variants: function, method
dispatch:
CPU, CUDA: copysign
structured_delegate: copysign.out

- func: copysign_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
variants: method
dispatch:
CPU, CUDA: copysign_

- func: copysign.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
CPU, CUDA: copysign_out
structured_delegate: copysign.out

- func: copysign.Scalar(Tensor self, Scalar other) -> Tensor
variants: function, method
dispatch:
CPU, CUDA: copysign
CompositeExplicitAutograd: copysign

- func: copysign_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
variants: method
dispatch:
CPU, CUDA: copysign_
CompositeExplicitAutograd: copysign_

- func: copysign.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
CompositeExplicitAutograd: copysign_out

- func: logical_not(Tensor self) -> Tensor
variants: function, method
Expand Down

0 comments on commit b880854

Please sign in to comment.