Skip to content

Commit

Permalink
Move SingletonSymNodeImpl from c10 to aten (pytorch#114895)
Browse files Browse the repository at this point in the history
Pull Request resolved: pytorch#114895
Approved by: https://github.com/jbschlosser
  • Loading branch information
soulitzer authored and dmenig committed Dec 21, 2023
1 parent fac26c2 commit e385d93
Show file tree
Hide file tree
Showing 7 changed files with 111 additions and 100 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include <c10/core/SingletonSymNodeImpl.h>
#include <ATen/core/SingletonSymNodeImpl.h>
#include <c10/core/SymNodeImpl.h>
#include <c10/util/Exception.h>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ namespace c10 {
// During tracing the strides of the outputs need to be a function of the size
// and strides of the inputs so it is important that SingletonSymNode itself is
// able to express this.
class C10_API SingletonSymNodeImpl : public SymNodeImpl {
class TORCH_API SingletonSymNodeImpl : public SymNodeImpl {
public:
// CAUTION: you should probably not be constructing these directly; please
// the higher-level API in python instead (TODO: actually introduce that).
Expand Down
1 change: 1 addition & 0 deletions build_variables.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -1021,6 +1021,7 @@ aten_cpu_source_non_codegen_list = [
"aten/src/ATen/core/operator_name.cpp",
"aten/src/ATen/core/TorchDispatchUtils.cpp",
"aten/src/ATen/core/register_symbols.cpp",
"aten/src/ATen/core/SingletonSymNodeImpl.cpp",
"aten/src/ATen/core/class_type.cpp",
"aten/src/ATen/core/type.cpp",
"aten/src/ATen/core/type_factory.cpp",
Expand Down
97 changes: 0 additions & 97 deletions c10/test/core/SymInt_test.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#include <gtest/gtest.h>

#include <c10/core/SingletonSymNodeImpl.h>
#include <c10/core/SymInt.h>
#include <c10/core/SymNodeImpl.h>

Expand All @@ -23,100 +22,4 @@ TEST(SymIntTest, CheckRange) {
EXPECT_FALSE(SymInt::check_range(INT64_MIN));
}

TEST(SymIntTest, SingletonSymNode) {
auto a = c10::SymInt(
c10::SymNode(c10::make_intrusive<c10::SingletonSymNodeImpl>(1, 1)));
auto b = c10::SymInt(
c10::SymNode(c10::make_intrusive<c10::SingletonSymNodeImpl>(1, 1)));
auto c = c10::SymInt(
c10::SymNode(c10::make_intrusive<c10::SingletonSymNodeImpl>(2, 1)));
auto d = c10::SymInt(3);

ASSERT_TRUE(a == a);
ASSERT_TRUE(a == b);
ASSERT_FALSE(a != a);
ASSERT_FALSE(a != b);
ASSERT_FALSE(a == c);
ASSERT_TRUE(a != c);

ASSERT_FALSE(a == d);
ASSERT_TRUE(a != d);
ASSERT_FALSE(d == a);
ASSERT_TRUE(d != a);

// ge
ASSERT_TRUE(a >= a);
ASSERT_TRUE(a >= b);
ASSERT_TRUE(b >= a);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW((void)(a >= c), c10::Error);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW((void)(c >= a), c10::Error);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW((void)(c >= 3), c10::Error);
ASSERT_TRUE(c >= 2);
ASSERT_TRUE(c >= 1);
ASSERT_FALSE(1 >= c);

// lt
ASSERT_FALSE(a < a);
ASSERT_FALSE(a < b);
ASSERT_FALSE(b < a);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW((void)(a < c), c10::Error);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW((void)(c < a), c10::Error);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW((void)(3 < a), c10::Error);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW((void)(2 < a), c10::Error);
ASSERT_TRUE(1 < a);

// le
ASSERT_TRUE(a <= a);
ASSERT_TRUE(b <= a);
ASSERT_TRUE(a <= b);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW((void)(a <= c), c10::Error);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW((void)(c <= a), c10::Error);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW((void)(3 <= c), c10::Error);
ASSERT_TRUE(2 <= c);
ASSERT_TRUE(1 <= c);
ASSERT_FALSE(c <= 1);

// gt
ASSERT_FALSE(a > a);
ASSERT_FALSE(b > a);
ASSERT_FALSE(a > b);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW((void)(a > c), c10::Error);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW((void)(c > a), c10::Error);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW((void)(a > 3), c10::Error);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW((void)(a > 2), c10::Error);
ASSERT_TRUE(a > 1);
}

TEST(SymIntTest, SingletonSymNodeWithFactor) {
auto a = c10::SymInt(
c10::SymNode(c10::make_intrusive<c10::SingletonSymNodeImpl>(1, 5)));
auto b = c10::SymInt(
c10::SymNode(c10::make_intrusive<c10::SingletonSymNodeImpl>(1, 10)));
// eq
ASSERT_FALSE(a == b);
ASSERT_FALSE(a >= b);
ASSERT_TRUE(b >= a);
ASSERT_TRUE(a <= b);
ASSERT_FALSE(b <= a);
// ne
ASSERT_TRUE(a != b);
// mul
ASSERT_TRUE(a * 2 == b);
ASSERT_TRUE(a * 3 >= b);
ASSERT_TRUE(a * 2 == 2 * a);
}
#endif
1 change: 1 addition & 0 deletions test/cpp/api/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ set(TORCH_API_TEST_SOURCES
${TORCH_API_TEST_DIR}/inference_mode.cpp
${TORCH_API_TEST_DIR}/grad_mode.cpp
${TORCH_API_TEST_DIR}/operations.cpp
${TORCH_API_TEST_DIR}/singleton_int.cpp
)
if(USE_CUDA OR USE_ROCM)
list(APPEND TORCH_API_TEST_SOURCES ${TORCH_API_TEST_DIR}/parallel.cpp)
Expand Down
105 changes: 105 additions & 0 deletions test/cpp/api/singleton_int.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
#include <gtest/gtest.h>

#include <ATen/core/SingletonSymNodeImpl.h>
#include <c10/core/SymInt.h>
#include <c10/core/SymNodeImpl.h>
#include <torch/torch.h>

#include <test/cpp/api/support.h>

TEST(SingletonIntTest, Comparisons) {
auto a = c10::SymInt(
c10::SymNode(c10::make_intrusive<c10::SingletonSymNodeImpl>(1, 1)));
auto b = c10::SymInt(
c10::SymNode(c10::make_intrusive<c10::SingletonSymNodeImpl>(1, 1)));
auto c = c10::SymInt(
c10::SymNode(c10::make_intrusive<c10::SingletonSymNodeImpl>(2, 1)));
auto d = c10::SymInt(3);

ASSERT_TRUE(a == a);
ASSERT_TRUE(a == b);
ASSERT_FALSE(a != a);
ASSERT_FALSE(a != b);
ASSERT_FALSE(a == c);
ASSERT_TRUE(a != c);

ASSERT_FALSE(a == d);
ASSERT_TRUE(a != d);
ASSERT_FALSE(d == a);
ASSERT_TRUE(d != a);

// ge
ASSERT_TRUE(a >= a);
ASSERT_TRUE(a >= b);
ASSERT_TRUE(b >= a);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW((void)(a >= c), c10::Error);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW((void)(c >= a), c10::Error);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW((void)(c >= 3), c10::Error);
ASSERT_TRUE(c >= 2);
ASSERT_TRUE(c >= 1);
ASSERT_FALSE(1 >= c);

// lt
ASSERT_FALSE(a < a);
ASSERT_FALSE(a < b);
ASSERT_FALSE(b < a);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW((void)(a < c), c10::Error);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW((void)(c < a), c10::Error);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW((void)(3 < a), c10::Error);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW((void)(2 < a), c10::Error);
ASSERT_TRUE(1 < a);

// le
ASSERT_TRUE(a <= a);
ASSERT_TRUE(b <= a);
ASSERT_TRUE(a <= b);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW((void)(a <= c), c10::Error);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW((void)(c <= a), c10::Error);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW((void)(3 <= c), c10::Error);
ASSERT_TRUE(2 <= c);
ASSERT_TRUE(1 <= c);
ASSERT_FALSE(c <= 1);

// gt
ASSERT_FALSE(a > a);
ASSERT_FALSE(b > a);
ASSERT_FALSE(a > b);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW((void)(a > c), c10::Error);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW((void)(c > a), c10::Error);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW((void)(a > 3), c10::Error);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
EXPECT_THROW((void)(a > 2), c10::Error);
ASSERT_TRUE(a > 1);
}

TEST(SingletonIntTest, WiithFactor) {
auto a = c10::SymInt(
c10::SymNode(c10::make_intrusive<c10::SingletonSymNodeImpl>(1, 5)));
auto b = c10::SymInt(
c10::SymNode(c10::make_intrusive<c10::SingletonSymNodeImpl>(1, 10)));
// eq
ASSERT_FALSE(a == b);
ASSERT_FALSE(a >= b);
ASSERT_TRUE(b >= a);
ASSERT_TRUE(a <= b);
ASSERT_FALSE(b <= a);
// ne
ASSERT_TRUE(a != b);
// mul
ASSERT_TRUE(a * 2 == b);
ASSERT_TRUE(a * 3 >= b);
ASSERT_TRUE(a * 2 == 2 * a);
}
3 changes: 2 additions & 1 deletion torch/csrc/utils/python_dispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
#include <ATen/FunctionalTensorWrapper.h>
#include <ATen/TensorSubclassLikeUtils.h>
#include <ATen/core/PythonOpRegistrationTrampoline.h>
#include <ATen/core/SingletonSymNodeImpl.h>
#include <ATen/core/dispatch/Dispatcher.h>

#include <ATen/functorch/BatchedTensorImpl.h>
#include <torch/library.h>

Expand All @@ -15,7 +17,6 @@
#include <torch/csrc/autograd/python_variable.h>
#include <torch/csrc/jit/python/pybind_utils.h>

#include <c10/core/SingletonSymNodeImpl.h>
#include <c10/util/flat_hash_map.h>
#include <pybind11/operators.h>
#include <pybind11/stl.h>
Expand Down

0 comments on commit e385d93

Please sign in to comment.