Skip to content

Commit

Permalink
Merge pull request #740 from LLNL/feature/yang39/arrayview_const
Browse files Browse the repository at this point in the history
Allow modifying underlying data from non-const ArrayView
  • Loading branch information
publixsubfan committed Dec 3, 2021
2 parents b344511 + 738956f commit 9d9ff2b
Show file tree
Hide file tree
Showing 8 changed files with 85 additions and 49 deletions.
11 changes: 11 additions & 0 deletions src/axom/core/Array.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,17 @@ struct Uninitialized
template <typename T, int DIM, MemorySpace SPACE>
class Array;

namespace detail
{
// Static information to pass to ArrayBase
template <typename T, int DIM, MemorySpace SPACE>
struct ArrayTraits<Array<T, DIM, SPACE>>
{
constexpr static bool is_view = false;
};

} // namespace detail

/*!
* \class Array
*
Expand Down
40 changes: 37 additions & 3 deletions src/axom/core/ArrayBase.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ namespace axom
template <typename T, int DIM, typename ArrayType>
class ArrayBase;

namespace detail
{
template <typename ArrayType>
struct ArrayTraits;
}

/// \name Overloaded ArrayBase Operator(s)
/// @{

Expand Down Expand Up @@ -83,11 +89,27 @@ bool operator!=(const ArrayBase<T1, DIM, LArrayType>& lhs,
* const T* data() const;
* int getAllocatorID() const;
* \endcode
*
* \pre A specialization of ArrayTraits for all ArrayTypes must also be provided
* with the boolean value IsView, which affects the const-ness of returned
* references.
*/
template <typename T, int DIM, typename ArrayType>
class ArrayBase
{
private:
constexpr static bool is_array_view = detail::ArrayTraits<ArrayType>::is_view;

public:
/* If ArrayType is an ArrayView, we use shallow-const semantics, akin to
* std::span; a const ArrayView will still allow for mutating the underlying
* pointed-to data.
*
* If ArrayType is an Array, we use deep-const semantics, akin to std::vector;
* a const Array will prevent modifications of the underlying Array data.
*/
using RealConstT = typename std::conditional<is_array_view, T, const T>::type;

/*!
* \brief Parameterized constructor that sets up the default strides
*
Expand Down Expand Up @@ -143,7 +165,7 @@ class ArrayBase
/// \overload
template <typename... Args,
typename SFINAE = typename std::enable_if<sizeof...(Args) == DIM>::type>
AXOM_HOST_DEVICE const T& operator()(Args... args) const
AXOM_HOST_DEVICE RealConstT& operator()(Args... args) const
{
const IndexType indices[] = {static_cast<IndexType>(args)...};
const IndexType idx = numerics::dot_product(indices, m_strides.begin(), DIM);
Expand All @@ -169,7 +191,7 @@ class ArrayBase
return asDerived().data()[idx];
}
/// \overload
AXOM_HOST_DEVICE const T& operator[](const IndexType idx) const
AXOM_HOST_DEVICE RealConstT& operator[](const IndexType idx) const
{
assert(inBounds(idx));
return asDerived().data()[idx];
Expand Down Expand Up @@ -283,7 +305,19 @@ class ArrayBase
template <typename T, typename ArrayType>
class ArrayBase<T, 1, ArrayType>
{
private:
constexpr static bool is_array_view = detail::ArrayTraits<ArrayType>::is_view;

public:
/* If ArrayType is an ArrayView, we use shallow-const semantics, akin to
* std::span; a const ArrayView will still allow for mutating the underlying
* pointed-to data.
*
* If ArrayType is an Array, we use deep-const semantics, akin to std::vector;
* a const Array will prevent modifications of the underlying Array data.
*/
using RealConstT = typename std::conditional<is_array_view, T, const T>::type;

ArrayBase(IndexType = 0) { }

// Empy implementation because no member data
Expand Down Expand Up @@ -321,7 +355,7 @@ class ArrayBase<T, 1, ArrayType>
return asDerived().data()[idx];
}
/// \overload
AXOM_HOST_DEVICE const T& operator[](const IndexType idx) const
AXOM_HOST_DEVICE RealConstT& operator[](const IndexType idx) const
{
assert(inBounds(idx));
return asDerived().data()[idx];
Expand Down
7 changes: 6 additions & 1 deletion src/axom/core/ArrayIteratorBase.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,14 @@ template <typename ArrayType, typename ValueType>
class ArrayIteratorBase
: public IteratorBase<ArrayIteratorBase<ArrayType, ValueType>, IndexType>
{
private:
constexpr static bool ReturnConstRef = std::is_const<ValueType>::value;
constexpr static bool FromArrayView =
!std::is_const<typename ArrayType::RealConstT>::value;

public:
using ArrayPointerType =
typename std::conditional<std::is_const<ValueType>::value,
typename std::conditional<ReturnConstRef || FromArrayView,
const ArrayType*,
ArrayType*>::type;
// FIXME: Define the iterator_traits types (or possibly in IteratorBase)
Expand Down
45 changes: 18 additions & 27 deletions src/axom/core/ArrayView.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,17 @@ namespace axom
template <typename T, int DIM, MemorySpace SPACE>
class ArrayView;

namespace detail
{
// Static information to pass to ArrayBase
template <typename T, int DIM, MemorySpace SPACE>
struct ArrayTraits<ArrayView<T, DIM, SPACE>>
{
constexpr static bool is_view = true;
};

} // namespace detail

/// \name ArrayView to wrap a pointer and provide indexing semantics
/// @{

Expand All @@ -38,8 +49,6 @@ class ArrayView : public ArrayBase<T, DIM, ArrayView<T, DIM, SPACE>>
static constexpr int dimension = DIM;
static constexpr MemorySpace space = SPACE;
using ArrayViewIterator = ArrayIteratorBase<ArrayView<T, DIM, SPACE>, T>;
using ConstArrayViewIterator =
ArrayIteratorBase<ArrayView<T, DIM, SPACE>, const T>;

/// \brief Default constructor
ArrayView() : m_allocator_id(axom::detail::getAllocatorID<SPACE>()) { }
Expand Down Expand Up @@ -83,50 +92,28 @@ class ArrayView : public ArrayBase<T, DIM, ArrayView<T, DIM, SPACE>>
/*!
* \brief Returns an ArrayViewIterator to the first element of the Array
*/
ArrayViewIterator begin()
ArrayViewIterator begin() const
{
assert(m_data != nullptr);
return ArrayViewIterator(0, this);
}

/// \overload
ConstArrayViewIterator begin() const
{
assert(m_data != nullptr);
return ConstArrayViewIterator(0, this);
}

/*!
* \brief Returns an ArrayViewIterator to the element following the last
* element of the Array.
*/
ArrayViewIterator end()
ArrayViewIterator end() const
{
assert(m_data != nullptr);
return ArrayViewIterator(size(), this);
}

/// \overload
ConstArrayViewIterator end() const
{
assert(m_data != nullptr);
return ConstArrayViewIterator(size(), this);
}

/*!
* \brief Return a pointer to the array of data.
*/
/// @{

AXOM_HOST_DEVICE inline T* data()
{
#ifdef AXOM_DEVICE_CODE
static_assert(SPACE != MemorySpace::Constant,
"Cannot modify Constant memory from device code");
#endif
return m_data;
}
AXOM_HOST_DEVICE inline const T* data() const { return m_data; }
AXOM_HOST_DEVICE inline T* data() const { return m_data; }

/// @}

Expand Down Expand Up @@ -162,6 +149,10 @@ ArrayView<T, DIM, SPACE>::ArrayView(T* data, Args... args)
{
static_assert(sizeof...(Args) == DIM,
"Array size must match number of dimensions");
#ifdef AXOM_DEVICE_CODE
static_assert((SPACE != MemorySpace::Constant) || std::is_const<T>::value,
"T must be const if memory space is Constant memory");
#endif
// Intel hits internal compiler error when casting as part of function call
IndexType tmp_args[] = {args...};
m_num_elements = detail::packProduct(tmp_args);
Expand Down
2 changes: 0 additions & 2 deletions src/axom/core/docs/sphinx/core_containers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,6 @@ via a lambda:
:end-before: _array_w_raja_end
:language: C++

.. note:: We need to mark the lambda as `mutable` if we want to modify the array (array view) data.

##########
StackArray
##########
Expand Down
11 changes: 4 additions & 7 deletions src/axom/core/examples/core_containers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -287,13 +287,10 @@ void demoArrayDevice()
axom::Array<int, 1, axom::MemorySpace::Device> C_device_raja(N);
DeviceIntArrayView C_view = C_device_raja;

// Declare the lambda mutable so our copy of C_view (captured by value) is mutable
axom::for_all<axom::CUDA_EXEC<1>>(
0,
N,
[=] AXOM_HOST_DEVICE(axom::IndexType i) mutable {
C_view[i] = A_view[i] + B_view[i] + 1;
});
// Write to the underlying array through C_view, which is captured by value
axom::for_all<axom::CUDA_EXEC<1>>(0, N, [=] AXOM_HOST_DEVICE(axom::IndexType i) {
C_view[i] = A_view[i] + B_view[i] + 1;
});

// Finally, copy things over to host memory so we can display the data
axom::Array<int, 1, axom::MemorySpace::Host> C_host_raja = C_view;
Expand Down
12 changes: 6 additions & 6 deletions src/axom/core/tests/core_array_for_all.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,11 @@ AXOM_TYPED_TEST(core_array_for_all, explicit_ArrayView)
constexpr int N = 374;
KernelArray arr(N);

// Modify array using mutable lambda and ArrayView
// Modify array using lambda and ArrayView
KernelArrayView arr_view(arr);
axom::for_all<ExecSpace>(
N,
AXOM_LAMBDA(axom::IndexType idx) mutable { arr_view[idx] = N - idx; });
AXOM_LAMBDA(axom::IndexType idx) { arr_view[idx] = N - idx; });

// handles synchronization, if necessary
if(axom::execution_space<ExecSpace>::async())
Expand All @@ -99,11 +99,11 @@ AXOM_TYPED_TEST(core_array_for_all, auto_ArrayView)
constexpr int N = 374;
KernelArray arr(N);

// Modify array using mutable lambda and ArrayView
// Modify array using lambda and ArrayView
auto arr_view = arr.view();
axom::for_all<ExecSpace>(
N,
AXOM_LAMBDA(axom::IndexType idx) mutable { arr_view[idx] = N - idx; });
AXOM_LAMBDA(axom::IndexType idx) { arr_view[idx] = N - idx; });

// handles synchronization, if necessary
if(axom::execution_space<ExecSpace>::async())
Expand Down Expand Up @@ -203,11 +203,11 @@ AXOM_TYPED_TEST(core_array_for_all, dynamic_array)
constexpr axom::IndexType N = 374;
DynamicArray arr(N, N, kernelAllocID);

// Modify array using mutable lambda and ArrayView
// Modify array using lambda and ArrayView
auto arr_view = arr.view();
axom::for_all<ExecSpace>(
N,
AXOM_LAMBDA(axom::IndexType idx) mutable { arr_view[idx] = N - idx; });
AXOM_LAMBDA(axom::IndexType idx) { arr_view[idx] = N - idx; });

// handles synchronization, if necessary
if(axom::execution_space<ExecSpace>::async())
Expand Down
6 changes: 3 additions & 3 deletions src/axom/quest/detail/PointFinder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ class PointFinder
// Step 1: count number of candidate intersections for each point
for_all<ExecSpace>(
npts,
AXOM_LAMBDA(IndexType i) mutable {
AXOM_LAMBDA(IndexType i) {
countsPtr[i] = gridQuery.countCandidates(pts[i]);
totalCountReduce += countsPtr[i];
});
Expand All @@ -191,7 +191,7 @@ class PointFinder
// Step 4: fill candidate array for each query box
for_all<ExecSpace>(
npts,
AXOM_LAMBDA(IndexType i) mutable {
AXOM_LAMBDA(IndexType i) {
int startIdx = offsetsPtr[i];
int currCount = 0;
auto onCandidate = [&](int candidateIdx) -> bool {
Expand Down Expand Up @@ -237,7 +237,7 @@ class PointFinder
// don't build MFEM in a thread-safe manner.
for_all<SEQ_EXEC>(
npts,
AXOM_HOST_LAMBDA(IndexType i) mutable {
AXOM_HOST_LAMBDA(IndexType i) {
outCellIdsPtr[i] = PointInCellTraits<mesh_tag>::NO_CELL;
SpacePoint pt = pts[i];
SpacePoint isopar;
Expand Down

0 comments on commit 9d9ff2b

Please sign in to comment.