From 0d8a6c599e4681d78ef17c77997e440dc6f8c246 Mon Sep 17 00:00:00 2001 From: Andrew Johnson Date: Wed, 19 Jun 2024 14:25:36 +0200 Subject: [PATCH] Add overloads for tuples of bounds to lub_free and lub_constrain --- stan/math/prim/fun/lub_constrain.hpp | 27 +++++++++++++++++++ stan/math/prim/fun/lub_free.hpp | 8 ++++++ .../math/mix/fun/lub_constrain_helpers.hpp | 20 ++++++++++++++ 3 files changed, 55 insertions(+) diff --git a/stan/math/prim/fun/lub_constrain.hpp b/stan/math/prim/fun/lub_constrain.hpp index a3e91883f51..bd358911c1e 100644 --- a/stan/math/prim/fun/lub_constrain.hpp +++ b/stan/math/prim/fun/lub_constrain.hpp @@ -400,6 +400,33 @@ inline auto lub_constrain(const T& x, const L& lb, const U& ub, } } +/** + * Wrapper for tuple of bounds, simply delegates to the appropriate overload + */ +template +inline auto lub_constrain(const T& x, const std::tuple& bounds) { + return lub_constrain(x, std::get<0>(bounds), std::get<1>(bounds)); +} + +/** + * Wrapper for tuple of bounds, simply delegates to the appropriate overload + */ +template +inline auto lub_constrain(const T& x, const std::tuple& bounds, + return_type_t& lp) { + return lub_constrain(x, std::get<0>(bounds), std::get<1>(bounds), lp); +} + +/** + * Wrapper for tuple of bounds, simply delegates to the appropriate overload + */ +template +inline auto lub_constrain(const T& x, const std::tuple& bounds, + return_type_t& lp) { + return lub_constrain(x, std::get<0>(bounds), std::get<1>(bounds), + lp); +} + } // namespace math } // namespace stan diff --git a/stan/math/prim/fun/lub_free.hpp b/stan/math/prim/fun/lub_free.hpp index c4876cbcd6d..db40ac73d82 100644 --- a/stan/math/prim/fun/lub_free.hpp +++ b/stan/math/prim/fun/lub_free.hpp @@ -180,6 +180,14 @@ inline auto lub_free(const std::vector y, const std::vector& lb, } return ret; } + +/** + * Wrapper for tuple of bounds, simply delegates to the appropriate overload + */ +template +inline auto lub_free(T&& y, const std::tuple& bounds) { + return lub_free(std::forward(y), std::get<0>(bounds), std::get<1>(bounds)); +} ///@} } // namespace math diff --git a/test/unit/math/mix/fun/lub_constrain_helpers.hpp b/test/unit/math/mix/fun/lub_constrain_helpers.hpp index c12c99fcc9e..80ba021c8ec 100644 --- a/test/unit/math/mix/fun/lub_constrain_helpers.hpp +++ b/test/unit/math/mix/fun/lub_constrain_helpers.hpp @@ -22,11 +22,21 @@ void expect(const T1& x, const T2& lb, const T3& ub) { auto xx = stan::math::lub_constrain(x, lb, ub, lp); return stan::math::add(lp, stan::math::sum(xx)); }; + auto f5 = [](const auto& x, const auto& lb, const auto& ub) { + stan::return_type_t lp = 0; + return stan::math::lub_constrain(x, std::make_tuple(lb, ub), lp); + }; + auto f6 = [](const auto& x, const auto& lb, const auto& ub) { + stan::return_type_t lp = 0; + return stan::math::lub_constrain(x, std::make_tuple(lb, ub), lp); + }; stan::test::expect_ad(f1, x, lb, ub); stan::test::expect_ad(f2, x, lb, ub); stan::test::expect_ad(f3, x, lb, ub); stan::test::expect_ad(f4, x, lb, ub); + stan::test::expect_ad(f5, x, lb, ub); + stan::test::expect_ad(f6, x, lb, ub); } template void expect_vec(const T1& x, const T2& lb, const T3& ub) { @@ -52,11 +62,21 @@ void expect_vec(const T1& x, const T2& lb, const T3& ub) { } return stan::math::add(lp, xx_acc); }; + auto f5 = [](const auto& x, const auto& lb, const auto& ub) { + stan::return_type_t lp = 0; + return stan::math::lub_constrain(x, std::make_tuple(lb, ub), lp); + }; + auto f6 = [](const auto& x, const auto& lb, const auto& ub) { + stan::return_type_t lp = 0; + return stan::math::lub_constrain(x, std::make_tuple(lb, ub), lp); + }; stan::test::expect_ad(f1, x, lb, ub); stan::test::expect_ad(f2, x, lb, ub); stan::test::expect_ad(f3, x, lb, ub); stan::test::expect_ad(f4, x, lb, ub); + stan::test::expect_ad(f5, x, lb, ub); + stan::test::expect_ad(f6, x, lb, ub); } } // namespace lub_constrain_tests