-
-
Notifications
You must be signed in to change notification settings - Fork 183
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
operator+ for var matrices and matrix of vars #2115
Conversation
I have a branch below that does this same thing for subtract as well. I'll post some speed tests tonight https://github.com/stan-dev/math/tree/feature/varmat-operatorminus |
…4.1 (tags/RELEASE_600/final)
…4.1 (tags/RELEASE_600/final)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This all looks great! Couple minor changes in there.
I might be a bit of a pain though. I tried this reverse_pass_callback
approach with the unary functions and performance got worse, since multiple passes over the same inputs to pull out values and adjoints 'cost' more than was gained by the callback. I wonder if this will have the same problem?
Would you mind benchmarking these against against an apply_scalar_binary
implementation?
Something like:
template <typename T1, typename T2, require_any_container_t<T1, T2>* = nullptr>
inline auto operator+(const T1& a, const T2& b) {
return apply_scalar_binary(
a, b, [&](const auto& c, const auto& d) { return add(c, d); });
}
Where there are only scalar overloads for add
(i.e., comment out the container specialisations or something).
Let me know if you don't have the time and I can work them up over the weekend as well.
stan/math/prim/err/is_equal.hpp
Outdated
|
||
/** | ||
* Return <code>true</code> if <code>y</code> is less or equal to | ||
* <code>high</code>. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The doc and variable names don't quite line up with the function - since it's testing for equality rather than less or equal
stan/math/prim/err/is_equal.hpp
Outdated
return (y.array() == x).all(); | ||
} | ||
|
||
template <typename T_y, typename T_high, require_std_vector_t<T_y>* = nullptr> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will need to add doc for these other specialisations as well
stan/math/prim/err/is_equal.hpp
Outdated
* to low and if and element of y or high is NaN | ||
*/ | ||
template <typename T_y, typename T_high, require_eigen_t<T_y>* = nullptr> | ||
inline bool is_equal(const T_y& y, const T_high& x) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You'll probably also need to add is_equal(container, container)
specialisations, as well as is_equal(scalar, container)
.
At least for the second one you can just call the existing specialisations with the arguments reversed, so not all bad
stan/math/prim/err/is_equal.hpp
Outdated
} | ||
|
||
template <typename T_y, typename T_high, require_std_vector_t<T_y>* = nullptr> | ||
inline bool is_equal(const T_y& y, const T_high& x) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could always just have the Eigen specialisations defined and then Map
any std::vectors and pass those to the Eigen versions. Will help cut down on the code for the other combinations
stan/math/prim/fun/is_nan.hpp
Outdated
@@ -20,6 +20,11 @@ inline bool is_nan(T x) { | |||
return std::isnan(x); | |||
} | |||
|
|||
template <typename T, typename = require_eigen_t<T>> | |||
inline bool is_nan(const T& x) { | |||
return Eigen::isnan(x.array()).any(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Eigen has a neat little member function just for this:
return Eigen::isnan(x.array()).any(); | |
return x.hasNaN(); |
*/ | ||
template <typename Var, typename EigMat, | ||
require_eigen_vt<std::is_arithmetic, EigMat>* = nullptr, | ||
require_var_vt<std::is_arithmetic, Var>* = nullptr> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not related to this pull but I didn't realise we had this require_var_vt
, that will simplify some of the wacky templating I've been resorting to. These require generics have been such a good addition!
return {new internal::add_vv_vari(a.vi_, b.vi_)}; | ||
var ret(a.val() + b.val()); | ||
if (unlikely(is_any_nan(a.val(), b.val()))) { | ||
reverse_pass_callback([a, b]() mutable { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think special handling of nans benefits either correctness of performance. The same propagation of nans happens in general branch as well. Maybe benchmark it?
Just realised that my |
Jenkins Console Log Machine informationProductName: Mac OS X ProductVersion: 10.11.6 BuildVersion: 15G22010CPU: G++: Clang: |
If anyone has a minute this should be ready for review! |
Sorry about the delay, will take a look at this today |
Jenkins Console Log Machine informationProductName: Mac OS X ProductVersion: 10.11.6 BuildVersion: 15G22010CPU: G++: Clang: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Couple of very minor queries, but otherwise looks great! Thanks for going through the apply_scalar_binary
benchmarking rigmarole.
stan/math/prim/eigen_plugins.h
Outdated
template<typename T = Scalar> | ||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE | ||
std::enable_if_t<std::is_pointer<T>::value, reverse_return_t<T>> | ||
coeffRef(T &v) { return v->adj_; } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this needed? Shouldn't mat.adj().coeffRef(i,j)
'just work'?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh hm maybe I did something stupid let me look at this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Totally unnecessary, deleted!
@SteveBronder yoyo Steve can you do the polishing on this? I want my |
Yes! Trying to finish up the slicing stuff today but I will clean up based on the review comments (thanks @andrjohns !) tmrw |
Gogogogogo (edit: that's my generic cheer squad) |
Ooof yeah sorry I'll try to fix this up tonight or tomorrow. The assign and subset tests are way more than I thought they were |
The cheer squad appreciates the effort |
+1 Cheer Let me know when I can take a gander at this again |
@andrjohns should be ready to rock! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All good here! Woohoo!
@andrjohns can you reclick the approve thing? Something happened in jenkins and I had to kick this off again |
@serban-nicusor-toptal is there something up with jenkins? The normal button to view jenkins stuff is gone from the pr and this pr has been waiting since Friday. Is there a backlog of PRs? |
I had to reboot Jenkins and it may have corrupted the job url here. This is the last build that ran, I've also restarted it, button should be back too. |
cool much appreciated! |
Jenkins Console Log Machine informationProductName: Mac OS X ProductVersion: 10.11.6 BuildVersion: 15G22010CPU: G++: Clang: |
Summary
This adds
add()
andoperator+
for var matrices and adds overloaded functions foradd()
usingreverse_pass_callback()
for matrices of vars. I can add subtraction here as well as the code will look nearly the sameTests
Tests were added for
operator+
for all the mixed types it can take in. The tests useadd()
because for var typesadd()
now has a specialization for var types that divert tooperator+
Side Effects
None
Release notes
Adds addition overloads for var matrices and matrices of vars
Checklist
Math issue Make functions with custom autodiff
var<mat>
friendly #2101Copyright holder: (fill in copyright holder information)
The copyright holder is typically you or your assignee, such as a university or company. By submitting this pull request, the copyright holder is agreeing to the license the submitted work under the following licenses:
- Code: BSD 3-clause (https://opensource.org/licenses/BSD-3-Clause)
- Documentation: CC-BY 4.0 (https://creativecommons.org/licenses/by/4.0/)
the basic tests are passing
./runTests.py test/unit
)make test-headers
)make test-math-dependencies
)make doxygen
)make cpplint
)the code is written in idiomatic C++ and changes are documented in the doxygen
the new changes are tested