Skip to content

Commit

Permalink
Use custom sqrt if stdc++ does not fall back to C99 csqrt (pytorch#54820
Browse files Browse the repository at this point in the history
)

Summary:
Pull Request resolved: pytorch#54820

template implementation of std::sqrt() in libstdc++ yields incorrect results for `std::complex(-std::abs(x), -0.0)`, see https://gcc.gnu.org/bugzilla/show_bug.cgi?id=89991
For example:
```
#include <iostream>
#include <complex>
int main() {
  std::cout << std::sqrt(std::complex<float>(-1.0f, -0.0f)) << std::endl;
}
```
prints `(0, -1)` if libstdc++ is compiled to use C99 csqrt/csqrtf fallback, but `(0, 1)` if configured not to use it.

Test Plan: CI

Reviewed By: luciang

Differential Revision: D27379302

fbshipit-source-id: 03f614fdb7ff734139736a2a5f6872cee0173bee
  • Loading branch information
malfet authored and facebook-github-bot committed Mar 29, 2021
1 parent 717e70a commit 68af6d9
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
4 changes: 3 additions & 1 deletion c10/util/complex_math.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
// numerical errors when arg is close to 0, pi/2, pi, or 3pi/4
// In that case provide a more conservative implementation which is
// slower but less prone to those kinds of errors
// In libstdc++ complex square root yield invalid results
// for -x-0.0j unless C99 csqrt/csqrtf fallbacks are used

#ifdef _LIBCPP_VERSION
#if defined(_LIBCPP_VERSION) || (defined(_GLIBCXX_USE_C99_COMPLEX) && !_GLIBCXX_USE_C99_COMPLEX)

namespace {
template <typename T>
Expand Down
4 changes: 2 additions & 2 deletions c10/util/complex_math.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ C10_HOST_DEVICE inline c10::complex<T> log2(const c10::complex<T> &x) {

// Power functions
//
#ifdef _LIBCPP_VERSION
#if defined(_LIBCPP_VERSION) || (defined(_GLIBCXX_USE_C99_COMPLEX) && !_GLIBCXX_USE_C99_COMPLEX)
namespace _detail {
TORCH_API c10::complex<float> sqrt(const c10::complex<float>& in);
TORCH_API c10::complex<double> sqrt(const c10::complex<double>& in);
Expand All @@ -55,7 +55,7 @@ template<typename T>
C10_HOST_DEVICE inline c10::complex<T> sqrt(const c10::complex<T> &x) {
#if defined(__CUDACC__) || defined(__HIPCC__)
return static_cast<c10::complex<T>>(thrust::sqrt(c10_internal::cuda101bug_cast_c10_complex_to_thrust_complex(x)));
#elif !defined(_LIBCPP_VERSION)
#elif !(defined(_LIBCPP_VERSION) || (defined(_GLIBCXX_USE_C99_COMPLEX) && !_GLIBCXX_USE_C99_COMPLEX))
return static_cast<c10::complex<T>>(std::sqrt(static_cast<std::complex<T>>(x)));
#else
return _detail::sqrt(x);
Expand Down

0 comments on commit 68af6d9

Please sign in to comment.