From e823ea6c75c3f031402c63d6fc2cff3af42274a0 Mon Sep 17 00:00:00 2001 From: Drew Hubley Date: Tue, 14 Nov 2023 21:03:23 -0400 Subject: [PATCH 1/6] Added dimension deduction for -1 shape parameters --- .github/workflows/ci.yml | 2 +- CMakeLists.txt | 2 +- benchmark/CMakeLists.txt | 2 +- include/xtensor/xstrided_view.hpp | 78 +++++++++++++++++++++++++++++-- test/CMakeLists.txt | 2 +- test/test_xstrided_view.cpp | 27 +++++++---- 6 files changed, 94 insertions(+), 19 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c87df8d70..ae11f8c5c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -170,7 +170,7 @@ jobs: ninja - name: Configure using CMake - run: cmake -Bbuild -DCMAKE_INSTALL_PREFIX=$CONDA_PREFIX -DBUILD_TESTS=ON -G Ninja + run: cmake -Bbuild -DCMAKE_BUILD_TYPE:STRING=Release -DCMAKE_INSTALL_PREFIX=$CONDA_PREFIX -DBUILD_TESTS=ON -G Ninja - name: Install working-directory: build diff --git a/CMakeLists.txt b/CMakeLists.txt index 693be0142..92c886dfc 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -7,7 +7,7 @@ # The full license is in the file LICENSE, distributed with this software. # ############################################################################ -cmake_minimum_required(VERSION 3.1) +cmake_minimum_required(VERSION 3.5) project(xtensor CXX) set(XTENSOR_INCLUDE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/include) diff --git a/benchmark/CMakeLists.txt b/benchmark/CMakeLists.txt index 412d03644..c5fa3c407 100644 --- a/benchmark/CMakeLists.txt +++ b/benchmark/CMakeLists.txt @@ -6,7 +6,7 @@ # The full license is in the file LICENSE, distributed with this software. # ############################################################################ -cmake_minimum_required(VERSION 3.1) +cmake_minimum_required(VERSION 3.5) if (CMAKE_CURRENT_SOURCE_DIR STREQUAL CMAKE_SOURCE_DIR) project(xtensor-benchmark) diff --git a/include/xtensor/xstrided_view.hpp b/include/xtensor/xstrided_view.hpp index 35689774c..2b003dea3 100644 --- a/include/xtensor/xstrided_view.hpp +++ b/include/xtensor/xstrided_view.hpp @@ -807,6 +807,66 @@ namespace xt ); } + namespace detail + { + // if shape is signed do the check + template + using do_shape_recalculation = std:: + enable_if_t::type>>::value, R>; + + // if shape is unsigned pass through + template + using no_shape_recalculation = std:: + enable_if_t::type>>::value, R>; + + template + inline no_shape_recalculation make_unsigned_shape(T shape) + { + return shape; + } + + template + struct rebind_shape; + + template + struct rebind_shape>>::value>> + { + using Shape = S; + }; + + template + struct rebind_shape>>::value>> + { + using Shape = rebind_container_t; + }; + + template = true> + inline auto recalculate_shape_impl(S& shape, size_t size) + { + using value_type = get_value_type_t>; + auto iter = std::find(shape.begin(), shape.end(), -1); + if (iter != std::end(shape)) + { + const auto total = std::accumulate(shape.cbegin(), shape.cend(), -1, std::multiplies{}); + const auto missing_dimension = size / total; + (*iter) = static_cast(missing_dimension); + } + return shape; + } + + template = true> + inline auto recalculate_shape_impl(S& shape, size_t) + { + return shape; + } + + template + inline auto recalculate_shape(S&& shape, size_t size) + { + return recalculate_shape_impl(shape, size); + } + } + template inline auto reshape_view(E&& e, S&& shape) { @@ -815,18 +875,26 @@ namespace xt "traversal has to be row or column major" ); - using shape_type = std::decay_t; - get_strides_t strides; + using shape_type = std::decay_t; + using unsigned_shape_type = typename detail::rebind_shape::Shape; + get_strides_t strides; + detail::recalculate_shape(shape, e.size()); xt::resize_container(strides, shape.size()); compute_strides(shape, L, strides); constexpr auto computed_layout = std::decay_t::static_layout == L ? L : layout_type::dynamic; using view_type = xstrided_view< xclosure_t, - shape_type, + unsigned_shape_type, computed_layout, detail::flat_adaptor_getter, L>>; - return view_type(std::forward(e), std::forward(shape), std::move(strides), 0, e.layout()); + return view_type( + std::forward(e), + xtl::forward_sequence(shape), + std::move(strides), + 0, + e.layout() + ); } /** @@ -858,7 +926,7 @@ namespace xt template inline auto reshape_view(E&& e, const I (&shape)[N]) { - using shape_type = std::array; + using shape_type = std::array; return reshape_view(std::forward(e), xtl::forward_sequence(shape)); } } diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index e6331f759..46127597e 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -7,7 +7,7 @@ # The full license is in the file LICENSE, distributed with this software. # ############################################################################ -cmake_minimum_required(VERSION 3.1) +cmake_minimum_required(VERSION 3.5) if (CMAKE_CURRENT_SOURCE_DIR STREQUAL CMAKE_SOURCE_DIR) project(xtensor-test CXX) diff --git a/test/test_xstrided_view.cpp b/test/test_xstrided_view.cpp index d327ba703..157a29a89 100644 --- a/test/test_xstrided_view.cpp +++ b/test/test_xstrided_view.cpp @@ -696,17 +696,20 @@ namespace xt EXPECT_EQ(av, e); EXPECT_EQ(av, a); - bool truthy; - truthy = std::is_same< - typename decltype(xv)::temporary_type, - xtensor_fixed, XTENSOR_DEFAULT_LAYOUT>>(); - EXPECT_TRUE(truthy); - - truthy = std::is_same>( + static_assert( + std::is_same< + typename decltype(xv)::temporary_type, + xtensor_fixed, XTENSOR_DEFAULT_LAYOUT>>::value, + "Container types do not match" + ); + static_assert( + std::is_same>::value, + "Container types do not match" + ); + static_assert( + std::is_same::value, + "Shape types do not match" ); - EXPECT_TRUE(truthy); - truthy = std::is_same::value; - EXPECT_TRUE(truthy); xarray xa = {{1, 2, 3}, {4, 5, 6}}; std::vector new_shape = {3, 2}; @@ -714,6 +717,10 @@ namespace xt xarray xres = {{1, 2}, {3, 4}, {5, 6}}; EXPECT_EQ(xrv, xres); + + auto nv = xt::reshape_view(a, {-1, 3}); + std::vector expected_shape({3, 3}); + EXPECT_TRUE(std::equal(nv.shape().begin(), nv.shape().end(), expected_shape.begin())); } TEST(xstrided_view, reshape_view_assign) From a35b0a2051b6232d427205b537829ff4f17e7280 Mon Sep 17 00:00:00 2001 From: Drew Hubley <96780897+spectre-ns@users.noreply.github.com> Date: Wed, 15 Nov 2023 08:29:45 -0400 Subject: [PATCH 2/6] Update include/xtensor/xstrided_view.hpp Co-authored-by: Tom de Geus --- include/xtensor/xstrided_view.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/xtensor/xstrided_view.hpp b/include/xtensor/xstrided_view.hpp index 2b003dea3..ab122a50f 100644 --- a/include/xtensor/xstrided_view.hpp +++ b/include/xtensor/xstrided_view.hpp @@ -831,7 +831,7 @@ namespace xt template struct rebind_shape>>::value>> { - using Shape = S; + using type = S; }; template From 4e67be40f9a43947a2b292fad60579106ca3d803 Mon Sep 17 00:00:00 2001 From: Drew Hubley Date: Wed, 15 Nov 2023 09:49:28 -0400 Subject: [PATCH 3/6] Addressed comments --- include/xtensor/xstrided_view.hpp | 28 ++++++++-------------------- 1 file changed, 8 insertions(+), 20 deletions(-) diff --git a/include/xtensor/xstrided_view.hpp b/include/xtensor/xstrided_view.hpp index ab122a50f..f688a4138 100644 --- a/include/xtensor/xstrided_view.hpp +++ b/include/xtensor/xstrided_view.hpp @@ -809,22 +809,6 @@ namespace xt namespace detail { - // if shape is signed do the check - template - using do_shape_recalculation = std:: - enable_if_t::type>>::value, R>; - - // if shape is unsigned pass through - template - using no_shape_recalculation = std:: - enable_if_t::type>>::value, R>; - - template - inline no_shape_recalculation make_unsigned_shape(T shape) - { - return shape; - } - template struct rebind_shape; @@ -837,13 +821,16 @@ namespace xt template struct rebind_shape>>::value>> { - using Shape = rebind_container_t; + using type = rebind_container_t; }; - template = true> + template ::type>>::value, bool> = true> inline auto recalculate_shape_impl(S& shape, size_t size) { using value_type = get_value_type_t>; + const auto num_auto_dims = std::count(shape.cbegin(), shape.cend(), -1); + XTENSOR_ASSERT(num_auto_dims <= 1); auto iter = std::find(shape.begin(), shape.end(), -1); if (iter != std::end(shape)) { @@ -854,7 +841,8 @@ namespace xt return shape; } - template = true> + template ::type>>::value, bool> = true> inline auto recalculate_shape_impl(S& shape, size_t) { return shape; @@ -876,7 +864,7 @@ namespace xt ); using shape_type = std::decay_t; - using unsigned_shape_type = typename detail::rebind_shape::Shape; + using unsigned_shape_type = typename detail::rebind_shape::type; get_strides_t strides; detail::recalculate_shape(shape, e.size()); From d0506d86a597e3d1ef8895b499bb60cfcdc3888d Mon Sep 17 00:00:00 2001 From: Drew Hubley Date: Wed, 15 Nov 2023 09:50:45 -0400 Subject: [PATCH 4/6] Ran pre-commit --- include/xtensor/xstrided_view.hpp | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/include/xtensor/xstrided_view.hpp b/include/xtensor/xstrided_view.hpp index f688a4138..ef276e23f 100644 --- a/include/xtensor/xstrided_view.hpp +++ b/include/xtensor/xstrided_view.hpp @@ -824,8 +824,9 @@ namespace xt using type = rebind_container_t; }; - template ::type>>::value, bool> = true> + template < + class S, + std::enable_if_t::type>>::value, bool> = true> inline auto recalculate_shape_impl(S& shape, size_t size) { using value_type = get_value_type_t>; @@ -841,8 +842,9 @@ namespace xt return shape; } - template ::type>>::value, bool> = true> + template < + class S, + std::enable_if_t::type>>::value, bool> = true> inline auto recalculate_shape_impl(S& shape, size_t) { return shape; From 444dbe5c31efe1a133491e2bf77ba587cc7dd81e Mon Sep 17 00:00:00 2001 From: Drew Hubley Date: Thu, 16 Nov 2023 08:17:51 -0400 Subject: [PATCH 5/6] Addressed comments --- include/xtensor/xstrided_view.hpp | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/include/xtensor/xstrided_view.hpp b/include/xtensor/xstrided_view.hpp index ef276e23f..c4a187a80 100644 --- a/include/xtensor/xstrided_view.hpp +++ b/include/xtensor/xstrided_view.hpp @@ -827,11 +827,10 @@ namespace xt template < class S, std::enable_if_t::type>>::value, bool> = true> - inline auto recalculate_shape_impl(S& shape, size_t size) + inline void recalculate_shape_impl(S& shape, size_t size) { using value_type = get_value_type_t>; - const auto num_auto_dims = std::count(shape.cbegin(), shape.cend(), -1); - XTENSOR_ASSERT(num_auto_dims <= 1); + XTENSOR_ASSERT(std::count(shape.cbegin(), shape.cend(), -1) <= 1); auto iter = std::find(shape.begin(), shape.end(), -1); if (iter != std::end(shape)) { @@ -839,15 +838,13 @@ namespace xt const auto missing_dimension = size / total; (*iter) = static_cast(missing_dimension); } - return shape; } template < class S, std::enable_if_t::type>>::value, bool> = true> - inline auto recalculate_shape_impl(S& shape, size_t) + inline void recalculate_shape_impl(S&, size_t) { - return shape; } template From aaa819e418f30e27d5eb3ee2084fc41c58c7f28b Mon Sep 17 00:00:00 2001 From: Drew Hubley Date: Thu, 16 Nov 2023 14:05:42 -0400 Subject: [PATCH 6/6] Specialized for fixed_shape --- include/xtensor/xstrided_view.hpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/include/xtensor/xstrided_view.hpp b/include/xtensor/xstrided_view.hpp index c4a187a80..f16ace56d 100644 --- a/include/xtensor/xstrided_view.hpp +++ b/include/xtensor/xstrided_view.hpp @@ -809,17 +809,17 @@ namespace xt namespace detail { - template + template struct rebind_shape; - template - struct rebind_shape>>::value>> + template + struct rebind_shape> { - using type = S; + using type = xt::fixed_shape; }; template - struct rebind_shape>>::value>> + struct rebind_shape { using type = rebind_container_t; };