Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Aliasing issue in OpenCL #2943

Merged
merged 11 commits into from
Sep 28, 2023
Next Next commit
Fixes adjoint accumulation for reverse mode where aliasing can occur.…
… Creates a assignment op tag that is used by adjoint_results to do a += instead of a = into the adjoint matrix
  • Loading branch information
SteveBronder committed Sep 12, 2023
commit fdb6d34d0800e841de71215622cc59dae76671c1
2 changes: 1 addition & 1 deletion stan/math/opencl/kernel_generator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@
#include <stan/math/opencl/kernel_generator/as_operation_cl.hpp>
#include <stan/math/opencl/kernel_generator/name_generator.hpp>
#include <stan/math/opencl/kernel_generator/type_str.hpp>

#include <stan/math/opencl/kernel_generator/assignment_ops.hpp>
#include <stan/math/opencl/kernel_generator/as_column_vector_or_scalar.hpp>
#include <stan/math/opencl/kernel_generator/load.hpp>
#include <stan/math/opencl/kernel_generator/scalar.hpp>
Expand Down
18 changes: 10 additions & 8 deletions stan/math/opencl/kernel_generator/as_operation_cl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define STAN_MATH_OPENCL_KERNEL_GENERATOR_AS_OPERATION_CL_HPP
#ifdef STAN_OPENCL

#include <stan/math/opencl/kernel_generator/assignment_ops.hpp>
#include <stan/math/opencl/kernel_generator/operation_cl.hpp>
#include <stan/math/opencl/kernel_generator/load.hpp>
#include <stan/math/opencl/kernel_generator/scalar.hpp>
Expand All @@ -23,7 +24,7 @@ namespace math {
* @param a an operation
* @return operation
*/
template <typename T_operation,
template <assignment_ops_cl AssignOp = assignment_ops_cl::equals, typename T_operation,
typename = std::enable_if_t<std::is_base_of<
operation_cl_base, std::remove_reference_t<T_operation>>::value>>
inline T_operation&& as_operation_cl(T_operation&& a) {
Expand All @@ -37,7 +38,7 @@ inline T_operation&& as_operation_cl(T_operation&& a) {
* @param a scalar
* @return \c scalar_ wrapping the input
*/
template <typename T_scalar, typename = require_arithmetic_t<T_scalar>,
template <assignment_ops_cl AssignOp = assignment_ops_cl::equals, typename T_scalar, typename = require_arithmetic_t<T_scalar>,
require_not_same_t<T_scalar, bool>* = nullptr>
inline scalar_<T_scalar> as_operation_cl(const T_scalar a) {
return scalar_<T_scalar>(a);
Expand All @@ -50,6 +51,7 @@ inline scalar_<T_scalar> as_operation_cl(const T_scalar a) {
* @param a scalar
* @return \c scalar_<char> wrapping the input
*/
template <assignment_ops_cl AssignOp = assignment_ops_cl::equals>
inline scalar_<char> as_operation_cl(const bool a) { return scalar_<char>(a); }

/**
Expand All @@ -59,11 +61,11 @@ inline scalar_<char> as_operation_cl(const bool a) { return scalar_<char>(a); }
* @param a \c matrix_cl
* @return \c load_ wrapping the input
*/
template <typename T_matrix_cl,
template <assignment_ops_cl AssignOp = assignment_ops_cl::equals, typename T_matrix_cl,
typename = require_any_t<is_matrix_cl<T_matrix_cl>,
is_arena_matrix_cl<T_matrix_cl>>>
inline load_<T_matrix_cl> as_operation_cl(T_matrix_cl&& a) {
return load_<T_matrix_cl>(std::forward<T_matrix_cl>(a));
inline load_<T_matrix_cl, AssignOp> as_operation_cl(T_matrix_cl&& a) {
return load_<T_matrix_cl, AssignOp>(std::forward<T_matrix_cl>(a));
}

/**
Expand All @@ -74,11 +76,11 @@ inline load_<T_matrix_cl> as_operation_cl(T_matrix_cl&& a) {
* rvalue reference, the reference is removed, so that a variable of this type
* actually stores the value.
*/
template <typename T>
template <typename T, assignment_ops_cl AssignOp = assignment_ops_cl::equals>
using as_operation_cl_t = std::conditional_t<
std::is_lvalue_reference<T>::value,
decltype(as_operation_cl(std::declval<T>())),
std::remove_reference_t<decltype(as_operation_cl(std::declval<T>()))>>;
decltype(as_operation_cl<AssignOp>(std::declval<T>())),
std::remove_reference_t<decltype(as_operation_cl<AssignOp>(std::declval<T>()))>>;

/** @}*/
} // namespace math
Expand Down
71 changes: 71 additions & 0 deletions stan/math/opencl/kernel_generator/assignment_ops.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#ifndef STAN_MATH_OPENCL_KERNEL_GENERATOR_ASSIGNMENT_OPS
#define STAN_MATH_OPENCL_KERNEL_GENERATOR_ASSIGNMENT_OPS
#ifdef STAN_OPENCL
#include <stan/math/prim/meta/is_detected.hpp>

namespace stan {
namespace math {

/**
* Ops that decide the type of assignment for LHS operations
*/
enum class assignment_ops_cl {equals, plus_equals, minus_equals, divide_equals};

/**
* @param value A static constexpr const char* member for printing assignment ops
*/
template <assignment_ops_cl assign_op>
struct assignment_op_str;

template <>
struct assignment_op_str<assignment_ops_cl::equals> {
static constexpr const char* value = " = ";
};

template <>
struct assignment_op_str<assignment_ops_cl::plus_equals> {
static constexpr const char* value = " += ";
};

template <>
struct assignment_op_str<assignment_ops_cl::minus_equals> {
static constexpr const char* value = " *= ";
};

template <>
struct assignment_op_str<assignment_ops_cl::divide_equals> {
static constexpr const char* value = " /= ";
};


namespace internal {
template <typename, typename = void>
struct has_assignment_op_str : std::false_type {};

template <typename T>
struct has_assignment_op_str<T, void_t<decltype(T::assignment_op)>> : std::true_type {};

} // namespace internal

/**
* @tparam T A type that does not have an `assignment_op` static constexpr member type
* @return A constexpr const char* equal to `" = "`
*/
template <typename T, std::enable_if_t<!internal::has_assignment_op_str<std::decay_t<T>>::value>* = nullptr>
inline constexpr const char* assignment_op() noexcept {
return " = ";
}

/**
* @tparam T A type that has an `assignment_op` static constexpr member type
* @return The types assignment op as a constexpr const char*
*/
template <typename T, std::enable_if_t<internal::has_assignment_op_str<T>::value>* = nullptr>
inline constexpr const char* assignment_op() noexcept {
return assignment_op_str<std::decay_t<T>::assignment_op>::value;
}

}
}
#endif
#endif
18 changes: 12 additions & 6 deletions stan/math/opencl/kernel_generator/load.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

#include <stan/math/opencl/matrix_cl.hpp>
#include <stan/math/opencl/matrix_cl_view.hpp>
#include <stan/math/opencl/kernel_generator/assignment_ops.hpp>

#include <stan/math/opencl/kernel_generator/type_str.hpp>
#include <stan/math/opencl/kernel_generator/name_generator.hpp>
#include <stan/math/opencl/kernel_generator/operation_cl.hpp>
Expand All @@ -23,17 +25,20 @@ namespace math {
/**
* Represents an access to a \c matrix_cl in kernel generator expressions
* @tparam T \c matrix_cl
* @tparam AssignOp tells higher level operations whether the final operation should be an assignment or a type of compound assignment.
*/
template <typename T>
template <typename T, assignment_ops_cl AssignOp = assignment_ops_cl::equals>
class load_
: public operation_cl_lhs<load_<T>,
: public operation_cl_lhs<load_<T, AssignOp>,
typename std::remove_reference_t<T>::type> {
protected:
T a_;

public:

static constexpr assignment_ops_cl assignment_op = AssignOp;
using Scalar = typename std::remove_reference_t<T>::type;
using base = operation_cl<load_<T>, Scalar>;
using base = operation_cl<load_<T, AssignOp>, Scalar>;
using base::var_name_;
static_assert(disjunction<is_matrix_cl<T>, is_arena_matrix_cl<T>>::value,
"load_: argument a must be a matrix_cl<T>!");
Expand All @@ -51,9 +56,9 @@ class load_
* Creates a deep copy of this expression.
* @return copy of \c *this
*/
inline load_<T&> deep_copy() & { return load_<T&>(a_); }
inline load_<const T&> deep_copy() const& { return load_<const T&>(a_); }
inline load_<T> deep_copy() && { return load_<T>(std::forward<T>(a_)); }
inline load_<T&, AssignOp> deep_copy() & { return load_<T&, AssignOp>(a_); }
inline load_<const T&, AssignOp> deep_copy() const& { return load_<const T&, AssignOp>(a_); }
inline load_<T, AssignOp> deep_copy() && { return load_<T, AssignOp>(std::forward<T>(a_)); }

/**
* Generates kernel code for this expression.
Expand Down Expand Up @@ -327,6 +332,7 @@ class load_
}
}
};

/** @}*/
} // namespace math
} // namespace stan
Expand Down
40 changes: 31 additions & 9 deletions stan/math/opencl/kernel_generator/multi_result_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <stan/math/prim/err/check_size_match.hpp>
#include <stan/math/prim/meta/is_kernel_expression.hpp>
#include <stan/math/opencl/kernel_generator/name_generator.hpp>
#include <stan/math/opencl/kernel_generator/assignment_ops.hpp>
#include <stan/math/opencl/kernel_generator/as_operation_cl.hpp>
#include <stan/math/opencl/kernel_generator/calc_if.hpp>
#include <stan/math/opencl/kernel_generator/check_cl.hpp>
Expand Down Expand Up @@ -334,13 +335,34 @@ class results_cl {
== sizeof...(T_expressions)>>
void operator+=(const expressions_cl<T_expressions...>& exprs) {
index_apply<sizeof...(T_expressions)>([this, &exprs](auto... Is) {
auto tmp = std::tuple_cat(make_assignment_pair(
auto tmp = std::tuple_cat(make_assignment_pair<assignment_ops_cl::plus_equals>(
std::get<Is>(results_), std::get<Is>(exprs.expressions_))...);
index_apply<std::tuple_size<decltype(tmp)>::value>(
[this, &tmp](auto... Is2) {
assignment_impl(std::make_tuple(std::make_pair(
std::get<Is2>(tmp).first,
std::get<Is2>(tmp).first + std::get<Is2>(tmp).second)...));
std::get<Is2>(tmp).first, std::get<Is2>(tmp).second)...));
});
});
}

/**
* Incrementing \c results_ object by \c expressions_cl object
* executes the kernel that evaluates expressions and increments results by
* those expressions.
* @tparam T_expressions types of expressions
* @param exprs expressions
*/
template <typename... T_expressions,
typename = std::enable_if_t<sizeof...(T_results)
== sizeof...(T_expressions)>>
void operator-=(const expressions_cl<T_expressions...>& exprs) {
index_apply<sizeof...(T_expressions)>([this, &exprs](auto... Is) {
auto tmp = std::tuple_cat(make_assignment_pair<assignment_ops_cl::minus_equals>(
std::get<Is>(results_), std::get<Is>(exprs.expressions_))...);
index_apply<std::tuple_size<decltype(tmp)>::value>(
[this, &tmp](auto... Is2) {
assignment_impl(std::make_tuple(std::make_pair(
std::get<Is2>(tmp).first, std::get<Is2>(tmp).second)...));
});
});
}
Expand Down Expand Up @@ -426,7 +448,7 @@ class results_cl {
+ parts.reduction_2d +
"}\n";
}
return src;
return src;
}

/**
Expand Down Expand Up @@ -529,16 +551,16 @@ class results_cl {
* @param expression expression
* @return a tuple of pair of result and expression
*/
template <typename T_result, typename T_expression,
template <assignment_ops_cl AssignOp = assignment_ops_cl::equals, typename T_result, typename T_expression,
require_all_not_t<is_without_output<T_expression>,
conjunction<internal::is_scalar_check<T_result>,
std::is_arithmetic<std::decay_t<
T_expression>>>>* = nullptr>
static auto make_assignment_pair(T_result&& result,
T_expression&& expression) {
return std::make_tuple(
std::pair<as_operation_cl_t<T_result>, as_operation_cl_t<T_expression>>(
as_operation_cl(std::forward<T_result>(result)),
std::pair<as_operation_cl_t<T_result, AssignOp>, as_operation_cl_t<T_expression>>(
as_operation_cl<AssignOp>(std::forward<T_result>(result)),
as_operation_cl(std::forward<T_expression>(expression))));
}

Expand All @@ -548,7 +570,7 @@ class results_cl {
* @param expression expression
* @return a tuple of pair of result and expression
*/
template <typename T_result, typename T_expression,
template <assignment_ops_cl AssignOp = assignment_ops_cl::equals, typename T_result, typename T_expression,
require_t<is_without_output<T_expression>>* = nullptr>
static auto make_assignment_pair(T_result&& result,
T_expression&& expression) {
Expand All @@ -562,7 +584,7 @@ class results_cl {
* @param pass bool scalar
* @return an empty tuple
*/
template <typename T_check, typename T_pass,
template <assignment_ops_cl AssignOp = assignment_ops_cl::equals, typename T_check, typename T_pass,
require_t<internal::is_scalar_check<T_check>>* = nullptr,
require_integral_t<T_pass>* = nullptr>
static std::tuple<> make_assignment_pair(T_check&& result, T_pass&& pass) {
Expand Down
21 changes: 20 additions & 1 deletion stan/math/opencl/kernel_generator/operation_cl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/err/check_nonnegative.hpp>
#include <stan/math/opencl/kernel_generator/assignment_ops.hpp>
#include <stan/math/opencl/kernel_generator/type_str.hpp>
#include <stan/math/opencl/kernel_generator/name_generator.hpp>
#include <stan/math/opencl/matrix_cl_view.hpp>
Expand Down Expand Up @@ -74,6 +75,24 @@ struct kernel_parts {
}
};

std::ostream& operator<<(std::ostream& os, kernel_parts& parts) {
os << "args:" << std::endl;
os << parts.args.substr(0, parts.args.size() - 2) << std::endl;
os << "Decl:" << std::endl;
os << parts.declarations << std::endl;
os << "Init:" << std::endl;
os << parts.initialization << std::endl;
os << "body:" << std::endl;
os << parts.body << std::endl;
os << "body_suffix:" << std::endl;
os << parts.body_suffix << std::endl;
os << "reduction_1d:" << std::endl;
os << parts.reduction_1d << std::endl;
os << "reduction_2d:" << std::endl;
os << parts.reduction_2d << std::endl;
return os;
}

/**
* Base for all kernel generator operations.
* @tparam Derived derived type
Expand Down Expand Up @@ -201,7 +220,7 @@ class operation_cl : public operation_cl_base {
generated, generated_all, ng, row_index_name, col_index_name, false);
kernel_parts out_parts = result.get_kernel_parts_lhs(
generated, generated_all, ng, row_index_name, col_index_name);
out_parts.body += " = " + derived().var_name_ + ";\n";
out_parts.body += assignment_op<T_result>() + derived().var_name_ + ";\n";
parts += out_parts;
return parts;
}
Expand Down
9 changes: 4 additions & 5 deletions stan/math/opencl/prim/normal_lccdf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,12 @@ return_type_t<T_y_cl, T_loc_cl, T_scale_cl> normal_lccdf(
matrix_cl<double> mu_deriv_cl;
matrix_cl<double> sigma_deriv_cl;

results(check_y_not_nan, check_mu_finite, check_sigma_positive, lccdf_cl,
y_deriv_cl, mu_deriv_cl, sigma_deriv_cl)
= expressions(y_not_nan_expr, mu_finite_expr, sigma_positive_expr,
lccdf_expr, calc_if<!is_constant<T_y_cl>::value>(y_deriv),
results(check_y_not_nan, check_mu_finite, check_sigma_positive)
= expressions(y_not_nan_expr, mu_finite_expr, sigma_positive_expr);
results(lccdf_cl, y_deriv_cl, mu_deriv_cl, sigma_deriv_cl)
= expressions(lccdf_expr, calc_if<!is_constant<T_y_cl>::value>(y_deriv),
calc_if<!is_constant<T_loc_cl>::value>(mu_deriv),
calc_if<!is_constant<T_scale_cl>::value>(sigma_deriv));

T_partials_return lccdf = LOG_HALF + sum(from_matrix_cl(lccdf_cl));

auto ops_partials = make_partials_propagator(y_col, mu_col, sigma_col);
Expand Down
1 change: 1 addition & 0 deletions stan/math/opencl/rev.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
#include <stan/math/opencl/rev/fmax.hpp>
#include <stan/math/opencl/rev/fmin.hpp>
#include <stan/math/opencl/rev/fmod.hpp>
#include <stan/math/opencl/rev/grad.hpp>
#include <stan/math/opencl/rev/hypot.hpp>
#include <stan/math/opencl/rev/inv.hpp>
#include <stan/math/opencl/rev/inv_cloglog.hpp>
Expand Down
Loading