Skip to content

Commit

Permalink
Added dimension deduction for -1 shape parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
spectre-ns committed Nov 15, 2023
1 parent 4beafb6 commit e823ea6
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 19 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ jobs:
ninja
- name: Configure using CMake
run: cmake -Bbuild -DCMAKE_INSTALL_PREFIX=$CONDA_PREFIX -DBUILD_TESTS=ON -G Ninja
run: cmake -Bbuild -DCMAKE_BUILD_TYPE:STRING=Release -DCMAKE_INSTALL_PREFIX=$CONDA_PREFIX -DBUILD_TESTS=ON -G Ninja

- name: Install
working-directory: build
Expand Down
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# The full license is in the file LICENSE, distributed with this software. #
############################################################################

cmake_minimum_required(VERSION 3.1)
cmake_minimum_required(VERSION 3.5)
project(xtensor CXX)

set(XTENSOR_INCLUDE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/include)
Expand Down
2 changes: 1 addition & 1 deletion benchmark/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# The full license is in the file LICENSE, distributed with this software. #
############################################################################

cmake_minimum_required(VERSION 3.1)
cmake_minimum_required(VERSION 3.5)

if (CMAKE_CURRENT_SOURCE_DIR STREQUAL CMAKE_SOURCE_DIR)
project(xtensor-benchmark)
Expand Down
78 changes: 73 additions & 5 deletions include/xtensor/xstrided_view.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -807,6 +807,66 @@ namespace xt
);
}

namespace detail
{
// if shape is signed do the check
template <class S, class R>
using do_shape_recalculation = std::
enable_if_t<std::is_signed<get_value_type_t<typename std::decay<S>::type>>::value, R>;

// if shape is unsigned pass through
template <class S, class R>
using no_shape_recalculation = std::
enable_if_t<!std::is_signed<get_value_type_t<typename std::decay<S>::type>>::value, R>;

template <typename T>
inline no_shape_recalculation<T, T> make_unsigned_shape(T shape)
{
return shape;
}

template <typename S, typename Enable = void>
struct rebind_shape;

template <class S>
struct rebind_shape<S, std::enable_if_t<!std::is_signed<get_value_type_t<typename std::decay_t<S>>>::value>>
{
using Shape = S;
};

template <class S>
struct rebind_shape<S, std::enable_if_t<std::is_signed<get_value_type_t<typename std::decay_t<S>>>::value>>
{
using Shape = rebind_container_t<size_t, S>;
};

template <class S, do_shape_recalculation<S, bool> = true>
inline auto recalculate_shape_impl(S& shape, size_t size)
{
using value_type = get_value_type_t<typename std::decay_t<S>>;
auto iter = std::find(shape.begin(), shape.end(), -1);
if (iter != std::end(shape))
{
const auto total = std::accumulate(shape.cbegin(), shape.cend(), -1, std::multiplies<int>{});
const auto missing_dimension = size / total;
(*iter) = static_cast<value_type>(missing_dimension);
}
return shape;
}

template <class S, no_shape_recalculation<S, bool> = true>
inline auto recalculate_shape_impl(S& shape, size_t)
{
return shape;
}

template <class S>
inline auto recalculate_shape(S&& shape, size_t size)
{
return recalculate_shape_impl(shape, size);
}
}

template <layout_type L = XTENSOR_DEFAULT_TRAVERSAL, class E, class S>
inline auto reshape_view(E&& e, S&& shape)
{
Expand All @@ -815,18 +875,26 @@ namespace xt
"traversal has to be row or column major"
);

using shape_type = std::decay_t<S>;
get_strides_t<shape_type> strides;
using shape_type = std::decay_t<decltype(shape)>;
using unsigned_shape_type = typename detail::rebind_shape<shape_type>::Shape;
get_strides_t<unsigned_shape_type> strides;

detail::recalculate_shape(shape, e.size());
xt::resize_container(strides, shape.size());
compute_strides(shape, L, strides);
constexpr auto computed_layout = std::decay_t<E>::static_layout == L ? L : layout_type::dynamic;
using view_type = xstrided_view<
xclosure_t<E>,
shape_type,
unsigned_shape_type,
computed_layout,
detail::flat_adaptor_getter<xclosure_t<E>, L>>;
return view_type(std::forward<E>(e), std::forward<S>(shape), std::move(strides), 0, e.layout());
return view_type(
std::forward<E>(e),
xtl::forward_sequence<unsigned_shape_type, S>(shape),
std::move(strides),
0,
e.layout()
);
}

/**
Expand Down Expand Up @@ -858,7 +926,7 @@ namespace xt
template <layout_type L = XTENSOR_DEFAULT_TRAVERSAL, class E, class I, std::size_t N>
inline auto reshape_view(E&& e, const I (&shape)[N])
{
using shape_type = std::array<std::size_t, N>;
using shape_type = std::array<I, N>;
return reshape_view<L>(std::forward<E>(e), xtl::forward_sequence<shape_type, decltype(shape)>(shape));
}
}
Expand Down
2 changes: 1 addition & 1 deletion test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# The full license is in the file LICENSE, distributed with this software. #
############################################################################

cmake_minimum_required(VERSION 3.1)
cmake_minimum_required(VERSION 3.5)

if (CMAKE_CURRENT_SOURCE_DIR STREQUAL CMAKE_SOURCE_DIR)
project(xtensor-test CXX)
Expand Down
27 changes: 17 additions & 10 deletions test/test_xstrided_view.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -696,24 +696,31 @@ namespace xt
EXPECT_EQ(av, e);
EXPECT_EQ(av, a);

bool truthy;
truthy = std::is_same<
typename decltype(xv)::temporary_type,
xtensor_fixed<double, xshape<3, 3>, XTENSOR_DEFAULT_LAYOUT>>();
EXPECT_TRUE(truthy);

truthy = std::is_same<typename decltype(av)::temporary_type, xtensor<double, 2, XTENSOR_DEFAULT_LAYOUT>>(
static_assert(
std::is_same<
typename decltype(xv)::temporary_type,
xtensor_fixed<double, xshape<3, 3>, XTENSOR_DEFAULT_LAYOUT>>::value,
"Container types do not match"
);
static_assert(
std::is_same<typename decltype(av)::temporary_type, xtensor<double, 2, XTENSOR_DEFAULT_LAYOUT>>::value,
"Container types do not match"
);
static_assert(
std::is_same<typename decltype(av)::shape_type, typename decltype(e)::shape_type>::value,
"Shape types do not match"
);
EXPECT_TRUE(truthy);
truthy = std::is_same<typename decltype(av)::shape_type, typename decltype(e)::shape_type>::value;
EXPECT_TRUE(truthy);

xarray<int> xa = {{1, 2, 3}, {4, 5, 6}};
std::vector<std::size_t> new_shape = {3, 2};
auto xrv = reshape_view(xa, new_shape);

xarray<int> xres = {{1, 2}, {3, 4}, {5, 6}};
EXPECT_EQ(xrv, xres);

auto nv = xt::reshape_view<XTENSOR_DEFAULT_LAYOUT>(a, {-1, 3});
std::vector<size_t> expected_shape({3, 3});
EXPECT_TRUE(std::equal(nv.shape().begin(), nv.shape().end(), expected_shape.begin()));
}

TEST(xstrided_view, reshape_view_assign)
Expand Down

0 comments on commit e823ea6

Please sign in to comment.