Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/develop' into feature/reverse-mo…
Browse files Browse the repository at this point in the history
…de-move-semantics
  • Loading branch information
SteveBronder committed Apr 26, 2024
2 parents c5f983a + e73651b commit 7a9601d
Show file tree
Hide file tree
Showing 17 changed files with 500 additions and 87 deletions.
7 changes: 7 additions & 0 deletions lib/tbb_2020.3/STAN_CHANGES
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,10 @@ This file documents changes done for the stan-math project
- build/windows.inc patches for RTools make:
- L15 changed setting to use '?=', allowing override
- L25,L113,L114 added additional '/' to each cmd flag

- Support for Windows ARM64 with RTools:
- build/Makefile.tbb
- L94 Wrapped the use of `--version-script` export in conditional on non-WINARM64
- build/windows.gcc.ino
- L84 Wrapped the use of `-flifetime-dse` flag in conditional on non-WINARM64

6 changes: 5 additions & 1 deletion lib/tbb_2020.3/build/Makefile.tbb
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,11 @@ ifneq (,$(TBB.DEF))
tbb.def: $(TBB.DEF) $(TBB.LST)
$(CPLUS) $(PREPROC_ONLY) $< $(CPLUS_FLAGS) $(INCLUDES) > $@

LIB_LINK_FLAGS += $(EXPORT_KEY)tbb.def
# LLVM on Windows doesn't need --version-script export
# https://reviews.llvm.org/D63743
ifeq (, $(WINARM64))
LIB_LINK_FLAGS += $(EXPORT_KEY)tbb.def
endif
$(TBB.DLL): tbb.def
endif

Expand Down
7 changes: 5 additions & 2 deletions lib/tbb_2020.3/build/windows.gcc.inc
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,11 @@ endif
# gcc 6.0 and later have -flifetime-dse option that controls
# elimination of stores done outside the object lifetime
ifeq (ok,$(call detect_js,/minversion gcc 6.0))
# keep pre-contruction stores for zero initialization
DSE_KEY = -flifetime-dse=1
# Clang does not support -flifetime-dse
ifeq (, $(WINARM64))
# keep pre-contruction stores for zero initialization
DSE_KEY = -flifetime-dse=1
endif
endif

ifeq ($(cfg), release)
Expand Down
15 changes: 13 additions & 2 deletions make/compiler_flags
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ endif

## Set OS specific library filename extensions
ifeq ($(OS),Windows_NT)
WINARM64 := $(shell echo | $(CXX) -E -dM - | findstr __aarch64__)
LIBRARY_SUFFIX ?= .dll
endif

Expand Down Expand Up @@ -271,8 +272,13 @@ CXXFLAGS_TBB ?= -I $(TBB_INC)
else
CXXFLAGS_TBB ?= -I $(TBB)/include
endif
LDFLAGS_TBB ?= -Wl,-L,"$(TBB_LIB)" -Wl,--disable-new-dtags

# Windows LLVM/Clang does not support -rpath, but is not needed on Windows anyway
ifeq ($(WINARM64),)
LDFLAGS_TBB += -Wl,-rpath,"$(TBB_LIB)"
endif

LDFLAGS_TBB ?= -Wl,-L,"$(TBB_LIB)" -Wl,-rpath,"$(TBB_LIB)" -Wl,--disable-new-dtags
LDLIBS_TBB ?= -ltbb

else
Expand All @@ -290,7 +296,12 @@ ifeq ($(OS),Linux)
endif

CXXFLAGS_TBB ?= -I $(TBB)/include
LDFLAGS_TBB ?= -Wl,-L,"$(TBB_BIN_ABSOLUTE_PATH)" -Wl,-rpath,"$(TBB_BIN_ABSOLUTE_PATH)" $(LDFLAGS_FLTO_FLTO) $(LDFLAGS_OPTIM_TBB)
LDFLAGS_TBB ?= -Wl,-L,"$(TBB_BIN_ABSOLUTE_PATH)" $(LDFLAGS_FLTO_FLTO) $(LDFLAGS_OPTIM_TBB)

# Windows LLVM/Clang does not support -rpath, but is not needed on Windows anyway
ifeq ($(WINARM64),)
LDFLAGS_TBB += -Wl,-rpath,"$(TBB_BIN_ABSOLUTE_PATH)"
endif
LDLIBS_TBB ?= -ltbb

endif
Expand Down
12 changes: 9 additions & 3 deletions make/libraries
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ CPPLINT ?= $(MATH)lib/cpplint_1.4.5
# Fortran bindings which we do not need for stan-math. Thus these targets
# are ignored here. This convention was introduced with 4.0.
##
ifndef SUNDIALS_TARGETS

SUNDIALS_CVODES := $(patsubst %.c,%.o,\
$(wildcard $(SUNDIALS)/src/cvodes/*.c) \
Expand Down Expand Up @@ -87,7 +88,7 @@ $(STAN_SUNDIALS_HEADERS) : $(SUNDIALS_TARGETS)
clean-sundials:
@echo ' cleaning sundials targets'
$(RM) $(wildcard $(sort $(SUNDIALS_CVODES) $(SUNDIALS_IDAS) $(SUNDIALS_KINSOL) $(SUNDIALS_NVECSERIAL) $(SUNDIALS_TARGETS)))

endif

############################################################
# TBB build rules
Expand Down Expand Up @@ -138,6 +139,11 @@ endif
ifeq (Windows_NT, $(OS))
ifeq ($(IS_UCRT),true)
TBB_CXXFLAGS += -D_UCRT
endif
# TBB does not have assembly code for Windows ARM64, so we need to use GCC builtins
ifneq ($(WINARM64),)
TBB_CXXFLAGS += -DTBB_USE_GCC_BUILTINS
CXXFLAGS_TBB += -DTBB_USE_GCC_BUILTINS
endif
SH_CHECK := $(shell command -v sh 2>/dev/null)
ifdef SH_CHECK
Expand Down Expand Up @@ -169,11 +175,11 @@ endif
$(TBB_BIN)/tbb.def: $(TBB_BIN)/tbb-make-check
@mkdir -p $(TBB_BIN)
touch $(TBB_BIN)/version_$(notdir $(TBB))
tbb_root="$(TBB_RELATIVE_PATH)" CXX="$(CXX)" CC="$(TBB_CC)" LDFLAGS='$(LDFLAGS_TBB)' '$(MAKE)' -C "$(TBB_BIN)" -r -f "$(TBB_ABSOLUTE_PATH)/build/Makefile.tbb" compiler=$(TBB_CXX_TYPE) cfg=release stdver=c++1y CXXFLAGS="$(TBB_CXXFLAGS)"
tbb_root="$(TBB_RELATIVE_PATH)" WINARM64="$(WINARM64)" CXX="$(CXX)" CC="$(TBB_CC)" LDFLAGS='$(LDFLAGS_TBB)' '$(MAKE)' -C "$(TBB_BIN)" -r -f "$(TBB_ABSOLUTE_PATH)/build/Makefile.tbb" compiler=$(TBB_CXX_TYPE) cfg=release stdver=c++1y CXXFLAGS="$(TBB_CXXFLAGS)"

$(TBB_BIN)/tbbmalloc.def: $(TBB_BIN)/tbb-make-check
@mkdir -p $(TBB_BIN)
tbb_root="$(TBB_RELATIVE_PATH)" CXX="$(CXX)" CC="$(TBB_CC)" LDFLAGS='$(LDFLAGS_TBB)' '$(MAKE)' -C "$(TBB_BIN)" -r -f "$(TBB_ABSOLUTE_PATH)/build/Makefile.tbbmalloc" compiler=$(TBB_CXX_TYPE) cfg=release stdver=c++1y malloc CXXFLAGS="$(TBB_CXXFLAGS)"
tbb_root="$(TBB_RELATIVE_PATH)" WINARM64="$(WINARM64)" CXX="$(CXX)" CC="$(TBB_CC)" LDFLAGS='$(LDFLAGS_TBB)' '$(MAKE)' -C "$(TBB_BIN)" -r -f "$(TBB_ABSOLUTE_PATH)/build/Makefile.tbbmalloc" compiler=$(TBB_CXX_TYPE) cfg=release stdver=c++1y malloc CXXFLAGS="$(TBB_CXXFLAGS)"

$(TBB_BIN)/libtbb.dylib: $(TBB_BIN)/tbb.def
$(TBB_BIN)/libtbbmalloc.dylib: $(TBB_BIN)/tbbmalloc.def
Expand Down
24 changes: 23 additions & 1 deletion stan/math/prim/fun/value_of.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ inline auto value_of(const T& x) {
* @param[in] M Matrix to be converted
* @return Matrix of values
**/
template <typename EigMat, require_eigen_t<EigMat>* = nullptr,
template <typename EigMat, require_eigen_dense_base_t<EigMat>* = nullptr,
require_not_st_arithmetic<EigMat>* = nullptr>
inline auto value_of(EigMat&& M) {
return make_holder(
Expand All @@ -77,6 +77,28 @@ inline auto value_of(EigMat&& M) {
std::forward<EigMat>(M));
}

template <typename EigMat, require_eigen_sparse_base_t<EigMat>* = nullptr,
require_not_st_arithmetic<EigMat>* = nullptr>
inline auto value_of(EigMat&& M) {
auto&& M_ref = to_ref(M);
using scalar_t = decltype(value_of(std::declval<value_type_t<EigMat>>()));
promote_scalar_t<scalar_t, plain_type_t<EigMat>> ret(M_ref.rows(),
M_ref.cols());
ret.reserve(M_ref.nonZeros());
for (int k = 0; k < M_ref.outerSize(); ++k) {
for (typename std::decay_t<EigMat>::InnerIterator it(M_ref, k); it; ++it) {
ret.insert(it.row(), it.col()) = value_of(it.valueRef());
}
}
ret.makeCompressed();
return ret;
}
template <typename EigMat, require_eigen_sparse_base_t<EigMat>* = nullptr,
require_st_arithmetic<EigMat>* = nullptr>
inline auto value_of(EigMat&& M) {
return std::forward<EigMat>(M);
}

} // namespace math
} // namespace stan

Expand Down
16 changes: 16 additions & 0 deletions stan/math/prim/meta/is_eigen_dense_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,22 @@ using require_eigen_dense_base_t
= require_t<is_eigen_dense_base<std::decay_t<T>>>;
/*! @} */

/*! \ingroup require_eigens_types */
/*! \defgroup eigen_dense_base_types eigen_dense_base_types */
/*! \addtogroup eigen_dense_base_types */
/*! @{ */

/*! \brief Require type satisfies @ref is_eigen_dense_base */
/*! and value type satisfies `TypeCheck` */
/*! @tparam TypeCheck The type trait to check the value type against */
/*! @tparam Check The type to test @ref is_eigen_dense_base for and whose
* @ref value_type is checked with `TypeCheck` */
template <template <class...> class TypeCheck, class... Check>
using require_eigen_dense_base_vt
= require_t<container_type_check_base<is_eigen_dense_base, value_type_t,
TypeCheck, Check...>>;
/*! @} */

} // namespace stan

#endif
14 changes: 13 additions & 1 deletion stan/math/prim/meta/promote_scalar_type.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
#include <stan/math/prim/fun/Eigen.hpp>
#include <stan/math/prim/meta/is_eigen.hpp>
#include <stan/math/prim/meta/is_var.hpp>
#include <stan/math/prim/meta/is_eigen_dense_base.hpp>
#include <stan/math/prim/meta/is_eigen_sparse_base.hpp>
#include <vector>

namespace stan {
Expand Down Expand Up @@ -80,7 +82,7 @@ struct promote_scalar_type<T, S,
* @tparam S input matrix type
*/
template <typename T, typename S>
struct promote_scalar_type<T, S, require_eigen_t<S>> {
struct promote_scalar_type<T, S, require_eigen_dense_base_t<S>> {
/**
* The promoted type.
*/
Expand All @@ -93,6 +95,16 @@ struct promote_scalar_type<T, S, require_eigen_t<S>> {
S::RowsAtCompileTime, S::ColsAtCompileTime>>::type;
};

template <typename T, typename S>
struct promote_scalar_type<T, S, require_eigen_sparse_base_t<S>> {
/**
* The promoted type.
*/
using type = Eigen::SparseMatrix<
typename promote_scalar_type<T, typename S::Scalar>::type, S::Options,
typename S::StorageIndex>;
};

template <typename... PromotionScalars, typename... UnPromotedTypes>
struct promote_scalar_type<std::tuple<PromotionScalars...>,
std::tuple<UnPromotedTypes...>> {
Expand Down
24 changes: 15 additions & 9 deletions stan/math/rev/core/arena_matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
#include <stan/math/rev/core/chainable_object.hpp>
#include <stan/math/rev/core/var_value_fwd_declare.hpp>
#include <stan/math/prim/fun/to_ref.hpp>

namespace stan {
namespace math {

Expand Down Expand Up @@ -269,8 +268,10 @@ class arena_matrix<MatrixType, require_eigen_sparse_base_t<MatrixType>>
*/
arena_matrix(const arena_matrix<MatrixType>& other)
: Base::Map(other.rows(), other.cols(), other.nonZeros(),
other.outerIndexPtr(), other.innerIndexPtr(),
other.valuePtr(), other.innernonZeroPtr()) {}
const_cast<StorageIndex*>(other.outerIndexPtr()),
const_cast<StorageIndex*>(other.innerIndexPtr()),
const_cast<Scalar*>(other.valuePtr()),
const_cast<StorageIndex*>(other.innerNonZeroPtr())) {}
/**
* Move constructor.
* @note Since the memory for the arena matrix sits in Stan's memory arena all
Expand All @@ -279,8 +280,10 @@ class arena_matrix<MatrixType, require_eigen_sparse_base_t<MatrixType>>
*/
arena_matrix(arena_matrix<MatrixType>&& other)
: Base::Map(other.rows(), other.cols(), other.nonZeros(),
other.outerIndexPtr(), other.innerIndexPtr(),
other.valuePtr(), other.innerNonZeroPtr()) {}
const_cast<StorageIndex*>(other.outerIndexPtr()),
const_cast<StorageIndex*>(other.innerIndexPtr()),
const_cast<Scalar*>(other.valuePtr()),
const_cast<StorageIndex*>(other.innerNonZeroPtr())) {}
/**
* Copy constructor. No actual copy is performed
* @note Since the memory for the arena matrix sits in Stan's memory arena all
Expand All @@ -289,8 +292,10 @@ class arena_matrix<MatrixType, require_eigen_sparse_base_t<MatrixType>>
*/
arena_matrix(arena_matrix<MatrixType>& other)
: Base::Map(other.rows(), other.cols(), other.nonZeros(),
other.outerIndexPtr(), other.innerIndexPtr(),
other.valuePtr(), other.innerNonZeroPtr()) {}
const_cast<StorageIndex*>(other.outerIndexPtr()),
const_cast<StorageIndex*>(other.innerIndexPtr()),
const_cast<Scalar*>(other.valuePtr()),
const_cast<StorageIndex*>(other.innerNonZeroPtr())) {}

// without this using, compiler prefers combination of implicit construction
// and copy assignment to the inherited operator when assigned an expression
Expand All @@ -303,7 +308,8 @@ class arena_matrix<MatrixType, require_eigen_sparse_base_t<MatrixType>>
* @return `*this`
*/
template <typename ArenaMatrix,
require_same_t<ArenaMatrix, arena_matrix<MatrixType>>* = nullptr>
require_same_t<std::decay_t<ArenaMatrix>,
arena_matrix<MatrixType>>* = nullptr>
arena_matrix& operator=(ArenaMatrix&& other) {
// placement new changes what data map points to - there is no allocation
new (this) Base(other.rows(), other.cols(), other.nonZeros(),
Expand All @@ -324,7 +330,7 @@ class arena_matrix<MatrixType, require_eigen_sparse_base_t<MatrixType>>
template <typename Expr,
require_not_same_t<Expr, arena_matrix<MatrixType>>* = nullptr>
arena_matrix& operator=(Expr&& expr) {
*this = arena_matrix(std::forward<Expr>(expr));
new (this) arena_matrix(std::forward<Expr>(expr));
return *this;
}

Expand Down
12 changes: 12 additions & 0 deletions stan/math/rev/core/var.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,18 @@ class var_value<T, internal::require_matrix_var_value<T>> {
});
}

/**
* Construct a `var_value` with premade @ref arena_matrix types.
* The values and adjoint matrices passed here will be shallow copied.
* @tparam S type of the value in the `var_value` to assing
* @param val The value matrix to go into the vari
* @param adj the adjoint matrix to go into the vari
*/
template <typename S, typename T_ = T,
require_assignable_t<value_type, S>* = nullptr,
require_arena_matrix_t<S>* = nullptr>
var_value(const S& val, const S& adj) : vi_(new vari_type(val, adj)) {}

/**
* Construct a variable from a pointer to a variable implementation.
* @param vi A vari_value pointer.
Expand Down
32 changes: 23 additions & 9 deletions stan/math/rev/core/vari.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -821,17 +821,16 @@ class vari_value<T, require_eigen_sparse_base_t<T>> : public vari_base {
*/
static constexpr int ColsAtCompileTime = T::ColsAtCompileTime;

/**
* The value of this variable.
*/
arena_matrix<PlainObject> val_;
/**
* The adjoint of this variable, which is the partial derivative
* of this variable with respect to the root variable.
*/
arena_matrix<PlainObject> adj_;

/**
* The value of this variable.
*/
arena_matrix<PlainObject> val_;

/**
* Construct a variable implementation from a value. The
* adjoint is initialized to zero.
Expand All @@ -847,10 +846,21 @@ class vari_value<T, require_eigen_sparse_base_t<T>> : public vari_base {
* @param x Value of the constructed variable.
*/
template <typename S, require_convertible_t<S&, T>* = nullptr>
explicit vari_value(S&& x) : adj_(x), val_(std::forward<S>(x)) {
this->set_zero_adjoint();
explicit vari_value(S&& x)
: val_(std::forward<S>(x)),
adj_(val_.rows(), val_.cols(), val_.nonZeros(), val_.outerIndexPtr(),
val_.innerIndexPtr(),
arena_matrix<Eigen::VectorXd>(val_.nonZeros()).setZero().data(),
val_.innerNonZeroPtr()) {
ChainableStack::instance_->var_stack_.push_back(this);
}

vari_value(const arena_matrix<PlainObject>& val,
const arena_matrix<PlainObject>& adj)
: val_(val), adj_(adj) {
ChainableStack::instance_->var_stack_.push_back(this);
}

/**
* Construct an sparse Eigen variable implementation from a value. The
* adjoint is initialized to zero and if `stacked` is `false` this vari
Expand All @@ -869,8 +879,12 @@ class vari_value<T, require_eigen_sparse_base_t<T>> : public vari_base {
* that its `chain()` method is not called.
*/
template <typename S, require_convertible_t<S&, T>* = nullptr>
vari_value(S&& x, bool stacked) : adj_(x), val_(std::forward<S>(x)) {
this->set_zero_adjoint();
vari_value(S&& x, bool stacked)
: val_(std::forward<S>(x)),
adj_(val_.rows(), val_.cols(), val_.nonZeros(), val_.outerIndexPtr(),
val_.innerIndexPtr(),
arena_matrix<Eigen::VectorXd>(val_.nonZeros()).setZero().data(),
val_.innerNonZeroPtr()) {
if (stacked) {
ChainableStack::instance_->var_stack_.push_back(this);
} else {
Expand Down
1 change: 1 addition & 0 deletions stan/math/rev/fun.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@
#include <stan/math/rev/fun/tgamma.hpp>
#include <stan/math/rev/fun/to_var.hpp>
#include <stan/math/rev/fun/to_arena.hpp>
#include <stan/math/rev/fun/to_soa_sparse_matrix.hpp>
#include <stan/math/rev/fun/to_var_value.hpp>
#include <stan/math/rev/fun/to_vector.hpp>
#include <stan/math/rev/fun/trace.hpp>
Expand Down
Loading

0 comments on commit 7a9601d

Please sign in to comment.