From e385d93166ae5f5c50999b4b66057480c3ac402d Mon Sep 17 00:00:00 2001 From: soulitzer Date: Tue, 12 Dec 2023 18:07:50 -0500 Subject: [PATCH] Move SingletonSymNodeImpl from c10 to aten (#114895) Pull Request resolved: https://github.com/pytorch/pytorch/pull/114895 Approved by: https://github.com/jbschlosser --- .../src/ATen}/core/SingletonSymNodeImpl.cpp | 2 +- .../src/ATen}/core/SingletonSymNodeImpl.h | 2 +- build_variables.bzl | 1 + c10/test/core/SymInt_test.cpp | 97 ---------------- test/cpp/api/CMakeLists.txt | 1 + test/cpp/api/singleton_int.cpp | 105 ++++++++++++++++++ torch/csrc/utils/python_dispatch.cpp | 3 +- 7 files changed, 111 insertions(+), 100 deletions(-) rename {c10 => aten/src/ATen}/core/SingletonSymNodeImpl.cpp (98%) rename {c10 => aten/src/ATen}/core/SingletonSymNodeImpl.h (98%) create mode 100644 test/cpp/api/singleton_int.cpp diff --git a/c10/core/SingletonSymNodeImpl.cpp b/aten/src/ATen/core/SingletonSymNodeImpl.cpp similarity index 98% rename from c10/core/SingletonSymNodeImpl.cpp rename to aten/src/ATen/core/SingletonSymNodeImpl.cpp index 601f62943ad2a..3ac668d987825 100644 --- a/c10/core/SingletonSymNodeImpl.cpp +++ b/aten/src/ATen/core/SingletonSymNodeImpl.cpp @@ -1,4 +1,4 @@ -#include +#include #include #include diff --git a/c10/core/SingletonSymNodeImpl.h b/aten/src/ATen/core/SingletonSymNodeImpl.h similarity index 98% rename from c10/core/SingletonSymNodeImpl.h rename to aten/src/ATen/core/SingletonSymNodeImpl.h index b71cb46cd8f4a..2b7764c66b40f 100644 --- a/c10/core/SingletonSymNodeImpl.h +++ b/aten/src/ATen/core/SingletonSymNodeImpl.h @@ -28,7 +28,7 @@ namespace c10 { // During tracing the strides of the outputs need to be a function of the size // and strides of the inputs so it is important that SingletonSymNode itself is // able to express this. -class C10_API SingletonSymNodeImpl : public SymNodeImpl { +class TORCH_API SingletonSymNodeImpl : public SymNodeImpl { public: // CAUTION: you should probably not be constructing these directly; please // the higher-level API in python instead (TODO: actually introduce that). diff --git a/build_variables.bzl b/build_variables.bzl index a634f640e8cbc..9d61861e3a33b 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -1021,6 +1021,7 @@ aten_cpu_source_non_codegen_list = [ "aten/src/ATen/core/operator_name.cpp", "aten/src/ATen/core/TorchDispatchUtils.cpp", "aten/src/ATen/core/register_symbols.cpp", + "aten/src/ATen/core/SingletonSymNodeImpl.cpp", "aten/src/ATen/core/class_type.cpp", "aten/src/ATen/core/type.cpp", "aten/src/ATen/core/type_factory.cpp", diff --git a/c10/test/core/SymInt_test.cpp b/c10/test/core/SymInt_test.cpp index c3a72614d0994..e4d54d1f80971 100644 --- a/c10/test/core/SymInt_test.cpp +++ b/c10/test/core/SymInt_test.cpp @@ -1,6 +1,5 @@ #include -#include #include #include @@ -23,100 +22,4 @@ TEST(SymIntTest, CheckRange) { EXPECT_FALSE(SymInt::check_range(INT64_MIN)); } -TEST(SymIntTest, SingletonSymNode) { - auto a = c10::SymInt( - c10::SymNode(c10::make_intrusive(1, 1))); - auto b = c10::SymInt( - c10::SymNode(c10::make_intrusive(1, 1))); - auto c = c10::SymInt( - c10::SymNode(c10::make_intrusive(2, 1))); - auto d = c10::SymInt(3); - - ASSERT_TRUE(a == a); - ASSERT_TRUE(a == b); - ASSERT_FALSE(a != a); - ASSERT_FALSE(a != b); - ASSERT_FALSE(a == c); - ASSERT_TRUE(a != c); - - ASSERT_FALSE(a == d); - ASSERT_TRUE(a != d); - ASSERT_FALSE(d == a); - ASSERT_TRUE(d != a); - - // ge - ASSERT_TRUE(a >= a); - ASSERT_TRUE(a >= b); - ASSERT_TRUE(b >= a); - // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) - EXPECT_THROW((void)(a >= c), c10::Error); - // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) - EXPECT_THROW((void)(c >= a), c10::Error); - // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) - EXPECT_THROW((void)(c >= 3), c10::Error); - ASSERT_TRUE(c >= 2); - ASSERT_TRUE(c >= 1); - ASSERT_FALSE(1 >= c); - - // lt - ASSERT_FALSE(a < a); - ASSERT_FALSE(a < b); - ASSERT_FALSE(b < a); - // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) - EXPECT_THROW((void)(a < c), c10::Error); - // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) - EXPECT_THROW((void)(c < a), c10::Error); - // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) - EXPECT_THROW((void)(3 < a), c10::Error); - // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) - EXPECT_THROW((void)(2 < a), c10::Error); - ASSERT_TRUE(1 < a); - - // le - ASSERT_TRUE(a <= a); - ASSERT_TRUE(b <= a); - ASSERT_TRUE(a <= b); - // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) - EXPECT_THROW((void)(a <= c), c10::Error); - // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) - EXPECT_THROW((void)(c <= a), c10::Error); - // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) - EXPECT_THROW((void)(3 <= c), c10::Error); - ASSERT_TRUE(2 <= c); - ASSERT_TRUE(1 <= c); - ASSERT_FALSE(c <= 1); - - // gt - ASSERT_FALSE(a > a); - ASSERT_FALSE(b > a); - ASSERT_FALSE(a > b); - // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) - EXPECT_THROW((void)(a > c), c10::Error); - // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) - EXPECT_THROW((void)(c > a), c10::Error); - // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) - EXPECT_THROW((void)(a > 3), c10::Error); - // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) - EXPECT_THROW((void)(a > 2), c10::Error); - ASSERT_TRUE(a > 1); -} - -TEST(SymIntTest, SingletonSymNodeWithFactor) { - auto a = c10::SymInt( - c10::SymNode(c10::make_intrusive(1, 5))); - auto b = c10::SymInt( - c10::SymNode(c10::make_intrusive(1, 10))); - // eq - ASSERT_FALSE(a == b); - ASSERT_FALSE(a >= b); - ASSERT_TRUE(b >= a); - ASSERT_TRUE(a <= b); - ASSERT_FALSE(b <= a); - // ne - ASSERT_TRUE(a != b); - // mul - ASSERT_TRUE(a * 2 == b); - ASSERT_TRUE(a * 3 >= b); - ASSERT_TRUE(a * 2 == 2 * a); -} #endif diff --git a/test/cpp/api/CMakeLists.txt b/test/cpp/api/CMakeLists.txt index a980a0de60199..2e89ec2e7c839 100644 --- a/test/cpp/api/CMakeLists.txt +++ b/test/cpp/api/CMakeLists.txt @@ -41,6 +41,7 @@ set(TORCH_API_TEST_SOURCES ${TORCH_API_TEST_DIR}/inference_mode.cpp ${TORCH_API_TEST_DIR}/grad_mode.cpp ${TORCH_API_TEST_DIR}/operations.cpp + ${TORCH_API_TEST_DIR}/singleton_int.cpp ) if(USE_CUDA OR USE_ROCM) list(APPEND TORCH_API_TEST_SOURCES ${TORCH_API_TEST_DIR}/parallel.cpp) diff --git a/test/cpp/api/singleton_int.cpp b/test/cpp/api/singleton_int.cpp new file mode 100644 index 0000000000000..152d7a72f28f2 --- /dev/null +++ b/test/cpp/api/singleton_int.cpp @@ -0,0 +1,105 @@ +#include + +#include +#include +#include +#include + +#include + +TEST(SingletonIntTest, Comparisons) { + auto a = c10::SymInt( + c10::SymNode(c10::make_intrusive(1, 1))); + auto b = c10::SymInt( + c10::SymNode(c10::make_intrusive(1, 1))); + auto c = c10::SymInt( + c10::SymNode(c10::make_intrusive(2, 1))); + auto d = c10::SymInt(3); + + ASSERT_TRUE(a == a); + ASSERT_TRUE(a == b); + ASSERT_FALSE(a != a); + ASSERT_FALSE(a != b); + ASSERT_FALSE(a == c); + ASSERT_TRUE(a != c); + + ASSERT_FALSE(a == d); + ASSERT_TRUE(a != d); + ASSERT_FALSE(d == a); + ASSERT_TRUE(d != a); + + // ge + ASSERT_TRUE(a >= a); + ASSERT_TRUE(a >= b); + ASSERT_TRUE(b >= a); + // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) + EXPECT_THROW((void)(a >= c), c10::Error); + // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) + EXPECT_THROW((void)(c >= a), c10::Error); + // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) + EXPECT_THROW((void)(c >= 3), c10::Error); + ASSERT_TRUE(c >= 2); + ASSERT_TRUE(c >= 1); + ASSERT_FALSE(1 >= c); + + // lt + ASSERT_FALSE(a < a); + ASSERT_FALSE(a < b); + ASSERT_FALSE(b < a); + // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) + EXPECT_THROW((void)(a < c), c10::Error); + // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) + EXPECT_THROW((void)(c < a), c10::Error); + // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) + EXPECT_THROW((void)(3 < a), c10::Error); + // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) + EXPECT_THROW((void)(2 < a), c10::Error); + ASSERT_TRUE(1 < a); + + // le + ASSERT_TRUE(a <= a); + ASSERT_TRUE(b <= a); + ASSERT_TRUE(a <= b); + // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) + EXPECT_THROW((void)(a <= c), c10::Error); + // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) + EXPECT_THROW((void)(c <= a), c10::Error); + // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) + EXPECT_THROW((void)(3 <= c), c10::Error); + ASSERT_TRUE(2 <= c); + ASSERT_TRUE(1 <= c); + ASSERT_FALSE(c <= 1); + + // gt + ASSERT_FALSE(a > a); + ASSERT_FALSE(b > a); + ASSERT_FALSE(a > b); + // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) + EXPECT_THROW((void)(a > c), c10::Error); + // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) + EXPECT_THROW((void)(c > a), c10::Error); + // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) + EXPECT_THROW((void)(a > 3), c10::Error); + // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) + EXPECT_THROW((void)(a > 2), c10::Error); + ASSERT_TRUE(a > 1); +} + +TEST(SingletonIntTest, WiithFactor) { + auto a = c10::SymInt( + c10::SymNode(c10::make_intrusive(1, 5))); + auto b = c10::SymInt( + c10::SymNode(c10::make_intrusive(1, 10))); + // eq + ASSERT_FALSE(a == b); + ASSERT_FALSE(a >= b); + ASSERT_TRUE(b >= a); + ASSERT_TRUE(a <= b); + ASSERT_FALSE(b <= a); + // ne + ASSERT_TRUE(a != b); + // mul + ASSERT_TRUE(a * 2 == b); + ASSERT_TRUE(a * 3 >= b); + ASSERT_TRUE(a * 2 == 2 * a); +} diff --git a/torch/csrc/utils/python_dispatch.cpp b/torch/csrc/utils/python_dispatch.cpp index 7084d15b31724..c887c822334bc 100644 --- a/torch/csrc/utils/python_dispatch.cpp +++ b/torch/csrc/utils/python_dispatch.cpp @@ -6,7 +6,9 @@ #include #include #include +#include #include + #include #include @@ -15,7 +17,6 @@ #include #include -#include #include #include #include