Skip to content

Commit

Permalink
adds opencl constructor for taking in values and adjoints
Browse files Browse the repository at this point in the history
  • Loading branch information
SteveBronder committed Jun 21, 2024
1 parent 4ab7914 commit e64628c
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 1 deletion.
21 changes: 20 additions & 1 deletion stan/math/opencl/rev/vari.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ class vari_value<T, require_matrix_cl_t<T>> : public chainable_alloc,
require_vt_same<T, S>* = nullptr>
explicit vari_value(const S& x)
: chainable_alloc(), vari_cl_base<T>(x, constant(0, x.rows(), x.cols())) {
ChainableStack::instance_->var_stack_.push_back(this);
ChainableStack::instance_->var_nochain_stack_.push_back(this);
}

/**
Expand Down Expand Up @@ -259,6 +259,25 @@ class vari_value<T, require_matrix_cl_t<T>> : public chainable_alloc,
}
}

/**
* Construct a dense Eigen variable implementation from a
* preconstructed values and adjoints.
*
* All constructed variables are not added to the stack. Variables
* should be constructed before variables on which they depend
* to insure proper partial derivative propagation.
* @tparam S A dense Eigen type that is convertible to `value_type`
* @tparam K A dense Eigen type that is convertible to `value_type`
* @param val Matrix of values
* @param adj Matrix of adjoints
*/
template <typename S, typename K, require_convertible_t<T, S>* = nullptr,
require_convertible_t<T, K>* = nullptr>
explicit vari_value(S&& val, K&& adj) : chainable_alloc(),
vari_cl_base<T>(std::forward<S>(val), std::forward<K>(adj)) {
ChainableStack::instance_->var_nochain_stack_.push_back(this);
}

/**
* Set the adjoint value of this variable to 0. This is used to
* reset adjoints before propagating derivatives again (for
Expand Down
5 changes: 5 additions & 0 deletions test/unit/math/opencl/rev/vari_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ TEST(AgradRev, matrix_cl_vari_block) {
stan::math::from_matrix_cl(B.block(0, 1, 2, 2).val_));
EXPECT_MATRIX_EQ(b.block(0, 1, 2, 2),
stan::math::from_matrix_cl(B.block(0, 1, 2, 2).adj_));
vari_value<stan::math::matrix_cl<double>> C(a_cl, a_cl);
EXPECT_MATRIX_EQ(a.block(0, 1, 2, 2),
stan::math::from_matrix_cl(C.block(0, 1, 2, 2).val_));
EXPECT_MATRIX_EQ(a.block(0, 1, 2, 2),
stan::math::from_matrix_cl(C.block(0, 1, 2, 2).adj_));
}

#endif

0 comments on commit e64628c

Please sign in to comment.