Skip to content

Commit

Permalink
Fixed select_rows and unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
neubig committed Feb 27, 2017
1 parent fcc5d11 commit e915aab
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 6 deletions.
20 changes: 14 additions & 6 deletions dynet/nodes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1667,8 +1667,13 @@ template<class MyDevice>
void SelectCols::forward_dev_impl(const MyDevice & dev, const vector<const Tensor*>& xs, Tensor& fx) const {
assert(xs.size() == 1);
auto& rm = *pcols;
for (unsigned i = 0; i < rm.size(); ++i)
for (unsigned i = 0; i < rm.size(); ++i) {
if(rm[i] >= xs[0]->d.cols()) {
ostringstream oss; oss << "Out-of-bounds index " << rm[i] << " in SelectCols over expression of dimensions " << xs[0]->d;
throw std::invalid_argument(oss.str());
}
fx.t<2>().chip<1>(i).device(*dev.edevice) = xs[0]->t<2>().chip<1>(rm[i]);
}
}

template<class MyDevice>
Expand All @@ -1689,9 +1694,13 @@ template<class MyDevice>
void SelectRows::forward_dev_impl(const MyDevice & dev, const vector<const Tensor*>& xs, Tensor& fx) const {
assert(xs.size() == 1);
auto& rm = *prows;
for (unsigned i = 0; i < rm.size(); ++i)
fx.t<2>().chip<0>(i) = xs[0]->t<2>().chip<0>(rm[i]);
// fx.t<2>().device(*dev.edevice).chip<0>(i) = xs[0]->t<2>().chip<0>(rm[i]);
for (unsigned i = 0; i < rm.size(); ++i) {
if(rm[i] >= xs[0]->d.rows()) {
ostringstream oss; oss << "Out-of-bounds index " << rm[i] << " in SelectRows over expression of dimensions " << xs[0]->d;
throw std::invalid_argument(oss.str());
}
fx.t<2>().chip<0>(i).device(*dev.edevice) = xs[0]->t<2>().chip<0>(rm[i]);
}
}

template<class MyDevice>
Expand All @@ -1704,8 +1713,7 @@ void SelectRows::backward_dev_impl(const MyDevice & dev,
assert(xs.size() == 1);
auto& rm = *prows;
for (unsigned i = 0; i < rm.size(); ++i)
dEdxi.t<2>().chip<0>(rm[i]) = dEdf.t<2>().chip<0>(i);
// dEdxi.t<2>().device(*dev.edevice).chip<0>(rm[i]) = dEdf.t<2>().chip<0>(i);
dEdxi.t<2>().chip<0>(rm[i]).device(*dev.edevice) = dEdf.t<2>().chip<0>(i);
}
DYNET_NODE_INST_DEV_IMPL(SelectRows)

Expand Down
18 changes: 18 additions & 0 deletions tests/test-nodes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -935,6 +935,15 @@ BOOST_AUTO_TEST_CASE( select_rows_gradient ) {
BOOST_CHECK(check_grad(mod, z, 0));
}

// Expression select_rows(const Expression& x, vector<unsigned>& rows);
BOOST_AUTO_TEST_CASE( select_rows_oob ) {
dynet::ComputationGraph cg;
vector<unsigned> rows = {3};
Expression x1 = parameter(cg, param_square1);
Expression y = select_rows(x1, rows);
BOOST_CHECK_THROW(y.value(), std::invalid_argument);
}

// Expression select_cols(const Expression& x, vector<unsigned>& rows);
BOOST_AUTO_TEST_CASE( select_cols_gradient ) {
dynet::ComputationGraph cg;
Expand All @@ -945,6 +954,15 @@ BOOST_AUTO_TEST_CASE( select_cols_gradient ) {
BOOST_CHECK(check_grad(mod, z, 0));
}

// Expression select_cols(const Expression& x, vector<unsigned>& rows);
BOOST_AUTO_TEST_CASE( select_cols_oob ) {
dynet::ComputationGraph cg;
vector<unsigned> cols = {3};
Expression x1 = parameter(cg, param_square1);
Expression y = select_cols(x1, cols);
BOOST_CHECK_THROW(y.value(), std::invalid_argument);
}

// Expression pickneglogsoftmax(const Expression& x, unsigned v);
BOOST_AUTO_TEST_CASE( pickneglogsoftmax_gradient ) {
unsigned idx = 1;
Expand Down

0 comments on commit e915aab

Please sign in to comment.