Skip to content

Commit

Permalink
Remove unused parameter from FORALL macros and rename STUBS to QINTS.
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#23340

Test Plan: Imported from OSS

Differential Revision: D16467981

Pulled By: gchanan

fbshipit-source-id: f4535c21ea54838d2086b2887a73e02e28b783d9
  • Loading branch information
gchanan authored and facebook-github-bot committed Aug 12, 2019
1 parent f5fefd6 commit 497bc3f
Show file tree
Hide file tree
Showing 13 changed files with 114 additions and 115 deletions.
6 changes: 3 additions & 3 deletions aten/src/ATen/core/TensorMethods.h
Original file line number Diff line number Diff line change
Expand Up @@ -1810,10 +1810,10 @@ inline bool is_quantized(Tensor self) {
return self.is_quantized();
}

#define DEFINE_CAST(T, name, _) \
#define DEFINE_CAST(T, name) \
template <> \
inline T* Tensor::data() const { \
TORCH_CHECK( \
TORCH_CHECK( \
scalar_type() == ScalarType::name, \
"expected scalar type ", \
#name, \
Expand All @@ -1826,7 +1826,7 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF(DEFINE_CAST)
AT_FORALL_QINT_TYPES(DEFINE_CAST)
#undef DEFINE_CAST

#define DEFINE_ITEM(T, name, _) \
#define DEFINE_ITEM(T, name) \
template <> \
inline T Tensor::item() const { \
return item().to##name(); \
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/native/TensorFactories.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ Tensor& empty_out(
// specialized operators for each datatype.
// TODO: remove when we have Type support in the IR

#define DEFINE_CAST_OP(_1, n, _2) \
#define DEFINE_CAST_OP(_1, n) \
Tensor _cast_##n(const Tensor& self, bool non_blocking) { \
if (self.scalar_type() == ScalarType::n) \
return self; \
Expand Down Expand Up @@ -798,7 +798,7 @@ Tensor tensor_cuda(ArrayRef<T> values, const TensorOptions& options) {
return cpu_tensor.to(options.device());
}

#define TENSOR(T, _1, _2) \
#define TENSOR(T, _1) \
Tensor tensor(ArrayRef<T> values, const TensorOptions& options) { \
if (options.device().is_cuda()) { \
return tensor_cuda(values, options); \
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/templates/NativeFunctions.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ namespace at {
namespace native {

// These functions are defined in native/TensorFactories.cpp.
#define TENSOR(T, S, _1) \
#define TENSOR(T, S) \
CAFFE2_API Tensor tensor(ArrayRef<T> values, const TensorOptions& options); \
inline Tensor tensor( \
std::initializer_list<T> values, const TensorOptions& options) { \
Expand Down
6 changes: 3 additions & 3 deletions aten/src/ATen/templates/TensorMethods.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,10 @@ inline bool is_quantized(Tensor self) {
return self.is_quantized();
}

#define DEFINE_CAST(T, name, _) \
#define DEFINE_CAST(T, name) \
template <> \
inline T* Tensor::data() const { \
TORCH_CHECK( \
TORCH_CHECK( \
scalar_type() == ScalarType::name, \
"expected scalar type ", \
#name, \
Expand All @@ -159,7 +159,7 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF(DEFINE_CAST)
AT_FORALL_QINT_TYPES(DEFINE_CAST)
#undef DEFINE_CAST

#define DEFINE_ITEM(T, name, _) \
#define DEFINE_ITEM(T, name) \
template <> \
inline T Tensor::item() const { \
return item().to##name(); \
Expand Down
10 changes: 5 additions & 5 deletions aten/src/TH/THBlasUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
template<typename T>
inline void THBlas_axpy(int64_t n, T a, T *x, int64_t incx, T *y, int64_t incy);

#define AXPY_SPECIALIZATION(ctype,name,_1) \
#define AXPY_SPECIALIZATION(ctype,name) \
template<> \
inline void THBlas_axpy<ctype>(int64_t n, ctype a, ctype *x, int64_t incx, \
ctype *y, int64_t incy) { \
Expand All @@ -22,7 +22,7 @@ AT_FORALL_SCALAR_TYPES(AXPY_SPECIALIZATION)
template<typename T>
inline void THBlas_copy(int64_t n, T *x, int64_t incx, T *y, int64_t incy);

#define COPY_SPECIALIZATION(ctype,name,_1) \
#define COPY_SPECIALIZATION(ctype,name) \
template<> \
inline void THBlas_copy<ctype>(int64_t n, ctype *x, int64_t incx, \
ctype *y, int64_t incy) { \
Expand All @@ -34,7 +34,7 @@ AT_FORALL_SCALAR_TYPES(COPY_SPECIALIZATION)
template<typename T>
inline T THBlas_dot(int64_t n, T *x, int64_t incx, T *y, int64_t incy);

#define DOT_SPECIALIZATION(ctype,name,_1) \
#define DOT_SPECIALIZATION(ctype,name) \
template<> \
inline ctype THBlas_dot<ctype>(int64_t n, ctype *x, int64_t incx, ctype *y, int64_t incy) { \
return TH ## name ## Blas_dot(n, x, incx, y, incy); \
Expand All @@ -58,7 +58,7 @@ inline void THBlas_gemm(
T *c,
int64_t ldc);

#define GEMM_SPECIALIZATION(ctype,name,_1) \
#define GEMM_SPECIALIZATION(ctype,name) \
template<> \
inline void THBlas_gemm<ctype>( \
char transa, \
Expand Down Expand Up @@ -94,7 +94,7 @@ inline void THBlas_gemv(
T* y,
int64_t incy);

#define GEMV_SPECIALIZATION(ctype, name, _1) \
#define GEMV_SPECIALIZATION(ctype, name) \
template <> \
inline void THBlas_gemv<ctype>( \
char transa, \
Expand Down
9 changes: 5 additions & 4 deletions c10/core/Scalar.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@ class C10_API Scalar {
public:
Scalar() : Scalar(int64_t(0)) {}

#define DEFINE_IMPLICIT_CTOR(type, name, member) \
#define DEFINE_IMPLICIT_CTOR(type, name) \
Scalar(type vv) : Scalar(vv, true) { }
AT_FORALL_SCALAR_TYPES_AND(c10::ScalarType::BFloat16, DEFINE_IMPLICIT_CTOR)

AT_FORALL_SCALAR_TYPES_AND(BFloat16, DEFINE_IMPLICIT_CTOR)

#undef DEFINE_IMPLICIT_CTOR

Expand All @@ -53,7 +54,7 @@ class C10_API Scalar {

#undef DEFINE_IMPLICIT_COMPLEX_CTOR

#define DEFINE_ACCESSOR(type, name, member) \
#define DEFINE_ACCESSOR(type, name) \
type to##name() const { \
if (Tag::HAS_d == tag) { \
return checked_convert<type, double>(v.d, #type); \
Expand Down Expand Up @@ -122,7 +123,7 @@ inline T Scalar::to() {
throw std::runtime_error("to() cast to unexpected type.");
}

#define DEFINE_TO(T, name, _) \
#define DEFINE_TO(T, name) \
template <> \
inline T Scalar::to<T>() { \
return to##name(); \
Expand Down
166 changes: 82 additions & 84 deletions c10/core/ScalarType.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,61 +12,54 @@
namespace c10 {

// For the macros below:
// NB: QInt ScalarTypes are referred to as "STUBS" here since they do not
// contain complete information to determine the tensor value of the data,
// they are just stubs for dispatch / quantization.
// NB: If you want to macro some code for all non-stub scalar types, you
// probably want one of the AT_FORALL_SCALAR_TYPES / AT_FORALL_SCALAR_TYPES_AND
// NB: If you want to macro some code for all non-QInt scalar types (i.e. types
// with complete information, you probably want one of the
// AT_FORALL_SCALAR_TYPES / AT_FORALL_SCALAR_TYPES_AND
// macros below, which are designed to behave similarly to the Dispatch macros
// with the same name.

// NB: Order matters for this macro; it is relied upon in
// _promoteTypesLookup and the serialization format.
#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_STUBS(_) \
_(uint8_t, Byte, __) /* 0 */ \
_(int8_t, Char, __) /* 1 */ \
_(int16_t, Short, __) /* 2 */ \
_(int, Int, __) /* 3 */ \
_(int64_t, Long, __) /* 4 */ \
_(at::Half, Half, __) /* 5 */ \
_(float, Float, __) /* 6 */ \
_(double, Double, __) /* 7 */ \
_(at::ComplexHalf, ComplexHalf, __) /* 8 */ \
_(std::complex<float>, ComplexFloat, __) /* 9 */ \
_(std::complex<double>, ComplexDouble, __) /* 10 */ \
_(bool, Bool, __) /* 11 */ \
_(c10::qint8, QInt8, __) /* 12 */ \
_(c10::quint8, QUInt8, __) /* 13 */ \
_(c10::qint32, QInt32, __) /* 14 */ \
_(at::BFloat16, BFloat16, __) /* 15 */
#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(_) \
_(uint8_t, Byte) /* 0 */ \
_(int8_t, Char) /* 1 */ \
_(int16_t, Short) /* 2 */ \
_(int, Int) /* 3 */ \
_(int64_t, Long) /* 4 */ \
_(at::Half, Half) /* 5 */ \
_(float, Float) /* 6 */ \
_(double, Double) /* 7 */ \
_(at::ComplexHalf, ComplexHalf) /* 8 */ \
_(std::complex<float>, ComplexFloat) /* 9 */ \
_(std::complex<double>, ComplexDouble) /* 10 */ \
_(bool, Bool) /* 11 */ \
_(c10::qint8, QInt8) /* 12 */ \
_(c10::quint8, QUInt8) /* 13 */ \
_(c10::qint32, QInt32) /* 14 */ \
_(at::BFloat16, BFloat16) /* 15 */


// If you want to support ComplexHalf for real, add ComplexHalf
// into this macro (and change the name). But beware: convert()
// doesn't work for all the conversions you need...
#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF(_) \
_(uint8_t, Byte, __) \
_(int8_t, Char, __) \
_(int16_t, Short, __) \
_(int, Int, __) \
_(int64_t, Long, __) \
_(at::Half, Half, __) \
_(float, Float, __) \
_(double, Double, __) \
_(std::complex<float>, ComplexFloat, __) \
_(std::complex<double>, ComplexDouble, __) \
_(bool, Bool, __) \
_(at::BFloat16, BFloat16, __)
_(uint8_t, Byte) \
_(int8_t, Char) \
_(int16_t, Short) \
_(int, Int) \
_(int64_t, Long) \
_(at::Half, Half) \
_(float, Float) \
_(double, Double) \
_(std::complex<float>, ComplexFloat) \
_(std::complex<double>, ComplexDouble) \
_(bool, Bool) \
_(at::BFloat16, BFloat16)


#define AT_FORALL_QINT_TYPES(_) \
_(c10::qint8, QInt8, i) \
_(c10::quint8, QUInt8, i) \
_(c10::qint32, QInt32, i)

enum class ScalarType : int8_t {
#define DEFINE_ENUM(_1, n, _2) n,
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_STUBS(DEFINE_ENUM)
#define DEFINE_ENUM(_1, n) n,
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_ENUM)
#undef DEFINE_ENUM
Undefined,
NumOptions
Expand Down Expand Up @@ -127,44 +120,49 @@ struct ScalarTypeToCPPType<c10::ScalarType::Long> {
}

#define AT_FORALL_SCALAR_TYPES(_) \
_(uint8_t, Byte, __) \
_(int8_t, Char, __) \
_(int16_t, Short, __) \
_(int, Int, __) \
_(int64_t, Long, __) \
_(float, Float, __) \
_(double, Double, __)

#define AT_FORALL_SCALAR_TYPES_AND(SCALARTYPE, _) \
_(uint8_t, Byte, __) \
_(int8_t, Char, __) \
_(int16_t, Short, __) \
_(int, Int, __) \
_(int64_t, Long, __) \
_(at::Half, Half, __) \
_(float, Float, __) \
_(double, Double, __) \
_(decltype(::c10::impl::ScalarTypeToCPPType<SCALARTYPE>::t), SCALARTYPE, __)

#define AT_FORALL_SCALAR_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, _) \
_(uint8_t, Byte, __) \
_(int8_t, Char, __) \
_(int16_t, Short, __) \
_(int, Int, __) \
_(int64_t, Long, __) \
_(at::Half, Half, __) \
_(float, Float, __) \
_(double, Double, __) \
_(decltype(::c10::impl::ScalarTypeToCPPType<c10::ScalarType::SCALARTYPE1>::t), SCALARTYPE1, __) \
_(decltype(::c10::impl::ScalarTypeToCPPType<c10::ScalarType::SCALARTYPE2>::t), SCALARTYPE2, __)
_(uint8_t, Byte) \
_(int8_t, Char) \
_(int16_t, Short) \
_(int, Int) \
_(int64_t, Long) \
_(float, Float) \
_(double, Double)

#define AT_FORALL_SCALAR_TYPES_AND(SCALARTYPE, _) \
_(uint8_t, Byte) \
_(int8_t, Char) \
_(int16_t, Short) \
_(int, Int) \
_(int64_t, Long) \
_(at::Half, Half) \
_(float, Float) \
_(double, Double) \
_(decltype(::c10::impl::ScalarTypeToCPPType<::c10::ScalarType::SCALARTYPE>::t), SCALARTYPE)

#define AT_FORALL_SCALAR_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, _) \
_(uint8_t, Byte) \
_(int8_t, Char) \
_(int16_t, Short) \
_(int, Int) \
_(int64_t, Long) \
_(at::Half, Half) \
_(float, Float) \
_(double, Double) \
_(decltype(::c10::impl::ScalarTypeToCPPType<::c10::ScalarType::SCALARTYPE1>::t), SCALARTYPE1) \
_(decltype(::c10::impl::ScalarTypeToCPPType<::c10::ScalarType::SCALARTYPE2>::t), SCALARTYPE2)

#define AT_FORALL_QINT_TYPES(_) \
_(c10::qint8, QInt8) \
_(c10::quint8, QUInt8) \
_(c10::qint32, QInt32)

static inline caffe2::TypeMeta scalarTypeToTypeMeta(ScalarType scalar_type) {
#define DEFINE_CASE(ctype, name, _) \
case ScalarType::name: \
#define DEFINE_CASE(ctype, name) \
case ScalarType::name: \
return caffe2::TypeMeta::Make<ctype>();

switch (scalar_type) {
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_STUBS(DEFINE_CASE)
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CASE)
case ScalarType::Undefined:
return caffe2::TypeMeta();
default:
Expand All @@ -178,11 +176,11 @@ static inline caffe2::TypeMeta scalarTypeToTypeMeta(ScalarType scalar_type) {

static inline c10::optional<ScalarType> tryTypeMetaToScalarType(
caffe2::TypeMeta dtype) {
#define DEFINE_IF(ctype, name, _) \
#define DEFINE_IF(ctype, name) \
if (dtype == caffe2::TypeMeta::Make<ctype>()) { \
return {ScalarType::name}; \
}
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_STUBS(DEFINE_IF)
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_IF)
#undef DEFINE_IF
if (dtype == caffe2::TypeMeta()) {
return {ScalarType::Undefined};
Expand All @@ -209,32 +207,32 @@ static inline bool operator==(caffe2::TypeMeta m, ScalarType t) {
return t == m;
}

#define DEFINE_CONSTANT(_, name, _2) \
#define DEFINE_CONSTANT(_, name) \
constexpr ScalarType k##name = ScalarType::name;

AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_STUBS(DEFINE_CONSTANT)
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CONSTANT)
#undef DEFINE_CONSTANT

static inline const char* toString(ScalarType t) {
#define DEFINE_CASE(_, name, _2) \
case ScalarType::name: \
#define DEFINE_CASE(_, name) \
case ScalarType::name: \
return #name;

switch (t) {
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_STUBS(DEFINE_CASE)
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CASE)
default:
return "UNKNOWN_SCALAR";
}
#undef DEFINE_CASE
}

static inline size_t elementSize(ScalarType t) {
#define CASE_ELEMENTSIZE_CASE(ctype, name, _2) \
case ScalarType::name: \
#define CASE_ELEMENTSIZE_CASE(ctype, name) \
case ScalarType::name: \
return sizeof(ctype);

switch (t) {
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_STUBS(CASE_ELEMENTSIZE_CASE)
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(CASE_ELEMENTSIZE_CASE)
default:
AT_ERROR("Unknown ScalarType");
}
Expand Down
Loading

0 comments on commit 497bc3f

Please sign in to comment.