From 5afe883dedc1290a283651ad1c1cc0857039c304 Mon Sep 17 00:00:00 2001 From: Pegita Date: Fri, 30 Sep 2016 17:42:24 -0400 Subject: [PATCH] Added DropoutComponent in nnet3 --- .../s5c/local/chain/tuning/run_tdnn_7d.sh | 2 - src/cudamatrix/cu-kernels-ansi.h | 4 +- src/cudamatrix/cu-kernels.cu | 10 +-- src/cudamatrix/cu-kernels.h | 8 +- src/cudamatrix/cu-matrix-test.cc | 8 +- src/cudamatrix/cu-matrix.cc | 6 +- src/cudamatrix/cu-matrix.h | 2 +- src/matrix/kaldi-matrix.cc | 6 +- src/matrix/kaldi-matrix.h | 4 +- src/nnet2/nnet-component.cc | 2 +- src/nnet3/nnet-component-itf.cc | 2 + src/nnet3/nnet-component-itf.h | 10 +++ src/nnet3/nnet-component-test.cc | 31 ++++++-- src/nnet3/nnet-nnet.cc | 10 ++- src/nnet3/nnet-nnet.h | 6 +- src/nnet3/nnet-simple-component.cc | 79 +++++++++++++++++++ src/nnet3/nnet-simple-component.h | 53 +++++++++++++ src/nnet3/nnet-test-utils.cc | 8 +- src/nnet3/nnet-utils.cc | 9 +++ src/nnet3/nnet-utils.h | 3 + src/nnet3bin/nnet3-copy.cc | 8 +- 21 files changed, 234 insertions(+), 37 deletions(-) diff --git a/egs/swbd/s5c/local/chain/tuning/run_tdnn_7d.sh b/egs/swbd/s5c/local/chain/tuning/run_tdnn_7d.sh index 0768bd786ae..5bcfea82ec3 100644 --- a/egs/swbd/s5c/local/chain/tuning/run_tdnn_7d.sh +++ b/egs/swbd/s5c/local/chain/tuning/run_tdnn_7d.sh @@ -60,7 +60,6 @@ If you want to use GPUs (and have them), go to src/, and configure and make on a where "nvcc" is installed. EOF fi - # The iVector-extraction and feature-dumping parts are the same as the standard # nnet3 setup, and you can skip them by setting "--stage 8" if you have already # run those things. @@ -76,7 +75,6 @@ ali_dir=exp/tri4_ali_nodup$suffix treedir=exp/chain/tri5_7d_tree$suffix lang=data/lang_chain_2y - # if we are using the speed-perturbed data we need to generate # alignments for it. local/nnet3/run_ivector_common.sh --stage $stage \ diff --git a/src/cudamatrix/cu-kernels-ansi.h b/src/cudamatrix/cu-kernels-ansi.h index 03dd91b793f..4642048989e 100644 --- a/src/cudamatrix/cu-kernels-ansi.h +++ b/src/cudamatrix/cu-kernels-ansi.h @@ -125,7 +125,7 @@ void cudaF_add_mat(dim3 Gr, dim3 Bl, float alpha, const float *src, float *dst, void cudaF_add_mat_blocks(dim3 Gr, dim3 Bl, float alpha, const float *src, int32_cuda num_row_blocks, int32_cuda num_col_blocks, float *dst, MatrixDim d, int src_stride, int A_trans); -void cudaF_add_mat_mat_div_mat(dim3 Gr, dim3 Bl, const float *A, const float *B, +void cudaF_set_mat_mat_div_mat(dim3 Gr, dim3 Bl, const float *A, const float *B, const float *C, float *dst, MatrixDim d, int stride_a, int stride_b, int stride_c); void cudaF_add_vec_to_cols(dim3 Gr, dim3 Bl, float alpha, const float *col, @@ -391,7 +391,7 @@ void cudaD_add_mat_blocks(dim3 Gr, dim3 Bl, double alpha, const double *src, int32_cuda num_row_blocks, int32_cuda num_col_blocks, double *dst, MatrixDim d, int src_stride, int A_trans); -void cudaD_add_mat_mat_div_mat(dim3 Gr, dim3 Bl, const double *A, +void cudaD_set_mat_mat_div_mat(dim3 Gr, dim3 Bl, const double *A, const double *B, const double *C, double *dst, MatrixDim d, int stride_a, int stride_b, int stride_c); diff --git a/src/cudamatrix/cu-kernels.cu b/src/cudamatrix/cu-kernels.cu index 6f098c87fb5..ba8688fe2be 100644 --- a/src/cudamatrix/cu-kernels.cu +++ b/src/cudamatrix/cu-kernels.cu @@ -584,7 +584,7 @@ static void _add_mat_blocks_trans(Real alpha, const Real* src, template __global__ -static void _add_mat_mat_div_mat(const Real* A, const Real* B, const Real* C, +static void _set_mat_mat_div_mat(const Real* A, const Real* B, const Real* C, Real* dst, MatrixDim d, int stride_a, int stride_b, int stride_c) { int32_cuda i = blockIdx.x * blockDim.x + threadIdx.x; @@ -2863,10 +2863,10 @@ void cudaF_add_mat_blocks(dim3 Gr, dim3 Bl, float alpha, const float* src, } } -void cudaF_add_mat_mat_div_mat(dim3 Gr, dim3 Bl, const float *A, const float *B, +void cudaF_set_mat_mat_div_mat(dim3 Gr, dim3 Bl, const float *A, const float *B, const float *C, float *dst, MatrixDim d, int stride_a, int stride_b, int stride_c) { - _add_mat_mat_div_mat<<>>(A,B,C,dst,d, stride_a, stride_b, stride_c); + _set_mat_mat_div_mat<<>>(A,B,C,dst,d, stride_a, stride_b, stride_c); } void cudaF_sy_add_tr2(dim3 Gr, dim3 Bl, float alpha, float beta, const float* T, @@ -3505,11 +3505,11 @@ void cudaD_add_mat_blocks(dim3 Gr, dim3 Bl, double alpha, const double* src, } } -void cudaD_add_mat_mat_div_mat(dim3 Gr, dim3 Bl, const double *A, +void cudaD_set_mat_mat_div_mat(dim3 Gr, dim3 Bl, const double *A, const double *B, const double *C, double *dst, MatrixDim d, int stride_a, int stride_b, int stride_c) { - _add_mat_mat_div_mat<<>>(A,B,C,dst,d,stride_a,stride_b,stride_c); + _set_mat_mat_div_mat<<>>(A,B,C,dst,d,stride_a,stride_b,stride_c); } void cudaD_sy_add_tr2(dim3 Gr, dim3 Bl, double alpha, double beta, diff --git a/src/cudamatrix/cu-kernels.h b/src/cudamatrix/cu-kernels.h index 748418e5f2f..a6e81db5d6c 100644 --- a/src/cudamatrix/cu-kernels.h +++ b/src/cudamatrix/cu-kernels.h @@ -337,11 +337,11 @@ inline void cuda_add_mat_blocks(dim3 Gr, dim3 Bl, float alpha, const float *src, cudaF_add_mat_blocks(Gr, Bl, alpha, src, num_row_blocks, num_col_blocks, dst, d, src_stride, A_trans); } -inline void cuda_add_mat_mat_div_mat(dim3 Gr, dim3 Bl, const float *A, +inline void cuda_set_mat_mat_div_mat(dim3 Gr, dim3 Bl, const float *A, const float *B, const float *C, float *dst, MatrixDim d, int stride_a, int stride_b, int stride_c) { - cudaF_add_mat_mat_div_mat(Gr, Bl, A, B, C, dst, d, stride_a, stride_b, + cudaF_set_mat_mat_div_mat(Gr, Bl, A, B, C, dst, d, stride_a, stride_b, stride_c); } inline void cuda_add_vec_to_cols(dim3 Gr, dim3 Bl, float alpha, @@ -872,11 +872,11 @@ inline void cuda_add_mat_blocks(dim3 Gr, dim3 Bl, double alpha, cudaD_add_mat_blocks(Gr, Bl, alpha, src, num_row_blocks, num_col_blocks, dst, d, src_stride, A_trans); } -inline void cuda_add_mat_mat_div_mat(dim3 Gr, dim3 Bl, const double *A, +inline void cuda_set_mat_mat_div_mat(dim3 Gr, dim3 Bl, const double *A, const double *B, const double *C, double *dst, MatrixDim d, int stride_a, int stride_b, int stride_c) { - cudaD_add_mat_mat_div_mat(Gr, Bl, A, B, C, dst, d, stride_a, stride_b, + cudaD_set_mat_mat_div_mat(Gr, Bl, A, B, C, dst, d, stride_a, stride_b, stride_c); } inline void cuda_add_vec_to_cols(dim3 Gr, dim3 Bl, double alpha, diff --git a/src/cudamatrix/cu-matrix-test.cc b/src/cudamatrix/cu-matrix-test.cc index 739cd8eeb22..da587e450e3 100644 --- a/src/cudamatrix/cu-matrix-test.cc +++ b/src/cudamatrix/cu-matrix-test.cc @@ -1077,7 +1077,7 @@ template static void UnitTestCuMatrixAddMatMatElements() { KALDI_ASSERT(M.Sum() != 0.0); } -template static void UnitTestCuMatrixAddMatMatDivMat() { +template static void UnitTestCuMatrixSetMatMatDivMat() { // M = a * b / c (by element; when c = 0, M = a) MatrixIndexT dimM = 100 + Rand() % 255, dimN = 100 + Rand() % 255; CuMatrix M(dimM, dimN), A(dimM, dimN), B(dimM, dimN), C(dimM, dimN); @@ -1087,13 +1087,13 @@ template static void UnitTestCuMatrixAddMatMatDivMat() { B.SetRandn(); C.SetRandn(); - M.AddMatMatDivMat(A,B,C); + M.SetMatMatDivMat(A,B,C); ref.AddMatMatElements(1.0, A, B, 0.0); ref.DivElements(C); AssertEqual(M, ref); C.SetZero(); - M.AddMatMatDivMat(A,B,C); + M.SetMatMatDivMat(A,B,C); AssertEqual(M, A); } @@ -2665,7 +2665,7 @@ template void CudaMatrixUnitTest() { UnitTestCuMatrixAddDiagVecMat(); UnitTestCuMatrixAddMatDiagVec(); UnitTestCuMatrixAddMatMatElements(); - UnitTestCuMatrixAddMatMatDivMat(); + UnitTestCuMatrixSetMatMatDivMat(); UnitTestCuTanh(); UnitTestCuCholesky(); UnitTestCuDiffTanh(); diff --git a/src/cudamatrix/cu-matrix.cc b/src/cudamatrix/cu-matrix.cc index 8d0ad950f2f..afe884b2b76 100644 --- a/src/cudamatrix/cu-matrix.cc +++ b/src/cudamatrix/cu-matrix.cc @@ -989,7 +989,7 @@ void CuMatrixBase::AddMatBlocks(Real alpha, const CuMatrixBase &A, /// dst = a * b / c (by element; when c = 0, dst = a) /// dst can be an alias of a, b or c safely and get expected result. template -void CuMatrixBase::AddMatMatDivMat(const CuMatrixBase &A, +void CuMatrixBase::SetMatMatDivMat(const CuMatrixBase &A, const CuMatrixBase &B, const CuMatrixBase &C) { #if HAVE_CUDA == 1 if (CuDevice::Instantiate().Enabled()) { @@ -1002,7 +1002,7 @@ void CuMatrixBase::AddMatMatDivMat(const CuMatrixBase &A, dim3 dimGrid, dimBlock; GetBlockSizesForSimpleMatrixOperation(NumRows(), NumCols(), &dimGrid, &dimBlock); - cuda_add_mat_mat_div_mat(dimGrid, dimBlock, A.data_, B.data_, C.data_, + cuda_set_mat_mat_div_mat(dimGrid, dimBlock, A.data_, B.data_, C.data_, data_, Dim(), A.Stride(), B.Stride(), C.Stride()); CU_SAFE_CALL(cudaGetLastError()); @@ -1010,7 +1010,7 @@ void CuMatrixBase::AddMatMatDivMat(const CuMatrixBase &A, } else #endif { - Mat().AddMatMatDivMat(A.Mat(), B.Mat(), C.Mat()); + Mat().SetMatMatDivMat(A.Mat(), B.Mat(), C.Mat()); } } diff --git a/src/cudamatrix/cu-matrix.h b/src/cudamatrix/cu-matrix.h index f72484f18e7..38a6c25071b 100644 --- a/src/cudamatrix/cu-matrix.h +++ b/src/cudamatrix/cu-matrix.h @@ -429,7 +429,7 @@ class CuMatrixBase { void AddVecVec(Real alpha, const CuVectorBase &x, const CuVectorBase &y); /// *this = a * b / c (by element; when c = 0, *this = a) /// *this can be an alias of a, b or c safely and get expected result. - void AddMatMatDivMat(const CuMatrixBase &A, const CuMatrixBase &B, const CuMatrixBase &C); + void SetMatMatDivMat(const CuMatrixBase &A, const CuMatrixBase &B, const CuMatrixBase &C); /// *this = beta * *this + alpha * M M^T, for symmetric matrices. It only /// updates the lower triangle of *this. It will leave the matrix asymmetric; diff --git a/src/matrix/kaldi-matrix.cc b/src/matrix/kaldi-matrix.cc index 817de100656..cb7d3be0cee 100644 --- a/src/matrix/kaldi-matrix.cc +++ b/src/matrix/kaldi-matrix.cc @@ -179,9 +179,9 @@ void MatrixBase::AddMatMat(const Real alpha, } template -void MatrixBase::AddMatMatDivMat(const MatrixBase& A, - const MatrixBase& B, - const MatrixBase& C) { +void MatrixBase::SetMatMatDivMat(const MatrixBase& A, + const MatrixBase& B, + const MatrixBase& C) { KALDI_ASSERT(A.NumRows() == B.NumRows() && A.NumCols() == B.NumCols()); KALDI_ASSERT(A.NumRows() == C.NumRows() && A.NumCols() == C.NumCols()); for (int32 r = 0; r < A.NumRows(); r++) { // each frame... diff --git a/src/matrix/kaldi-matrix.h b/src/matrix/kaldi-matrix.h index 5b4216002fb..e254fcad118 100644 --- a/src/matrix/kaldi-matrix.h +++ b/src/matrix/kaldi-matrix.h @@ -579,8 +579,8 @@ class MatrixBase { const Real beta); /// *this = a * b / c (by element; when c = 0, *this = a) - void AddMatMatDivMat(const MatrixBase& A, - const MatrixBase& B, + void SetMatMatDivMat(const MatrixBase& A, + const MatrixBase& B, const MatrixBase& C); /// A version of AddMatMat specialized for when the second argument diff --git a/src/nnet2/nnet-component.cc b/src/nnet2/nnet-component.cc index 498cc809e5f..9608a5475e0 100644 --- a/src/nnet2/nnet-component.cc +++ b/src/nnet2/nnet-component.cc @@ -3593,7 +3593,7 @@ void DropoutComponent::Backprop(const ChunkInfo &, //in_info, CuMatrix *in_deriv) const { KALDI_ASSERT(SameDim(in_value, out_value) && SameDim(in_value, out_deriv)); in_deriv->Resize(out_deriv.NumRows(), out_deriv.NumCols()); - in_deriv->AddMatMatDivMat(out_deriv, out_value, in_value); + in_deriv->SetMatMatDivMat(out_deriv, out_value, in_value); } Component* DropoutComponent::Copy() const { diff --git a/src/nnet3/nnet-component-itf.cc b/src/nnet3/nnet-component-itf.cc index cdb43473090..168a2a5350a 100644 --- a/src/nnet3/nnet-component-itf.cc +++ b/src/nnet3/nnet-component-itf.cc @@ -141,6 +141,8 @@ Component* Component::NewComponentOfType(const std::string &component_type) { ans = new StatisticsPoolingComponent(); } else if (component_type == "ConstantFunctionComponent") { ans = new ConstantFunctionComponent(); + } else if (component_type == "DropoutComponent") { + ans = new DropoutComponent(); } if (ans != NULL) { KALDI_ASSERT(component_type == ans->Type()); diff --git a/src/nnet3/nnet-component-itf.h b/src/nnet3/nnet-component-itf.h index 90b463ec578..164f9d056e7 100644 --- a/src/nnet3/nnet-component-itf.h +++ b/src/nnet3/nnet-component-itf.h @@ -350,6 +350,16 @@ class Component { }; +class RandomComponent: public Component { + public: + // This function is required in testing code and in other places we need + // consistency in the random number generation (e.g. when optimizing + // validation-set performance), but check where else we call srand(). You'll + // need to call srand as well as making this call. + void ResetGenerator() { random_generator_.SeedGpu(); } + protected: + CuRand random_generator_; +}; /** * Class UpdatableComponent is a Component which has trainable parameters; it diff --git a/src/nnet3/nnet-component-test.cc b/src/nnet3/nnet-component-test.cc index 51760d67557..3cc6af1c70d 100644 --- a/src/nnet3/nnet-component-test.cc +++ b/src/nnet3/nnet-component-test.cc @@ -23,7 +23,16 @@ namespace kaldi { namespace nnet3 { - +// Reset seeds for test time for RandomComponent +static void ResetSeed(int32 rand_seed, const Component &c) { + RandomComponent *rand_component = + const_cast(dynamic_cast(&c)); + + if (rand_component != NULL) { + srand(rand_seed); + rand_component->ResetGenerator(); + } +} // returns true if two are string are equal except for what looks like it might // be a difference last digit of a floating point number, e.g. accept // 1.234 to be the same as 1.235. Not very rigorous. @@ -188,6 +197,8 @@ void TestNnetComponentUpdatable(Component *c) { void TestSimpleComponentPropagateProperties(const Component &c) { int32 properties = c.Properties(); Component *c_copy = NULL, *c_copy_scaled = NULL; + int32 rand_seed = Rand(); + if (RandInt(0, 1) == 0) c_copy = c.Copy(); // This will test backprop with an updatable component. if (RandInt(0, 1) == 0 && @@ -223,10 +234,14 @@ void TestSimpleComponentPropagateProperties(const Component &c) { if ((properties & kPropagateAdds) && (properties & kPropagateInPlace)) { KALDI_ERR << "kPropagateAdds and kPropagateInPlace flags are incompatible."; } - + + ResetSeed(rand_seed, c); c.Propagate(NULL, input_data, &output_data1); + + ResetSeed(rand_seed, c); c.Propagate(NULL, input_data, &output_data2); if (properties & kPropagateInPlace) { + ResetSeed(rand_seed, c); c.Propagate(NULL, output_data3, &output_data3); if (!output_data1.ApproxEqual(output_data3)) { KALDI_ERR << "Test of kPropagateInPlace flag for component of type " @@ -238,12 +253,14 @@ void TestSimpleComponentPropagateProperties(const Component &c) { AssertEqual(output_data1, output_data2); if (c_copy_scaled) { + ResetSeed(rand_seed, *c_copy_scaled); c_copy_scaled->Propagate(NULL, input_data, &output_data4); output_data4.Scale(2.0); // we scaled the parameters by 0.5 above, and the // output is supposed to be linear in the parameter value. AssertEqual(output_data1, output_data4); } if (properties & kLinearInInput) { + ResetSeed(rand_seed, c); c.Propagate(NULL, input_data_scaled, &output_data5); output_data5.Scale(0.5); AssertEqual(output_data1, output_data5); @@ -302,14 +319,16 @@ bool TestSimpleComponentDataDerivative(const Component &c, int32 input_dim = c.InputDim(), output_dim = c.OutputDim(), - num_rows = RandInt(1, 100); + num_rows = RandInt(1, 100), + rand_seed = Rand(); int32 properties = c.Properties(); CuMatrix input_data(num_rows, input_dim, kSetZero, input_stride_type), output_data(num_rows, output_dim, kSetZero, output_stride_type), output_deriv(num_rows, output_dim, kSetZero, output_stride_type); input_data.SetRandn(); output_deriv.SetRandn(); - + + ResetSeed(rand_seed, c); c.Propagate(NULL, input_data, &output_data); CuMatrix input_deriv(num_rows, input_dim, kSetZero, input_stride_type), @@ -334,6 +353,8 @@ bool TestSimpleComponentDataDerivative(const Component &c, predicted_objf_change(i) = TraceMatMat(perturbed_input_data, input_deriv, kTrans); perturbed_input_data.AddMat(1.0, input_data); + + ResetSeed(rand_seed, c); c.Propagate(NULL, perturbed_input_data, &perturbed_output_data); measured_objf_change(i) = TraceMatMat(output_deriv, perturbed_output_data, kTrans) - original_objf; @@ -503,7 +524,7 @@ int main() { TestStringsApproxEqual(); for (kaldi::int32 loop = 0; loop < 2; loop++) { #if HAVE_CUDA == 1 - CuDevice::Instantiate().SetDebugStrideMode(true); + //CuDevice::Instantiate().SetDebugStrideMode(true); if (loop == 0) CuDevice::Instantiate().SelectGpuId("no"); else diff --git a/src/nnet3/nnet-nnet.cc b/src/nnet3/nnet-nnet.cc index acd322eb515..c84df89177d 100644 --- a/src/nnet3/nnet-nnet.cc +++ b/src/nnet3/nnet-nnet.cc @@ -897,7 +897,15 @@ void Nnet::RemoveOrphanNodes(bool remove_orphan_inputs) { RemoveSomeNodes(orphan_nodes); } - +void Nnet::ResetGenerators() { + // resets random-number generators for all random + // components. + for (int32 c = 0; c < NumComponents(); c++) { + RandomComponent *rc = dynamic_cast(GetComponent(c)); + if (rc != NULL) + rc->ResetGenerator(); + } +} } // namespace nnet3 } // namespace kaldi diff --git a/src/nnet3/nnet-nnet.h b/src/nnet3/nnet-nnet.h index fc10a8bb09f..16e8333d5b1 100644 --- a/src/nnet3/nnet-nnet.h +++ b/src/nnet3/nnet-nnet.h @@ -236,7 +236,6 @@ class Nnet { // Assignment operator Nnet& operator =(const Nnet &nnet); - // Removes nodes that are never needed to compute any output. void RemoveOrphanNodes(bool remove_orphan_inputs = false); @@ -247,6 +246,10 @@ class Nnet { // as it could ruin the graph structure if done carelessly. void RemoveSomeNodes(const std::vector &nodes_to_remove); + void ResetGenerators(); // resets random-number generators for all + // random components. You must also set srand() for this to be + // effective. + private: void Destroy(); @@ -323,7 +326,6 @@ class Nnet { }; - } // namespace nnet3 } // namespace kaldi diff --git a/src/nnet3/nnet-simple-component.cc b/src/nnet3/nnet-simple-component.cc index ec9f226cf9f..6940ba8302a 100644 --- a/src/nnet3/nnet-simple-component.cc +++ b/src/nnet3/nnet-simple-component.cc @@ -86,6 +86,85 @@ void PnormComponent::Write(std::ostream &os, bool binary) const { } +void DropoutComponent::Init(int32 dim, BaseFloat dropout_proportion) { + dropout_proportion_ = dropout_proportion; + dim_ = dim; +} + +void DropoutComponent::InitFromConfig(ConfigLine *cfl) { + int32 dim = 0; + BaseFloat dropout_proportion = 0.0; + bool ok = cfl->GetValue("dim", &dim) && + cfl->GetValue("dropout-proportion", &dropout_proportion); + if (!ok || cfl->HasUnusedValues() || dim <= 0 || + dropout_proportion < 0.0 || dropout_proportion > 1.0) + KALDI_ERR << "Invalid initializer for layer of type " + << Type() << ": \"" << cfl->WholeLine() << "\""; + Init(dim, dropout_proportion); +} + +std::string DropoutComponent::Info() const { + std::ostringstream stream; + stream << Type() << ", dim = " << dim_ + << ", dropout-proportion = " << dropout_proportion_; + return stream.str(); +} + +void DropoutComponent::Propagate(const ComponentPrecomputedIndexes *indexes, + const CuMatrixBase &in, + CuMatrixBase *out) const { + KALDI_ASSERT(out->NumRows() == in.NumRows() && out->NumCols() == in.NumCols() + && in.NumCols() == dim_); + + BaseFloat dropout = dropout_proportion_; + KALDI_ASSERT(dropout >= 0.0 && dropout <= 1.0); + + // This const_cast is only safe assuming you don't attempt + // to use multi-threaded code with the GPU. + const_cast&>(random_generator_).RandUniform(out); + + out->Add(-dropout); // now, a proportion "dropout" will be <0.0 + out->ApplyHeaviside(); // apply the function (x>0?1:0). Now, a proportion "dropout" will + // be zero and (1 - dropout) will be 1.0. + + out->MulElements(in); +} + + +void DropoutComponent::Backprop(const std::string &debug_info, + const ComponentPrecomputedIndexes *indexes, + const CuMatrixBase &in_value, + const CuMatrixBase &out_value, + const CuMatrixBase &out_deriv, + Component *to_update, + CuMatrixBase *in_deriv) const { + KALDI_ASSERT(in_value.NumRows() == out_value.NumRows() && + in_value.NumCols() == out_value.NumCols()); + + KALDI_ASSERT(in_value.NumRows() == out_deriv.NumRows() && + in_value.NumCols() == out_deriv.NumCols()); + in_deriv->SetMatMatDivMat(out_deriv, out_value, in_value); +} + + + +void DropoutComponent::Read(std::istream &is, bool binary) { + ExpectOneOrTwoTokens(is, binary, "", ""); + ReadBasicType(is, binary, &dim_); + ExpectToken(is, binary, ""); + ReadBasicType(is, binary, &dropout_proportion_); + ExpectToken(is, binary, ""); +} + +void DropoutComponent::Write(std::ostream &os, bool binary) const { + WriteToken(os, binary, ""); + WriteToken(os, binary, ""); + WriteBasicType(os, binary, dim_); + WriteToken(os, binary, ""); + WriteBasicType(os, binary, dropout_proportion_); + WriteToken(os, binary, ""); +} + void SumReduceComponent::Init(int32 input_dim, int32 output_dim) { input_dim_ = input_dim; output_dim_ = output_dim; diff --git a/src/nnet3/nnet-simple-component.h b/src/nnet3/nnet-simple-component.h index c3101582aac..8cd1539a02a 100644 --- a/src/nnet3/nnet-simple-component.h +++ b/src/nnet3/nnet-simple-component.h @@ -79,6 +79,59 @@ class PnormComponent: public Component { int32 output_dim_; }; +// This component randomly zeros dropout_proportion of the input +// and the derivatives are backpropagated through the nonzero inputs. +// Typically this component used during training but not in test time. +// The idea is described under the name Dropout, in the paper +// "Dropout: A Simple Way to Prevent Neural Networks from Overfitting". +class DropoutComponent : public RandomComponent { + public: + void Init(int32 dim, BaseFloat dropout_proportion = 0.0); + + DropoutComponent(int32 dim, BaseFloat dropout = 0.0) { Init(dim, dropout); } + + DropoutComponent(): dim_(0), dropout_proportion_(0.0) { } + + virtual int32 Properties() const { + return kLinearInInput|kBackpropInPlace|kSimpleComponent|kBackpropNeedsInput|kBackpropNeedsOutput; + } + virtual std::string Type() const { return "DropoutComponent"; } + + virtual void InitFromConfig(ConfigLine *cfl); + + virtual int32 InputDim() const { return dim_; } + + virtual int32 OutputDim() const { return dim_; } + + virtual void Read(std::istream &is, bool binary); + + // Write component to stream + virtual void Write(std::ostream &os, bool binary) const; + + virtual void Propagate(const ComponentPrecomputedIndexes *indexes, + const CuMatrixBase &in, + CuMatrixBase *out) const; + virtual void Backprop(const std::string &debug_info, + const ComponentPrecomputedIndexes *indexes, + const CuMatrixBase &in_value, + const CuMatrixBase &out_value, + const CuMatrixBase &out_deriv, + Component *to_update, + CuMatrixBase *in_deriv) const; + virtual Component* Copy() const { return new DropoutComponent(dim_, + dropout_proportion_); } + virtual std::string Info() const; + + void SetDropoutProportion(BaseFloat dropout_proportion) { dropout_proportion_ = dropout_proportion; } + + private: + int32 dim_; + /// dropout-proportion is the proportion that is dropped out, + /// e.g. if 0.1, we set 10% to zero value. + BaseFloat dropout_proportion_; + +}; + class ElementwiseProductComponent: public Component { public: void Init(int32 input_dim, int32 output_dim); diff --git a/src/nnet3/nnet-test-utils.cc b/src/nnet3/nnet-test-utils.cc index dc2696e4e12..e02ae4974c9 100644 --- a/src/nnet3/nnet-test-utils.cc +++ b/src/nnet3/nnet-test-utils.cc @@ -926,7 +926,7 @@ void ComputeExampleComputationRequestSimple( static void GenerateRandomComponentConfig(std::string *component_type, std::string *config) { - int32 n = RandInt(0, 28); + int32 n = RandInt(0, 29); BaseFloat learning_rate = 0.001 * RandInt(1, 3); std::ostringstream os; @@ -1219,6 +1219,12 @@ static void GenerateRandomComponentConfig(std::string *component_type, os << " self-repair-target=" << RandUniform(); break; } + case 29: { + *component_type = "DropoutComponent"; + os << "dim=" << RandInt(1, 200) + << " dropout-proportion=" << RandUniform(); + break; + } default: KALDI_ERR << "Error generating random component"; } diff --git a/src/nnet3/nnet-utils.cc b/src/nnet3/nnet-utils.cc index 0cb7d1fe9b3..955e200d072 100644 --- a/src/nnet3/nnet-utils.cc +++ b/src/nnet3/nnet-utils.cc @@ -496,6 +496,15 @@ std::string NnetInfo(const Nnet &nnet) { return ostr.str(); } +void SetDropoutProportion(BaseFloat dropout_proportion, + Nnet *nnet) { + for (int32 c = 0; c < nnet->NumComponents(); c++) { + Component *comp = nnet->GetComponent(c); + DropoutComponent *dc = dynamic_cast(comp); + if (dc != NULL) + dc->SetDropoutProportion(dropout_proportion); + } +} void FindOrphanComponents(const Nnet &nnet, std::vector *components) { int32 num_components = nnet.NumComponents(), num_nodes = nnet.NumNodes(); diff --git a/src/nnet3/nnet-utils.h b/src/nnet3/nnet-utils.h index 22e1d9bdf5c..8bb10ff16e7 100644 --- a/src/nnet3/nnet-utils.h +++ b/src/nnet3/nnet-utils.h @@ -174,6 +174,9 @@ void ConvertRepeatedToBlockAffine(Nnet *nnet); /// Info() function (we need this in the CTC code). std::string NnetInfo(const Nnet &nnet); +/// This function sets the dropout proportion in all dropout component to +/// dropout_proportion value. +void SetDropoutProportion(BaseFloat dropout_proportion, Nnet *nnet); /// This function finds a list of components that are never used, and outputs /// the integer comopnent indexes (you can use these to index diff --git a/src/nnet3bin/nnet3-copy.cc b/src/nnet3bin/nnet3-copy.cc index d2196758594..c419e0e0f91 100644 --- a/src/nnet3bin/nnet3-copy.cc +++ b/src/nnet3bin/nnet3-copy.cc @@ -41,7 +41,8 @@ int main(int argc, char *argv[]) { " nnet3-copy --binary=false 0.raw text.raw\n"; bool binary_write = true; - BaseFloat learning_rate = -1; + BaseFloat learning_rate = -1, + dropout = 0.0; std::string nnet_config, edits_config, edits_str; ParseOptions po(usage); @@ -61,6 +62,8 @@ int main(int argc, char *argv[]) { "Can be used as an inline alternative to edits-config; semicolons " "will be converted to newlines before parsing. E.g. " "'--edits=remove-orphans'."); + po.Register("set-dropout-proportion", &dropout, "Set dropout proportion " + "in all DropoutComponent to this value."); po.Read(argc, argv); if (po.NumArgs() != 2) { @@ -81,6 +84,9 @@ int main(int argc, char *argv[]) { if (learning_rate >= 0) SetLearningRate(learning_rate, &nnet); + + if (dropout > 0) + SetDropoutProportion(dropout, &nnet); if (!edits_config.empty()) { Input ki(edits_config);