diff --git a/include/xtensor/xstrided_view.hpp b/include/xtensor/xstrided_view.hpp index 35689774c..f16ace56d 100644 --- a/include/xtensor/xstrided_view.hpp +++ b/include/xtensor/xstrided_view.hpp @@ -807,6 +807,53 @@ namespace xt ); } + namespace detail + { + template + struct rebind_shape; + + template + struct rebind_shape> + { + using type = xt::fixed_shape; + }; + + template + struct rebind_shape + { + using type = rebind_container_t; + }; + + template < + class S, + std::enable_if_t::type>>::value, bool> = true> + inline void recalculate_shape_impl(S& shape, size_t size) + { + using value_type = get_value_type_t>; + XTENSOR_ASSERT(std::count(shape.cbegin(), shape.cend(), -1) <= 1); + auto iter = std::find(shape.begin(), shape.end(), -1); + if (iter != std::end(shape)) + { + const auto total = std::accumulate(shape.cbegin(), shape.cend(), -1, std::multiplies{}); + const auto missing_dimension = size / total; + (*iter) = static_cast(missing_dimension); + } + } + + template < + class S, + std::enable_if_t::type>>::value, bool> = true> + inline void recalculate_shape_impl(S&, size_t) + { + } + + template + inline auto recalculate_shape(S&& shape, size_t size) + { + return recalculate_shape_impl(shape, size); + } + } + template inline auto reshape_view(E&& e, S&& shape) { @@ -815,18 +862,26 @@ namespace xt "traversal has to be row or column major" ); - using shape_type = std::decay_t; - get_strides_t strides; + using shape_type = std::decay_t; + using unsigned_shape_type = typename detail::rebind_shape::type; + get_strides_t strides; + detail::recalculate_shape(shape, e.size()); xt::resize_container(strides, shape.size()); compute_strides(shape, L, strides); constexpr auto computed_layout = std::decay_t::static_layout == L ? L : layout_type::dynamic; using view_type = xstrided_view< xclosure_t, - shape_type, + unsigned_shape_type, computed_layout, detail::flat_adaptor_getter, L>>; - return view_type(std::forward(e), std::forward(shape), std::move(strides), 0, e.layout()); + return view_type( + std::forward(e), + xtl::forward_sequence(shape), + std::move(strides), + 0, + e.layout() + ); } /** @@ -858,7 +913,7 @@ namespace xt template inline auto reshape_view(E&& e, const I (&shape)[N]) { - using shape_type = std::array; + using shape_type = std::array; return reshape_view(std::forward(e), xtl::forward_sequence(shape)); } } diff --git a/test/test_xstrided_view.cpp b/test/test_xstrided_view.cpp index d327ba703..157a29a89 100644 --- a/test/test_xstrided_view.cpp +++ b/test/test_xstrided_view.cpp @@ -696,17 +696,20 @@ namespace xt EXPECT_EQ(av, e); EXPECT_EQ(av, a); - bool truthy; - truthy = std::is_same< - typename decltype(xv)::temporary_type, - xtensor_fixed, XTENSOR_DEFAULT_LAYOUT>>(); - EXPECT_TRUE(truthy); - - truthy = std::is_same>( + static_assert( + std::is_same< + typename decltype(xv)::temporary_type, + xtensor_fixed, XTENSOR_DEFAULT_LAYOUT>>::value, + "Container types do not match" + ); + static_assert( + std::is_same>::value, + "Container types do not match" + ); + static_assert( + std::is_same::value, + "Shape types do not match" ); - EXPECT_TRUE(truthy); - truthy = std::is_same::value; - EXPECT_TRUE(truthy); xarray xa = {{1, 2, 3}, {4, 5, 6}}; std::vector new_shape = {3, 2}; @@ -714,6 +717,10 @@ namespace xt xarray xres = {{1, 2}, {3, 4}, {5, 6}}; EXPECT_EQ(xrv, xres); + + auto nv = xt::reshape_view(a, {-1, 3}); + std::vector expected_shape({3, 3}); + EXPECT_TRUE(std::equal(nv.shape().begin(), nv.shape().end(), expected_shape.begin())); } TEST(xstrided_view, reshape_view_assign)