Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

hmm_marginal_lpdf for discrete latent states #1778

Merged
merged 84 commits into from
May 4, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
84 commits
Select commit Hold shift + click to select a range
829c641
initial draft code.
charlesm93 Jan 17, 2020
cc777f4
intermediate vari class.
charlesm93 Jan 17, 2020
c08c103
prototype functions.
charlesm93 Jan 22, 2020
f157f1d
debug to get code to compile.
charlesm93 Jan 22, 2020
eeaa18c
simulate data for unit tests.
charlesm93 Jan 24, 2020
519eb32
Test and compile dbl case.
charlesm93 Jan 24, 2020
6e5abcd
Get var version to compile.
charlesm93 Jan 24, 2020
f9f5e82
corrections, but still broken.
charlesm93 Jan 25, 2020
e895f9a
Prototype ad unit test.
charlesm93 Jan 26, 2020
5d9767b
Wrap up diff unit test.
charlesm93 Jan 28, 2020
a85bfa0
Remove superfluous comments and improve descriptions."
charlesm93 Jan 28, 2020
e3e7ae8
Revise unit tests to make finite diff tests not break simplex.
charlesm93 Jan 28, 2020
b9a12e7
Add and test exceptions.
charlesm93 Jan 28, 2020
66bf591
Add unit test for log density evaluation.
charlesm93 Feb 29, 2020
9e2722b
Update Gamma such that the rows are now simplexes.
charlesm93 Feb 29, 2020
b055197
Adjust error message for new Gamma configuration.
charlesm93 Feb 29, 2020
04a02f8
Remove superfluous code in unit test.
charlesm93 Feb 29, 2020
2a54f2b
Fix and test case where we have 0 transitions.
charlesm93 Mar 12, 2020
2ddcb6f
Refactor unit test.
charlesm93 Mar 13, 2020
06527a5
Unit test for the edge case where we only have one state.
charlesm93 Mar 13, 2020
14d490c
Add doxygen doc and expose file in prob.hpp
charlesm93 Mar 13, 2020
89b1339
correct errors detected by cpplint.
charlesm93 Mar 13, 2020
abcb7df
Merge commit 'fe3a41c3e854fe604841b237fa0475f26d29fe98' into HEAD
yashikno Mar 13, 2020
54be7f1
[Jenkins] auto-formatting by clang-format version 5.0.0-3~16.04.1 (ta…
stan-buildbot Mar 13, 2020
a6e0f02
Remove rev file from hmm_marginal_lpdf file.
charlesm93 Mar 13, 2020
6c90efb
Resolve conflict.
charlesm93 Mar 13, 2020
9c236aa
Fix tolerance for diff tests.
charlesm93 Mar 25, 2020
b02a0f5
Adress minor review comments.
charlesm93 Mar 25, 2020
bb46c07
Merge commit '6c432ff239a1a2ea15eaaebc89e1423c51c6a088' into HEAD
yashikno Mar 25, 2020
ec7412e
[Jenkins] auto-formatting by clang-format version 6.0.0 (tags/google/…
stan-buildbot Mar 25, 2020
61de1e4
Merge branch 'try-efficient_hmm_gradient' of https://github.com/stan-…
charlesm93 Mar 25, 2020
aad9377
[Jenkins] auto-formatting by clang-format version 5.0.0-3~16.04.1 (ta…
stan-buildbot Mar 25, 2020
feea473
Reorganize code for the n_transitions = 0 case.
charlesm93 Mar 25, 2020
01c2ab1
Merge branch 'try-efficient_hmm_gradient' of https://github.com/stan-…
charlesm93 Mar 25, 2020
c141eaa
[Jenkins] auto-formatting by clang-format version 5.0.0-3~16.04.1 (ta…
stan-buildbot Mar 25, 2020
6775bef
implement Steves feedback
charlesm93 Mar 25, 2020
2465f86
Merge branch 'try-efficient_hmm_gradient' of https://github.com/stan-…
charlesm93 Mar 25, 2020
a2775cc
[Jenkins] auto-formatting by clang-format version 6.0.0 (tags/google/…
stan-buildbot Mar 25, 2020
581c2ee
Add inline statement.
charlesm93 Mar 27, 2020
fe0b3bc
Merge branch 'try-efficient_hmm_gradient' of https://github.com/stan-…
charlesm93 Mar 27, 2020
7cdf8bb
Generalize check_simplex to handle row_vectors.
charlesm93 Mar 27, 2020
0d7d01b
Merge commit 'b6134fbf1a75d9bfa4716bafc8ced948b794f4b3' into HEAD
yashikno Mar 27, 2020
4d7d9b2
[Jenkins] auto-formatting by clang-format version 6.0.0 (tags/google/…
stan-buildbot Mar 27, 2020
d30f856
Merge branch 'try-efficient_hmm_gradient' of https://github.com/stan-…
charlesm93 Apr 7, 2020
aeb3e12
Attempt at overload for fvar.
charlesm93 Apr 7, 2020
5e04b48
Merge commit 'fee10a80ad1d0d8c72f0990dc9652e5c6296d4e9' into HEAD
yashikno Apr 7, 2020
d257eca
[Jenkins] auto-formatting by clang-format version 6.0.0
stan-buildbot Apr 7, 2020
cc76a31
overpromote so that hmm-marginal works with fwd mode
SteveBronder Apr 8, 2020
215638f
add auto in some places to help with deduction for hmm
SteveBronder Apr 8, 2020
d69fe1b
use partial_type instead of full return type
SteveBronder Apr 8, 2020
d490911
Merge branch 'try-efficient_hmm_gradient' of https://github.com/stan-…
charlesm93 Apr 8, 2020
8767d1b
revert back the initialization of kappa for hmm
SteveBronder Apr 8, 2020
64b46a7
Merge branch 'hmm-marginal-fwd' of https://github.com/stan-dev/math i…
charlesm93 Apr 10, 2020
d6b2b36
Uncomment unit tests.
charlesm93 Apr 10, 2020
543b68b
Merge branch 'hmm-marginal-fwd' into try-efficient_hmm_gradient
charlesm93 Apr 10, 2020
56c7c49
Remove unnesecessary includes from the fwd directory.
charlesm93 Apr 10, 2020
950458c
use log_marginal_density instead of value_of(log_marginal_density)
SteveBronder Apr 21, 2020
863f0ef
Merge commit '7fb33f751581333ea577ad5a5b6e62bc9d04c1de' into HEAD
yashikno Apr 21, 2020
45f1bf9
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2 (tag…
stan-buildbot Apr 21, 2020
fce60f3
Remove auto from C in hmm so it's only computed once
SteveBronder Apr 21, 2020
b1a0e60
Merge branch 'try-efficient_hmm_gradient' of github.com:stan-dev/math…
SteveBronder Apr 21, 2020
d94988e
Move constant check to if so compiler has easier time deducing during…
SteveBronder Apr 21, 2020
f03ba32
Move constant check to if so compiler has easier time deducing during…
SteveBronder Apr 21, 2020
82c7ee7
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Apr 21, 2020
1f0aeb7
revert generic templates for hmm
SteveBronder Apr 21, 2020
37da7af
merge to remote
SteveBronder Apr 21, 2020
be8b507
Merge branch 'try-efficient_hmm_gradient' of https://github.com/stan-…
charlesm93 Apr 21, 2020
ede1e0b
First pass at addressing Bobs review.
charlesm93 Apr 23, 2020
59cf78e
Merge commit '2133af2e116aa9618996e6a4664bdd8cb001b61a' into HEAD
yashikno Apr 23, 2020
1b7dfa4
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2 (tag…
stan-buildbot Apr 23, 2020
f457562
Fix documentation.
charlesm93 Apr 23, 2020
4b965dc
Merge branch 'try-efficient_hmm_gradient' of https://github.com/stan-…
charlesm93 Apr 23, 2020
79ff5e7
Revert edits to compiler flags.
charlesm93 Apr 24, 2020
4147b89
Update documentation.
charlesm93 Apr 24, 2020
be9e90b
add comment to test.
charlesm93 Apr 24, 2020
5a0786b
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Apr 24, 2020
bfd1d7a
Merge branch 'try-efficient_hmm_gradient' of https://github.com/stan-…
charlesm93 Apr 30, 2020
d12113b
update doc on throw
charlesm93 Apr 30, 2020
e8bc649
use check_multiplicable.
charlesm93 Apr 30, 2020
4532594
move norm_norm
charlesm93 Apr 30, 2020
0adeadf
Merge commit '569fa36fc529b5bcc9f3e3a8236e681d09e99364' into HEAD
yashikno Apr 30, 2020
3ee0a62
[Jenkins] auto-formatting by clang-format version 6.0.0
stan-buildbot Apr 30, 2020
cea2cae
Merge branch 'try-efficient_hmm_gradient' of https://github.com/stan-…
charlesm93 May 1, 2020
7861bd7
remove conditional for test.
charlesm93 May 1, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
228 changes: 109 additions & 119 deletions stan/math/prim/prob/hmm_marginal_lpdf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,106 +16,101 @@
namespace stan {
namespace math {

/**
* For a Hidden Markov Model with observation y, hidden state x,
* and parameters theta, return the log marginal density, log
* pi(y | theta). In this setting, the hidden states are discrete
* and take values over the finite space {1, ..., K}.
* The marginal lpdf is obtained via a forward pass.
* The [in, out] argument are saved so that we can use them when
* calculating the derivatives.
*
* @param[in] log_omega log matrix of observational densities.
* The (i, j)th entry corresponds to the
* density of the ith observation, y_i,
* given x_i = j.
* @param[in] Gamma transition density between hidden states.
* The (i, j)th entry is the probability that x_n = j,
* given x_{n - 1} = i. The rows of Gamma are simplexes.
* @param[in] rho initial state
* @param[in, out] alphas unnormalized partial marginal density.
* The jth column is the joint density over all
* observations y and the hidden state j.
* @param[in, out] alpha_log_norms max coefficient for column of alpha,
* to be used to normalize alphas.
* @param[in, out] omegas term-wise exponential of omegas.
* @return log marginal density.
*/
double hmm_marginal_lpdf(const Eigen::MatrixXd& log_omegas,
const Eigen::MatrixXd& Gamma,
const Eigen::VectorXd& rho,
Eigen::MatrixXd& alphas,
Eigen::VectorXd& alpha_log_norms,
Eigen::MatrixXd& omegas) {
omegas = log_omegas.array().exp(); // CHECK -- why the .array()?
int n_states = log_omegas.rows();
int n_transitions = log_omegas.cols() - 1;

alphas.col(0) = omegas.col(0).cwiseProduct(rho);

double norm = alphas.col(0).maxCoeff();
alphas.col(0) /= norm;
alpha_log_norms(0) = std::log(norm);

for (int n = 0; n < n_transitions; ++n) {
alphas.col(n + 1)
= omegas.col(n + 1).cwiseProduct(Gamma.transpose()
* alphas.col(n));

double norm = alphas.col(n + 1).maxCoeff();
alphas.col(n + 1) /= norm;
alpha_log_norms(n + 1) = std::log(norm) + alpha_log_norms(n);
}
/**
* For a Hidden Markov Model with observation y, hidden state x,
* and parameters theta, return the log marginal density, log
* pi(y | theta). In this setting, the hidden states are discrete
* and take values over the finite space {1, ..., K}.
* The marginal lpdf is obtained via a forward pass.
* The [in, out] argument are saved so that we can use them when
* calculating the derivatives.
*
* @param[in] log_omega log matrix of observational densities.
* The (i, j)th entry corresponds to the
* density of the ith observation, y_i,
* given x_i = j.
* @param[in] Gamma transition density between hidden states.
* The (i, j)th entry is the probability that x_n = j,
* given x_{n - 1} = i. The rows of Gamma are simplexes.
* @param[in] rho initial state
* @param[in, out] alphas unnormalized partial marginal density.
* The jth column is the joint density over all
* observations y and the hidden state j.
* @param[in, out] alpha_log_norms max coefficient for column of alpha,
* to be used to normalize alphas.
* @param[in, out] omegas term-wise exponential of omegas.
* @return log marginal density.
charlesm93 marked this conversation as resolved.
Show resolved Hide resolved
*/
double hmm_marginal_lpdf(const Eigen::MatrixXd& log_omegas,
const Eigen::MatrixXd& Gamma,
const Eigen::VectorXd& rho, Eigen::MatrixXd& alphas,
Eigen::VectorXd& alpha_log_norms,
Eigen::MatrixXd& omegas) {
omegas = log_omegas.array().exp(); // CHECK -- why the .array()?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Arrays in Eigen perform operations element by element so the matrix has to be "cast" to an array before doing elementwise exp()

int n_states = log_omegas.rows();
int n_transitions = log_omegas.cols() - 1;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
int n_states = log_omegas.rows();
int n_transitions = log_omegas.cols() - 1;
const int n_states = log_omegas.rows();
const int n_transitions = log_omegas.cols() - 1;

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally make things here const that are not changed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm... overkilled, no?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No it's good to let the compiler know things are const

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It also let's readers know "The value of this object will not change through the rest of it's life"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm.... I'm going to need a high charisma roll to be convinced.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My dude you can chase your bliss and if that doesn't involve const then sure that's fine. But if you don't mutate something later why not let the reader know it won't change by declaring it const?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mainly because it clutters the code. The compiler can't act just on const because it's possible to cast away the const modifier. I try to follow the lead of mature packages like Boost, which don't declare all their local variables (or arithmetic function args) const.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, that's settled then.


alphas.col(0) = omegas.col(0).cwiseProduct(rho);

double norm = alphas.col(0).maxCoeff();
alphas.col(0) /= norm;
alpha_log_norms(0) = std::log(norm);

for (int n = 0; n < n_transitions; ++n) {
alphas.col(n + 1)
= omegas.col(n + 1).cwiseProduct(Gamma.transpose() * alphas.col(n));

return log(alphas.col(n_transitions).sum())
+ alpha_log_norms(n_transitions);
double norm = alphas.col(n + 1).maxCoeff();
alphas.col(n + 1) /= norm;
alpha_log_norms(n + 1) = std::log(norm) + alpha_log_norms(n);
}

return log(alphas.col(n_transitions).sum()) + alpha_log_norms(n_transitions);
}

/**
* For a Hidden Markov Model with observation y, hidden state x,
* and parameters theta, return the log marginal density, log
* pi(y | theta). In this setting, the hidden states are discrete
* and take values over the finite space {1, ..., K}.
* The marginal lpdf is obtained via a forward pass, and
* the derivative is calculated with an adjoint method,
* see (Betancourt, Margossian, & Leos-Barajas, 2020).
*
* @tparam T_omega type of the log likelihood matrix
* @tparam T_Gamma type of the transition matrix
* @tparam T_rho type of the initial guess vector
*
* @param[in] log_omega log matrix of observational densities.
* The (i, j)th entry corresponds to the
* density of the ith observation, y_i,
* given x_i = j.
* @param[in] Gamma transition density between hidden states.
* The (i, j)th entry is the probability that x_n = j,
* given x_{n - 1} = i. The rows of Gamma are simplexes.
* @param[in] rho initial state
* @throw <code>std::domain_error</code> if the rows of Gamma are
* not a simplex.
* @throw <code>std::invalid_argument</code> if the size of rho is not
* the number of states.
* @throw <code>std::domain_error</code> if rho is not a simplex.
* @return log marginal density.
*/
/**
* For a Hidden Markov Model with observation y, hidden state x,
* and parameters theta, return the log marginal density, log
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should not have "marginal" in its name because we think of the latent states as parameters and thus they are naturally marginalized out. In the traditional literature, they distinguished between the "data density" (for y) and the "complete data distribution" (for y and the latent states z), because strict frequentists have to pretend the mixture responsibility parameters are data so that they don't have a philosophical meltdown. You see that usage all over the EM literature, where the latent states z get marginalized out.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll ask for a second opinion on this (@vianeylb, @betanalpha). As someone familiar with a Bayesian framework and a candid approach to the problem, this is clearly a marginal distribution. Naturally, I'm willing to accomodate the user.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the Cappé et al book, the distribution of the observations is referred to as the "marginal distribution of the observations only", obtained through the process of marginalization of the state variables.

In the Zucchini et al book, they have a section on marginal distributions of a single observation, Y_t, and discuss higher order marginal distributions (e.g. [Y_t, Y_t+k]). Extending this, the distribution of the data only also falls under what they refer to as marginal distributions.

The likelihood in both contexts comes from what they refer to as the marginal distribution of the data only, and they emphasize that this quantity is obtained through marginalization over the state variables. Given that inference for HMMs via the likelihood (marginal distribution of the data) vs the complete-data likelihood seems to be a topic that people are highly opinionated about, emphasizing that inference for the parameters of an HMM in Stan is done through marginalization over the state variables and use of the marginal distribution of the data directly seems preferable to me. And as the Zucchini book is clearly more on the frequentist side, I don't think the "marginal_lpdf" will be divisive to use. Just my opinion.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with Vianey!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We all agree that what's at play is a joint Bayesian model p(y, z, theta) where y is observed data, z is unobserved, and theta are parameters.

The terminology I'm familiar with calls p(y | theta) the "likelihood" and p(y, z | theta) the "full data likelihood", because frequentist philosophy melts down if z is considered a parameter.

I looked up Zucchini et al. and in section 2.3, they use likelihood in exactly this way for p(y | theta) with the states z marginalized out of p(y, z | theta). They then talk about marginalizing p(y | theta) to subsequences of y. In general, if we have a joint distribution p(y), we don't talk about p(y) as a marginal distribution even though it is the degenerate case with nothing marginalized out. That's why I find calling p(y | theta) a marginal distribution confusing in this context.

I'd also strongly prefer not to have verbose names for things if we're not going to be distinguishing among many alternatives.

* pi(y | theta). In this setting, the hidden states are discrete
charlesm93 marked this conversation as resolved.
Show resolved Hide resolved
* and take values over the finite space {1, ..., K}.
* The marginal lpdf is obtained via a forward pass, and
* the derivative is calculated with an adjoint method,
* see (Betancourt, Margossian, & Leos-Barajas, 2020).
*
* @tparam T_omega type of the log likelihood matrix
* @tparam T_Gamma type of the transition matrix
* @tparam T_rho type of the initial guess vector
*
charlesm93 marked this conversation as resolved.
Show resolved Hide resolved
* @param[in] log_omega log matrix of observational densities.
* The (i, j)th entry corresponds to the
* density of the ith observation, y_i,
* given x_i = j.
* @param[in] Gamma transition density between hidden states.
* The (i, j)th entry is the probability that x_n = j,
* given x_{n - 1} = i. The rows of Gamma are simplexes.
* @param[in] rho initial state
* @throw <code>std::domain_error</code> if the rows of Gamma are
* not a simplex.
* @throw <code>std::invalid_argument</code> if the size of rho is not
* the number of states.
charlesm93 marked this conversation as resolved.
Show resolved Hide resolved
* @throw <code>std::domain_error</code> if rho is not a simplex.
* @return log marginal density.
charlesm93 marked this conversation as resolved.
Show resolved Hide resolved
*/
template <typename T_omega, typename T_Gamma, typename T_rho>
inline return_type_t<T_omega, T_Gamma, T_rho> hmm_marginal_lpdf(
const Eigen::Matrix<T_omega, Eigen::Dynamic, Eigen::Dynamic>& log_omegas,
const Eigen::Matrix<T_Gamma, Eigen::Dynamic, Eigen::Dynamic>& Gamma,
const Eigen::Matrix<T_rho, Eigen::Dynamic, 1>& rho) {
const Eigen::Matrix<T_omega, Eigen::Dynamic, Eigen::Dynamic>& log_omegas,
const Eigen::Matrix<T_Gamma, Eigen::Dynamic, Eigen::Dynamic>& Gamma,
const Eigen::Matrix<T_rho, Eigen::Dynamic, 1>& rho) {
int n_states = log_omegas.rows();
int n_transitions = log_omegas.cols() - 1;

check_square("hmm_marginal_lpdf", "Gamma", Gamma);
check_consistent_size("hmm_marginal_lpdf", "Gamma", row(Gamma, 1),
n_states);
check_consistent_size("hmm_marginal_lpdf", "Gamma", row(Gamma, 1), n_states);
SteveBronder marked this conversation as resolved.
Show resolved Hide resolved
charlesm93 marked this conversation as resolved.
Show resolved Hide resolved
{
charlesm93 marked this conversation as resolved.
Show resolved Hide resolved
// Temporary vector to use check_simplex, which only works once
// column vectors.
Eigen::Matrix<T_Gamma, Eigen::Dynamic, Eigen::Dynamic>
Gamma_transpose = Gamma.transpose();
Eigen::Matrix<T_Gamma, Eigen::Dynamic, Eigen::Dynamic> Gamma_transpose
= Gamma.transpose();
for (int i = 0; i < Gamma.cols(); i++) {
charlesm93 marked this conversation as resolved.
Show resolved Hide resolved
// CHECK -- does check_simplex not work on row-vectors?
check_simplex("hmm_marginal_lpdf", "Gamma[i, ]",
Expand All @@ -126,22 +121,19 @@ inline return_type_t<T_omega, T_Gamma, T_rho> hmm_marginal_lpdf(
check_simplex("hmm_marginal_lpdf", "rho", rho);

using T_partials_return = partials_return_t<T_omega, T_Gamma, T_rho>;
operands_and_partials<
Eigen::Matrix<T_omega, Eigen::Dynamic, Eigen::Dynamic>,
Eigen::Matrix<T_Gamma, Eigen::Dynamic, Eigen::Dynamic>,
Eigen::Matrix<T_rho, Eigen::Dynamic, 1>
> ops_partials(log_omegas, Gamma, rho);
operands_and_partials<Eigen::Matrix<T_omega, Eigen::Dynamic, Eigen::Dynamic>,
Eigen::Matrix<T_Gamma, Eigen::Dynamic, Eigen::Dynamic>,
Eigen::Matrix<T_rho, Eigen::Dynamic, 1> >
ops_partials(log_omegas, Gamma, rho);

Eigen::MatrixXd alphas(n_states, n_transitions + 1);
Eigen::VectorXd alpha_log_norms(n_transitions + 1);
Eigen::MatrixXd omegas;
Eigen::MatrixXd Gamma_dbl = value_of_rec(Gamma);

T_partials_return log_marginal_density
= hmm_marginal_lpdf(value_of_rec(log_omegas),
Gamma_dbl,
value_of_rec(rho),
alphas, alpha_log_norms, omegas);
= hmm_marginal_lpdf(value_of_rec(log_omegas), Gamma_dbl,
value_of_rec(rho), alphas, alpha_log_norms, omegas);

// Variables required for all three Jacobian-adjoint products.
double norm_norm = alpha_log_norms(n_transitions);
Expand All @@ -155,18 +147,17 @@ inline return_type_t<T_omega, T_Gamma, T_rho> hmm_marginal_lpdf(
kappa[n_transitions - 1] = Eigen::VectorXd::Ones(n_states);
kappa_log_norms(n_transitions - 1) = 0;
grad_corr[n_transitions - 1]
= std::exp(alpha_log_norms(n_transitions - 1) - norm_norm);
= std::exp(alpha_log_norms(n_transitions - 1) - norm_norm);
}

for (int n = n_transitions - 2; n >= 0; --n) {
charlesm93 marked this conversation as resolved.
Show resolved Hide resolved
kappa[n] = Gamma_dbl
* (omegas.col(n + 2).cwiseProduct(kappa[n + 1]));
kappa[n] = Gamma_dbl * (omegas.col(n + 2).cwiseProduct(kappa[n + 1]));

double norm = kappa[n].maxCoeff();
kappa[n] /= norm;
kappa_log_norms(n) = std::log(norm) + kappa_log_norms(n + 1);
grad_corr[n] = std::exp(alpha_log_norms(n) + kappa_log_norms(n)
- norm_norm);
grad_corr[n]
= std::exp(alpha_log_norms(n) + kappa_log_norms(n) - norm_norm);
}

if (!is_constant_all<T_Gamma>::value) {
Expand All @@ -175,9 +166,9 @@ inline return_type_t<T_omega, T_Gamma, T_rho> hmm_marginal_lpdf(

if (n_transitions != 0) {
charlesm93 marked this conversation as resolved.
Show resolved Hide resolved
for (int n = n_transitions - 1; n >= 0; --n) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idiom here is the following:

for (int n = n_transitions; n_transitions-- > 0; )

Mainly because it'll work if the type is an unsized type, but also becaue it reduces the number of operations.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't follow. The given code doesn't work (and nor do small tweaks around it). Can you be more specific about the wanted change?

Gamma_jacad += (grad_corr[n]
* kappa[n].cwiseProduct(omegas.col(n + 1))
* alphas.col(n).transpose()).transpose();
Gamma_jacad += (grad_corr[n] * kappa[n].cwiseProduct(omegas.col(n + 1))
* alphas.col(n).transpose())
.transpose();
}
}

Expand All @@ -186,8 +177,7 @@ inline return_type_t<T_omega, T_Gamma, T_rho> hmm_marginal_lpdf(
}

bool sensitivities_for_omega_or_rho
= (!is_constant_all<T_omega>::value)
|| (!is_constant_all<T_rho>::value);
= (!is_constant_all<T_omega>::value) || (!is_constant_all<T_rho>::value);

// boundary terms
if (sensitivities_for_omega_or_rho) {
Expand All @@ -196,8 +186,9 @@ inline return_type_t<T_omega, T_Gamma, T_rho> hmm_marginal_lpdf(

if (!is_constant_all<T_omega>::value) {
for (int n = n_transitions - 1; n >= 0; --n)
log_omega_jacad.col(n + 1) = grad_corr[n]
* kappa[n].cwiseProduct(Gamma_dbl.transpose() * alphas.col(n));
log_omega_jacad.col(n + 1)
= grad_corr[n]
* kappa[n].cwiseProduct(Gamma_dbl.transpose() * alphas.col(n));
}

// Boundary terms
Expand All @@ -212,26 +203,25 @@ inline return_type_t<T_omega, T_Gamma, T_rho> hmm_marginal_lpdf(

if (!is_constant_all<T_omega>::value) {
if (n_transitions != 0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[optional]

I'd prefer something at the very top handling the n_transitions == 0 case so it's not in the way of understanding all the code. (Overall the code's really understandable.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. I'll work on it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, also Charles note that the if statements that use type traits like is_constant_all are resolved at compile time (i.e. it just becomes if (true/false)) so the alt code path will just be deleted since it's dead so you don't need to worry about those ifs in terms of performance

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I found a fairly good solution. The n_transitions == 0 case only matters when handling the boundary terms, so it's now handled at the very top, when we start treating the boundary terms.

log_omega_jacad.col(0) = grad_corr_boundary
* c.cwiseProduct(value_of_rec(rho));
log_omega_jacad.col(0)
= grad_corr_boundary * c.cwiseProduct(value_of_rec(rho));
log_omega_jacad
= log_omega_jacad.cwiseProduct(omegas / unnormed_marginal);
= log_omega_jacad.cwiseProduct(omegas / unnormed_marginal);
charlesm93 marked this conversation as resolved.
Show resolved Hide resolved
} else {
log_omega_jacad.col(0)
= omegas.col(0).cwiseProduct(value_of_rec(rho)) /
exp(value_of_rec(log_marginal_density));
log_omega_jacad.col(0) = omegas.col(0).cwiseProduct(value_of_rec(rho))
/ exp(value_of_rec(log_marginal_density));
}
ops_partials.edge1_.partials_ = log_omega_jacad;
}

if (!is_constant_all<T_rho>::value) {
if (n_transitions != 0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If n_transitions == 0 isn't just split out at the top, push branching as far down as possible so that the parallel assignment structure is clear.

ops_partials.edge3_.partials_
  = (n_transitions == 0)
  ? omegas.col(0) / exp(value_of_rec(log_marginal_density))
  : grad_corr_boundary * c.cwiseProduct(omegas.col(0)) / unnormed_marginal;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Handling the n_transitions == 0 at the top now.

ops_partials.edge3_.partials_
= grad_corr_boundary * c.cwiseProduct(omegas.col(0))
/ unnormed_marginal;
ops_partials.edge3_.partials_ = grad_corr_boundary
* c.cwiseProduct(omegas.col(0))
/ unnormed_marginal;
} else {
ops_partials.edge3_.partials_ = omegas.col(0)
/ exp(value_of_rec(log_marginal_density));
ops_partials.edge3_.partials_
= omegas.col(0) / exp(value_of_rec(log_marginal_density));
}
}
}
Expand Down
Loading