Skip to content

Commit

Permalink
[src] Fix to nnet3 compilation issue affecting BLSTMP. Thanks: Alim M…
Browse files Browse the repository at this point in the history
…isbullah (kaldi-asr#2286)

WARNING: may invalidate previously-trained BLSTMP's.
  • Loading branch information
danpovey committed Mar 19, 2018
1 parent 9654a7c commit c6b3588
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 64 deletions.
23 changes: 19 additions & 4 deletions src/nnet3/nnet-compile.cc
Original file line number Diff line number Diff line change
Expand Up @@ -322,10 +322,23 @@ void Compiler::CreateStepInfo(
stride_type);
} else {
// kDimRange. Will just be a sub-matrix of a Component or Input node.
int32 cindex_id = this_info.output_cindex_ids.front(),
input_cindex_id = graph_.dependencies[cindex_id][0],
input_step = cindex_id_to_location_[input_cindex_id].first;
KALDI_ASSERT(input_step != -1 && input_step < step);
std::vector<int32>::const_iterator
iter = this_info.output_cindex_ids.begin(),
end = this_info.output_cindex_ids.end();
int32 source_cindex_id = -1;
for (; iter != end; ++iter) {
int32 cindex_id = *iter;
if (!graph_.dependencies[cindex_id].empty()) {
KALDI_ASSERT(graph_.dependencies[cindex_id].size() == 1);
source_cindex_id = graph_.dependencies[cindex_id][0];
break;
}
}
KALDI_ASSERT(source_cindex_id >= 0);
int32 input_step = cindex_id_to_location_[source_cindex_id].first;
KALDI_ASSERT(this_info.output_cindex_ids.size() ==
steps_[input_step].output_cindex_ids.size());
KALDI_ASSERT(input_step >= 0 && input_step < step);
KALDI_PARANOID_ASSERT(this_info.output_indexes ==
steps_[input_step].output_indexes);
this_info.value = computation->NewSubMatrix(steps_[input_step].value,
Expand Down Expand Up @@ -376,6 +389,8 @@ void Compiler::CreateStepInfo(
KALDI_ASSERT(cur_dim_offset == desc.Dim(nnet_));
}
}
KALDI_ASSERT(static_cast<int32>(this_info.output_cindex_ids.size()) ==
computation->submatrices[this_info.value].num_rows);
}
}

Expand Down
4 changes: 3 additions & 1 deletion src/nnet3/nnet-compile.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,9 @@ class Compiler {
std::vector<bool> *deriv_needed);

// this sets up steps_, destroying the input "by_step" in the process. It
// also sets various matrix and sub-matrix sizes in "computation".
// also sets various matrix and sub-matrix sizes in "computation". The input
// 'by_step' is elsewhere referred to as just 'step'; it is a vector of steps,
// and each step is a vector of cindex_ids that are computed by that step.
void CreateStepInfo(const std::vector<bool> &deriv_needed,
const std::vector<int32> &step_to_segment,
std::vector<std::vector<int32> > *by_step,
Expand Down
82 changes: 33 additions & 49 deletions src/nnet3/nnet-computation-graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1643,7 +1643,7 @@ int32 ComputationStepsComputer::AddStep(const std::vector<Cindex> &cindexes,
*out_iter = cindex_id;
if (added) {
KALDI_ASSERT(cindex_id == static_cast<int32>(locations_->size()));
locations_->resize(cindex_id + 1);
locations_->resize(cindex_id + 1, std::pair<int32, int32>(-1, -1));
locations_->back().first = step_index;
locations_->back().second = row_index;
locations = &((*locations_)[0]); // in case it was reallocated
Expand Down Expand Up @@ -1867,68 +1867,52 @@ void ComputationStepsComputer::ProcessDimRangeSubPhase(
ConvertToCindexIds(input_cindexes, &input_cindex_ids);
std::vector<std::pair<int32, int32> > locations;
ConvertToLocations(input_cindex_ids, &locations);
std::sort(locations.begin(), locations.end());

// get a list of the source step indexes (corresponding to computations for the
// source component-node)
std::unordered_set<int32> source_step_indexes;
KALDI_ASSERT(!locations.empty());
std::vector<std::pair<int32, int32> >::const_iterator
locations_iter = locations.begin(),
locations_end = locations.end();
// Each unique .first number in locations (i.e. each source step, and they
// will all correspond to component-output or input steps) will generate one
// 'step' of type kDimRange. Because dim-range nodes must be contiguous
// ranges of a source step (since they are represented as sub-matrices), for
// each source step we work out the first and last row-index (i.e. first and
// last .second member of locations) and use that to reconstruct the range.

// each element of 'steps' will be (source_step, (begin_row, end_row)) so that
// the source of the dim-range node is indexes begin_row ... end_row-1 in that
// source step.
std::vector<std::pair<int32, std::pair<int32, int32> > > steps;

int32 cur_source_step = locations_iter->first,
cur_row_begin = locations_iter->second,
cur_row_end = cur_row_begin + 1;
while (1) {
++locations_iter;
if (locations_iter == locations_end ||
locations_iter->first != cur_source_step) {
// we reached the end of a run of the same step.
std::pair<int32, std::pair<int32, int32> > this_step;
this_step.first = cur_source_step;
this_step.second.first = cur_row_begin;
this_step.second.second = cur_row_end;
steps.push_back(this_step);
if (locations_iter != locations_end) {
cur_source_step = locations_iter->first;
cur_row_begin = locations_iter->second;
cur_row_end = cur_row_begin + 1;
} else {
break;
}
} else {
cur_row_end = locations_iter->second + 1;

// 'cur_source_step_index' is just an optimization to prevent unnecessary
// unordered_set inserts.
int32 cur_source_step_index = -1;
for (; locations_iter != locations_end; ++locations_iter) {
int32 source_step_index = locations_iter->first;
if (source_step_index != cur_source_step_index) {
cur_source_step_index = source_step_index;
source_step_indexes.insert(cur_source_step_index);
}
}

for (size_t i = 0; i < steps.size(); i++) {
// iterating over different source steps, although normally
// there will be just one.
int32 source_step = steps[i].first,
row_begin = steps[i].second.first,
row_end = steps[i].second.second;
// 'source' is just the elements of the source step that we're consuming.
std::vector<int32> source((*steps_)[source_step].begin() + row_begin,
(*steps_)[source_step].begin() + row_end);
std::unordered_set<int32>::const_iterator
source_step_iter = source_step_indexes.begin(),
source_step_end = source_step_indexes.end();
// iterating over the indexes of the source steps.
for (; source_step_iter != source_step_end; ++source_step_iter) {
int32 source_step_index = *source_step_iter;
std::pair<int32, int32> p(source_step_index, dim_range_node);
if (dim_range_nodes_.count(p) > 0) {
// We don't need to do anything; a dim-range node already exists for this
// step and this node index.
continue;
}
dim_range_nodes_.insert(p);
const std::vector<int32> &source_step = (*steps_)[source_step_index];
// 'cindexes' will be the cindexes of the new step that we're going to add.
std::vector<Cindex> cindexes;
ConvertToCindexes(source, &cindexes);
ConvertToCindexes(source_step, &cindexes);
std::vector<Cindex>::iterator iter = cindexes.begin(),
end = cindexes.end();
for (; iter != end; ++iter)
iter->first = dim_range_node;
bool add_if_absent = true;
// this add_if_absent says, even if cindexes were not in the graph,
// add them. This is possible in principle; it's to satisfy the
// requirement that DimRangeNodes be implemented as contiguous ranges
// of rows of component nodes or input nodes.
// add them. This is possible; the step will contain all cindexes for the
// input step, even if they won't be needed. (This is costless; it's just
// setting up a sub-matrix).
AddStep(cindexes, add_if_absent);
}
}
Expand Down
11 changes: 10 additions & 1 deletion src/nnet3/nnet-computation-graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ class ComputationStepsComputer {
/// (step-index, index-into-step), so that for any cindex_id c,
/// (*steps)[locations[c].first][locations[c].second] == c.
/// It's possible in principle if there are non-simple
/// Components, that for node corresponding to component-input
/// Components, that for nodes corresponding to component-input
/// descriptors, a cindex might be present in more than one step,
/// so it doesn't follow that if (*steps)[i][j] == c, then
/// locations[c] == (i,j).
Expand Down Expand Up @@ -547,6 +547,15 @@ class ComputationStepsComputer {
/// (*steps_)[i][j] == c. This is also an output (we get the pointer in
/// the constructor).
std::vector<std::pair<int32, int32> > *locations_;


/// dim_range_nodes_ is used when allocating steps for nodes of type kDimRangeNode.
/// This is a set of (source_step, dim_range_node_index),
/// where source_step is the step in which we computed of the input
/// of the dim-range node (this step will be for a node of type kComponentNode).
/// This just tells us whether we've already added a particular dim-range node
/// for this step, so we know whether we need to add it again.
std::unordered_set<std::pair<int32, int32>, PairHasher<int32> > dim_range_nodes_;
};


Expand Down
36 changes: 27 additions & 9 deletions src/nnet3/nnet-test-utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -679,26 +679,36 @@ void GenerateConfigSequenceLstmWithTruncation(
}
std::string spliced_input = temp_string_stream.str();

std::string c_tminus1 = "IfDefined(Offset(c_t, -1))";
int32 offset = RandInt(-3, 3);
if (offset == 0)
offset = -1;


std::string c_tminus1;
{
std::ostringstream os_temp;
os_temp << "IfDefined(Offset(c_t, " << offset << "))";
c_tminus1 = os_temp.str();
}
os << "component-node name=c_t component=c input=Sum(c1_t, c2_t)\n";

// i_t
os << "component-node name=i1 component=Wi-xr input=Append("
<< spliced_input << ", IfDefined(Offset(r_t, -1)))\n";
<< spliced_input << ", IfDefined(Offset(r_t, " << offset << ")))\n";
os << "component-node name=i2 component=Wic "
<< " input=" << c_tminus1 << std::endl;
os << "component-node name=i_t component=i input=Sum(i1, i2)\n";

// f_t
os << "component-node name=f1 component=Wf-xr input=Append("
<< spliced_input << ", IfDefined(Offset(r_t, -1)))\n";
<< spliced_input << ", IfDefined(Offset(r_t, " << offset << ")))\n";
os << "component-node name=f2 component=Wfc "
<< " input=" << c_tminus1 << std::endl;
os << "component-node name=f_t component=f input=Sum(f1, f2)\n";

// o_t
os << "component-node name=o1 component=Wo-xr input=Append("
<< spliced_input << ", IfDefined(Offset(r_t, -1)))\n";
<< spliced_input << ", IfDefined(Offset(r_t, " << offset << ")))\n";
os << "component-node name=o2 component=Woc input=Sum(c1_t, c2_t)\n";
os << "component-node name=o_t component=o input=Sum(o1, o2)\n";

Expand All @@ -707,7 +717,7 @@ void GenerateConfigSequenceLstmWithTruncation(

// g_t
os << "component-node name=g1 component=Wc-xr input=Append("
<< spliced_input << ", IfDefined(Offset(r_t, -1)))\n";
<< spliced_input << ", IfDefined(Offset(r_t, " << offset << ")))\n";
os << "component-node name=g_t component=g input=g1\n";

// parts of c_t
Expand Down Expand Up @@ -758,6 +768,10 @@ void GenerateConfigSequenceLstmType2(
cell_dim = 40 + Rand() % 50,
projection_dim = std::ceil(cell_dim / (Rand() % 10 + 2));

int32 offset = RandInt(-3, 3);
if (offset == 0)
offset = -1;

os << "input-node name=input dim=" << input_dim << std::endl;
// Parameter Definitions W*(* replaced by - to have valid names)
os << "component name=W-x type=NaturalGradientAffineComponent input-dim="
Expand Down Expand Up @@ -819,10 +833,13 @@ void GenerateConfigSequenceLstmType2(
}
os << ")\n";

os << "component-node name=W-r component=W-r input=IfDefined(Offset(r_t, -1))\n";
os << "component-node name=W-r component=W-r input=IfDefined(Offset(r_t"
<< offset << "))\n";
os << "component-node name=W-m component=W-m input=m_t \n";
os << "component-node name=Wic component=Wic input=IfDefined(Offset(c_t, -1))\n";
os << "component-node name=Wfc component=Wfc input=IfDefined(Offset(c_t, -1))\n";
os << "component-node name=Wic component=Wic input=IfDefined(Offset(c_t"
<< offset << "))\n";
os << "component-node name=Wfc component=Wfc input=IfDefined(Offset(c_t"
<< offset << "))\n";
os << "component-node name=Woc component=Woc input=c_t\n";

// Splitting the outputs of W*m node
Expand Down Expand Up @@ -857,7 +874,8 @@ void GenerateConfigSequenceLstmType2(
os << "component-node name=i_t component=i_t input=Sum(W_ix-x_t, Sum(W_ir-r_tminus1, Wic))\n";
os << "component-node name=f_t component=f_t input=Sum(W_fx-x_t, Sum(W_fr-r_tminus1, Wfc))\n";
os << "component-node name=o_t component=o_t input=Sum(W_ox-x_t, Sum(W_or-r_tminus1, Woc))\n";
os << "component-node name=f_t-c_tminus1 component=f_t-c_tminus1 input=Append(f_t, Offset(c_t, -1))\n";
os << "component-node name=f_t-c_tminus1 component=f_t-c_tminus1 input=Append(f_t, Offset(c_t"
<< offset << "))\n";
os << "component-node name=i_t-g component=i_t-g input=Append(i_t, g)\n";
os << "component-node name=m_t component=m_t input=Append(o_t, h)\n";

Expand Down

0 comments on commit c6b3588

Please sign in to comment.