diff --git a/aten/src/ATen/core/function_schema.h b/aten/src/ATen/core/function_schema.h deleted file mode 100644 index 9d5403f4bd8ca..0000000000000 --- a/aten/src/ATen/core/function_schema.h +++ /dev/null @@ -1,140 +0,0 @@ -#pragma once - -#include -#include - -namespace torch { namespace jit { - -// schema as used in the compiler for resolving function calls and reporting -// errors. These objects should be constructed from C10 schema once those -// are available. -struct Argument { - Argument( - std::string name = "", - TypePtr type = nullptr, - at::optional N = at::nullopt, - at::optional default_value = at::nullopt, - bool kwarg_only = false) - : name(std::move(name)), - type(type? type : DynamicType::get()), - N(std::move(N)), - default_value(std::move(default_value)), - kwarg_only(kwarg_only) {} - std::string name; - TypePtr type; - - // for list types, an optional statically known length for the list - // e.g. for int[3]: type = ListType::ofInts(), N = 3 - // If present, this will allow scalars to be broadcast to this length to - // become a list. - at::optional N; - - at::optional default_value; - // is this only specifyable as a keyword argument? - bool kwarg_only; -}; - -struct FunctionSchema { - FunctionSchema( - std::string name, - std::vector arguments, - std::vector returns, - bool is_vararg = false, - bool is_varret = false) - : name(std::move(name)), - arguments(std::move(arguments)), - returns(std::move(returns)), - is_vararg(is_vararg), - is_varret(is_varret), - is_mutable(isMutable()) { - validate(); - } - FunctionSchema( - Symbol name, - std::vector arguments, - std::vector returns, - bool is_vararg = false, - bool is_varret = false) - : FunctionSchema( - name.toQualString(), - std::move(std::move(arguments)), - std::move(std::move(returns)), - is_vararg, - is_varret) { - validate(); - } - - const std::string name; - const std::vector arguments; - const std::vector returns; - // if true then this schema takes an arbitrary number of additional arguments - // after the argument specified in arguments - // currently this is used primarily to represent 'primtive' operators whose - // arguments are not checked by schema - const bool is_vararg; - const bool is_varret; - const bool is_mutable; - - at::optional argumentIndexWithName(const std::string& name) const { - for(size_t i = 0; i < arguments.size(); ++i) { - if(name == arguments[i].name) - return i; - } - return at::nullopt; - } - - private: - bool isMutable() const { - return std::any_of( - arguments.cbegin(), arguments.cend(), [](const Argument& arg) { - return arg.type == WorldType::get(); - }); - } - - void validate() const { - if (is_mutable) { - // Mutable schemas should have a world token as the first argument - // and return. - AT_ASSERT(arguments.at(0).type == WorldType::get()); - AT_ASSERT(returns.at(0).type == WorldType::get()); - } - } -}; - -// for debugging, make sure we can describe the call site -inline std::ostream& operator<<(std::ostream& out, const Argument& arg) { - return out << arg.type->str() << " " << arg.name << (arg.default_value ? "=" : ""); -} - -inline std::ostream& operator<<(std::ostream& out, const FunctionSchema& schema) { - // eventually this should look almost identical to python arg parser, but - // it is simpler for now to work directly on this schema - - out << schema.name; - out << "("; - - bool seen_kwarg_only = false; - for(size_t i = 0; i < schema.arguments.size(); ++i) { - if (i > 0) out << ", "; - if (schema.arguments[i].kwarg_only && !seen_kwarg_only) { - out << "*, "; - seen_kwarg_only = true; - } - out << schema.arguments[i]; - } - - out << ") -> "; - if (schema.returns.size() == 1) { - out << schema.returns.at(0).type->str(); - } else if (schema.returns.size() > 1) { - out << "("; - for (size_t i = 0; i < schema.returns.size(); ++i) { - if (i > 0) out << ", "; - out << schema.returns[i].type->str(); - } - out << ")"; - } - return out; -} - -}} diff --git a/aten/src/ATen/core/functional.h b/aten/src/ATen/core/functional.h deleted file mode 100644 index f3b39bcdeb4ed..0000000000000 --- a/aten/src/ATen/core/functional.h +++ /dev/null @@ -1,63 +0,0 @@ -#pragma once - -#include -#include - -namespace torch { - -// The passed in function must take T by value (T), or by -// const reference (const T&); taking T by non-const reference -// will result in an error like: -// -// error: no type named 'type' in 'class std::result_of' -// -// No explicit template parameters are required. - -// Overload for explicit function and ArrayRef -template -inline auto fmap(const T& inputs, const F& fn) -> std::vector { - std::vector r; - r.reserve(inputs.size()); - for(const auto & input : inputs) - r.push_back(fn(input)); - return r; -} - -template -inline auto fmap(T& inputs, const F& fn) -> std::vector { - std::vector r; - r.reserve(inputs.size()); - for(auto & input : inputs) - r.push_back(fn(input)); - return r; -} - -// C++ forbids taking an address of a constructor, so here's a workaround... -// Overload for constructor (R) application -template -inline std::vector fmap(const T& inputs) { - std::vector r; - r.reserve(inputs.size()); - for(auto & input : inputs) - r.push_back(R(input)); - return r; -} - -template -inline std::vector filter(at::ArrayRef inputs, const F& fn) { - std::vector r; - r.reserve(inputs.size()); - for(auto & input : inputs) { - if (fn(input)) { - r.push_back(input); - } - } - return r; -} - -template -inline std::vector filter(const std::vector& inputs, const F& fn) { - return filter(static_cast>(inputs), fn); -} - -} diff --git a/aten/src/ATen/core/jit_type.h b/aten/src/ATen/core/jit_type.h deleted file mode 100644 index 7aba3bcb0a9a3..0000000000000 --- a/aten/src/ATen/core/jit_type.h +++ /dev/null @@ -1,814 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include - -#include - -#include -#include -#include - -namespace torch { namespace jit { - -#define TH_FORALL_TYPES(_) \ -_(DynamicType) \ -_(TensorType) \ -_(CompleteTensorType) \ -_(UndefinedTensorType) \ -_(TupleType) \ -_(ListType) \ -_(NumberType) \ -_(FloatType) \ -_(IntType) \ -_(NoneType) \ -_(StringType) \ -_(GeneratorType) \ -_(BoolType) \ -_(VarType) \ -_(WorldType) \ - -enum class TypeKind { -#define DEFINE_TYPE(T) T, - TH_FORALL_TYPES(DEFINE_TYPE) -#undef DEFINE_TYPE -}; - -struct Type; -using TypePtr = std::shared_ptr; - -template -struct cloneType {}; - -template -struct cloneType { - std::shared_ptr operator()(std::shared_ptr ptr) const { - return ptr; - } - std::shared_ptr operator()(std::shared_ptr ptr) const { - return ptr; - } -}; - -template -struct cloneType { - std::shared_ptr operator()(std::shared_ptr ptr) const { - auto result = std::make_shared::type>(*ptr); - // XXX: the line above will correctly slice the struct, and make its runtype - // type exactly equal to T. However, kind_ is a field of Type, so it will simply - // be copied, and we need to fix it in here to match the dynamic type. - result->kind_ = T::Kind; - return result; - } -}; - -struct CAFFE2_API Type : std::enable_shared_from_this { -private: - TypeKind kind_; - template - friend struct cloneType; - -protected: - Type(TypeKind kind) - : kind_(kind) {} - -public: - virtual bool operator==(const Type& rhs) const = 0; - - // subtyping relation. By default, we return true for the case - // when the type is exactly equal - virtual bool isSubtypeOf(const TypePtr rhs) const { - return *this == *rhs; - } - - // How this type will appear in FunctionSchema declarations - virtual std::string str() const = 0; - - // How this type will appear as if it were a type annotation in Python - // which is sometimes different than how it appears in declarations (e.g. int[] vs List[int]) - virtual std::string python_str() const { - return str(); - } - - TypeKind kind() const { - return kind_; - } - - virtual bool requires_grad() const { return false; } - - // Dynamically cast this object to the subclass indicated by the - // template variable, returning nullptr if the cast is invalid. - // NOTE: if the cast succeeds, but the casted kind is not the - // run-time kind of the type, we also slice the structure, so - // that assignments of those types to values don't accidentally - // inherit more detailed information from subclasses. - template - std::shared_ptr cast() { - auto r = caffe2::dynamic_pointer_cast_if_rtti(shared_from_this()); - if (!r || T::Kind == kind()) { - return r; - } else { - return cloneType{}(r); - } - } - template - std::shared_ptr cast() const { - auto r = caffe2::dynamic_pointer_cast_if_rtti(shared_from_this()); - if (!r || T::Kind == kind()) { - return r; - } else { - return cloneType{}(r); - } - } - template - std::shared_ptr expect() { - auto r = cast(); - AT_ASSERT(r); - return r; - } - template - std::shared_ptr expect() const { - auto r = cast(); - AT_ASSERT(r); - return r; - } - virtual ~Type() = default; - virtual bool hasFreeVariables() const { - return false; - } -}; - -inline bool operator!=(const Type & lhs, const Type & rhs) { - return !(lhs == rhs); -} - -struct DynamicType; -using DynamicTypePtr = std::shared_ptr; -// This node represents a single Tensor value, with an unknown shape. -struct CAFFE2_API DynamicType : public Type { - static constexpr bool is_singleton = true; - template - static DynamicTypePtr create( T&& ... all ) { - return DynamicTypePtr(new DynamicType( std::forward(all)... )); // NOLINT(modernize-make-shared) - } - - bool requires_grad() const override { return true; } - - bool operator==(const Type& rhs) const override { - return rhs.kind() == kind(); - } - std::string str() const override { - return "Tensor"; - } - static const TypeKind Kind = TypeKind::DynamicType; - // global singleton - static DynamicTypePtr get(); -private: - DynamicType() - : Type(TypeKind::DynamicType) {} -}; - -struct UndefinedTensorType; -using UndefinedTensorTypePtr = std::shared_ptr; -struct CAFFE2_API UndefinedTensorType : public Type { - static constexpr bool is_singleton = true; - friend struct Type; - static const TypeKind Kind = TypeKind::UndefinedTensorType; - - template - static UndefinedTensorTypePtr create( T&& ... all ) { - return UndefinedTensorTypePtr(new UndefinedTensorType( std::forward(all)... )); // NOLINT(modernize-make-shared) - } - - bool operator==(const Type& rhs) const override { - return rhs.kind() == kind(); - } - bool isSubtypeOf(const TypePtr rhs) const override { - return rhs->kind() == TypeKind::DynamicType || - rhs->kind() == TypeKind::UndefinedTensorType; - } - std::string str() const override { - return "UndefinedTensor"; - } - static UndefinedTensorTypePtr get(); -protected: - UndefinedTensorType(): Type(TypeKind::UndefinedTensorType) {} -}; - -struct TensorType; -using TensorTypePtr = std::shared_ptr; -// This node represents a single Tensor value with a specific size -struct CAFFE2_API TensorType : public Type { - static constexpr bool is_singleton = false; - friend struct Type; - static const TypeKind Kind = TypeKind::TensorType; - - template - static TensorTypePtr create( T&& ... all ) { - return TensorTypePtr(new TensorType( std::forward(all)... )); // NOLINT(modernize-make-shared) - } - - at::ScalarType scalarType() const { return scalar_type_; } - int device() const { return device_; } - int dim() const { return dim_; } - bool requires_grad() const override { return requires_grad_; } - - TensorTypePtr toScalarType(at::ScalarType type){ - auto t = TensorType::create(*this); - t->scalar_type_ = type; - return t; - } - TensorTypePtr withDim(int new_dim) { - auto t = TensorType::create(*this); - t->dim_ = new_dim; - return t; - } - TensorTypePtr withRequiresGrad(bool req) { - auto t = TensorType::create(*this); - t->requires_grad_ = req; - return t; - } - - bool operator==(const Type& rhs) const override { - if (rhs.kind() != TypeKind::TensorType) - return false; - auto rt = rhs.expect(); - return scalarType() == rt->scalarType() && - device() == rt->device() && - dim() == rt->dim(); - } - bool isSubtypeOf(const TypePtr rhs) const override { - if (rhs->kind() == TypeKind::DynamicType) - return true; - return rhs->kind() == TypeKind::TensorType && *this == *rhs; - } - std::string str() const override { - // str is used for user-facing error messages, where we - // don't want to reveal underlying size information. - return "Tensor"; - } - -protected: - TensorType(const at::Tensor& tensor, TypeKind kind=TypeKind::TensorType) - : TensorType(tensor.type().scalarType(), - tensor.type().is_cuda() ? tensor.get_device() : -1, - tensor.dim(), - tensor.is_variable() && tensor.requires_grad(), - kind) {} - TensorType(at::ScalarType scalar_type, int device, int dim, bool requires_grad=true, TypeKind kind=TypeKind::TensorType) - : Type(kind) - , scalar_type_(scalar_type) - , requires_grad_(at::isFloatingType(scalar_type) && requires_grad) - , device_(device) - , dim_(dim) {} - - at::ScalarType scalar_type_; - bool requires_grad_; - int device_; - int dim_; -}; - -struct CompleteTensorType; -using CompleteTensorTypePtr = std::shared_ptr; -// This node represents a single Tensor value with a specific size -struct CAFFE2_API CompleteTensorType : public TensorType { - static constexpr bool is_singleton = false; - friend struct Type; - template - static CompleteTensorTypePtr create( T&& ... all ) { - return CompleteTensorTypePtr(new CompleteTensorType( std::forward(all)... )); // NOLINT(modernize-make-shared) - } - - // overloaded create variadic template argument as it could not distinguish initializer list - static CompleteTensorTypePtr create(at::ScalarType scalar_type, int device, at::IntList sizes) { - return CompleteTensorTypePtr(new CompleteTensorType(scalar_type, device, sizes)); // NOLINT(modernize-make-shared) - } - static CompleteTensorTypePtr create(at::ScalarType scalar_type, int device, at::IntList sizes, at::IntList strides) { - return CompleteTensorTypePtr(new CompleteTensorType(scalar_type, device, sizes, strides)); // NOLINT(modernize-make-shared) - } - - static const TypeKind Kind = TypeKind::CompleteTensorType; - - const std::vector& sizes() const { return sizes_; } - const std::vector& strides() const { return strides_; } - - TypePtr withSizesStrides(at::IntList sizes, at::IntList strides) const { - return CompleteTensorType::create(scalar_type_, device_, sizes, strides); - } - - TypePtr withSizes(at::IntList sizes) const { - return withSizesStrides(sizes, CompleteTensorType::contiguousStridesOf(sizes)); - } - - CompleteTensorTypePtr contiguous() const { - auto t = CompleteTensorType::create(*this); - t->strides_ = CompleteTensorType::contiguousStridesOf(sizes_); - return t; - } - - CompleteTensorTypePtr toScalarType(at::ScalarType type){ - auto t = CompleteTensorType::create(*this); - t->scalar_type_ = type; - return t; - } - - bool operator==(const Type& rhs) const override { - if(rhs.kind() != kind()) - return false; - auto rt = rhs.expect(); - return scalarType() == rt->scalarType() && - sizes() == rt->sizes() && - strides() == rt->strides() && - device() == rt->device(); - } - bool isSubtypeOf(const TypePtr rhs) const override { - if (rhs->kind() == TypeKind::DynamicType) - return true; - if (rhs->kind() == TypeKind::TensorType) - return *expect() == *rhs; - return *this == *rhs; - } - std::string str() const override { - // str is used for user-facing error messages, where we - // don't want to reveal underlying size information. - return "Tensor"; - } - bool numel() const { - size_t prod = 1; - for(auto s : sizes()) { - prod *= s; - } - return prod; - } - static TypePtr fromNumberType(TypePtr typ); - static TypePtr fromBoolType(); - -private: - CompleteTensorType(const at::Tensor& tensor) - : TensorType(tensor, TypeKind::CompleteTensorType) - , sizes_(tensor.sizes().vec()) - , strides_(tensor.strides().vec()) {} - CompleteTensorType(at::ScalarType scalar_type, int device, at::IntList sizes, bool requires_grad=true) - : CompleteTensorType(scalar_type, device, sizes, CompleteTensorType::contiguousStridesOf(sizes), requires_grad) {} - CompleteTensorType(at::ScalarType scalar_type, int device, at::IntList sizes, at::IntList strides, bool requires_grad=true) - : TensorType(scalar_type, device, sizes.size(), requires_grad, TypeKind::CompleteTensorType) - , sizes_(sizes.vec()) - , strides_(strides.vec()) {} - - static std::vector contiguousStridesOf(at::IntList sizes) { - std::vector strides(sizes.size()); - if(sizes.empty()) // zero-dim case - return strides; - strides.back() = 1; - for(size_t i = strides.size() - 1; i > 0; i--) { - strides[i-1] = strides[i] * sizes[i]; - } - return strides; - } - - std::vector sizes_; - std::vector strides_; -}; - -// This type is a token used to represent effectful computation in the IR. -// See the AnnotateEffects pass for how it is used. -struct WorldType; -using WorldTypePtr = std::shared_ptr; -struct CAFFE2_API WorldType : public Type { - template - static WorldTypePtr create(T&&... all) { - return WorldTypePtr(new WorldType(std::forward(all)...)); - } - bool operator==(const Type& rhs) const override { - return rhs.kind() == kind(); - } - std::string str() const override { - return "world"; - } - bool isSubtypeOf(const TypePtr rhs) const override { - return *this == *rhs; - } - static const TypeKind Kind = TypeKind::WorldType; - // global singleton - static WorldTypePtr get(); - - private: - WorldType() : Type(TypeKind::WorldType) {} -}; - -struct ListType; -using ListTypePtr = std::shared_ptr; - -struct CAFFE2_API ListType : public Type { - // It's not exactly a singleton, but there should be exactly once instance of - // List[T] for every T - static constexpr bool is_singleton = true; - friend struct Type; - template - static ListTypePtr create( T&& ... all ) { - return ListTypePtr(new ListType( std::forward(all)... )); // NOLINT(modernize-make-shared) - } - bool operator==(const Type& rhs) const override { - if(auto rhs_ = rhs.cast()) { - return *getElementType() == *rhs_->getElementType(); - } - return false; - } - bool requires_grad() const override { - return elem->requires_grad(); - } - std::string str() const override { - std::stringstream ss; - ss << getElementType()->str() << "[]"; - return ss.str(); - } - std::string python_str() const override { - std::stringstream ss; - ss << "List[" << getElementType()->python_str() << "]"; - return ss.str(); - } - TypePtr getElementType() const { - return elem; - } - bool hasFreeVariables() const override { - return has_free_variables_; - } - // common cast List[Tensor] - static ListTypePtr ofTensors(); - static ListTypePtr ofInts(); - static ListTypePtr ofFloats(); - static ListTypePtr ofBools(); - - static const TypeKind Kind = TypeKind::ListType; -private: - ListType(TypePtr elem) - : Type(TypeKind::ListType) - , elem(std::move(elem)) - , has_free_variables_(getElementType()->hasFreeVariables()) {} - TypePtr elem; - bool has_free_variables_; -}; - -struct TupleType; -using TupleTypePtr = std::shared_ptr; - -struct CAFFE2_API TupleType : public Type { - static constexpr bool is_singleton = false; - friend struct Type; - static TupleTypePtr create(std::vector types) { - return TupleTypePtr(new TupleType( std::move(types) )); // NOLINT(modernize-make-shared) - } - at::ArrayRef elements() const { - return elements_; - } - bool operator==(const Type& rhs) const override { - return compare(rhs, [](const TypePtr a, const TypePtr b) { - return *a == *b; - }); - } - bool isSubtypeOf(const TypePtr rhs) const override { - // co-variant rules for tuples - return compare(*rhs, [](const TypePtr a, const TypePtr b) { - return a->isSubtypeOf(b); - }); - } - bool requires_grad() const override { - return std::any_of(elements_.begin(), elements_.end(), - [](const TypePtr& ptr) { return ptr->requires_grad(); }); - } - std::string str() const override { - std::stringstream ss; - ss << "("; - for(size_t i = 0; i < elements().size(); ++i) { - if(i > 0) - ss << ", "; - ss << elements()[i]->str(); - } - ss << ")"; - return ss.str(); - } - std::string python_str() const override { - std::stringstream ss; - ss << "Tuple["; - for(size_t i = 0; i < elements().size(); ++i) { - if(i > 0) - ss << ", "; - ss << elements()[i]->python_str(); - } - ss << "]"; - return ss.str(); - } - bool hasFreeVariables() const override { - return has_free_variables_; - } - - static const TypeKind Kind = TypeKind::TupleType; -private: - TupleType(std::vector elements_) - : Type(TypeKind::TupleType) - , elements_(std::move(elements_)) { - has_free_variables_ = - std::any_of(elements_.begin(), elements_.end(), [](TypePtr v) { - return v->hasFreeVariables(); - }); - } - - bool compare(const Type& rhs, std::function fn) const { - if(rhs.kind() != kind()) - return false; - const auto & l_elements = elements(); - const auto & r_elements = rhs.cast()->elements(); - if(l_elements.size() != r_elements.size()) - return false; - for(size_t i = 0; i < l_elements.size(); ++i) { - if(!fn(l_elements[i], r_elements[i])) - return false; - } - return true; - } - std::vector elements_; - bool has_free_variables_; -}; - -struct NumberType; -using NumberTypePtr = std::shared_ptr; -// This node represents a Python number value -struct CAFFE2_API NumberType : public Type { - static constexpr bool is_singleton = true; - template - static NumberTypePtr create( T&& ... all ) { - return NumberTypePtr(new NumberType( std::forward(all)... )); // NOLINT(modernize-make-shared) - } - bool operator==(const Type& rhs) const override { - return rhs.kind() == kind(); - } - std::string str() const override { - return "Scalar"; // match what PythonArgParser says for clarity - } - static const TypeKind Kind = TypeKind::NumberType; - // global singleton - static NumberTypePtr get(); -private: - NumberType() - : Type(TypeKind::NumberType) {} -}; - -struct FloatType; -using FloatTypePtr = std::shared_ptr; -// This node represents a Python float number value -struct CAFFE2_API FloatType : public Type { - static constexpr bool is_singleton = true; - template - static FloatTypePtr create( T&& ... all ) { - return FloatTypePtr(new FloatType( std::forward(all)... )); // NOLINT(modernize-make-shared) - } - bool operator==(const Type& rhs) const override { - return rhs.kind() == kind(); - } - std::string str() const override { - return "float"; - } - bool isSubtypeOf(const TypePtr rhs) const override { - return *this == *rhs || rhs->kind() == TypeKind::NumberType; - } - static const TypeKind Kind = TypeKind::FloatType; - // global singleton - static FloatTypePtr get(); -private: - FloatType() - : Type(TypeKind::FloatType) {} -}; - -struct IntType; -using IntTypePtr = std::shared_ptr; -// This node represents a Python int number value -struct CAFFE2_API IntType : public Type { - static constexpr bool is_singleton = true; - template - static IntTypePtr create( T&& ... all ) { - return IntTypePtr(new IntType( std::forward(all)... )); // NOLINT(modernize-make-shared) - } - bool operator==(const Type& rhs) const override { - return rhs.kind() == kind(); - } - std::string str() const override { - return "int"; - } - bool isSubtypeOf(const TypePtr rhs) const override { - return *this == *rhs || rhs->kind() == TypeKind::NumberType; - } - static const TypeKind Kind = TypeKind::IntType; - // global singleton - static IntTypePtr get(); -private: - IntType() - : Type(TypeKind::IntType) {} -}; - -struct BoolType; -using BoolTypePtr = std::shared_ptr; -// This node represents a Python bool value -struct CAFFE2_API BoolType : public Type { - template - static BoolTypePtr create( T&& ... all ) { - return BoolTypePtr(new BoolType(std::forward(all)... )); - } - bool operator==(const Type& rhs) const override { - return rhs.kind() == kind(); - } - std::string str() const override { - return "bool"; - } - bool isSubtypeOf(const TypePtr rhs) const override { - return *this == *rhs || rhs->kind() == TypeKind::BoolType; - } - static const TypeKind Kind = TypeKind::BoolType; - // global singleton - static BoolTypePtr get(); -private: - BoolType() - : Type(TypeKind::BoolType) {} -}; - -struct StringType; -using StringTypePtr = std::shared_ptr; -// This node represents a Python string value -struct CAFFE2_API StringType : public Type { - static constexpr bool is_singleton = true; - template - static StringTypePtr create( T&& ... all ) { - return StringTypePtr(new StringType( std::forward(all)... )); // NOLINT(modernize-make-shared) - } - bool operator==(const Type& rhs) const override { - return rhs.kind() == kind(); - } - std::string str() const override { - return "string"; - } - bool isSubtypeOf(const TypePtr rhs) const override { - return *this == *rhs; - } - static const TypeKind Kind = TypeKind::StringType; - // global singleton - static StringTypePtr get(); -private: - StringType() - : Type(TypeKind::StringType) {} -}; - -struct NoneType; -using NoneTypePtr = std::shared_ptr; -// This node represents a Python int number value -struct CAFFE2_API NoneType : public Type { - static constexpr bool is_singleton = true; - template - static NoneTypePtr create( T&& ... all ) { - return NoneTypePtr(new NoneType( std::forward(all)... )); // NOLINT(modernize-make-shared) - } - bool operator==(const Type& rhs) const override { - return rhs.kind() == kind(); - } - std::string str() const override { - return "None"; - } - static const TypeKind Kind = TypeKind::NoneType; - // global singleton - static NoneTypePtr get(); -private: - NoneType() - : Type(TypeKind::NoneType) {} -}; - -struct GeneratorType; -using GeneratorTypePtr = std::shared_ptr; -struct CAFFE2_API GeneratorType : public Type { - static constexpr bool is_singleton = true; - template - static GeneratorTypePtr create( T&& ... all) { - return GeneratorTypePtr(new GeneratorType( std::forward(all)... )); // NOLINT(modernize-make-shared) - } - bool operator==(const Type& rhs) const override { - return rhs.kind() == kind(); - } - std::string str() const override { - return "Generator"; - } - static const TypeKind Kind = TypeKind::GeneratorType; - // global singleton - static GeneratorTypePtr get(); -private: - GeneratorType() - : Type(TypeKind::GeneratorType) {} -}; - - -// a type variable, used in FunctionSchema -struct VarType; -using VarTypePtr = std::shared_ptr; -struct CAFFE2_API VarType : public Type { - static constexpr bool is_singleton = false; - template - static VarTypePtr create(std::string name_) { - return VarTypePtr(new VarType(std::move(name_))); - } - bool operator==(const Type& rhs) const override { - return rhs.kind() == kind(); - } - std::string str() const override { - return name(); - } - static const TypeKind Kind = TypeKind::VarType; - const std::string& name() const { - return name_; - } - bool hasFreeVariables() const override { - return true; - } -private: - VarType(std::string name_) - : Type(TypeKind::VarType), name_(name_) {} - std::string name_; -}; - -CAFFE2_API std::ostream& operator<<(std::ostream & out, const Type & t); -// what is the type, ignoring extra size/shape information? -// e.g. Tensor(2x3) -> Dynamic, and Tuple(Tensor(2x3),...) -> Tuple(Dynamic,...) - -inline TypePtr unshapedType(const TypePtr& type) { - if (TupleTypePtr t = type->cast()) { - return TupleType::create(fmap(t->elements(), unshapedType)); - } else if (ListTypePtr t = type->cast()) { - return ListType::create(unshapedType(t->getElementType())); - } else if (type->kind() == TypeKind::TensorType || - type->kind() == TypeKind::CompleteTensorType) { - return DynamicType::get(); - } else { - return type; - } -} - -inline TypePtr CompleteTensorType::fromNumberType(TypePtr typ) { - AT_ASSERT(typ->isSubtypeOf(NumberType::get())); - if (typ->isSubtypeOf(IntType::get())) { - return CompleteTensorType::create(at::kLong, -1, {}); - } else if (typ->isSubtypeOf(FloatType::get())) { - return CompleteTensorType::create(at::kFloat, -1, {}); - } else if (typ->isSubtypeOf(BoolType::get())) { - return CompleteTensorType::create(at::kLong, -1, {}); - } - AT_ERROR("unknown number type", typ->str()); -} - -inline TypePtr CompleteTensorType::fromBoolType() { - return CompleteTensorType::create(at::kLong, -1, {}); -} - - -// Attempt to find the correct supertype of t1 and t2. If none is found then -// nullopt will be returned. If t1 == t2, or t1 is a type refinement of t2, -// then t2 will be returned (and vice versa). -// Two different tensortypes will return dynamic. -// Currently we chose not to support returning a NumberType for a float & int -// input because of a lack of operator support for NumberType -CAFFE2_API at::optional unifyTypes(const TypePtr& t1, const TypePtr& t2); - -template -TypePtr getTypePtr() { -#define TYPE_STR(Type) #Type, " ", - AT_ERROR( - "Type ", - at::demangle_type(), - " could not be converted to any of the known types { ", - TH_FORALL_TYPES(TYPE_STR) "}"); -#undef TYPE_STR - return nullptr; -} - -template<> inline TypePtr getTypePtr() { return DynamicType::get(); } -template<> inline TypePtr getTypePtr() { return FloatType::get(); } -template<> inline TypePtr getTypePtr() { return IntType::get(); } -template<> inline TypePtr getTypePtr() { return BoolType::get(); } -template<> inline TypePtr getTypePtr() { return NumberType::get(); } -template<> inline TypePtr getTypePtr>() { return ListType::ofTensors(); } -template<> inline TypePtr getTypePtr>() { return ListType::ofFloats(); } -template<> inline TypePtr getTypePtr>() { return ListType::ofInts(); } - -CAFFE2_API TypePtr inferTypeFrom(const IValue& value); - -struct CAFFE2_API TypeMatchError : public std::exception { - TypeMatchError(std::string msg_) - : msg_(std::move(msg_)) {} - const char * what() const noexcept override { - return msg_.c_str(); - } -private: - std::string msg_; -}; -using TypeEnv = std::unordered_map; -CAFFE2_API TypePtr matchTypeVariables(TypePtr formal, TypePtr actual, TypeEnv & type_env); -CAFFE2_API TypePtr evalTypeVariables(TypePtr type, TypeEnv & type_env); - -}} // namespace torch::jit diff --git a/caffe2/core/common.h b/caffe2/core/common.h index 3ef3906d447a1..63f1c9e752357 100644 --- a/caffe2/core/common.h +++ b/caffe2/core/common.h @@ -183,16 +183,6 @@ inline Dst dynamic_cast_if_rtti(Src ptr) { #endif } -template< class T, class U > -std::shared_ptr dynamic_pointer_cast_if_rtti( const std::shared_ptr& r ) noexcept -{ - if (auto p = dynamic_cast_if_rtti::element_type*>(r.get())) { - return std::shared_ptr(r, p); - } else { - return std::shared_ptr(); - } -} - // SkipIndices are used in operator_fallback_gpu.h and operator_fallback_mkl.h // as utilty functions that marks input / output indices to skip when we use a // CPU operator as the fallback of GPU/MKL operator option. diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt index 7a6743d0dbffa..ae2b0b0429608 100644 --- a/torch/CMakeLists.txt +++ b/torch/CMakeLists.txt @@ -180,6 +180,7 @@ set(TORCH_SRCS ${TORCH_SRC_DIR}/csrc/jit/script/lexer.cpp ${TORCH_SRC_DIR}/csrc/jit/script/module.cpp ${TORCH_SRC_DIR}/csrc/jit/tracer.cpp + ${TORCH_SRC_DIR}/csrc/jit/type.cpp ${TORCH_SRC_DIR}/csrc/torch.cpp ${TORCH_SRC_DIR}/csrc/utils/tensor_flatten.cpp ${TORCH_SRC_DIR}/csrc/utils/variadic.cpp diff --git a/torch/csrc/jit/function_schema.h b/torch/csrc/jit/function_schema.h index 160eef3340470..dcaaf766e18c0 100644 --- a/torch/csrc/jit/function_schema.h +++ b/torch/csrc/jit/function_schema.h @@ -1 +1,141 @@ -#include +#pragma once +#include "ATen/ATen.h" + +#include "torch/csrc/jit/type.h" +#include "torch/csrc/jit/ivalue.h" + +namespace torch { namespace jit { + +// schema as used in the compiler for resolving function calls and reporting +// errors. These objects should be constructed from C10 schema once those +// are available. +struct Argument { + Argument( + std::string name = "", + TypePtr type = nullptr, + at::optional N = at::nullopt, + at::optional default_value = at::nullopt, + bool kwarg_only = false) + : name(std::move(name)), + type(type? type : DynamicType::get()), + N(std::move(N)), + default_value(std::move(default_value)), + kwarg_only(kwarg_only) {} + std::string name; + TypePtr type; + + // for list types, an optional statically known length for the list + // e.g. for int[3]: type = ListType::ofInts(), N = 3 + // If present, this will allow scalars to be broadcast to this length to + // become a list. + at::optional N; + + at::optional default_value; + // is this only specifyable as a keyword argument? + bool kwarg_only; +}; + +struct FunctionSchema { + FunctionSchema( + std::string name, + std::vector arguments, + std::vector returns, + bool is_vararg = false, + bool is_varret = false) + : name(std::move(name)), + arguments(std::move(arguments)), + returns(std::move(returns)), + is_vararg(is_vararg), + is_varret(is_varret), + is_mutable(isMutable()) { + validate(); + } + FunctionSchema( + Symbol name, + std::vector arguments, + std::vector returns, + bool is_vararg = false, + bool is_varret = false) + : FunctionSchema( + name.toQualString(), + std::move(std::move(arguments)), + std::move(std::move(returns)), + is_vararg, + is_varret) { + validate(); + } + + const std::string name; + const std::vector arguments; + const std::vector returns; + // if true then this schema takes an arbitrary number of additional arguments + // after the argument specified in arguments + // currently this is used primarily to represent 'primtive' operators whose + // arguments are not checked by schema + const bool is_vararg; + const bool is_varret; + const bool is_mutable; + + at::optional argumentIndexWithName(const std::string& name) const { + for(size_t i = 0; i < arguments.size(); ++i) { + if(name == arguments[i].name) + return i; + } + return at::nullopt; + } + + private: + bool isMutable() const { + return std::any_of( + arguments.cbegin(), arguments.cend(), [](const Argument& arg) { + return arg.type == WorldType::get(); + }); + } + + void validate() const { + if (is_mutable) { + // Mutable schemas should have a world token as the first argument + // and return. + JIT_ASSERT(arguments.at(0).type == WorldType::get()); + JIT_ASSERT(returns.at(0).type == WorldType::get()); + } + } +}; + +// for debugging, make sure we can describe the call site +inline std::ostream& operator<<(std::ostream& out, const Argument& arg) { + return out << arg.type->str() << " " << arg.name << (arg.default_value ? "=" : ""); +} + +inline std::ostream& operator<<(std::ostream& out, const FunctionSchema& schema) { + // eventually this should look almost identical to python arg parser, but + // it is simpler for now to work directly on this schema + + out << schema.name; + out << "("; + + bool seen_kwarg_only = false; + for(size_t i = 0; i < schema.arguments.size(); ++i) { + if (i > 0) out << ", "; + if (schema.arguments[i].kwarg_only && !seen_kwarg_only) { + out << "*, "; + seen_kwarg_only = true; + } + out << schema.arguments[i]; + } + + out << ") -> "; + if (schema.returns.size() == 1) { + out << schema.returns.at(0).type->str(); + } else if (schema.returns.size() > 1) { + out << "("; + for (size_t i = 0; i < schema.returns.size(); ++i) { + if (i > 0) out << ", "; + out << schema.returns[i].type->str(); + } + out << ")"; + } + return out; +} + +}} diff --git a/aten/src/ATen/core/type.cpp b/torch/csrc/jit/type.cpp similarity index 97% rename from aten/src/ATen/core/type.cpp rename to torch/csrc/jit/type.cpp index a56173b24b8b5..ce3cd8c1f225f 100644 --- a/aten/src/ATen/core/type.cpp +++ b/torch/csrc/jit/type.cpp @@ -1,4 +1,6 @@ -#include +#include "torch/csrc/jit/type.h" + +#include "torch/csrc/jit/assertions.h" #include @@ -9,7 +11,7 @@ std::ostream& operator<<(std::ostream & out, const Type & t) { out << at::toString(value->scalarType()) << "("; auto& sizes = value->sizes(); auto& strides = value->strides(); - AT_ASSERT(sizes.size() == strides.size()); + JIT_ASSERT(sizes.size() == strides.size()); for (size_t i = 0; i < sizes.size(); i++) { if (i > 0) { out << ", "; @@ -239,7 +241,7 @@ TypePtr matchTypeVariables(TypePtr formal, TypePtr actual, TypeEnv& type_env) { } // change return types like List[List[t]] into List[List[int]] -CAFFE2_API TypePtr evalTypeVariables(TypePtr type, std::unordered_map& type_env) { +TORCH_API TypePtr evalTypeVariables(TypePtr type, std::unordered_map& type_env) { if(!type->hasFreeVariables()) return type; diff --git a/torch/csrc/jit/type.h b/torch/csrc/jit/type.h index 171d61be5aa78..69133f6be9c3a 100644 --- a/torch/csrc/jit/type.h +++ b/torch/csrc/jit/type.h @@ -1 +1,814 @@ -#include +#pragma once + +#include "torch/csrc/jit/ivalue.h" +#include "torch/csrc/jit/assertions.h" +#include "torch/csrc/jit/interned_strings.h" +#include "torch/csrc/WindowsTorchApiMacro.h" +#include "torch/csrc/utils/functional.h" + +#include + +#include +#include +#include + +namespace torch { namespace jit { + +#define TH_FORALL_TYPES(_) \ +_(DynamicType) \ +_(TensorType) \ +_(CompleteTensorType) \ +_(UndefinedTensorType) \ +_(TupleType) \ +_(ListType) \ +_(NumberType) \ +_(FloatType) \ +_(IntType) \ +_(NoneType) \ +_(StringType) \ +_(GeneratorType) \ +_(BoolType) \ +_(VarType) \ +_(WorldType) \ + +enum class TypeKind { +#define DEFINE_TYPE(T) T, + TH_FORALL_TYPES(DEFINE_TYPE) +#undef DEFINE_TYPE +}; + +struct Type; +using TypePtr = std::shared_ptr; + +template +struct cloneType {}; + +template +struct cloneType { + std::shared_ptr operator()(std::shared_ptr ptr) const { + return ptr; + } + std::shared_ptr operator()(std::shared_ptr ptr) const { + return ptr; + } +}; + +template +struct cloneType { + std::shared_ptr operator()(std::shared_ptr ptr) const { + auto result = std::make_shared::type>(*ptr); + // XXX: the line above will correctly slice the struct, and make its runtype + // type exactly equal to T. However, kind_ is a field of Type, so it will simply + // be copied, and we need to fix it in here to match the dynamic type. + result->kind_ = T::Kind; + return result; + } +}; + +struct TORCH_API Type : std::enable_shared_from_this { +private: + TypeKind kind_; + template + friend struct cloneType; + +protected: + Type(TypeKind kind) + : kind_(kind) {} + +public: + virtual bool operator==(const Type& rhs) const = 0; + + // subtyping relation. By default, we return true for the case + // when the type is exactly equal + virtual bool isSubtypeOf(const TypePtr rhs) const { + return *this == *rhs; + } + + // How this type will appear in FunctionSchema declarations + virtual std::string str() const = 0; + + // How this type will appear as if it were a type annotation in Python + // which is sometimes different than how it appears in declarations (e.g. int[] vs List[int]) + virtual std::string python_str() const { + return str(); + } + + TypeKind kind() const { + return kind_; + } + + virtual bool requires_grad() const { return false; } + + // Dynamically cast this object to the subclass indicated by the + // template variable, returning nullptr if the cast is invalid. + // NOTE: if the cast succeeds, but the casted kind is not the + // run-time kind of the type, we also slice the structure, so + // that assignments of those types to values don't accidentally + // inherit more detailed information from subclasses. + template + std::shared_ptr cast() { + auto r = std::dynamic_pointer_cast(shared_from_this()); + if (!r || T::Kind == kind()) { + return r; + } else { + return cloneType{}(r); + } + } + template + std::shared_ptr cast() const { + auto r = std::dynamic_pointer_cast(shared_from_this()); + if (!r || T::Kind == kind()) { + return r; + } else { + return cloneType{}(r); + } + } + template + std::shared_ptr expect() { + auto r = cast(); + JIT_ASSERT(r); + return r; + } + template + std::shared_ptr expect() const { + auto r = cast(); + JIT_ASSERT(r); + return r; + } + virtual ~Type() = default; + virtual bool hasFreeVariables() const { + return false; + } +}; + +inline bool operator!=(const Type & lhs, const Type & rhs) { + return !(lhs == rhs); +} + +struct DynamicType; +using DynamicTypePtr = std::shared_ptr; +// This node represents a single Tensor value, with an unknown shape. +struct TORCH_API DynamicType : public Type { + static constexpr bool is_singleton = true; + template + static DynamicTypePtr create( T&& ... all ) { + return DynamicTypePtr(new DynamicType( std::forward(all)... )); // NOLINT(modernize-make-shared) + } + + bool requires_grad() const override { return true; } + + bool operator==(const Type& rhs) const override { + return rhs.kind() == kind(); + } + std::string str() const override { + return "Tensor"; + } + static const TypeKind Kind = TypeKind::DynamicType; + // global singleton + static DynamicTypePtr get(); +private: + DynamicType() + : Type(TypeKind::DynamicType) {} +}; + +struct UndefinedTensorType; +using UndefinedTensorTypePtr = std::shared_ptr; +struct TORCH_API UndefinedTensorType : public Type { + static constexpr bool is_singleton = true; + friend struct Type; + static const TypeKind Kind = TypeKind::UndefinedTensorType; + + template + static UndefinedTensorTypePtr create( T&& ... all ) { + return UndefinedTensorTypePtr(new UndefinedTensorType( std::forward(all)... )); // NOLINT(modernize-make-shared) + } + + bool operator==(const Type& rhs) const override { + return rhs.kind() == kind(); + } + bool isSubtypeOf(const TypePtr rhs) const override { + return rhs->kind() == TypeKind::DynamicType || + rhs->kind() == TypeKind::UndefinedTensorType; + } + std::string str() const override { + return "UndefinedTensor"; + } + static UndefinedTensorTypePtr get(); +protected: + UndefinedTensorType(): Type(TypeKind::UndefinedTensorType) {} +}; + +struct TensorType; +using TensorTypePtr = std::shared_ptr; +// This node represents a single Tensor value with a specific size +struct TORCH_API TensorType : public Type { + static constexpr bool is_singleton = false; + friend struct Type; + static const TypeKind Kind = TypeKind::TensorType; + + template + static TensorTypePtr create( T&& ... all ) { + return TensorTypePtr(new TensorType( std::forward(all)... )); // NOLINT(modernize-make-shared) + } + + at::ScalarType scalarType() const { return scalar_type_; } + int device() const { return device_; } + int dim() const { return dim_; } + bool requires_grad() const override { return requires_grad_; } + + TensorTypePtr toScalarType(at::ScalarType type){ + auto t = TensorType::create(*this); + t->scalar_type_ = type; + return t; + } + TensorTypePtr withDim(int new_dim) { + auto t = TensorType::create(*this); + t->dim_ = new_dim; + return t; + } + TensorTypePtr withRequiresGrad(bool req) { + auto t = TensorType::create(*this); + t->requires_grad_ = req; + return t; + } + + bool operator==(const Type& rhs) const override { + if (rhs.kind() != TypeKind::TensorType) + return false; + auto rt = rhs.expect(); + return scalarType() == rt->scalarType() && + device() == rt->device() && + dim() == rt->dim(); + } + bool isSubtypeOf(const TypePtr rhs) const override { + if (rhs->kind() == TypeKind::DynamicType) + return true; + return rhs->kind() == TypeKind::TensorType && *this == *rhs; + } + std::string str() const override { + // str is used for user-facing error messages, where we + // don't want to reveal underlying size information. + return "Tensor"; + } + +protected: + TensorType(const at::Tensor& tensor, TypeKind kind=TypeKind::TensorType) + : TensorType(tensor.type().scalarType(), + tensor.type().is_cuda() ? tensor.get_device() : -1, + tensor.dim(), + tensor.is_variable() && tensor.requires_grad(), + kind) {} + TensorType(at::ScalarType scalar_type, int device, int dim, bool requires_grad=true, TypeKind kind=TypeKind::TensorType) + : Type(kind) + , scalar_type_(scalar_type) + , requires_grad_(at::isFloatingType(scalar_type) && requires_grad) + , device_(device) + , dim_(dim) {} + + at::ScalarType scalar_type_; + bool requires_grad_; + int device_; + int dim_; +}; + +struct CompleteTensorType; +using CompleteTensorTypePtr = std::shared_ptr; +// This node represents a single Tensor value with a specific size +struct TORCH_API CompleteTensorType : public TensorType { + static constexpr bool is_singleton = false; + friend struct Type; + template + static CompleteTensorTypePtr create( T&& ... all ) { + return CompleteTensorTypePtr(new CompleteTensorType( std::forward(all)... )); // NOLINT(modernize-make-shared) + } + + // overloaded create variadic template argument as it could not distinguish initializer list + static CompleteTensorTypePtr create(at::ScalarType scalar_type, int device, at::IntList sizes) { + return CompleteTensorTypePtr(new CompleteTensorType(scalar_type, device, sizes)); // NOLINT(modernize-make-shared) + } + static CompleteTensorTypePtr create(at::ScalarType scalar_type, int device, at::IntList sizes, at::IntList strides) { + return CompleteTensorTypePtr(new CompleteTensorType(scalar_type, device, sizes, strides)); // NOLINT(modernize-make-shared) + } + + static const TypeKind Kind = TypeKind::CompleteTensorType; + + const std::vector& sizes() const { return sizes_; } + const std::vector& strides() const { return strides_; } + + TypePtr withSizesStrides(at::IntList sizes, at::IntList strides) const { + return CompleteTensorType::create(scalar_type_, device_, sizes, strides); + } + + TypePtr withSizes(at::IntList sizes) const { + return withSizesStrides(sizes, CompleteTensorType::contiguousStridesOf(sizes)); + } + + CompleteTensorTypePtr contiguous() const { + auto t = CompleteTensorType::create(*this); + t->strides_ = CompleteTensorType::contiguousStridesOf(sizes_); + return t; + } + + CompleteTensorTypePtr toScalarType(at::ScalarType type){ + auto t = CompleteTensorType::create(*this); + t->scalar_type_ = type; + return t; + } + + bool operator==(const Type& rhs) const override { + if(rhs.kind() != kind()) + return false; + auto rt = rhs.expect(); + return scalarType() == rt->scalarType() && + sizes() == rt->sizes() && + strides() == rt->strides() && + device() == rt->device(); + } + bool isSubtypeOf(const TypePtr rhs) const override { + if (rhs->kind() == TypeKind::DynamicType) + return true; + if (rhs->kind() == TypeKind::TensorType) + return *expect() == *rhs; + return *this == *rhs; + } + std::string str() const override { + // str is used for user-facing error messages, where we + // don't want to reveal underlying size information. + return "Tensor"; + } + bool numel() const { + size_t prod = 1; + for(auto s : sizes()) { + prod *= s; + } + return prod; + } + static TypePtr fromNumberType(TypePtr typ); + static TypePtr fromBoolType(); + +private: + CompleteTensorType(const at::Tensor& tensor) + : TensorType(tensor, TypeKind::CompleteTensorType) + , sizes_(tensor.sizes().vec()) + , strides_(tensor.strides().vec()) {} + CompleteTensorType(at::ScalarType scalar_type, int device, at::IntList sizes, bool requires_grad=true) + : CompleteTensorType(scalar_type, device, sizes, CompleteTensorType::contiguousStridesOf(sizes), requires_grad) {} + CompleteTensorType(at::ScalarType scalar_type, int device, at::IntList sizes, at::IntList strides, bool requires_grad=true) + : TensorType(scalar_type, device, sizes.size(), requires_grad, TypeKind::CompleteTensorType) + , sizes_(sizes.vec()) + , strides_(strides.vec()) {} + + static std::vector contiguousStridesOf(at::IntList sizes) { + std::vector strides(sizes.size()); + if(sizes.empty()) // zero-dim case + return strides; + strides.back() = 1; + for(size_t i = strides.size() - 1; i > 0; i--) { + strides[i-1] = strides[i] * sizes[i]; + } + return strides; + } + + std::vector sizes_; + std::vector strides_; +}; + +// This type is a token used to represent effectful computation in the IR. +// See the AnnotateEffects pass for how it is used. +struct WorldType; +using WorldTypePtr = std::shared_ptr; +struct TORCH_API WorldType : public Type { + template + static WorldTypePtr create(T&&... all) { + return WorldTypePtr(new WorldType(std::forward(all)...)); + } + bool operator==(const Type& rhs) const override { + return rhs.kind() == kind(); + } + std::string str() const override { + return "world"; + } + bool isSubtypeOf(const TypePtr rhs) const override { + return *this == *rhs; + } + static const TypeKind Kind = TypeKind::WorldType; + // global singleton + static WorldTypePtr get(); + + private: + WorldType() : Type(TypeKind::WorldType) {} +}; + +struct ListType; +using ListTypePtr = std::shared_ptr; + +struct TORCH_API ListType : public Type { + // It's not exactly a singleton, but there should be exactly once instance of + // List[T] for every T + static constexpr bool is_singleton = true; + friend struct Type; + template + static ListTypePtr create( T&& ... all ) { + return ListTypePtr(new ListType( std::forward(all)... )); // NOLINT(modernize-make-shared) + } + bool operator==(const Type& rhs) const override { + if(auto rhs_ = rhs.cast()) { + return *getElementType() == *rhs_->getElementType(); + } + return false; + } + bool requires_grad() const override { + return elem->requires_grad(); + } + std::string str() const override { + std::stringstream ss; + ss << getElementType()->str() << "[]"; + return ss.str(); + } + std::string python_str() const override { + std::stringstream ss; + ss << "List[" << getElementType()->python_str() << "]"; + return ss.str(); + } + TypePtr getElementType() const { + return elem; + } + bool hasFreeVariables() const override { + return has_free_variables_; + } + // common cast List[Tensor] + static ListTypePtr ofTensors(); + static ListTypePtr ofInts(); + static ListTypePtr ofFloats(); + static ListTypePtr ofBools(); + + static const TypeKind Kind = TypeKind::ListType; +private: + ListType(TypePtr elem) + : Type(TypeKind::ListType) + , elem(std::move(elem)) + , has_free_variables_(getElementType()->hasFreeVariables()) {} + TypePtr elem; + bool has_free_variables_; +}; + +struct TupleType; +using TupleTypePtr = std::shared_ptr; + +struct TORCH_API TupleType : public Type { + static constexpr bool is_singleton = false; + friend struct Type; + static TupleTypePtr create(std::vector types) { + return TupleTypePtr(new TupleType( std::move(types) )); // NOLINT(modernize-make-shared) + } + at::ArrayRef elements() const { + return elements_; + } + bool operator==(const Type& rhs) const override { + return compare(rhs, [](const TypePtr a, const TypePtr b) { + return *a == *b; + }); + } + bool isSubtypeOf(const TypePtr rhs) const override { + // co-variant rules for tuples + return compare(*rhs, [](const TypePtr a, const TypePtr b) { + return a->isSubtypeOf(b); + }); + } + bool requires_grad() const override { + return std::any_of(elements_.begin(), elements_.end(), + [](const TypePtr& ptr) { return ptr->requires_grad(); }); + } + std::string str() const override { + std::stringstream ss; + ss << "("; + for(size_t i = 0; i < elements().size(); ++i) { + if(i > 0) + ss << ", "; + ss << elements()[i]->str(); + } + ss << ")"; + return ss.str(); + } + std::string python_str() const override { + std::stringstream ss; + ss << "Tuple["; + for(size_t i = 0; i < elements().size(); ++i) { + if(i > 0) + ss << ", "; + ss << elements()[i]->python_str(); + } + ss << "]"; + return ss.str(); + } + bool hasFreeVariables() const override { + return has_free_variables_; + } + + static const TypeKind Kind = TypeKind::TupleType; +private: + TupleType(std::vector elements_) + : Type(TypeKind::TupleType) + , elements_(std::move(elements_)) { + has_free_variables_ = + std::any_of(elements_.begin(), elements_.end(), [](TypePtr v) { + return v->hasFreeVariables(); + }); + } + + bool compare(const Type& rhs, std::function fn) const { + if(rhs.kind() != kind()) + return false; + const auto & l_elements = elements(); + const auto & r_elements = rhs.cast()->elements(); + if(l_elements.size() != r_elements.size()) + return false; + for(size_t i = 0; i < l_elements.size(); ++i) { + if(!fn(l_elements[i], r_elements[i])) + return false; + } + return true; + } + std::vector elements_; + bool has_free_variables_; +}; + +struct NumberType; +using NumberTypePtr = std::shared_ptr; +// This node represents a Python number value +struct TORCH_API NumberType : public Type { + static constexpr bool is_singleton = true; + template + static NumberTypePtr create( T&& ... all ) { + return NumberTypePtr(new NumberType( std::forward(all)... )); // NOLINT(modernize-make-shared) + } + bool operator==(const Type& rhs) const override { + return rhs.kind() == kind(); + } + std::string str() const override { + return "Scalar"; // match what PythonArgParser says for clarity + } + static const TypeKind Kind = TypeKind::NumberType; + // global singleton + static NumberTypePtr get(); +private: + NumberType() + : Type(TypeKind::NumberType) {} +}; + +struct FloatType; +using FloatTypePtr = std::shared_ptr; +// This node represents a Python float number value +struct TORCH_API FloatType : public Type { + static constexpr bool is_singleton = true; + template + static FloatTypePtr create( T&& ... all ) { + return FloatTypePtr(new FloatType( std::forward(all)... )); // NOLINT(modernize-make-shared) + } + bool operator==(const Type& rhs) const override { + return rhs.kind() == kind(); + } + std::string str() const override { + return "float"; + } + bool isSubtypeOf(const TypePtr rhs) const override { + return *this == *rhs || rhs->kind() == TypeKind::NumberType; + } + static const TypeKind Kind = TypeKind::FloatType; + // global singleton + static FloatTypePtr get(); +private: + FloatType() + : Type(TypeKind::FloatType) {} +}; + +struct IntType; +using IntTypePtr = std::shared_ptr; +// This node represents a Python int number value +struct TORCH_API IntType : public Type { + static constexpr bool is_singleton = true; + template + static IntTypePtr create( T&& ... all ) { + return IntTypePtr(new IntType( std::forward(all)... )); // NOLINT(modernize-make-shared) + } + bool operator==(const Type& rhs) const override { + return rhs.kind() == kind(); + } + std::string str() const override { + return "int"; + } + bool isSubtypeOf(const TypePtr rhs) const override { + return *this == *rhs || rhs->kind() == TypeKind::NumberType; + } + static const TypeKind Kind = TypeKind::IntType; + // global singleton + static IntTypePtr get(); +private: + IntType() + : Type(TypeKind::IntType) {} +}; + +struct BoolType; +using BoolTypePtr = std::shared_ptr; +// This node represents a Python bool value +struct TORCH_API BoolType : public Type { + template + static BoolTypePtr create( T&& ... all ) { + return BoolTypePtr(new BoolType(std::forward(all)... )); + } + bool operator==(const Type& rhs) const override { + return rhs.kind() == kind(); + } + std::string str() const override { + return "bool"; + } + bool isSubtypeOf(const TypePtr rhs) const override { + return *this == *rhs || rhs->kind() == TypeKind::BoolType; + } + static const TypeKind Kind = TypeKind::BoolType; + // global singleton + static BoolTypePtr get(); +private: + BoolType() + : Type(TypeKind::BoolType) {} +}; + +struct StringType; +using StringTypePtr = std::shared_ptr; +// This node represents a Python string value +struct TORCH_API StringType : public Type { + static constexpr bool is_singleton = true; + template + static StringTypePtr create( T&& ... all ) { + return StringTypePtr(new StringType( std::forward(all)... )); // NOLINT(modernize-make-shared) + } + bool operator==(const Type& rhs) const override { + return rhs.kind() == kind(); + } + std::string str() const override { + return "string"; + } + bool isSubtypeOf(const TypePtr rhs) const override { + return *this == *rhs; + } + static const TypeKind Kind = TypeKind::StringType; + // global singleton + static StringTypePtr get(); +private: + StringType() + : Type(TypeKind::StringType) {} +}; + +struct NoneType; +using NoneTypePtr = std::shared_ptr; +// This node represents a Python int number value +struct NoneType : public Type { + static constexpr bool is_singleton = true; + template + static NoneTypePtr create( T&& ... all ) { + return NoneTypePtr(new NoneType( std::forward(all)... )); // NOLINT(modernize-make-shared) + } + bool operator==(const Type& rhs) const override { + return rhs.kind() == kind(); + } + std::string str() const override { + return "None"; + } + static const TypeKind Kind = TypeKind::NoneType; + // global singleton + static NoneTypePtr get(); +private: + NoneType() + : Type(TypeKind::NoneType) {} +}; + +struct GeneratorType; +using GeneratorTypePtr = std::shared_ptr; +struct GeneratorType : public Type { + static constexpr bool is_singleton = true; + template + static GeneratorTypePtr create( T&& ... all) { + return GeneratorTypePtr(new GeneratorType( std::forward(all)... )); // NOLINT(modernize-make-shared) + } + bool operator==(const Type& rhs) const override { + return rhs.kind() == kind(); + } + std::string str() const override { + return "Generator"; + } + static const TypeKind Kind = TypeKind::GeneratorType; + // global singleton + static GeneratorTypePtr get(); +private: + GeneratorType() + : Type(TypeKind::GeneratorType) {} +}; + + +// a type variable, used in FunctionSchema +struct VarType; +using VarTypePtr = std::shared_ptr; +struct VarType : public Type { + static constexpr bool is_singleton = false; + template + static VarTypePtr create(std::string name_) { + return VarTypePtr(new VarType(std::move(name_))); + } + bool operator==(const Type& rhs) const override { + return rhs.kind() == kind(); + } + std::string str() const override { + return name(); + } + static const TypeKind Kind = TypeKind::VarType; + const std::string& name() const { + return name_; + } + bool hasFreeVariables() const override { + return true; + } +private: + VarType(std::string name_) + : Type(TypeKind::VarType), name_(name_) {} + std::string name_; +}; + +TORCH_API std::ostream& operator<<(std::ostream & out, const Type & t); +// what is the type, ignoring extra size/shape information? +// e.g. Tensor(2x3) -> Dynamic, and Tuple(Tensor(2x3),...) -> Tuple(Dynamic,...) + +inline TypePtr unshapedType(const TypePtr& type) { + if (TupleTypePtr t = type->cast()) { + return TupleType::create(fmap(t->elements(), unshapedType)); + } else if (ListTypePtr t = type->cast()) { + return ListType::create(unshapedType(t->getElementType())); + } else if (type->kind() == TypeKind::TensorType || + type->kind() == TypeKind::CompleteTensorType) { + return DynamicType::get(); + } else { + return type; + } +} + +inline TypePtr CompleteTensorType::fromNumberType(TypePtr typ) { + JIT_ASSERT(typ->isSubtypeOf(NumberType::get())); + if (typ->isSubtypeOf(IntType::get())) { + return CompleteTensorType::create(at::kLong, -1, {}); + } else if (typ->isSubtypeOf(FloatType::get())) { + return CompleteTensorType::create(at::kFloat, -1, {}); + } else if (typ->isSubtypeOf(BoolType::get())) { + return CompleteTensorType::create(at::kLong, -1, {}); + } + AT_ERROR("unknown number type", typ->str()); +} + +inline TypePtr CompleteTensorType::fromBoolType() { + return CompleteTensorType::create(at::kLong, -1, {}); +} + + +// Attempt to find the correct supertype of t1 and t2. If none is found then +// nullopt will be returned. If t1 == t2, or t1 is a type refinement of t2, +// then t2 will be returned (and vice versa). +// Two different tensortypes will return dynamic. +// Currently we chose not to support returning a NumberType for a float & int +// input because of a lack of operator support for NumberType +TORCH_API at::optional unifyTypes(const TypePtr& t1, const TypePtr& t2); + +template +TypePtr getTypePtr() { +#define TYPE_STR(Type) #Type, " ", + AT_ERROR( + "Type ", + at::demangle_type(), + " could not be converted to any of the known types { ", + TH_FORALL_TYPES(TYPE_STR) "}"); +#undef TYPE_STR + return nullptr; +} + +template<> inline TypePtr getTypePtr() { return DynamicType::get(); } +template<> inline TypePtr getTypePtr() { return FloatType::get(); } +template<> inline TypePtr getTypePtr() { return IntType::get(); } +template<> inline TypePtr getTypePtr() { return BoolType::get(); } +template<> inline TypePtr getTypePtr() { return NumberType::get(); } +template<> inline TypePtr getTypePtr>() { return ListType::ofTensors(); } +template<> inline TypePtr getTypePtr>() { return ListType::ofFloats(); } +template<> inline TypePtr getTypePtr>() { return ListType::ofInts(); } + +TORCH_API TypePtr inferTypeFrom(const IValue& value); + +struct TORCH_API TypeMatchError : public std::exception { + TypeMatchError(std::string msg_) + : msg_(std::move(msg_)) {} + const char * what() const noexcept override { + return msg_.c_str(); + } +private: + std::string msg_; +}; +using TypeEnv = std::unordered_map; +TORCH_API TypePtr matchTypeVariables(TypePtr formal, TypePtr actual, TypeEnv & type_env); +TORCH_API TypePtr evalTypeVariables(TypePtr type, TypeEnv & type_env); + +}} // namespace torch::jit diff --git a/torch/csrc/utils/functional.h b/torch/csrc/utils/functional.h index 81a0316978bd3..af5099e7ce4e8 100644 --- a/torch/csrc/utils/functional.h +++ b/torch/csrc/utils/functional.h @@ -1 +1,63 @@ -#include +#pragma once + +#include +#include + +namespace torch { + +// The passed in function must take T by value (T), or by +// const reference (const T&); taking T by non-const reference +// will result in an error like: +// +// error: no type named 'type' in 'class std::result_of' +// +// No explicit template parameters are required. + +// Overload for explicit function and ArrayRef +template +inline auto fmap(const T& inputs, const F& fn) -> std::vector { + std::vector r; + r.reserve(inputs.size()); + for(const auto & input : inputs) + r.push_back(fn(input)); + return r; +} + +template +inline auto fmap(T& inputs, const F& fn) -> std::vector { + std::vector r; + r.reserve(inputs.size()); + for(auto & input : inputs) + r.push_back(fn(input)); + return r; +} + +// C++ forbids taking an address of a constructor, so here's a workaround... +// Overload for constructor (R) application +template +inline std::vector fmap(const T& inputs) { + std::vector r; + r.reserve(inputs.size()); + for(auto & input : inputs) + r.push_back(R(input)); + return r; +} + +template +inline std::vector filter(at::ArrayRef inputs, const F& fn) { + std::vector r; + r.reserve(inputs.size()); + for(auto & input : inputs) { + if (fn(input)) { + r.push_back(input); + } + } + return r; +} + +template +inline std::vector filter(const std::vector& inputs, const F& fn) { + return filter(static_cast>(inputs), fn); +} + +}