diff --git a/velox/type/Type.cpp b/velox/type/Type.cpp index 17b88074f172..9f4f43c93870 100644 --- a/velox/type/Type.cpp +++ b/velox/type/Type.cpp @@ -244,6 +244,17 @@ bool ArrayType::equivalent(const Type& other) const { return child_->equivalent(*otherArray.child_); } +bool ArrayType::equals(const Type& other) const { + if (&other == this) { + return true; + } + if (!Type::hasSameTypeId(other)) { + return false; + } + auto& otherArray = other.asArray(); + return *child_ == *otherArray.child_; +} + folly::dynamic ArrayType::serialize() const { folly::dynamic obj = folly::dynamic::object; obj["name"] = "Type"; @@ -476,14 +487,6 @@ bool RowType::equals(const Type& other) const { return true; } -bool RowType::operator==(const Type& other) const { - return this->equals(other); -} - -bool RowType::operator==(const RowType& other) const { - return this->equals(other); -} - void RowType::printChildren(std::stringstream& ss, std::string_view delimiter) const { bool any = false; @@ -569,6 +572,17 @@ bool MapType::equivalent(const Type& other) const { valueType_->equivalent(*otherMap.valueType_); } +bool MapType::equals(const Type& other) const { + if (&other == this) { + return true; + } + if (!Type::hasSameTypeId(other)) { + return false; + } + auto& otherMap = other.asMap(); + return *keyType_ == *otherMap.keyType_ && *valueType_ == *otherMap.valueType_; +} + FunctionType::FunctionType( std::vector>&& argumentTypes, std::shared_ptr returnType) @@ -598,6 +612,29 @@ bool FunctionType::equivalent(const Type& other) const { return true; } +bool FunctionType::equals(const Type& other) const { + if (&other == this) { + return true; + } + + if (!Type::hasSameTypeId(other)) { + return false; + } + + auto& otherTyped = *reinterpret_cast(&other); + if (children_.size() != otherTyped.size()) { + return false; + } + + for (auto i = 0; i < children_.size(); ++i) { + if (*children_.at(i) != *otherTyped.children_.at(i)) { + return false; + } + } + + return true; +} + std::string FunctionType::toString() const { std::stringstream out; out << "FUNCTION<"; @@ -629,10 +666,7 @@ bool OpaqueType::equivalent(const Type& other) const { return true; } -bool OpaqueType::operator==(const Type& other) const { - if (&other == this) { - return true; - } +bool OpaqueType::equals(const Type& other) const { if (!this->equivalent(other)) { return false; } diff --git a/velox/type/Type.h b/velox/type/Type.h index 15ede35c463c..40e00ecd2511 100644 --- a/velox/type/Type.h +++ b/velox/type/Type.h @@ -32,6 +32,7 @@ #include #include "velox/common/base/ClassName.h" +#include "velox/common/base/Exceptions.h" #include "velox/common/serialization/Serializable.h" #include "velox/type/HugeInt.h" #include "velox/type/StringView.h" @@ -492,12 +493,10 @@ class Type : public Tree, public velox::ISerializable { /// equivalent if the typeKind matches, but the typeIndex could be different. virtual bool equivalent(const Type& other) const = 0; - /// Types are strongly matched. - /// Examples: Two RowTypes are == if the children types and the children names - /// are same. Two OpaqueTypes are == if the typeKind and the typeIndex are - /// same. Same as equivalent for most types except for Row, Opaque types. + /// For Complex types (Row, Array, Map, Opaque): types are strongly matched. + /// For primitive types: same as equivalent. virtual bool operator==(const Type& other) const { - return this->equivalent(other); + return this->equals(other); } inline bool operator!=(const Type& other) const { @@ -566,6 +565,16 @@ class Type : public Tree, public velox::ISerializable { return typeid(*this) == typeid(other); } + /// For Complex types (Row, Array, Map, Opaque): types are strongly matched. + /// Examples: Two RowTypes are == if the children types and the children names + /// are same. Two OpaqueTypes are == if the typeKind and the typeIndex are + /// same. + /// For primitive types: same as equivalent. + virtual bool equals(const Type& other) const { + VELOX_CHECK(this->isPrimitiveType()); + return this->equivalent(other); + } + private: const TypeKind kind_; const bool providesCustomComparison_; @@ -914,6 +923,8 @@ class ArrayType : public TypeBase { } protected: + bool equals(const Type& other) const override; + TypePtr child_; const std::vector parameters_; }; @@ -965,6 +976,9 @@ class MapType : public TypeBase { return parameters_; } + protected: + bool equals(const Type& other) const override; + private: TypePtr keyType_; TypePtr valueType_; @@ -1005,10 +1019,6 @@ class RowType : public TypeBase { bool equivalent(const Type& other) const override; - bool equals(const Type& other) const; - bool operator==(const Type& other) const override; - bool operator==(const RowType& other) const; - std::string toString() const override; /// Print child names and types separated by 'delimiter'. @@ -1037,6 +1047,9 @@ class RowType : public TypeBase { return *parameters; } + protected: + bool equals(const Type& other) const override; + private: std::unique_ptr> makeParameters() const; @@ -1090,6 +1103,9 @@ class FunctionType : public TypeBase { return parameters_; } + protected: + bool equals(const Type& other) const override; + private: static std::vector> allChildren( std::vector>&& argumentTypes, @@ -1124,8 +1140,6 @@ class OpaqueType : public TypeBase { bool equivalent(const Type& other) const override; - bool operator==(const Type& other) const override; - const std::type_index& typeIndex() const { return typeIndex_; } @@ -1186,6 +1200,9 @@ class OpaqueType : public TypeBase { deserializeTypeErased); } + protected: + bool equals(const Type& other) const override; + private: const std::type_index typeIndex_; diff --git a/velox/vector/tests/VectorMakerTest.cpp b/velox/vector/tests/VectorMakerTest.cpp index c2887fb4044e..d464f15ab364 100644 --- a/velox/vector/tests/VectorMakerTest.cpp +++ b/velox/vector/tests/VectorMakerTest.cpp @@ -596,7 +596,9 @@ TEST_F(VectorMakerTest, arrayOfRowVectorFromTuples) { auto expected = maker_.arrayVector(offsets, elements); ASSERT_EQ(expected->size(), arrayVector->size()); - ASSERT_EQ(*expected->type(), *arrayVector->type()); + // check equivalent because arrayVector's row type doesn't have name for each + // column ('', '' ..) whereas expected's row type have names ('c0', 'c1' ..) + ASSERT_TRUE((*expected->type()).equivalent((*arrayVector->type()))); for (auto i = 0; i < expected->size(); i++) { ASSERT_TRUE(expected->equalValueAt(arrayVector.get(), i, i)); }