diff --git a/aten/src/ATen/core/TensorMethods.h b/aten/src/ATen/core/TensorMethods.h index 4b35abff69081..f78b02b5b4633 100644 --- a/aten/src/ATen/core/TensorMethods.h +++ b/aten/src/ATen/core/TensorMethods.h @@ -1810,10 +1810,10 @@ inline bool is_quantized(Tensor self) { return self.is_quantized(); } -#define DEFINE_CAST(T, name, _) \ +#define DEFINE_CAST(T, name) \ template <> \ inline T* Tensor::data() const { \ - TORCH_CHECK( \ + TORCH_CHECK( \ scalar_type() == ScalarType::name, \ "expected scalar type ", \ #name, \ @@ -1826,7 +1826,7 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF(DEFINE_CAST) AT_FORALL_QINT_TYPES(DEFINE_CAST) #undef DEFINE_CAST -#define DEFINE_ITEM(T, name, _) \ +#define DEFINE_ITEM(T, name) \ template <> \ inline T Tensor::item() const { \ return item().to##name(); \ diff --git a/aten/src/ATen/native/TensorFactories.cpp b/aten/src/ATen/native/TensorFactories.cpp index 58ad3445c7ecf..3f68f2d545427 100644 --- a/aten/src/ATen/native/TensorFactories.cpp +++ b/aten/src/ATen/native/TensorFactories.cpp @@ -162,7 +162,7 @@ Tensor& empty_out( // specialized operators for each datatype. // TODO: remove when we have Type support in the IR -#define DEFINE_CAST_OP(_1, n, _2) \ +#define DEFINE_CAST_OP(_1, n) \ Tensor _cast_##n(const Tensor& self, bool non_blocking) { \ if (self.scalar_type() == ScalarType::n) \ return self; \ @@ -798,7 +798,7 @@ Tensor tensor_cuda(ArrayRef values, const TensorOptions& options) { return cpu_tensor.to(options.device()); } -#define TENSOR(T, _1, _2) \ +#define TENSOR(T, _1) \ Tensor tensor(ArrayRef values, const TensorOptions& options) { \ if (options.device().is_cuda()) { \ return tensor_cuda(values, options); \ diff --git a/aten/src/ATen/templates/NativeFunctions.h b/aten/src/ATen/templates/NativeFunctions.h index 654bb1667ebb1..75b3757e3bead 100644 --- a/aten/src/ATen/templates/NativeFunctions.h +++ b/aten/src/ATen/templates/NativeFunctions.h @@ -26,7 +26,7 @@ namespace at { namespace native { // These functions are defined in native/TensorFactories.cpp. -#define TENSOR(T, S, _1) \ +#define TENSOR(T, S) \ CAFFE2_API Tensor tensor(ArrayRef values, const TensorOptions& options); \ inline Tensor tensor( \ std::initializer_list values, const TensorOptions& options) { \ diff --git a/aten/src/ATen/templates/TensorMethods.h b/aten/src/ATen/templates/TensorMethods.h index 974d7a4be97ab..7b242854f9ba5 100644 --- a/aten/src/ATen/templates/TensorMethods.h +++ b/aten/src/ATen/templates/TensorMethods.h @@ -143,10 +143,10 @@ inline bool is_quantized(Tensor self) { return self.is_quantized(); } -#define DEFINE_CAST(T, name, _) \ +#define DEFINE_CAST(T, name) \ template <> \ inline T* Tensor::data() const { \ - TORCH_CHECK( \ + TORCH_CHECK( \ scalar_type() == ScalarType::name, \ "expected scalar type ", \ #name, \ @@ -159,7 +159,7 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF(DEFINE_CAST) AT_FORALL_QINT_TYPES(DEFINE_CAST) #undef DEFINE_CAST -#define DEFINE_ITEM(T, name, _) \ +#define DEFINE_ITEM(T, name) \ template <> \ inline T Tensor::item() const { \ return item().to##name(); \ diff --git a/aten/src/TH/THBlasUtils.h b/aten/src/TH/THBlasUtils.h index 47b74f85ac824..b7edb5b4fa6c6 100644 --- a/aten/src/TH/THBlasUtils.h +++ b/aten/src/TH/THBlasUtils.h @@ -9,7 +9,7 @@ template inline void THBlas_axpy(int64_t n, T a, T *x, int64_t incx, T *y, int64_t incy); -#define AXPY_SPECIALIZATION(ctype,name,_1) \ +#define AXPY_SPECIALIZATION(ctype,name) \ template<> \ inline void THBlas_axpy(int64_t n, ctype a, ctype *x, int64_t incx, \ ctype *y, int64_t incy) { \ @@ -22,7 +22,7 @@ AT_FORALL_SCALAR_TYPES(AXPY_SPECIALIZATION) template inline void THBlas_copy(int64_t n, T *x, int64_t incx, T *y, int64_t incy); -#define COPY_SPECIALIZATION(ctype,name,_1) \ +#define COPY_SPECIALIZATION(ctype,name) \ template<> \ inline void THBlas_copy(int64_t n, ctype *x, int64_t incx, \ ctype *y, int64_t incy) { \ @@ -34,7 +34,7 @@ AT_FORALL_SCALAR_TYPES(COPY_SPECIALIZATION) template inline T THBlas_dot(int64_t n, T *x, int64_t incx, T *y, int64_t incy); -#define DOT_SPECIALIZATION(ctype,name,_1) \ +#define DOT_SPECIALIZATION(ctype,name) \ template<> \ inline ctype THBlas_dot(int64_t n, ctype *x, int64_t incx, ctype *y, int64_t incy) { \ return TH ## name ## Blas_dot(n, x, incx, y, incy); \ @@ -58,7 +58,7 @@ inline void THBlas_gemm( T *c, int64_t ldc); -#define GEMM_SPECIALIZATION(ctype,name,_1) \ +#define GEMM_SPECIALIZATION(ctype,name) \ template<> \ inline void THBlas_gemm( \ char transa, \ @@ -94,7 +94,7 @@ inline void THBlas_gemv( T* y, int64_t incy); -#define GEMV_SPECIALIZATION(ctype, name, _1) \ +#define GEMV_SPECIALIZATION(ctype, name) \ template <> \ inline void THBlas_gemv( \ char transa, \ diff --git a/c10/core/Scalar.h b/c10/core/Scalar.h index afe36b5da98bd..984821af11962 100644 --- a/c10/core/Scalar.h +++ b/c10/core/Scalar.h @@ -24,9 +24,10 @@ class C10_API Scalar { public: Scalar() : Scalar(int64_t(0)) {} -#define DEFINE_IMPLICIT_CTOR(type, name, member) \ +#define DEFINE_IMPLICIT_CTOR(type, name) \ Scalar(type vv) : Scalar(vv, true) { } - AT_FORALL_SCALAR_TYPES_AND(c10::ScalarType::BFloat16, DEFINE_IMPLICIT_CTOR) + + AT_FORALL_SCALAR_TYPES_AND(BFloat16, DEFINE_IMPLICIT_CTOR) #undef DEFINE_IMPLICIT_CTOR @@ -53,7 +54,7 @@ class C10_API Scalar { #undef DEFINE_IMPLICIT_COMPLEX_CTOR -#define DEFINE_ACCESSOR(type, name, member) \ +#define DEFINE_ACCESSOR(type, name) \ type to##name() const { \ if (Tag::HAS_d == tag) { \ return checked_convert(v.d, #type); \ @@ -122,7 +123,7 @@ inline T Scalar::to() { throw std::runtime_error("to() cast to unexpected type."); } -#define DEFINE_TO(T, name, _) \ +#define DEFINE_TO(T, name) \ template <> \ inline T Scalar::to() { \ return to##name(); \ diff --git a/c10/core/ScalarType.h b/c10/core/ScalarType.h index 4f85d19de552d..8539489c7487a 100644 --- a/c10/core/ScalarType.h +++ b/c10/core/ScalarType.h @@ -12,61 +12,54 @@ namespace c10 { // For the macros below: -// NB: QInt ScalarTypes are referred to as "STUBS" here since they do not -// contain complete information to determine the tensor value of the data, -// they are just stubs for dispatch / quantization. -// NB: If you want to macro some code for all non-stub scalar types, you -// probably want one of the AT_FORALL_SCALAR_TYPES / AT_FORALL_SCALAR_TYPES_AND +// NB: If you want to macro some code for all non-QInt scalar types (i.e. types +// with complete information, you probably want one of the +// AT_FORALL_SCALAR_TYPES / AT_FORALL_SCALAR_TYPES_AND // macros below, which are designed to behave similarly to the Dispatch macros // with the same name. // NB: Order matters for this macro; it is relied upon in // _promoteTypesLookup and the serialization format. -#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_STUBS(_) \ - _(uint8_t, Byte, __) /* 0 */ \ - _(int8_t, Char, __) /* 1 */ \ - _(int16_t, Short, __) /* 2 */ \ - _(int, Int, __) /* 3 */ \ - _(int64_t, Long, __) /* 4 */ \ - _(at::Half, Half, __) /* 5 */ \ - _(float, Float, __) /* 6 */ \ - _(double, Double, __) /* 7 */ \ - _(at::ComplexHalf, ComplexHalf, __) /* 8 */ \ - _(std::complex, ComplexFloat, __) /* 9 */ \ - _(std::complex, ComplexDouble, __) /* 10 */ \ - _(bool, Bool, __) /* 11 */ \ - _(c10::qint8, QInt8, __) /* 12 */ \ - _(c10::quint8, QUInt8, __) /* 13 */ \ - _(c10::qint32, QInt32, __) /* 14 */ \ - _(at::BFloat16, BFloat16, __) /* 15 */ +#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(_) \ + _(uint8_t, Byte) /* 0 */ \ + _(int8_t, Char) /* 1 */ \ + _(int16_t, Short) /* 2 */ \ + _(int, Int) /* 3 */ \ + _(int64_t, Long) /* 4 */ \ + _(at::Half, Half) /* 5 */ \ + _(float, Float) /* 6 */ \ + _(double, Double) /* 7 */ \ + _(at::ComplexHalf, ComplexHalf) /* 8 */ \ + _(std::complex, ComplexFloat) /* 9 */ \ + _(std::complex, ComplexDouble) /* 10 */ \ + _(bool, Bool) /* 11 */ \ + _(c10::qint8, QInt8) /* 12 */ \ + _(c10::quint8, QUInt8) /* 13 */ \ + _(c10::qint32, QInt32) /* 14 */ \ + _(at::BFloat16, BFloat16) /* 15 */ // If you want to support ComplexHalf for real, add ComplexHalf // into this macro (and change the name). But beware: convert() // doesn't work for all the conversions you need... #define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF(_) \ - _(uint8_t, Byte, __) \ - _(int8_t, Char, __) \ - _(int16_t, Short, __) \ - _(int, Int, __) \ - _(int64_t, Long, __) \ - _(at::Half, Half, __) \ - _(float, Float, __) \ - _(double, Double, __) \ - _(std::complex, ComplexFloat, __) \ - _(std::complex, ComplexDouble, __) \ - _(bool, Bool, __) \ - _(at::BFloat16, BFloat16, __) + _(uint8_t, Byte) \ + _(int8_t, Char) \ + _(int16_t, Short) \ + _(int, Int) \ + _(int64_t, Long) \ + _(at::Half, Half) \ + _(float, Float) \ + _(double, Double) \ + _(std::complex, ComplexFloat) \ + _(std::complex, ComplexDouble) \ + _(bool, Bool) \ + _(at::BFloat16, BFloat16) -#define AT_FORALL_QINT_TYPES(_) \ - _(c10::qint8, QInt8, i) \ - _(c10::quint8, QUInt8, i) \ - _(c10::qint32, QInt32, i) - enum class ScalarType : int8_t { -#define DEFINE_ENUM(_1, n, _2) n, - AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_STUBS(DEFINE_ENUM) +#define DEFINE_ENUM(_1, n) n, + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_ENUM) #undef DEFINE_ENUM Undefined, NumOptions @@ -127,44 +120,49 @@ struct ScalarTypeToCPPType { } #define AT_FORALL_SCALAR_TYPES(_) \ - _(uint8_t, Byte, __) \ - _(int8_t, Char, __) \ - _(int16_t, Short, __) \ - _(int, Int, __) \ - _(int64_t, Long, __) \ - _(float, Float, __) \ - _(double, Double, __) - -#define AT_FORALL_SCALAR_TYPES_AND(SCALARTYPE, _) \ - _(uint8_t, Byte, __) \ - _(int8_t, Char, __) \ - _(int16_t, Short, __) \ - _(int, Int, __) \ - _(int64_t, Long, __) \ - _(at::Half, Half, __) \ - _(float, Float, __) \ - _(double, Double, __) \ - _(decltype(::c10::impl::ScalarTypeToCPPType::t), SCALARTYPE, __) - -#define AT_FORALL_SCALAR_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, _) \ - _(uint8_t, Byte, __) \ - _(int8_t, Char, __) \ - _(int16_t, Short, __) \ - _(int, Int, __) \ - _(int64_t, Long, __) \ - _(at::Half, Half, __) \ - _(float, Float, __) \ - _(double, Double, __) \ - _(decltype(::c10::impl::ScalarTypeToCPPType::t), SCALARTYPE1, __) \ - _(decltype(::c10::impl::ScalarTypeToCPPType::t), SCALARTYPE2, __) + _(uint8_t, Byte) \ + _(int8_t, Char) \ + _(int16_t, Short) \ + _(int, Int) \ + _(int64_t, Long) \ + _(float, Float) \ + _(double, Double) + +#define AT_FORALL_SCALAR_TYPES_AND(SCALARTYPE, _) \ + _(uint8_t, Byte) \ + _(int8_t, Char) \ + _(int16_t, Short) \ + _(int, Int) \ + _(int64_t, Long) \ + _(at::Half, Half) \ + _(float, Float) \ + _(double, Double) \ + _(decltype(::c10::impl::ScalarTypeToCPPType<::c10::ScalarType::SCALARTYPE>::t), SCALARTYPE) + +#define AT_FORALL_SCALAR_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, _) \ + _(uint8_t, Byte) \ + _(int8_t, Char) \ + _(int16_t, Short) \ + _(int, Int) \ + _(int64_t, Long) \ + _(at::Half, Half) \ + _(float, Float) \ + _(double, Double) \ + _(decltype(::c10::impl::ScalarTypeToCPPType<::c10::ScalarType::SCALARTYPE1>::t), SCALARTYPE1) \ + _(decltype(::c10::impl::ScalarTypeToCPPType<::c10::ScalarType::SCALARTYPE2>::t), SCALARTYPE2) + +#define AT_FORALL_QINT_TYPES(_) \ + _(c10::qint8, QInt8) \ + _(c10::quint8, QUInt8) \ + _(c10::qint32, QInt32) static inline caffe2::TypeMeta scalarTypeToTypeMeta(ScalarType scalar_type) { -#define DEFINE_CASE(ctype, name, _) \ - case ScalarType::name: \ +#define DEFINE_CASE(ctype, name) \ + case ScalarType::name: \ return caffe2::TypeMeta::Make(); switch (scalar_type) { - AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_STUBS(DEFINE_CASE) + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CASE) case ScalarType::Undefined: return caffe2::TypeMeta(); default: @@ -178,11 +176,11 @@ static inline caffe2::TypeMeta scalarTypeToTypeMeta(ScalarType scalar_type) { static inline c10::optional tryTypeMetaToScalarType( caffe2::TypeMeta dtype) { -#define DEFINE_IF(ctype, name, _) \ +#define DEFINE_IF(ctype, name) \ if (dtype == caffe2::TypeMeta::Make()) { \ return {ScalarType::name}; \ } - AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_STUBS(DEFINE_IF) + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_IF) #undef DEFINE_IF if (dtype == caffe2::TypeMeta()) { return {ScalarType::Undefined}; @@ -209,19 +207,19 @@ static inline bool operator==(caffe2::TypeMeta m, ScalarType t) { return t == m; } -#define DEFINE_CONSTANT(_, name, _2) \ +#define DEFINE_CONSTANT(_, name) \ constexpr ScalarType k##name = ScalarType::name; -AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_STUBS(DEFINE_CONSTANT) +AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CONSTANT) #undef DEFINE_CONSTANT static inline const char* toString(ScalarType t) { -#define DEFINE_CASE(_, name, _2) \ - case ScalarType::name: \ +#define DEFINE_CASE(_, name) \ + case ScalarType::name: \ return #name; switch (t) { - AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_STUBS(DEFINE_CASE) + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CASE) default: return "UNKNOWN_SCALAR"; } @@ -229,12 +227,12 @@ static inline const char* toString(ScalarType t) { } static inline size_t elementSize(ScalarType t) { -#define CASE_ELEMENTSIZE_CASE(ctype, name, _2) \ - case ScalarType::name: \ +#define CASE_ELEMENTSIZE_CASE(ctype, name) \ + case ScalarType::name: \ return sizeof(ctype); switch (t) { - AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_STUBS(CASE_ELEMENTSIZE_CASE) + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(CASE_ELEMENTSIZE_CASE) default: AT_ERROR("Unknown ScalarType"); } diff --git a/caffe2/contrib/aten/aten_op_template.h b/caffe2/contrib/aten/aten_op_template.h index 8eb3352c0685f..40ba58df48e9c 100644 --- a/caffe2/contrib/aten/aten_op_template.h +++ b/caffe2/contrib/aten/aten_op_template.h @@ -43,13 +43,13 @@ class ATenOp : public Operator { return typeMetaFor(t.scalar_type()); } TypeMeta typeMetaFor(at::ScalarType st) { - #define DEFINE_CASE(ctype,aten_name,_) \ + #define DEFINE_CASE(ctype,aten_name) \ case at::k##aten_name: \ return TypeMeta::Make(); switch(st) { AT_FORALL_SCALAR_TYPES_AND2(Bool, BFloat16, DEFINE_CASE) - default: - CAFFE_THROW("Unknown ATen Type"); + default: + CAFFE_THROW("Unknown ATen Type"); } #undef DEFINE_CASE } @@ -134,7 +134,7 @@ class ATenOp : public Operator { void assignTo(Tensor* dst, at::ScalarType scalar_type, at::Scalar scalar) { switch(scalar_type) { - #define DEFINE_CASE(ctype,aten_name,_1) \ + #define DEFINE_CASE(ctype,aten_name) \ case at::k##aten_name: { \ auto value = extract(scalar); \ assignToValue(dst, at::convert(value)); \ diff --git a/caffe2/operators/experimental/c10/cpu/cast_cpu.cc b/caffe2/operators/experimental/c10/cpu/cast_cpu.cc index ccd69e2c76b6d..49544c3d4be2b 100644 --- a/caffe2/operators/experimental/c10/cpu/cast_cpu.cc +++ b/caffe2/operators/experimental/c10/cpu/cast_cpu.cc @@ -79,7 +79,7 @@ void cast_op_cpu( const at::Tensor& output, int64_t to) { switch (input.scalar_type()) { -#define CASE(ctype,name,_2) case ScalarType:: name : return cast_op_cpu_impl(input, output, to); +#define CASE(ctype,name) case ScalarType:: name : return cast_op_cpu_impl(input, output, to); AT_FORALL_SCALAR_TYPES_AND2(Bool, BFloat16, CASE) #undef CASE default: throw std::runtime_error(string() + "Unsupported scalar type " + toString(input.scalar_type())); diff --git a/tools/autograd/templates/variable_factories.h b/tools/autograd/templates/variable_factories.h index 97d266858f4ce..4c2d69ae95b91 100644 --- a/tools/autograd/templates/variable_factories.h +++ b/tools/autograd/templates/variable_factories.h @@ -19,7 +19,7 @@ using at::DimnameList; namespace torch { -#define TENSOR(T, S, _1) \ +#define TENSOR(T, S) \ inline at::Tensor tensor( \ at::ArrayRef values, const at::TensorOptions& options) { \ at::Tensor result = ([&]() { \ diff --git a/torch/csrc/jit/fuser/codegen.cpp b/torch/csrc/jit/fuser/codegen.cpp index cb555594b03f4..a0579dc9fbb9d 100644 --- a/torch/csrc/jit/fuser/codegen.cpp +++ b/torch/csrc/jit/fuser/codegen.cpp @@ -67,10 +67,10 @@ static const char* scalarTypeName(const at::ScalarType type) { } switch (type) { -#define DEFINE_CASE(ctype, name, _) \ - case at::ScalarType::name: \ +#define DEFINE_CASE(ctype, name) \ + case at::ScalarType::name: \ return #ctype; - AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_STUBS(DEFINE_CASE) + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CASE) #undef DEFINE_CASE default: throw std::runtime_error("unknown scalar type"); diff --git a/torch/csrc/jit/script/schema_type_parser.cpp b/torch/csrc/jit/script/schema_type_parser.cpp index 898b57f697beb..268a9c856ad4f 100644 --- a/torch/csrc/jit/script/schema_type_parser.cpp +++ b/torch/csrc/jit/script/schema_type_parser.cpp @@ -124,10 +124,10 @@ c10::optional SchemaTypeParser::parseAliasAnnotation() { c10::optional SchemaTypeParser::parseTensorDType( const std::string& dtype) { -#define DEFINE_SCALAR_TYPE(_1, n, _2) {#n, at::ScalarType::n}, +#define DEFINE_SCALAR_TYPE(_1, n) {#n, at::ScalarType::n}, static std::unordered_map type_map = { - AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_STUBS(DEFINE_SCALAR_TYPE)}; + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_SCALAR_TYPE)}; auto type = type_map.find(dtype); if (type != type_map.end()) { diff --git a/torch/csrc/utils/tensor_dtypes.cpp b/torch/csrc/utils/tensor_dtypes.cpp index 7ed9a0d1dc09d..bc399563c1833 100644 --- a/torch/csrc/utils/tensor_dtypes.cpp +++ b/torch/csrc/utils/tensor_dtypes.cpp @@ -59,10 +59,10 @@ void initializeDtypes() { if (!torch_module) throw python_error(); -#define DEFINE_SCALAR_TYPE(_1, n, _2) at::ScalarType::n, +#define DEFINE_SCALAR_TYPE(_1, n) at::ScalarType::n, at::ScalarType all_scalar_types[] = { - AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_STUBS(DEFINE_SCALAR_TYPE)}; + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_SCALAR_TYPE)}; for (at::ScalarType scalarType : all_scalar_types) { std::string primary_name, legacy_name;