Skip to content

Commit

Permalink
Adding tests (Issue #1805)
Browse files Browse the repository at this point in the history
  • Loading branch information
bbbales2 committed Mar 3, 2021
1 parent efc64cb commit ee7568a
Show file tree
Hide file tree
Showing 7 changed files with 223 additions and 24 deletions.
18 changes: 10 additions & 8 deletions stan/math/prim/fun/offset_multiplier_free.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ namespace stan {
namespace math {

/**
* Return the unconstrained scalar that transforms to the
* specified offset and multiplier constrained scalar given the specified
* Return the unconstrained variable that transforms to the
* specified offset and multiplier constrained variable given the specified
* offset and multiplier.
*
* <p>The transform in <code>locmultiplier_constrain(T, double, double)</code>,
Expand All @@ -26,24 +26,26 @@ namespace math {
* <p>If the offset is zero and multiplier is one,
* this function reduces to <code>identity_free(y)</code>.
*
* @tparam T type of scalar
* @tparam T type of constrained variable
* @tparam L type of offset
* @tparam S type of multiplier
* @param y constrained value
* @param[in] mu offset of constrained output
* @param[in] sigma multiplier of constrained output
* @return the free scalar that transforms to the input scalar
* given the offset and multiplier
* @return the unconstrained variable that transforms to the given constrained
* variable given the offset and multiplier
* @throw std::domain_error if sigma <= 0
* @throw std::domain_error if mu is not finite
* @throw std::invalid_argument if non-scalar arguments don't match in size
*/
template <typename T, typename L, typename S>
inline auto offset_multiplier_free(const T& y, const L& mu, const S& sigma) {
const char* function = "offset_multiplier_free";
auto&& mu_ref = to_ref(mu);
auto&& sigma_ref = to_ref(sigma);
check_finite("offset_multiplier_free", "offset", value_of(mu_ref));
check_positive_finite("offset_multiplier_free", "multiplier",
value_of(sigma_ref));
check_consistent_sizes(function, "offset", mu, "multiplier", sigma, "parameter", y);
check_finite(function, "offset", value_of(mu_ref));
check_positive_finite(function, "multiplier", value_of(sigma_ref));
return divide(subtract(y, mu_ref), sigma_ref);
}

Expand Down
File renamed without changes.
36 changes: 36 additions & 0 deletions test/unit/math/mix/fun/fma_2_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#include <test/unit/math/test_ad.hpp>
#include <limits>

TEST(mathMixScalFun, fma_vector) {
auto f = [](const auto& x1, const auto& x2, const auto& x3) {
return stan::math::fma(x1, x2, x3);
};

double xd = 1.0;
Eigen::VectorXd xv(2);
xv << 1.0, 2.0;

double yd = 2.0;
Eigen::VectorXd yv(2);
yv << 2.0, -3.0;

double zd = 3.0;
Eigen::VectorXd zv(2);
zv << -1.0, 2.0;

stan::test::expect_ad(f, xd, yd, zv);
stan::test::expect_ad(f, xd, yv, zd);
stan::test::expect_ad(f, xd, yv, zv);
stan::test::expect_ad(f, xv, yd, zd);
stan::test::expect_ad(f, xv, yd, zv);
stan::test::expect_ad(f, xv, yv, zd);
stan::test::expect_ad(f, xv, yv, zv);

stan::test::expect_ad_matvar(f, xd, yd, zv);
stan::test::expect_ad_matvar(f, xd, yv, zd);
stan::test::expect_ad_matvar(f, xd, yv, zv);
stan::test::expect_ad_matvar(f, xv, yd, zd);
stan::test::expect_ad_matvar(f, xv, yd, zv);
stan::test::expect_ad_matvar(f, xv, yv, zd);
stan::test::expect_ad_matvar(f, xv, yv, zv);
}
36 changes: 36 additions & 0 deletions test/unit/math/mix/fun/fma_3_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#include <test/unit/math/test_ad.hpp>
#include <limits>

TEST(mathMixScalFun, fma_vector) {
auto f = [](const auto& x1, const auto& x2, const auto& x3) {
return stan::math::fma(x1, x2, x3);
};

double xd = 1.0;
Eigen::RowVectorXd xr(2);
xr << 1.0, 2.0;

double yd = 2.0;
Eigen::RowVectorXd yr(2);
yr << 2.0, -3.0;

double zd = 3.0;
Eigen::RowVectorXd zr(2);
zr << -1.0, 2.0;

stan::test::expect_ad(f, xd, yd, zr);
stan::test::expect_ad(f, xd, yr, zd);
stan::test::expect_ad(f, xd, yr, zr);
stan::test::expect_ad(f, xr, yd, zd);
stan::test::expect_ad(f, xr, yd, zr);
stan::test::expect_ad(f, xr, yr, zd);
stan::test::expect_ad(f, xr, yr, zr);

stan::test::expect_ad_matvar(f, xd, yd, zr);
stan::test::expect_ad_matvar(f, xd, yr, zd);
stan::test::expect_ad_matvar(f, xd, yr, zr);
stan::test::expect_ad_matvar(f, xr, yd, zd);
stan::test::expect_ad_matvar(f, xr, yd, zr);
stan::test::expect_ad_matvar(f, xr, yr, zd);
stan::test::expect_ad_matvar(f, xr, yr, zr);
}
39 changes: 39 additions & 0 deletions test/unit/math/mix/fun/fma_4_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#include <test/unit/math/test_ad.hpp>
#include <limits>

TEST(mathMixScalFun, fma_vector) {
auto f = [](const auto& x1, const auto& x2, const auto& x3) {
return stan::math::fma(x1, x2, x3);
};

double xd = 1.0;
Eigen::MatrixXd xm(2, 2);
xm << 1.0, 2.0,
-1.0, 1.1;

double yd = 2.0;
Eigen::MatrixXd ym(2, 2);
xm << 1.0, 2.0,
-1.0, 1.1;

double zd = 3.0;
Eigen::MatrixXd zm(2, 2);
xm << 1.0, 2.0,
-1.0, 1.1;

stan::test::expect_ad(f, xd, yd, zm);
stan::test::expect_ad(f, xd, ym, zd);
stan::test::expect_ad(f, xd, ym, zm);
stan::test::expect_ad(f, xm, yd, zd);
stan::test::expect_ad(f, xm, yd, zm);
stan::test::expect_ad(f, xm, ym, zd);
stan::test::expect_ad(f, xm, ym, zm);

stan::test::expect_ad_matvar(f, xd, yd, zm);
stan::test::expect_ad_matvar(f, xd, ym, zd);
stan::test::expect_ad_matvar(f, xd, ym, zm);
stan::test::expect_ad_matvar(f, xm, yd, zd);
stan::test::expect_ad_matvar(f, xm, yd, zm);
stan::test::expect_ad_matvar(f, xm, ym, zd);
stan::test::expect_ad_matvar(f, xm, ym, zm);
}
58 changes: 58 additions & 0 deletions test/unit/math/prim/fun/fma_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <gtest/gtest.h>
#include <cmath>
#include <limits>
#include <test/unit/util.hpp>

// this is just testing the nan behavior of the built-in fma
// there is no longer a stan::math::fma, just the agrad versions
Expand Down Expand Up @@ -33,3 +34,60 @@ TEST(MathFunctions, fma_nan) {

EXPECT_TRUE(std::isnan(fma(nan, nan, nan)));
}

TEST(MathFunctions, fma_matrix) {
using stan::math::fma;
using stan::math::elt_multiply;
using stan::math::add;

double xd = 1.0;
Eigen::VectorXd xv(2);
xv << 1.0, 2.0;
Eigen::RowVectorXd xr(2);
xr << 1.0, 2.0;
Eigen::MatrixXd xm(2, 2);
xm << 1.0, 2.0,
-1.0, 1.1;

double yd = 2.0;
Eigen::VectorXd yv(2);
yv << 2.0, -3.0;
Eigen::RowVectorXd yr(2);
yr << 2.0, -3.0;
Eigen::MatrixXd ym(2, 2);
xm << 1.0, 2.0,
-1.0, 1.1;

double zd = 3.0;
Eigen::VectorXd zv(2);
zv << -3.0, 4.0;
Eigen::RowVectorXd zr(2);
zr << -3.0, 4.0;
Eigen::MatrixXd zm(2, 2);
xm << 3.0, 4.0,
-1.0, 1.1;

EXPECT_MATRIX_EQ(add(elt_multiply(xd, yd), zv), fma(xd, yd, zv));
EXPECT_MATRIX_EQ(add(elt_multiply(xd, yv), zd), fma(xd, yv, zd));
EXPECT_MATRIX_EQ(add(elt_multiply(xd, yv), zv), fma(xd, yv, zv));
EXPECT_MATRIX_EQ(add(elt_multiply(xv, yd), zd), fma(xv, yd, zd));
EXPECT_MATRIX_EQ(add(elt_multiply(xv, yd), zv), fma(xv, yd, zv));
EXPECT_MATRIX_EQ(add(elt_multiply(xv, yv), zd), fma(xv, yv, zd));
EXPECT_MATRIX_EQ(add(elt_multiply(xv, yv), zv), fma(xv, yv, zv));

EXPECT_MATRIX_EQ(add(elt_multiply(xd, yd), zr), fma(xd, yd, zr));
EXPECT_MATRIX_EQ(add(elt_multiply(xd, yr), zd), fma(xd, yr, zd));
EXPECT_MATRIX_EQ(add(elt_multiply(xd, yr), zr), fma(xd, yr, zr));
EXPECT_MATRIX_EQ(add(elt_multiply(xr, yd), zd), fma(xr, yd, zd));
EXPECT_MATRIX_EQ(add(elt_multiply(xr, yd), zr), fma(xr, yd, zr));
EXPECT_MATRIX_EQ(add(elt_multiply(xr, yr), zd), fma(xr, yr, zd));
EXPECT_MATRIX_EQ(add(elt_multiply(xr, yr), zr), fma(xr, yr, zr));

EXPECT_MATRIX_EQ(add(elt_multiply(xd, yd), zm), fma(xd, yd, zm));
EXPECT_MATRIX_EQ(add(elt_multiply(xd, ym), zd), fma(xd, ym, zd));
EXPECT_MATRIX_EQ(add(elt_multiply(xd, ym), zm), fma(xd, ym, zm));
EXPECT_MATRIX_EQ(add(elt_multiply(xm, yd), zd), fma(xm, yd, zd));
EXPECT_MATRIX_EQ(add(elt_multiply(xm, yd), zm), fma(xm, yd, zm));
EXPECT_MATRIX_EQ(add(elt_multiply(xm, ym), zd), fma(xm, ym, zd));
EXPECT_MATRIX_EQ(add(elt_multiply(xm, ym), zm), fma(xm, ym, zm));
}
Loading

0 comments on commit ee7568a

Please sign in to comment.