Skip to content

Commit

Permalink
Merge pull request #3048 from stan-dev/fix/csr-matrix-times-vector
Browse files Browse the repository at this point in the history
Fix csr_matrix_times_vector linker error
  • Loading branch information
SteveBronder authored Apr 19, 2024
2 parents 11663a2 + b0815c4 commit 86a3e83
Show file tree
Hide file tree
Showing 12 changed files with 362 additions and 86 deletions.
24 changes: 23 additions & 1 deletion stan/math/prim/fun/value_of.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ inline auto value_of(const T& x) {
* @param[in] M Matrix to be converted
* @return Matrix of values
**/
template <typename EigMat, require_eigen_t<EigMat>* = nullptr,
template <typename EigMat, require_eigen_dense_base_t<EigMat>* = nullptr,
require_not_st_arithmetic<EigMat>* = nullptr>
inline auto value_of(EigMat&& M) {
return make_holder(
Expand All @@ -77,6 +77,28 @@ inline auto value_of(EigMat&& M) {
std::forward<EigMat>(M));
}

template <typename EigMat, require_eigen_sparse_base_t<EigMat>* = nullptr,
require_not_st_arithmetic<EigMat>* = nullptr>
inline auto value_of(EigMat&& M) {
auto&& M_ref = to_ref(M);
using scalar_t = decltype(value_of(std::declval<value_type_t<EigMat>>()));
promote_scalar_t<scalar_t, plain_type_t<EigMat>> ret(M_ref.rows(),
M_ref.cols());
ret.reserve(M_ref.nonZeros());
for (int k = 0; k < M_ref.outerSize(); ++k) {
for (typename std::decay_t<EigMat>::InnerIterator it(M_ref, k); it; ++it) {
ret.insert(it.row(), it.col()) = value_of(it.valueRef());
}
}
ret.makeCompressed();
return ret;
}
template <typename EigMat, require_eigen_sparse_base_t<EigMat>* = nullptr,
require_st_arithmetic<EigMat>* = nullptr>
inline auto value_of(EigMat&& M) {
return std::forward<EigMat>(M);
}

} // namespace math
} // namespace stan

Expand Down
16 changes: 16 additions & 0 deletions stan/math/prim/meta/is_eigen_dense_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,22 @@ using require_eigen_dense_base_t
= require_t<is_eigen_dense_base<std::decay_t<T>>>;
/*! @} */

/*! \ingroup require_eigens_types */
/*! \defgroup eigen_dense_base_types eigen_dense_base_types */
/*! \addtogroup eigen_dense_base_types */
/*! @{ */

/*! \brief Require type satisfies @ref is_eigen_dense_base */
/*! and value type satisfies `TypeCheck` */
/*! @tparam TypeCheck The type trait to check the value type against */
/*! @tparam Check The type to test @ref is_eigen_dense_base for and whose
* @ref value_type is checked with `TypeCheck` */
template <template <class...> class TypeCheck, class... Check>
using require_eigen_dense_base_vt
= require_t<container_type_check_base<is_eigen_dense_base, value_type_t,
TypeCheck, Check...>>;
/*! @} */

} // namespace stan

#endif
14 changes: 13 additions & 1 deletion stan/math/prim/meta/promote_scalar_type.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
#include <stan/math/prim/fun/Eigen.hpp>
#include <stan/math/prim/meta/is_eigen.hpp>
#include <stan/math/prim/meta/is_var.hpp>
#include <stan/math/prim/meta/is_eigen_dense_base.hpp>
#include <stan/math/prim/meta/is_eigen_sparse_base.hpp>
#include <vector>

namespace stan {
Expand Down Expand Up @@ -80,7 +82,7 @@ struct promote_scalar_type<T, S,
* @tparam S input matrix type
*/
template <typename T, typename S>
struct promote_scalar_type<T, S, require_eigen_t<S>> {
struct promote_scalar_type<T, S, require_eigen_dense_base_t<S>> {
/**
* The promoted type.
*/
Expand All @@ -93,6 +95,16 @@ struct promote_scalar_type<T, S, require_eigen_t<S>> {
S::RowsAtCompileTime, S::ColsAtCompileTime>>::type;
};

template <typename T, typename S>
struct promote_scalar_type<T, S, require_eigen_sparse_base_t<S>> {
/**
* The promoted type.
*/
using type = Eigen::SparseMatrix<
typename promote_scalar_type<T, typename S::Scalar>::type, S::Options,
typename S::StorageIndex>;
};

template <typename... PromotionScalars, typename... UnPromotedTypes>
struct promote_scalar_type<std::tuple<PromotionScalars...>,
std::tuple<UnPromotedTypes...>> {
Expand Down
24 changes: 15 additions & 9 deletions stan/math/rev/core/arena_matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
#include <stan/math/rev/core/chainablestack.hpp>
#include <stan/math/rev/core/var_value_fwd_declare.hpp>
#include <stan/math/prim/fun/to_ref.hpp>

namespace stan {
namespace math {

Expand Down Expand Up @@ -225,8 +224,10 @@ class arena_matrix<MatrixType, require_eigen_sparse_base_t<MatrixType>>
*/
arena_matrix(const arena_matrix<MatrixType>& other)
: Base::Map(other.rows(), other.cols(), other.nonZeros(),
other.outerIndexPtr(), other.innerIndexPtr(),
other.valuePtr(), other.innernonZeroPtr()) {}
const_cast<StorageIndex*>(other.outerIndexPtr()),
const_cast<StorageIndex*>(other.innerIndexPtr()),
const_cast<Scalar*>(other.valuePtr()),
const_cast<StorageIndex*>(other.innerNonZeroPtr())) {}
/**
* Move constructor.
* @note Since the memory for the arena matrix sits in Stan's memory arena all
Expand All @@ -235,8 +236,10 @@ class arena_matrix<MatrixType, require_eigen_sparse_base_t<MatrixType>>
*/
arena_matrix(arena_matrix<MatrixType>&& other)
: Base::Map(other.rows(), other.cols(), other.nonZeros(),
other.outerIndexPtr(), other.innerIndexPtr(),
other.valuePtr(), other.innerNonZeroPtr()) {}
const_cast<StorageIndex*>(other.outerIndexPtr()),
const_cast<StorageIndex*>(other.innerIndexPtr()),
const_cast<Scalar*>(other.valuePtr()),
const_cast<StorageIndex*>(other.innerNonZeroPtr())) {}
/**
* Copy constructor. No actual copy is performed
* @note Since the memory for the arena matrix sits in Stan's memory arena all
Expand All @@ -245,8 +248,10 @@ class arena_matrix<MatrixType, require_eigen_sparse_base_t<MatrixType>>
*/
arena_matrix(arena_matrix<MatrixType>& other)
: Base::Map(other.rows(), other.cols(), other.nonZeros(),
other.outerIndexPtr(), other.innerIndexPtr(),
other.valuePtr(), other.innerNonZeroPtr()) {}
const_cast<StorageIndex*>(other.outerIndexPtr()),
const_cast<StorageIndex*>(other.innerIndexPtr()),
const_cast<Scalar*>(other.valuePtr()),
const_cast<StorageIndex*>(other.innerNonZeroPtr())) {}

// without this using, compiler prefers combination of implicit construction
// and copy assignment to the inherited operator when assigned an expression
Expand All @@ -259,7 +264,8 @@ class arena_matrix<MatrixType, require_eigen_sparse_base_t<MatrixType>>
* @return `*this`
*/
template <typename ArenaMatrix,
require_same_t<ArenaMatrix, arena_matrix<MatrixType>>* = nullptr>
require_same_t<std::decay_t<ArenaMatrix>,
arena_matrix<MatrixType>>* = nullptr>
arena_matrix& operator=(ArenaMatrix&& other) {
// placement new changes what data map points to - there is no allocation
new (this) Base(other.rows(), other.cols(), other.nonZeros(),
Expand All @@ -280,7 +286,7 @@ class arena_matrix<MatrixType, require_eigen_sparse_base_t<MatrixType>>
template <typename Expr,
require_not_same_t<Expr, arena_matrix<MatrixType>>* = nullptr>
arena_matrix& operator=(Expr&& expr) {
*this = arena_matrix(std::forward<Expr>(expr));
new (this) arena_matrix(std::forward<Expr>(expr));
return *this;
}

Expand Down
12 changes: 12 additions & 0 deletions stan/math/rev/core/var.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,18 @@ class var_value<T, internal::require_matrix_var_value<T>> {
});
}

/**
* Construct a `var_value` with premade @ref arena_matrix types.
* The values and adjoint matrices passed here will be shallow copied.
* @tparam S type of the value in the `var_value` to assing
* @param val The value matrix to go into the vari
* @param adj the adjoint matrix to go into the vari
*/
template <typename S, typename T_ = T,
require_assignable_t<value_type, S>* = nullptr,
require_arena_matrix_t<S>* = nullptr>
var_value(const S& val, const S& adj) : vi_(new vari_type(val, adj)) {}

/**
* Construct a variable from a pointer to a variable implementation.
* @param vi A vari_value pointer.
Expand Down
32 changes: 23 additions & 9 deletions stan/math/rev/core/vari.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -821,17 +821,16 @@ class vari_value<T, require_eigen_sparse_base_t<T>> : public vari_base {
*/
static constexpr int ColsAtCompileTime = T::ColsAtCompileTime;

/**
* The value of this variable.
*/
arena_matrix<PlainObject> val_;
/**
* The adjoint of this variable, which is the partial derivative
* of this variable with respect to the root variable.
*/
arena_matrix<PlainObject> adj_;

/**
* The value of this variable.
*/
arena_matrix<PlainObject> val_;

/**
* Construct a variable implementation from a value. The
* adjoint is initialized to zero.
Expand All @@ -847,10 +846,21 @@ class vari_value<T, require_eigen_sparse_base_t<T>> : public vari_base {
* @param x Value of the constructed variable.
*/
template <typename S, require_convertible_t<S&, T>* = nullptr>
explicit vari_value(S&& x) : adj_(x), val_(std::forward<S>(x)) {
this->set_zero_adjoint();
explicit vari_value(S&& x)
: val_(std::forward<S>(x)),
adj_(val_.rows(), val_.cols(), val_.nonZeros(), val_.outerIndexPtr(),
val_.innerIndexPtr(),
arena_matrix<Eigen::VectorXd>(val_.nonZeros()).setZero().data(),
val_.innerNonZeroPtr()) {
ChainableStack::instance_->var_stack_.push_back(this);
}

vari_value(const arena_matrix<PlainObject>& val,
const arena_matrix<PlainObject>& adj)
: val_(val), adj_(adj) {
ChainableStack::instance_->var_stack_.push_back(this);
}

/**
* Construct an sparse Eigen variable implementation from a value. The
* adjoint is initialized to zero and if `stacked` is `false` this vari
Expand All @@ -869,8 +879,12 @@ class vari_value<T, require_eigen_sparse_base_t<T>> : public vari_base {
* that its `chain()` method is not called.
*/
template <typename S, require_convertible_t<S&, T>* = nullptr>
vari_value(S&& x, bool stacked) : adj_(x), val_(std::forward<S>(x)) {
this->set_zero_adjoint();
vari_value(S&& x, bool stacked)
: val_(std::forward<S>(x)),
adj_(val_.rows(), val_.cols(), val_.nonZeros(), val_.outerIndexPtr(),
val_.innerIndexPtr(),
arena_matrix<Eigen::VectorXd>(val_.nonZeros()).setZero().data(),
val_.innerNonZeroPtr()) {
if (stacked) {
ChainableStack::instance_->var_stack_.push_back(this);
} else {
Expand Down
1 change: 1 addition & 0 deletions stan/math/rev/fun.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@
#include <stan/math/rev/fun/tgamma.hpp>
#include <stan/math/rev/fun/to_var.hpp>
#include <stan/math/rev/fun/to_arena.hpp>
#include <stan/math/rev/fun/to_soa_sparse_matrix.hpp>
#include <stan/math/rev/fun/to_var_value.hpp>
#include <stan/math/rev/fun/to_vector.hpp>
#include <stan/math/rev/fun/trace.hpp>
Expand Down
81 changes: 19 additions & 62 deletions stan/math/rev/fun/csr_matrix_times_vector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define STAN_MATH_REV_FUN_CSR_MATRIX_TIMES_VECTOR_HPP

#include <stan/math/prim/fun/Eigen.hpp>
#include <stan/math/rev/fun/to_soa_sparse_matrix.hpp>
#include <stan/math/rev/core.hpp>
#include <stan/math/prim/err.hpp>
#include <stan/math/prim/fun/csr_u_to_z.hpp>
Expand All @@ -10,40 +11,6 @@
namespace stan {
namespace math {

namespace internal {
template <typename T1, typename T2, typename Res,
require_eigen_t<T1>* = nullptr>
void update_w(T1& w, int m, int n, std::vector<int, arena_allocator<int>>& u,
std::vector<int, arena_allocator<int>>& v, T2&& b, Res&& res) {
Eigen::Map<Eigen::SparseMatrix<var, Eigen::RowMajor>> w_mat(
m, n, w.size(), u.data(), v.data(), w.data());
for (int k = 0; k < w_mat.outerSize(); ++k) {
for (Eigen::Map<Eigen::SparseMatrix<var, Eigen::RowMajor>>::InnerIterator
it(w_mat, k);
it; ++it) {
it.valueRef().adj()
+= res.adj().coeff(it.row()) * value_of(b).coeff(it.col());
}
}
}

template <typename T1, typename T2, typename Res,
require_var_matrix_t<T1>* = nullptr>
void update_w(T1& w, int m, int n, std::vector<int, arena_allocator<int>>& u,
std::vector<int, arena_allocator<int>>& v, T2&& b, Res&& res) {
Eigen::Map<Eigen::SparseMatrix<double, Eigen::RowMajor>> w_mat(
m, n, w.size(), u.data(), v.data(), w.adj().data());
for (int k = 0; k < w_mat.outerSize(); ++k) {
for (Eigen::Map<Eigen::SparseMatrix<double, Eigen::RowMajor>>::InnerIterator
it(w_mat, k);
it; ++it) {
it.valueRef() += res.adj().coeff(it.row()) * value_of(b).coeff(it.col());
}
}
}

} // namespace internal

/**
* \addtogroup csr_format
* Return the multiplication of the sparse matrix (specified by
Expand Down Expand Up @@ -100,46 +67,36 @@ inline auto csr_matrix_times_vector(int m, int n, const T1& w,
std::vector<int, arena_allocator<int>> u_arena(u.size());
std::transform(u.begin(), u.end(), u_arena.begin(),
[](auto&& x) { return x - 1; });
using sparse_var_value_t
= var_value<Eigen::SparseMatrix<double, Eigen::RowMajor>>;
if (!is_constant<T2>::value && !is_constant<T1>::value) {
arena_t<promote_scalar_t<var, T2>> b_arena = b;
arena_t<promote_scalar_t<var, T1>> w_arena = to_arena(w);
auto w_val_arena = to_arena(value_of(w_arena));
sparse_val_mat w_val_mat(m, n, w_val_arena.size(), u_arena.data(),
v_arena.data(), w_val_arena.data());
arena_t<return_t> res = w_val_mat * value_of(b_arena);
reverse_pass_callback(
[m, n, w_arena, w_val_arena, v_arena, u_arena, res, b_arena]() mutable {
sparse_val_mat w_val_mat(m, n, w_val_arena.size(), u_arena.data(),
v_arena.data(), w_val_arena.data());
internal::update_w(w_arena, m, n, u_arena, v_arena, b_arena, res);
b_arena.adj() += w_val_mat.transpose() * res.adj();
});
sparse_var_value_t w_mat_arena
= to_soa_sparse_matrix<Eigen::RowMajor>(m, n, w, u_arena, v_arena);
arena_t<return_t> res = w_mat_arena.val() * value_of(b_arena);
reverse_pass_callback([res, w_mat_arena, b_arena]() mutable {
w_mat_arena.adj() += res.adj() * b_arena.val().transpose();
b_arena.adj() += w_mat_arena.val().transpose() * res.adj();
});
return return_t(res);
} else if (!is_constant<T2>::value) {
arena_t<promote_scalar_t<var, T2>> b_arena = b;
auto w_val_arena = to_arena(value_of(w));
sparse_val_mat w_val_mat(m, n, w_val_arena.size(), u_arena.data(),
v_arena.data(), w_val_arena.data());

arena_t<return_t> res = w_val_mat * value_of(b_arena);
reverse_pass_callback(
[m, n, w_val_arena, v_arena, u_arena, res, b_arena]() mutable {
sparse_val_mat w_val_mat(m, n, w_val_arena.size(), u_arena.data(),
v_arena.data(), w_val_arena.data());
b_arena.adj() += w_val_mat.transpose() * res.adj();
});
reverse_pass_callback([w_val_mat, res, b_arena]() mutable {
b_arena.adj() += w_val_mat.transpose() * res.adj();
});
return return_t(res);
} else {
arena_t<promote_scalar_t<var, T1>> w_arena = to_arena(w);
auto&& w_val = eval(value_of(w_arena));
sparse_val_mat w_val_mat(m, n, w_val.size(), u_arena.data(), v_arena.data(),
w_val.data());
sparse_var_value_t w_mat_arena
= to_soa_sparse_matrix<Eigen::RowMajor>(m, n, w, u_arena, v_arena);
auto b_arena = to_arena(value_of(b));
arena_t<return_t> res = w_val_mat * b_arena;
reverse_pass_callback(
[m, n, w_arena, v_arena, u_arena, res, b_arena]() mutable {
internal::update_w(w_arena, m, n, u_arena, v_arena, b_arena, res);
});
arena_t<return_t> res = w_mat_arena.val() * b_arena;
reverse_pass_callback([res, w_mat_arena, b_arena]() mutable {
w_mat_arena.adj() += res.adj() * b_arena.transpose();
});
return return_t(res);
}
}
Expand Down
Loading

0 comments on commit 86a3e83

Please sign in to comment.