Skip to content

Commit

Permalink
Manually unroll multi-vector casting conversions
Browse files Browse the repository at this point in the history
  • Loading branch information
p12tic committed Jun 11, 2024
1 parent 44c41af commit 4970d73
Showing 1 changed file with 34 additions and 28 deletions.
62 changes: 34 additions & 28 deletions simdpp/detail/cast_bitwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -309,19 +309,17 @@ template<> struct native_cast_combine<64, __m256i, __m512d> {
};
#endif

template<unsigned CastType>
struct cast_bitwise_vector_impl;

template<class T>
struct is_vararray : std::false_type {};

template<class T, unsigned N>
struct is_vararray<vararray<T, N>> : std::true_type {};

template<>
struct cast_bitwise_vector_impl<VECTOR_CAST_TYPE_1_TO_1> {
template<unsigned I, unsigned MaxI>
struct cast_bitwise_vector_impl {
template<class T, class R> SIMDPP_INL static
void cast(const T& t, R& r)
void cast(const T& t, R& r, std::integral_constant<unsigned, VECTOR_CAST_TYPE_1_TO_1> c)
{
using NativeT = typename T::base_vector_type::native_type;
using NativeR = typename R::base_vector_type::native_type;
Expand All @@ -330,57 +328,65 @@ struct cast_bitwise_vector_impl<VECTOR_CAST_TYPE_1_TO_1> {
using CastImpl = native_cast<sizeof(NativeT), NativeT,
NativeR, is_arg_vararray>;

for (unsigned i = 0; i < T::vec_length; ++i) {
r.vec(i) = CastImpl::cast(t.vec(i).native());
}
r.template vec<I>() = CastImpl::cast(t.template vec<I>().native());
cast_bitwise_vector_impl<I + 1, MaxI>::cast(t, r, c);
}
};

template<>
struct cast_bitwise_vector_impl<VECTOR_CAST_TYPE_SPLIT2> {
template<class T, class R> SIMDPP_INL static
void cast(const T& t, R& r)
void cast(const T& t, R& r, std::integral_constant<unsigned, VECTOR_CAST_TYPE_SPLIT2> c)
{
using NativeT = typename T::base_vector_type::native_type;
using NativeR = typename R::base_vector_type::native_type;
using CastImpl = native_cast_split<sizeof(NativeT), NativeT, NativeR>;

for (unsigned i = 0; i < T::vec_length; ++i) {
NativeR r0, r1;
CastImpl::cast(t.vec(i).native(), r0, r1);
r.vec(i*2) = r0;
r.vec(i*2+1) = r1;
}
NativeR r0, r1;
CastImpl::cast(t.template vec<I>().native(), r0, r1);
r.template vec<I*2>() = r0;
r.template vec<I*2+1>() = r1;

cast_bitwise_vector_impl<I + 1, MaxI>::cast(t, r, c);
}
};

template<>
struct cast_bitwise_vector_impl<VECTOR_CAST_TYPE_COMBINE2> {
template<class T, class R> SIMDPP_INL static
void cast(const T& t, R& r)
void cast(const T& t, R& r, std::integral_constant<unsigned, VECTOR_CAST_TYPE_COMBINE2> c)
{
using NativeT = typename T::base_vector_type::native_type;
using NativeR = typename R::base_vector_type::native_type;
using CastImpl = native_cast_combine<sizeof(NativeR), NativeT, NativeR>;

for (unsigned i = 0; i < R::vec_length; ++i) {
r.vec(i) = CastImpl::cast(t.vec(i*2).native(),
t.vec(i*2+1).native());
}
r.template vec<I>() = CastImpl::cast(t.template vec<I*2>().native(),
t.template vec<I*2+1>().native());
cast_bitwise_vector_impl<I + 1, MaxI>::cast(t, r, c);
}
};

template<unsigned I>
struct cast_bitwise_vector_impl<I, I> {
template<class T, class R> SIMDPP_INL static
void cast(const T&, R&, std::integral_constant<unsigned, VECTOR_CAST_TYPE_1_TO_1>) {}
template<class T, class R> SIMDPP_INL static
void cast(const T&, R&, std::integral_constant<unsigned, VECTOR_CAST_TYPE_SPLIT2>) {}
template<class T, class R> SIMDPP_INL static
void cast(const T&, R&, std::integral_constant<unsigned, VECTOR_CAST_TYPE_COMBINE2>) {}
};


template<class T, class R> SIMDPP_INL
void cast_bitwise_vector(const T& t, R& r)
{
static_assert(sizeof(R) == sizeof(T), "Size mismatch");
const unsigned vector_cast_type =
constexpr unsigned vector_cast_type =
T::vec_length == R::vec_length ? VECTOR_CAST_TYPE_1_TO_1 :
T::vec_length == R::vec_length*2 ? VECTOR_CAST_TYPE_COMBINE2 :
T::vec_length*2 == R::vec_length ? VECTOR_CAST_TYPE_SPLIT2 :
VECTOR_CAST_TYPE_INVALID;

cast_bitwise_vector_impl<vector_cast_type>::cast(t, r);
constexpr unsigned iteration_count =
vector_cast_type == VECTOR_CAST_TYPE_COMBINE2 ? R::vec_length : T::vec_length;

using selector = std::integral_constant<unsigned, vector_cast_type>;

cast_bitwise_vector_impl<0, iteration_count>::cast(t, r, selector());
}

#if (__GNUC__ >= 6) && !defined(__INTEL_COMPILER) && !defined(__clang__)
Expand Down

0 comments on commit 4970d73

Please sign in to comment.