From e64628c7c3ea8125faa045d91da8823f12f7e398 Mon Sep 17 00:00:00 2001 From: Steve Bronder Date: Fri, 21 Jun 2024 11:40:53 -0400 Subject: [PATCH] adds opencl constructor for taking in values and adjoints --- stan/math/opencl/rev/vari.hpp | 21 ++++++++++++++++++++- test/unit/math/opencl/rev/vari_test.cpp | 5 +++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/stan/math/opencl/rev/vari.hpp b/stan/math/opencl/rev/vari.hpp index cd35d14c0f0..0fbda385097 100644 --- a/stan/math/opencl/rev/vari.hpp +++ b/stan/math/opencl/rev/vari.hpp @@ -227,7 +227,7 @@ class vari_value> : public chainable_alloc, require_vt_same* = nullptr> explicit vari_value(const S& x) : chainable_alloc(), vari_cl_base(x, constant(0, x.rows(), x.cols())) { - ChainableStack::instance_->var_stack_.push_back(this); + ChainableStack::instance_->var_nochain_stack_.push_back(this); } /** @@ -259,6 +259,25 @@ class vari_value> : public chainable_alloc, } } + /** + * Construct a dense Eigen variable implementation from a + * preconstructed values and adjoints. + * + * All constructed variables are not added to the stack. Variables + * should be constructed before variables on which they depend + * to insure proper partial derivative propagation. + * @tparam S A dense Eigen type that is convertible to `value_type` + * @tparam K A dense Eigen type that is convertible to `value_type` + * @param val Matrix of values + * @param adj Matrix of adjoints + */ + template * = nullptr, + require_convertible_t* = nullptr> + explicit vari_value(S&& val, K&& adj) : chainable_alloc(), + vari_cl_base(std::forward(val), std::forward(adj)) { + ChainableStack::instance_->var_nochain_stack_.push_back(this); + } + /** * Set the adjoint value of this variable to 0. This is used to * reset adjoints before propagating derivatives again (for diff --git a/test/unit/math/opencl/rev/vari_test.cpp b/test/unit/math/opencl/rev/vari_test.cpp index 4960f887324..5800f8875da 100644 --- a/test/unit/math/opencl/rev/vari_test.cpp +++ b/test/unit/math/opencl/rev/vari_test.cpp @@ -20,6 +20,11 @@ TEST(AgradRev, matrix_cl_vari_block) { stan::math::from_matrix_cl(B.block(0, 1, 2, 2).val_)); EXPECT_MATRIX_EQ(b.block(0, 1, 2, 2), stan::math::from_matrix_cl(B.block(0, 1, 2, 2).adj_)); + vari_value> C(a_cl, a_cl); + EXPECT_MATRIX_EQ(a.block(0, 1, 2, 2), + stan::math::from_matrix_cl(C.block(0, 1, 2, 2).val_)); + EXPECT_MATRIX_EQ(a.block(0, 1, 2, 2), + stan::math::from_matrix_cl(C.block(0, 1, 2, 2).adj_)); } #endif