Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Reverse Mode For Static Matrix Multiplication #1884

Closed
wants to merge 116 commits into from

Conversation

SteveBronder
Copy link
Collaborator

@SteveBronder SteveBronder commented May 13, 2020

Summary

If all these WIPs are getting annoying I can close all the others since this one has everything in it.

This is adds the reverse mode matrix multiplication for static matrices. At uses a trick in the chain() method to call either the standard multiplication chain method or the matrix chain method. The chain_impl() function has a doc going over how this works It also adds a multiply_vari specialization for Arith * eigen_var vs. eigen_var * Arith.

This leaks memory right now as op_vari holds the matrix which is not allocated on our stack. I've been having some trouble writing the code to allocate that memory with op_vari, if @t4c1 or @bbbales2 know some tuple magic to write that it would be very appreciated! Another option is just to remove op_vari and template dv_vari, vd_vari etc. to take in and allocate the mem for eigen matrices. It's a code density vs maintanence tradeoff. If we can find a nice solution then op_vari is fine, but if it's too confusing then it may be better to go back to the old op code. I have the start of the code for op_vari with stack allocated mem for the eigen matrices here

Tests

I just wrote an informal test right now that checks if static vs dynamic matrices return the same adjoint calculations after calling .grad()

./runTests.py -j18 ./test/unit/math/rev/core/operator_multiplication_test.cpp

Side Effects

Release notes

Checklist

  • Math issue #21

  • Copyright holder: Steve Bronder

    The copyright holder is typically you or your assignee, such as a university or company. By submitting this pull request, the copyright holder is agreeing to the license the submitted work under the following licenses:
    - Code: BSD 3-clause (https://opensource.org/licenses/BSD-3-Clause)
    - Documentation: CC-BY 4.0 (https://creativecommons.org/licenses/by/4.0/)

  • the basic tests are passing

    • unit tests pass (to run, use: ./runTests.py test/unit)
    • header checks pass, (make test-headers)
    • dependencies checks pass, (make test-math-dependencies)
    • docs build, (make doxygen)
    • code passes the built in C++ standards checks (make cpplint)
  • the code is written in idiomatic C++ and changes are documented in the doxygen

  • the new changes are tested

SteveBronder and others added 30 commits May 8, 2020 16:56
@bbbales2
Copy link
Member

bbbales2 commented Jun 6, 2020

@SteveBronder Yo my hope by merging feature/vari-base-templates into this and pointing the pull there would be that it would simplify the diff. There's still 78 things there, including: https://github.com/stan-dev/math/pull/1884/files#diff-320e0518dfbfe51abbc870b3ae852b13 which doesn't look like part of either pull.

Any advice on what I should be doing to getting these things more in sync?

@bbbales2
Copy link
Member

bbbales2 commented Jun 6, 2020

Hmm, and now there are all these merge conflicts! Not sure I'm doing my gits correctly.

@SteveBronder
Copy link
Collaborator Author

tbh this is kind of why I wanted to do this in pieces. There's enough very dramatic changes here that when a bunch of stuff at one level changes it causes a bunch of conflicts in the larger branch. I think we should focus efforts on #1915 so then we can start the Eigen var PR and then the adj_jac_apply PR.

@bbbales2
Copy link
Member

bbbales2 commented Jun 6, 2020

Well I just know I'm not gonna understand #1915 until I know how it filters up to matrices and stuff.

You think there might be a difference in rebase and merge here?

@bbbales2
Copy link
Member

bbbales2 commented Jun 6, 2020

Eh I'll just try it lol.

@SteveBronder
Copy link
Collaborator Author

Lemme look at this rq

@SteveBronder
Copy link
Collaborator Author

Huh seems fine, tbh just pulled it down and git merge feature/vari-base-templats

@bbbales2
Copy link
Member

bbbales2 commented Jun 6, 2020

Huh seems fine, tbh just pulled it down and git merge feature/vari-base-templats

I wonder if I was merging in my local version of your branch or something? But it seems like that wouldn't have given me conflicts twice. Oh well.

@SteveBronder
Copy link
Collaborator Author

¯\_(ツ)_/¯

@SteveBronder
Copy link
Collaborator Author

@bbbales2 moving this convo over from #1915 the convo here I think we could do something like the below where (T1 and T2 are var_values). But the problem is that op_vari needs to know what the adj_ type should be so we need to deduce that which requires a lot of weird looking duplicated code. imo it's more confusing than having the boilerplate in the function.

/**
 * Deduces the return type for matrix multiplication of two types
 */
template <typename T1, typename T2, typename = void>
struct mat_mul_return_type {};

// arithmetic is just double
template <typename T1, typename T2>
struct mat_mul_return_type<T1, T2, require_all_arithmetic_t<T1, T2>> {
  using type = double;
};

template <typename T1, typename T2>
struct mat_mul_return_type<T1, T2, require_any_eigen_t<T1, T2>> {
  using type = decltype((std::declval<T1>() * std::declval<T2>()).eval());
};

// helper alias
template <typename T1, typename T2>
using mat_mul_return_t = typename mat_mul_return_type<T1, T2>::type;

template <typename T1, typename T2>
class multiply_vari<T1, T2, require_all_var_t<T1, T2>>
    final : public op_vari<mat_mul_return_t<value_type_t<T1>, value_type_t<T2>>, vari_type_t<T2>*, vari_type_t<T2>*> {
  using op_vari<mat_mul_return_t<value_type_t<T1>, value_type_t<T2>>, vari_type_t<T2>*, vari_type_t<T2>*>::avi;
  using op_vari<mat_mul_return_t<value_type_t<T1>, value_type_t<T2>>, vari_type_t<T2>*, vari_type_t<T2>*>::bvi;
  using lhs_type = vari_type_t<T1>;
  using rhs_type = vari_type_t<T2>;
  using return_t = mat_mul_return_t<value_type_t<T1>, value_type_t<T2>>;

 public:
  multiply_vari(lhs_type* avi, rhs_type* bvi)
      : op_vari<return_t, lhs_type*, rhs_type*>(avi->val_ * bvi->val_, avi, bvi) {}

  template <typename TT1 = T1, typename TT2 = T2,
            require_all_var_vt<std::is_arithmetic, TT1, TT2>* = nullptr>
  inline void chain_impl() {
    avi()->adj_ += bvi()->val_ * this->adj_;
    bvi()->adj_ += avi()->val_ * this->adj_;
  }

  template <typename TT1 = T1, typename TT2 = T2,
            require_all_var_vt<is_eigen, TT1, TT2>* = nullptr>
  inline void chain_impl() {
    avi()->adj_ += this->adj_ * bvi()->val_.transpose();
    bvi()->adj_ += avi()->val_.transpose() * this->adj_;
  }

  void chain() {
    if (unlikely(is_any_nan(avi()->val_, bvi()->val_))) {
      fill(avi()->adj_, NOT_A_NUMBER);
      fill(bvi()->adj_, NOT_A_NUMBER);
    } else {
      chain_impl();
    }
  }
};

template <typename T1, typename T2, require_all_var_t<T1, T2>* = nullptr>
inline auto operator*(const T1& a, const T2& b) {
  using multiply_type = internal::multiply_vari<T1, T2>;
  // store the return type multiply_vari
  using mat_return = typename multiply_type::return_t;
  return var_value<mat_return>(new multiply_type(a.vi_, b.vi_)};
}

We could try to simplify this with something like a default template with a VarOpTraits class that stores the return type, value_type, and vari_type of everything but I'm really not sure if that's going to clean things up in a meaningful way

@SteveBronder SteveBronder mentioned this pull request Jun 7, 2020
5 tasks
@SteveBronder
Copy link
Collaborator Author

From @bbbales2 comment here wrt to

template <typename T1, typename T2, require_all_var_t<T1, T2>* = nullptr>
inline auto operator*(const T1& a, const T2& b) {
  using multiply_type = internal::multiply_vari<T1, T2>;
  // store the return type multiply_vari
  using mat_return = typename multiply_type::return_type;
  return var_value<mat_return>(new multiply_type(a.vi_, b.vi_)};
}
   var_value<mat_return>
   internal::multiply_vari<T1, T2>

Both of these are constructors so you gotta pass the template arguments.

We can deduce the constructor's parameters from T1 and T2

It'd be nice to not do that, cause the types should be deducible in the arguments. Presumably we have to define this somewhere, but doing it in every function would not be great.

Yeah we need the boilerplate somewhere, imo it's simpler to have in the function in than in the class (though is annoying). One thing we could do is put the type traits stuff into op_vari. Then multiply etc. could look like

  template <typename T1, typename T2>
  class multiply_vari<T1, T2, require_all_var_t<T1, T2>>
      final : public op_vari<mat_mul_return_t<T1, T2>, T1, T2> {
    using return_t = mat_mul_return_t<T1, T2>;
    using op_vari<return_t, T1, T2>::avi;
    using op_vari<return_t, T1, T2>::bvi;
    using lhs_vari = vari_type_t<T1>;
    using rhs_vari = vari_type_t<T2>;
   public:
    multiply_vari(lhs_vari* avi, rhs_vari* bvi)
        : op_vari<return_t, T1, T2>(avi->val_ * bvi->val_, avi, bvi) {}
// yada yada

Then in op_vari do all the deduction stuff

This is connected to the multiple definitions of operator* and the multiple multiply_vari types.

I at least want less usings here.

Yeah it would be good to have less, the Q is just where do they go

@SteveBronder SteveBronder reopened this Jun 7, 2020
@SteveBronder
Copy link
Collaborator Author

clicked close by accident

@bbbales2
Copy link
Member

bbbales2 commented Jun 7, 2020

Yeah it would be good to have less, the Q is just where do they go

Yes, and how much this will need repeated in other places.

@SteveBronder
Copy link
Collaborator Author

I got something almost working here.

https://github.com/stan-dev/math/tree/feature/eigen-vari-ops-vari-deduction

I'm doing something dumb with inheritance though and getting an error that it doesn't understand it can use the var_value(vari_value<T>* vi) constructor for multiply_vari

@SteveBronder
Copy link
Collaborator Author

^That works now. Some parts are a little icky but I'd clean that up when we are at this stage

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

10 participants