Skip to content

Commit

Permalink
Add a base class for abstract sets
Browse files Browse the repository at this point in the history
Summary: Same concept as D50019979, but this implements a base class for all set implementations.

Reviewed By: arnaudvenet

Differential Revision: D50054827

fbshipit-source-id: b733bd3e82bb6bf42ecebffad25d561baf04e8a0
  • Loading branch information
arthaud authored and facebook-github-bot committed Oct 9, 2023
1 parent 483d9b8 commit 028d60b
Show file tree
Hide file tree
Showing 3 changed files with 195 additions and 36 deletions.
177 changes: 177 additions & 0 deletions include/sparta/AbstractSet.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <type_traits>
#include <utility>

namespace sparta {

/*
* This describes the API for a generic set container.
*/
template <typename Derived>
class AbstractSet {
public:
~AbstractSet() {
// The destructor is the only method that is guaranteed to be created when a
// class template is instantiated. This is a good place to perform all the
// sanity checks on the template parameters.
static_assert(std::is_base_of<AbstractSet<Derived>, Derived>::value,
"Derived doesn't inherit from AbstractSet");
static_assert(std::is_final<Derived>::value, "Derived is not final");

using Element = typename Derived::value_type;
using Iterator = typename Derived::iterator;

// Derived();
static_assert(std::is_default_constructible<Derived>::value,
"Derived is not default constructible");

// Derived(const Derived&);
static_assert(std::is_copy_constructible<Derived>::value,
"Derived is not copy constructible");

// Derived& operator=(const Derived&);
static_assert(std::is_copy_assignable<Derived>::value,
"Derived is not copy assignable");

// bool empty() const;
static_assert(std::is_same<decltype(std::declval<const Derived>().empty()),
bool>::value,
"Derived::empty() does not exist");

// std::size_t size() const;
static_assert(std::is_same<decltype(std::declval<const Derived>().size()),
std::size_t>::value,
"Derived::size() does not exist");

// std::size_t max_size() const;
static_assert(
std::is_same<decltype(std::declval<const Derived>().max_size()),
std::size_t>::value,
"Derived::max_size() does not exist");

// iterator begin() const;
static_assert(std::is_same<decltype(std::declval<const Derived>().begin()),
Iterator>::value,
"Derived::begin() does not exist");

// iterator end() const;
static_assert(std::is_same<decltype(std::declval<const Derived>().end()),
Iterator>::value,
"Derived::begin() does not exist");

// Derived& insert(const Element& element);
static_assert(std::is_same<decltype(std::declval<Derived>().insert(
std::declval<const Element>())),
Derived&>::value,
"Derived::insert(const Element&) does not exist");

// Derived& remove(const Element& element);
static_assert(std::is_same<decltype(std::declval<Derived>().remove(
std::declval<const Element>())),
Derived&>::value,
"Derived::remove(const Element&) does not exist");

// void clear();
static_assert(std::is_void_v<decltype(std::declval<Derived>().clear())>,
"Derived::clear() does not exist");
/*
* If the set is a singleton, returns a pointer to the element.
* Otherwise, returns nullptr.
*/
// const Element* singleton() const;
static_assert(
std::is_same<decltype(std::declval<const Derived>().singleton()),
const Element*>::value,
"Derived::singleton() does not exist");

// bool contains(const Element& element) const;
static_assert(std::is_same<decltype(std::declval<const Derived>().contains(
std::declval<const Element>())),
bool>::value,
"Derived::contains(const Element&) does not exist");

// bool is_subset_of(const Derived& other) const;
static_assert(
std::is_same<decltype(std::declval<const Derived>().is_subset_of(
std::declval<const Derived>())),
bool>::value,
"Derived::is_subset_of(const Derived&) does not exist");

// bool equals(const Derived& other) const;
static_assert(std::is_same<decltype(std::declval<const Derived>().equals(
std::declval<const Derived>())),
bool>::value,
"Derived::equals(const Derived&) does not exist");

// void visit(Visitor&& visitor) const;
static_assert(std::is_same<decltype(std::declval<const Derived>().visit(
std::declval<void(const Element&)>())),
void>::value,
"Derived::visit(Visitor&&) does not exist");

// Derived& filter(Predicate&& predicate);
static_assert(std::is_same<decltype(std::declval<Derived>().filter(
std::declval<bool(const Element&)>())),
Derived&>::value,
"Derived::filter(Predicate&&) does not exist");

// Derived& union_with(const Derived& other);
static_assert(std::is_same<decltype(std::declval<Derived>().union_with(
std::declval<const Derived>())),
Derived&>::value,
"Derived::union_with(const Derived&) does not exist");

// Derived& intersection_with(const Derived& other);
static_assert(
std::is_same<decltype(std::declval<Derived>().intersection_with(
std::declval<const Derived>())),
Derived&>::value,
"Derived::intersection_with(const Derived&) does not exist");

// Derived& difference_with(const Derived& other);
static_assert(std::is_same<decltype(std::declval<Derived>().difference_with(
std::declval<const Derived>())),
Derived&>::value,
"Derived::difference_with(const Derived&) does not exist");
}

/*
* Many C++ libraries default to using operator== to check for equality,
* so we define it here as an alias of equals().
*/
friend bool operator==(const Derived& self, const Derived& other) {
return self.equals(other);
}

friend bool operator!=(const Derived& self, const Derived& other) {
return !self.equals(other);
}

Derived get_union_with(const Derived& other) const {
Derived result(static_cast<const Derived&>(*this));
result.union_with(other);
return result;
}

Derived get_intersection_with(const Derived& other) const {
Derived result(static_cast<const Derived&>(*this));
result.intersection_with(other);
return result;
}

Derived get_difference_with(const Derived& other) const {
Derived result(static_cast<const Derived&>(*this));
result.difference_with(other);
return result;
}
};

} // namespace sparta
25 changes: 16 additions & 9 deletions include/sparta/FlatSet.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include <boost/container/flat_set.hpp>

#include <sparta/AbstractSet.h>
#include <sparta/PatriciaTreeUtil.h>

namespace sparta {
Expand All @@ -30,7 +31,9 @@ template <typename Element,
typename Equal = std::equal_to<Element>,
typename AllocatorOrContainer =
boost::container::new_allocator<Element>>
class FlatSet final {
class FlatSet final
: public AbstractSet<
FlatSet<Element, Compare, Equal, AllocatorOrContainer>> {
private:
using BoostFlatSet =
boost::container::flat_set<Element, Compare, AllocatorOrContainer>;
Expand Down Expand Up @@ -95,14 +98,6 @@ class FlatSet final {
other.m_set.end(), Equal());
}

friend bool operator==(const FlatSet& s1, const FlatSet& s2) {
return s1.equals(s2);
}

friend bool operator!=(const FlatSet& s1, const FlatSet& s2) {
return !s1.equals(s2);
}

FlatSet& insert(Element key) {
m_set.insert(key);
return *this;
Expand All @@ -120,6 +115,18 @@ class FlatSet final {
}
}

/*
* If the set is a singleton, returns a pointer to the element.
* Otherwise, returns nullptr.
*/
const Element* singleton() const {
if (m_set.size() == 1) {
return &*m_set.begin();
} else {
return nullptr;
}
}

template <typename Predicate> // bool(const Element&)
FlatSet& filter(Predicate&& predicate) {
auto container = m_set.extract_sequence();
Expand Down
29 changes: 2 additions & 27 deletions include/sparta/PatriciaTreeSet.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <type_traits>
#include <utility>

#include <sparta/AbstractSet.h>
#include <sparta/Exceptions.h>
#include <sparta/PatriciaTreeCore.h>
#include <sparta/PatriciaTreeUtil.h>
Expand Down Expand Up @@ -48,7 +49,7 @@ namespace sparta {
* or pointers to objects.
*/
template <typename Element>
class PatriciaTreeSet final {
class PatriciaTreeSet final : public AbstractSet<PatriciaTreeSet<Element>> {
using Empty = pt_core::EmptyValue;
using Core = pt_core::PatriciaTreeCore<Element, Empty>;
using Codec = typename Core::Codec;
Expand Down Expand Up @@ -102,14 +103,6 @@ class PatriciaTreeSet final {
return m_core.equals(other.m_core);
}

friend bool operator==(const PatriciaTreeSet& s1, const PatriciaTreeSet& s2) {
return s1.equals(s2);
}

friend bool operator!=(const PatriciaTreeSet& s1, const PatriciaTreeSet& s2) {
return !s1.equals(s2);
}

/*
* This faster equality predicate can be used to check whether a sequence of
* in-place modifications leaves a Patricia-tree set unchanged. For comparing
Expand Down Expand Up @@ -186,24 +179,6 @@ class PatriciaTreeSet final {
return *this;
}

PatriciaTreeSet get_union_with(const PatriciaTreeSet& other) const {
auto result = *this;
result.union_with(other);
return result;
}

PatriciaTreeSet get_intersection_with(const PatriciaTreeSet& other) const {
auto result = *this;
result.intersection_with(other);
return result;
}

PatriciaTreeSet get_difference_with(const PatriciaTreeSet& other) const {
auto result = *this;
result.difference_with(other);
return result;
}

/*
* The hash codes are computed incrementally when the Patricia trees are
* constructed. Hence, this method has complexity O(1).
Expand Down

0 comments on commit 028d60b

Please sign in to comment.