Skip to content

Commit

Permalink
nnet3+ctc: finish more binaries, and some code required for that; fix…
Browse files Browse the repository at this point in the history
… some #ifdefs in headers via script.
  • Loading branch information
danpovey committed Sep 26, 2015
1 parent 0f6a15b commit ff080db
Show file tree
Hide file tree
Showing 24 changed files with 496 additions and 36 deletions.
9 changes: 9 additions & 0 deletions misc/maintenance/fix_include_guards.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,12 @@ for x in */*.h ; do
done


for x in */*.h ; do
name=`echo $x | tr '[a-z]/.-' '[A-Z]___' `
m=KALDI_${name}_
n=`grep endif $x | grep _H_ | sed s://:: | awk '{print $2}' | head -n 1`
if [ ! -s $n ] && [ "$m" != "$n" ]; then
echo "#endif: $m != $n";
cp $x tmp; sed s/$n/$m/ <tmp > $x;
fi
done
2 changes: 1 addition & 1 deletion src/ctc/cctc-supervision.h
Original file line number Diff line number Diff line change
Expand Up @@ -407,4 +407,4 @@ typedef RandomAccessTableReader<KaldiObjectHolder<CctcSupervision > > RandomAcce
} // namespace ctc
} // namespace kaldi

#endif // KALDI_CTC_CTC_SUPERVISION_H_
#endif // KALDI_CTC_CCTC_SUPERVISION_H_
6 changes: 3 additions & 3 deletions src/ctc/cctc-training.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,16 @@ namespace ctc {
struct CctcTrainingOptions {
BaseFloat normalizing_weight;

CctcTrainingOptions(): normalizing_weight(0.0001) { }
CctcTrainingOptions(): normalizing_weight(0.0) { }

void Register(OptionsItf *opts) {
opts->Register("normalizing-weight", &normalizing_weight, "Weight on a "
"term in the objective function that's a negative squared "
"log of the numerator in the CCTC likelihood; it "
"exists to keep the network outputs in a reasonable "
"range so we can exp() them without overflow.");
"range so we can exp() them without overflow. "
"Warning: not supported yet.");
}

};


Expand Down
87 changes: 87 additions & 0 deletions src/ctcbin/nnet3-ctc-compute-prob.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
// nnet3bin/nnet3-ctc-compute-prob.cc

// Copyright 2015 Johns Hopkins University (author: Daniel Povey)

// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.

#include "base/kaldi-common.h"
#include "util/common-utils.h"
#include "nnet3/nnet-cctc-diagnostics.h"


int main(int argc, char *argv[]) {
try {
using namespace kaldi;
using namespace kaldi::nnet3;
typedef kaldi::int32 int32;
typedef kaldi::int64 int64;

const char *usage =
"Computes and prints to in logging messages the average log-prob per frame of\n"
"the given data with an nnet3+ctc neural net. The input of this is the output of\n"
"e.g. nnet3-ctc-get-egs | nnet3-ctc-merge-egs.\n"
"\n"
"Usage: nnet3-ctc-compute-prob [options] <nnet3-ctc-model-in> <training-examples-in>\n"
"e.g.: nnet3-ctc-compute-prob 0.mdl ark:valid.egs\n";


// This program doesn't support using a GPU, because these probabilities are
// used for diagnostics, and you can just compute them with a small enough
// amount of data that a CPU can do it within reasonable time.
// It wouldn't be hard to make it support GPU, though.

NnetCctcComputeProbOptions opts;

ParseOptions po(usage);

opts.Register(&po);

po.Read(argc, argv);

if (po.NumArgs() != 2) {
po.PrintUsage();
exit(1);
}

std::string ctc_nnet_rxfilename = po.GetArg(1),
examples_rspecifier = po.GetArg(2);

ctc::CctcTransitionModel trans_model;
Nnet nnet;
{
bool binary;
Input input(ctc_nnet_rxfilename, &binary);
trans_model.Read(input.Stream(), binary);
nnet.Read(input.Stream(), binary);
}

NnetCctcComputeProb cctc_prob_computer(opts, trans_model, nnet);

SequentialNnetCctcExampleReader example_reader(examples_rspecifier);

for (; !example_reader.Done(); example_reader.Next())
cctc_prob_computer.Compute(example_reader.Value());

bool ok = cctc_prob_computer.PrintTotalStats();

return (ok ? 0 : 1);
} catch(const std::exception &e) {
std::cerr << e.what() << '\n';
return -1;
}
}


2 changes: 1 addition & 1 deletion src/nnet2/online-nnet2-decodable.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,4 +119,4 @@ class DecodableNnet2Online: public DecodableInterface {
} // namespace nnet2
} // namespace kaldi

#endif // KALDI_NNET2_ONLINE_GMM_DECODABLE_H_
#endif // KALDI_NNET2_ONLINE_NNET2_DECODABLE_H_
2 changes: 1 addition & 1 deletion src/nnet3/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ OBJFILES = nnet-common.o nnet-compile.o nnet-component-itf.o \
nnet-utils.o nnet-compute.o nnet-test-utils.o nnet-analyze.o \
nnet-example-utils.o nnet-training.o \
nnet-diagnostics.o nnet-combine.o nnet-am-decodable-simple.o \
nnet-optimize-utils.o nnet-cctc-example.o
nnet-optimize-utils.o nnet-cctc-example.o nnet-cctc-diagnostics.o

LIBNAME = kaldi-nnet3

Expand Down
2 changes: 1 addition & 1 deletion src/nnet3/nnet-am-decodable-simple.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,4 +171,4 @@ class DecodableAmNnetSimple: public DecodableInterface {
} // namespace nnet3
} // namespace kaldi

#endif // KALDI_NNET2_DECODABLE_AM_NNET_H_
#endif // KALDI_NNET3_NNET_AM_DECODABLE_SIMPLE_H_
152 changes: 152 additions & 0 deletions src/nnet3/nnet-cctc-diagnostics.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
// nnet3/nnet-cctc-diagnostics.cc

// Copyright 2015 Johns Hopkins University (author: Daniel Povey)

// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.

#include "nnet3/nnet-cctc-diagnostics.h"
#include "nnet3/nnet-utils.h"

namespace kaldi {
namespace nnet3 {

NnetCctcComputeProb::NnetCctcComputeProb(
const NnetCctcComputeProbOptions &config,
const ctc::CctcTransitionModel &trans_model,
const Nnet &nnet):
config_(config),
trans_model_(trans_model),
nnet_(nnet),
deriv_nnet_(NULL),
compiler_(nnet),
num_minibatches_processed_(0) {
if (config_.compute_deriv) {
deriv_nnet_ = new Nnet(nnet_);
bool is_gradient = true; // force simple update
SetZero(is_gradient, deriv_nnet_);
}
trans_model_.ComputeWeights(&cu_weights_);
}

const Nnet &NnetCctcComputeProb::GetDeriv() const {
if (deriv_nnet_ == NULL)
KALDI_ERR << "GetDeriv() called when no derivatives were requested.";
return *deriv_nnet_;
}

NnetCctcComputeProb::~NnetCctcComputeProb() {
delete deriv_nnet_; // delete does nothing if pointer is NULL.
}

void NnetCctcComputeProb::Reset() {
num_minibatches_processed_ = 0;
objf_info_.clear();
if (deriv_nnet_) {
bool is_gradient = true;
SetZero(is_gradient, deriv_nnet_);
}
}

void NnetCctcComputeProb::Compute(const NnetCctcExample &cctc_eg) {
bool need_model_derivative = config_.compute_deriv,
store_component_stats = false;
ComputationRequest request;
GetCctcComputationRequest(nnet_, cctc_eg, need_model_derivative,
store_component_stats,
&request);
const NnetComputation *computation = compiler_.Compile(request);
NnetComputer computer(config_.compute_config, *computation,
nnet_, deriv_nnet_);
// give the inputs to the computer object.
computer.AcceptInputs(nnet_, cctc_eg.inputs);
computer.Forward();
this->ProcessOutputs(cctc_eg, &computer);
if (config_.compute_deriv)
computer.Backward();
}

void NnetCctcComputeProb::ProcessOutputs(const NnetCctcExample &eg,
NnetComputer *computer) {
// There will normally be just one output here, named 'output',
// but the code is more general than this.
std::vector<NnetCctcSupervision>::const_iterator iter = eg.outputs.begin(),
end = eg.outputs.end();
for (; iter != end; ++iter) {
const NnetCctcSupervision &sup = *iter;
int32 node_index = nnet_.GetNodeIndex(sup.name);
if (node_index < 0 ||
!nnet_.IsOutputNode(node_index))
KALDI_ERR << "Network has no output named " << sup.name;

const CuMatrixBase<BaseFloat> &nnet_output = computer->GetOutput(sup.name);
CuMatrix<BaseFloat> nnet_output_deriv;
if (config_.compute_deriv)
nnet_output_deriv.Resize(nnet_output.NumRows(), nnet_output.NumCols(),
kUndefined);

BaseFloat tot_weight, tot_objf;
sup.ComputeObjfAndDerivs(config_.cctc_training_config,
trans_model_,
cu_weights_, nnet_output,
&tot_weight, &tot_objf,
(config_.compute_deriv ?
&nnet_output_deriv : NULL));

SimpleObjectiveInfo &totals = objf_info_[sup.name];
totals.tot_weight += tot_weight;
totals.tot_objective += tot_objf;

if (config_.compute_deriv)
computer->AcceptOutputDeriv(sup.name, &nnet_output_deriv);

num_minibatches_processed_++;
}
}

bool NnetCctcComputeProb::PrintTotalStats() const {
bool ans = false;
unordered_map<std::string, SimpleObjectiveInfo, StringHasher>::const_iterator
iter, end;
iter = objf_info_.begin();
end = objf_info_.end();
for (; iter != end; ++iter) {
const std::string &name = iter->first;
int32 node_index = nnet_.GetNodeIndex(name);
KALDI_ASSERT(node_index >= 0);
const SimpleObjectiveInfo &info = iter->second;
KALDI_LOG << "Overall log-probability for '"
<< name << "' is "
<< (info.tot_objective / info.tot_weight) << " per frame"
<< ", over " << info.tot_weight << " frames.";
if (info.tot_weight > 0)
ans = true;
}
return ans;
}


const SimpleObjectiveInfo* NnetCctcComputeProb::GetObjective(
const std::string &output_name) const {
unordered_map<std::string, SimpleObjectiveInfo, StringHasher>::const_iterator
iter = objf_info_.find(output_name);
if (iter != objf_info_.end())
return &(iter->second);
else
return NULL;
}

} // namespace nnet3
} // namespace kaldi
Loading

0 comments on commit ff080db

Please sign in to comment.