Skip to content

Commit

Permalink
trunk,nnet1:
Browse files Browse the repository at this point in the history
- fixing a bug in _diff_sigmoid, _diff_tanh. The matrix arguments of a kernel now have individual strides
- better diagnostic log-prints of nested networks in parallel components



git-svn-id: https://svn.code.sf.net/p/kaldi/code/trunk@4445 5e6a8d80-dfce-4ca6-a32a-6e07a63d50c8
  • Loading branch information
KarelVesely84 committed Sep 18, 2014
1 parent a6c5dd4 commit 178b032
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 75 deletions.
8 changes: 4 additions & 4 deletions src/cudamatrix/cu-kernels-ansi.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,9 @@ void cudaF_softmax_part(dim3 Gr, dim3 Bl, const float *X, const int32_cuda *vec_
void cudaF_soft_hinge(dim3 Gr, dim3 Bl, float *y, const float *x, MatrixDim d, int src_stride);
void cudaF_group_pnorm(dim3 Gr, dim3 Bl, float *y, const float *x, MatrixDim d, int src_stride, int group_size, float power);
void cudaF_sigmoid(dim3 Gr, dim3 Bl, float *y, const float *x, MatrixDim d, int src_stride);
void cudaF_diff_sigmoid(dim3 Gr, dim3 Bl, float *eout, const float *e, const float *y, MatrixDim d, int src_stride);
void cudaF_diff_sigmoid(dim3 Gr, dim3 Bl, float *eout, const float *e, const float *y, MatrixDim d, int e_stride, int y_stride);
void cudaF_tanh(dim3 Gr, dim3 Bl, float *y, const float *x, MatrixDim d, int src_stride);
void cudaF_diff_tanh(dim3 Gr, dim3 Bl, float *eout, const float *e, const float *y, MatrixDim d);
void cudaF_diff_tanh(dim3 Gr, dim3 Bl, float *eout, const float *e, const float *y, MatrixDim d, int e_stride, int y_stride);

void cudaF_regularize_l1(dim3 Gr, dim3 Bl, float *wei, float *grad, float l1, float lr, MatrixDim d);
void cudaF_find_row_max_id(dim3 Gr, dim3 Bl, const float *mat, float *vec_val, int32_cuda *vec_id, int32_cuda voff, MatrixDim d);
Expand Down Expand Up @@ -264,9 +264,9 @@ void cudaD_softmax_part(dim3 Gr, dim3 Bl, const double *X, const int32_cuda *vec
void cudaD_soft_hinge(dim3 Gr, dim3 Bl, double *y, const double *x, MatrixDim d, int src_stride);
void cudaD_group_pnorm(dim3 Gr, dim3 Bl, double *y, const double *x, MatrixDim d, int src_stride, int group_size, double power);
void cudaD_sigmoid(dim3 Gr, dim3 Bl, double *y, const double *x, MatrixDim d, int src_stride);
void cudaD_diff_sigmoid(dim3 Gr, dim3 Bl, double *eout, const double *e, const double *y, MatrixDim d, int src_stride);
void cudaD_diff_sigmoid(dim3 Gr, dim3 Bl, double *eout, const double *e, const double *y, MatrixDim d, int e_stride, int y_stride);
void cudaD_tanh(dim3 Gr, dim3 Bl, double *y, const double *x, MatrixDim d, int src_stride);
void cudaD_diff_tanh(dim3 Gr, dim3 Bl, double *eout, const double *e, const double *y, MatrixDim d);
void cudaD_diff_tanh(dim3 Gr, dim3 Bl, double *eout, const double *e, const double *y, MatrixDim d, int e_stride, int y_stride);

void cudaD_regularize_l1(dim3 Gr, dim3 Bl, double *wei, double *grad, double l1, double lr, MatrixDim d);
void cudaD_find_row_max_id(dim3 Gr, dim3 Bl, const double *mat, double *vec_val, int32_cuda *vec_id, int32_cuda voff, MatrixDim d);
Expand Down
32 changes: 18 additions & 14 deletions src/cudamatrix/cu-kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1552,12 +1552,14 @@ static void _sigmoid(Real*y, const Real*x, MatrixDim d, int src_stride) {

template<typename Real>
__global__
static void _diff_sigmoid(Real*eout, const Real*e, const Real*y, MatrixDim d, int src_stride) {
static void _diff_sigmoid(Real*eout, const Real*e, const Real*y, MatrixDim d, int e_stride, int y_stride) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
int j = blockIdx.y * blockDim.y + threadIdx.y;
int dst_index = i + j*d.stride, src_index = i + j*src_stride;
int dst_index = i + j*d.stride;
int e_index = i + j*e_stride;
int y_index = i + j*y_stride;
if (i < d.cols && j < d.rows )
eout[dst_index] = y[src_index]*(1.0-y[src_index]) * e[src_index];
eout[dst_index] = y[y_index]*(1.0-y[y_index]) * e[e_index];
}


Expand All @@ -1582,12 +1584,14 @@ static void _tanh(Real*y, const Real*x, MatrixDim d, int src_stride) {

template<typename Real>
__global__
static void _diff_tanh(Real*eout, const Real*e, const Real*y, MatrixDim d) {
static void _diff_tanh(Real*eout, const Real*e, const Real*y, MatrixDim d, int e_stride, int y_stride) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
int j = blockIdx.y * blockDim.y + threadIdx.y;
int index = i + j*d.stride;
int dst_index = i + j*d.stride;
int e_index = i + j*e_stride;
int y_index = i + j*y_stride;
if (i < d.cols && j < d.rows )
eout[index] = (1.0 - y[index]*y[index]) * e[index];
eout[dst_index] = (1.0 - y[y_index]*y[y_index]) * e[e_index];
}

template<typename Real>
Expand Down Expand Up @@ -2212,16 +2216,16 @@ void cudaF_sigmoid (dim3 Gr, dim3 Bl, float* y, const float* x, MatrixDim d, int
_sigmoid<<<Gr,Bl>>>(y, x, d, src_stride);
}

void cudaF_diff_sigmoid (dim3 Gr, dim3 Bl, float* eout, const float* e, const float* y, MatrixDim d, int src_stride) {
_diff_sigmoid<<<Gr,Bl>>>(eout, e, y, d, src_stride);
void cudaF_diff_sigmoid (dim3 Gr, dim3 Bl, float* eout, const float* e, const float* y, MatrixDim d, int e_stride, int y_stride) {
_diff_sigmoid<<<Gr,Bl>>>(eout, e, y, d, e_stride, y_stride);
}

void cudaF_tanh (dim3 Gr, dim3 Bl, float* y, const float* x, MatrixDim d, int src_stride) {
_tanh<<<Gr,Bl>>>(y, x, d, src_stride);
}

void cudaF_diff_tanh (dim3 Gr, dim3 Bl, float* eout, const float* e, const float* y, MatrixDim d) {
_diff_tanh<<<Gr,Bl>>>(eout, e, y, d);
void cudaF_diff_tanh (dim3 Gr, dim3 Bl, float* eout, const float* e, const float* y, MatrixDim d, int e_stride, int y_stride) {
_diff_tanh<<<Gr,Bl>>>(eout, e, y, d, e_stride, y_stride);
}

void cudaF_softmax (size_t Gr, size_t Bl, float* y, const float* x, MatrixDim d) {
Expand Down Expand Up @@ -2628,16 +2632,16 @@ void cudaD_sigmoid (dim3 Gr, dim3 Bl, double* y, const double* x, MatrixDim d, i
_sigmoid<<<Gr,Bl>>>(y, x, d, src_stride);
}

void cudaD_diff_sigmoid (dim3 Gr, dim3 Bl, double* eout, const double* e, const double* y, MatrixDim d, int src_stride) {
_diff_sigmoid<<<Gr,Bl>>>(eout, e, y, d, src_stride);
void cudaD_diff_sigmoid (dim3 Gr, dim3 Bl, double* eout, const double* e, const double* y, MatrixDim d, int e_stride, int y_stride) {
_diff_sigmoid<<<Gr,Bl>>>(eout, e, y, d, e_stride, y_stride);
}

void cudaD_tanh (dim3 Gr, dim3 Bl, double* y, const double* x, MatrixDim d, int src_stride) {
_tanh<<<Gr,Bl>>>(y, x, d, src_stride);
}

void cudaD_diff_tanh (dim3 Gr, dim3 Bl, double* eout, const double* e, const double* y, MatrixDim d) {
_diff_tanh<<<Gr,Bl>>>(eout, e, y, d);
void cudaD_diff_tanh (dim3 Gr, dim3 Bl, double* eout, const double* e, const double* y, MatrixDim d, int e_stride, int y_stride) {
_diff_tanh<<<Gr,Bl>>>(eout, e, y, d, e_stride, y_stride);
}


Expand Down
8 changes: 4 additions & 4 deletions src/cudamatrix/cu-kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,9 @@ inline void cuda_block_add_mat_mat(dim3 Gr, dim3 Bl, CuBlockMatrixData *B_cu_dat
inline void cuda_soft_hinge(dim3 Gr, dim3 Bl, float *y, const float *x, MatrixDim d, int src_stride) { cudaF_soft_hinge(Gr,Bl,y,x,d,src_stride); }
inline void cuda_group_pnorm(dim3 Gr, dim3 Bl, float *y, const float *x, MatrixDim d, int src_stride, int group_size, float power) { cudaF_group_pnorm(Gr, Bl, y, x, d, src_stride, group_size, power);}
inline void cuda_sigmoid(dim3 Gr, dim3 Bl, float *y, const float *x, MatrixDim d, int src_stride) { cudaF_sigmoid(Gr,Bl,y,x,d,src_stride); }
inline void cuda_diff_sigmoid(dim3 Gr, dim3 Bl, float *eout, const float *e, const float *y, MatrixDim d, int src_stride) { cudaF_diff_sigmoid(Gr,Bl,eout,e,y,d,src_stride); }
inline void cuda_diff_sigmoid(dim3 Gr, dim3 Bl, float *eout, const float *e, const float *y, MatrixDim d, int e_stride, int y_stride) { cudaF_diff_sigmoid(Gr,Bl,eout,e,y,d,e_stride,y_stride); }
inline void cuda_tanh(dim3 Gr, dim3 Bl, float *y, const float *x, MatrixDim d, int src_stride) { cudaF_tanh(Gr,Bl,y,x,d,src_stride); }
inline void cuda_diff_tanh(dim3 Gr, dim3 Bl, float *eout, const float *e, const float *y, MatrixDim d) { cudaF_diff_tanh(Gr,Bl,eout,e,y,d); }
inline void cuda_diff_tanh(dim3 Gr, dim3 Bl, float *eout, const float *e, const float *y, MatrixDim d, int e_stride, int y_stride) { cudaF_diff_tanh(Gr,Bl,eout,e,y,d,e_stride,y_stride); }
inline void cuda_softmax(size_t Gr, size_t Bl, float *y, const float *x, MatrixDim d) { cudaF_softmax(Gr,Bl,y,x,d); }
/*
Bl: dimBlock value is fixed min(d.col, CU1DBLOCK), represent CU1DBLOCK threads reduce a row at the same time.
Expand Down Expand Up @@ -351,9 +351,9 @@ inline void cuda_block_add_mat_mat(dim3 Gr, dim3 Bl, CuBlockMatrixData *B_cu_dat
inline void cuda_soft_hinge(dim3 Gr, dim3 Bl, double *y, const double *x, MatrixDim d, int src_stride) { cudaD_soft_hinge(Gr,Bl,y,x,d,src_stride); }
inline void cuda_group_pnorm(dim3 Gr, dim3 Bl, double *y, const double *x, MatrixDim d, int src_stride, int group_size, double power) { cudaD_group_pnorm(Gr, Bl, y, x, d, src_stride, group_size, power); }
inline void cuda_sigmoid(dim3 Gr, dim3 Bl, double *y, const double *x, MatrixDim d, int src_stride) { cudaD_sigmoid(Gr,Bl,y,x,d,src_stride); }
inline void cuda_diff_sigmoid(dim3 Gr, dim3 Bl, double *eout, const double *e, const double *y, MatrixDim d, int src_stride) { cudaD_diff_sigmoid(Gr,Bl,eout,e,y,d,src_stride); }
inline void cuda_diff_sigmoid(dim3 Gr, dim3 Bl, double *eout, const double *e, const double *y, MatrixDim d, int e_stride, int y_stride) { cudaD_diff_sigmoid(Gr,Bl,eout,e,y,d,e_stride,y_stride); }
inline void cuda_tanh(dim3 Gr, dim3 Bl, double *y, const double *x, MatrixDim d, int src_stride) { cudaD_tanh(Gr,Bl,y,x,d,src_stride); }
inline void cuda_diff_tanh(dim3 Gr, dim3 Bl, double *eout, const double *e, const double *y, MatrixDim d) { cudaD_diff_tanh(Gr,Bl,eout,e,y,d); }
inline void cuda_diff_tanh(dim3 Gr, dim3 Bl, double *eout, const double *e, const double *y, MatrixDim d, int e_stride, int y_stride) { cudaD_diff_tanh(Gr,Bl,eout,e,y,d,e_stride,y_stride); }
inline void cuda_softmax(size_t Gr, size_t Bl, double *y, const double *x, MatrixDim d) { cudaD_softmax(Gr,Bl,y,x,d); }
inline void cuda_softmax_reduce(size_t Gr, size_t Bl, double *y, const double *x, MatrixDim d, int src_stride) { cudaD_softmax_reduce(Gr,Bl,y,x,d,src_stride); }
inline void cuda_softmax_part(dim3 Gr, dim3 Bl, const double *X, const int32_cuda *vec_ids, double* Y, MatrixDim d) { cudaD_softmax_part(Gr,Bl,X,vec_ids,Y,d); }
Expand Down
4 changes: 2 additions & 2 deletions src/cudamatrix/cu-matrix.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1167,7 +1167,7 @@ void CuMatrixBase<Real>::DiffSigmoid(const CuMatrixBase<Real> &value,
dim3 dimBlock(CU2DBLOCK, CU2DBLOCK);
dim3 dimGrid(n_blocks(num_cols_, CU2DBLOCK), n_blocks(num_rows_, CU2DBLOCK));

cuda_diff_sigmoid(dimGrid, dimBlock, data_, diff.data_, value.data_, Dim(), diff.Stride());
cuda_diff_sigmoid(dimGrid, dimBlock, data_, diff.data_, value.data_, Dim(), diff.Stride(), value.Stride());
CU_SAFE_CALL(cudaGetLastError());

CuDevice::Instantiate().AccuProfile(__func__, tim.Elapsed());
Expand Down Expand Up @@ -1212,7 +1212,7 @@ void CuMatrixBase<Real>::DiffTanh(const CuMatrixBase<Real> &value,
dim3 dimBlock(CU2DBLOCK, CU2DBLOCK);
dim3 dimGrid(n_blocks(num_cols_, CU2DBLOCK), n_blocks(num_rows_, CU2DBLOCK));

cuda_diff_tanh(dimGrid, dimBlock, data_, diff.data_, value.data_, Dim());
cuda_diff_tanh(dimGrid, dimBlock, data_, diff.data_, value.data_, Dim(), diff.Stride(), value.Stride());
CU_SAFE_CALL(cudaGetLastError());

CuDevice::Instantiate().AccuProfile(__func__, tim.Elapsed());
Expand Down
80 changes: 30 additions & 50 deletions src/nnet/nnet-nnet.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#include "nnet/nnet-nnet.h"
#include "nnet/nnet-component.h"
#include "nnet/nnet-parallel-component.h"
#include "nnet/nnet-activation.h"
#include "nnet/nnet-affine-transform.h"
#include "nnet/nnet-various.h"
Expand All @@ -35,7 +36,7 @@ Nnet::Nnet(const Nnet& other) {
}
// create empty buffers
propagate_buf_.resize(NumComponents()+1);
backpropagate_buf_.resize(NumComponents()-1);
backpropagate_buf_.resize(NumComponents()+1);
// copy train opts
SetTrainOptions(other.opts_);
Check();
Expand All @@ -49,7 +50,7 @@ Nnet & Nnet::operator = (const Nnet& other) {
}
// create empty buffers
propagate_buf_.resize(NumComponents()+1);
backpropagate_buf_.resize(NumComponents()-1);
backpropagate_buf_.resize(NumComponents()+1);
// copy train opts
SetTrainOptions(other.opts_);
Check();
Expand Down Expand Up @@ -91,55 +92,24 @@ void Nnet::Backpropagate(const CuMatrixBase<BaseFloat> &out_diff, CuMatrix<BaseF
//

// 0 layers
if(NumComponents() == 0) {
(*in_diff) = out_diff; //copy
return;
}

// we need at least L+1 input bufers
KALDI_ASSERT((int32)propagate_buf_.size() >= NumComponents()+1);
// we need at least L-1 error derivative bufers
KALDI_ASSERT((int32)backpropagate_buf_.size() >= NumComponents()-1);

// 1 layer
if(NumComponents() == 1) {
components_[0]->Backpropagate(propagate_buf_[0], propagate_buf_[1], out_diff, in_diff);
if (components_[0]->IsUpdatable()) {
UpdatableComponent *uc = dynamic_cast<UpdatableComponent*>(components_[0]);
uc->Update(propagate_buf_[0], out_diff);
}
return;
}
if (NumComponents() == 0) { (*in_diff) = out_diff; return; }

// >1 layers
// we don't copy the out_diff to buffers, we use it as it is...
int32 i = components_.size()-1;
components_.back()->Backpropagate(propagate_buf_[i], propagate_buf_[i+1],
out_diff, &backpropagate_buf_[i-1]);
if (components_[i]->IsUpdatable()) {
UpdatableComponent *uc = dynamic_cast<UpdatableComponent*>(components_[i]);
uc->Update(propagate_buf_[i], out_diff);
}
KALDI_ASSERT((int32)propagate_buf_.size() == NumComponents()+1);
KALDI_ASSERT((int32)backpropagate_buf_.size() == NumComponents()+1);

// backpropagate by using buffers
for (i--; i >= 1; i--) {
// copy out_diff to last buffer
backpropagate_buf_[NumComponents()] = out_diff;
// backpropagate using buffers
for (int32 i = NumComponents()-1; i >= 0; i--) {
components_[i]->Backpropagate(propagate_buf_[i], propagate_buf_[i+1],
backpropagate_buf_[i], &backpropagate_buf_[i-1]);
backpropagate_buf_[i+1], &backpropagate_buf_[i]);
if (components_[i]->IsUpdatable()) {
UpdatableComponent *uc = dynamic_cast<UpdatableComponent*>(components_[i]);
uc->Update(propagate_buf_[i], backpropagate_buf_[i]);
uc->Update(propagate_buf_[i], backpropagate_buf_[i+1]);
}
}

// now backpropagate through first layer,
components_[0]->Backpropagate(propagate_buf_[0], propagate_buf_[1],
backpropagate_buf_[0], in_diff);

// update the first layer
if (components_[0]->IsUpdatable()) {
UpdatableComponent *uc = dynamic_cast<UpdatableComponent*>(components_[0]);
uc->Update(propagate_buf_[0], backpropagate_buf_[0]);
}
// eventually export the derivative
if (NULL != in_diff) (*in_diff) = backpropagate_buf_[0];

//
// End of Backpropagation
Expand Down Expand Up @@ -413,7 +383,7 @@ void Nnet::Read(std::istream &is, bool binary) {
}
// create empty buffers
propagate_buf_.resize(NumComponents()+1);
backpropagate_buf_.resize(NumComponents()-1);
backpropagate_buf_.resize(NumComponents()+1);
// reset learn rate
opts_.learn_rate = 0.0;

Expand Down Expand Up @@ -475,10 +445,15 @@ std::string Nnet::InfoPropagate() const {
std::ostringstream ostr;
// forward-pass buffer stats
ostr << "### Forward propagation buffer content :\n";
for (int32 i=0; i<propagate_buf_.size(); i++) {
ostr << "[0] output of <Input> " << MomentStatistics(propagate_buf_[0]) << std::endl;
for (int32 i=0; i<NumComponents(); i++) {
ostr << "["<<1+i<< "] output of "
<< (i==0 ? "<Input>" : Component::TypeToMarker(components_[i-1]->GetType()))
<< MomentStatistics(propagate_buf_[i]) << std::endl;
<< Component::TypeToMarker(components_[i]->GetType())
<< MomentStatistics(propagate_buf_[i+1]) << std::endl;
// nested networks too...
if (Component::kParallelComponent == components_[i]->GetType()) {
ostr << dynamic_cast<ParallelComponent*>(components_[i])->InfoPropagate();
}
}
return ostr.str();
}
Expand All @@ -487,10 +462,15 @@ std::string Nnet::InfoBackPropagate() const {
std::ostringstream ostr;
// forward-pass buffer stats
ostr << "### Backward propagation buffer content :\n";
for (int32 i=0; i<backpropagate_buf_.size(); i++) {
ostr << "[0] diff of <Input> " << MomentStatistics(backpropagate_buf_[0]) << std::endl;
for (int32 i=0; i<NumComponents(); i++) {
ostr << "["<<1+i<< "] diff-output of "
<< Component::TypeToMarker(components_[i]->GetType())
<< MomentStatistics(backpropagate_buf_[i]) << std::endl;
<< MomentStatistics(backpropagate_buf_[i+1]) << std::endl;
// nested networks too...
if (Component::kParallelComponent == components_[i]->GetType()) {
ostr << dynamic_cast<ParallelComponent*>(components_[i])->InfoBackPropagate();
}
}
return ostr.str();
}
Expand Down
22 changes: 21 additions & 1 deletion src/nnet/nnet-parallel-component.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,14 +159,34 @@ class ParallelComponent : public UpdatableComponent {
for (int32 i=0; i<nnet_.size(); i++) {
os << "nested_network #" << i+1 << "{\n" << nnet_[i].Info() << "}\n";
}
return os.str();
std::string s(os.str());
s.erase(s.end() -1); // removing last '\n'
return s;
}

std::string InfoGradient() const {
std::ostringstream os;
for (int32 i=0; i<nnet_.size(); i++) {
os << "nested_gradient #" << i+1 << "{\n" << nnet_[i].InfoGradient() << "}\n";
}
std::string s(os.str());
s.erase(s.end() -1); // removing last '\n'
return s;
}

std::string InfoPropagate() const {
std::ostringstream os;
for (int32 i=0; i<nnet_.size(); i++) {
os << "nested_propagate #" << i+1 << "{\n" << nnet_[i].InfoPropagate() << "}\n";
}
return os.str();
}

std::string InfoBackPropagate() const {
std::ostringstream os;
for (int32 i=0; i<nnet_.size(); i++) {
os << "nested_backpropagate #" << i+1 << "{\n" << nnet_[i].InfoBackPropagate() << "}\n";
}
return os.str();
}

Expand Down

0 comments on commit 178b032

Please sign in to comment.